pax_global_header00006660000000000000000000000064135632356130014521gustar00rootroot0000000000000052 comment=b44b2e60fe69b9c587615b6c02a80e506e93b4ba python-skytools-3.4/000077500000000000000000000000001356323561300145755ustar00rootroot00000000000000python-skytools-3.4/.coveragerc000066400000000000000000000000731356323561300167160ustar00rootroot00000000000000[report] exclude_lines = ^try: ^except pragma: no cover python-skytools-3.4/.gitignore000066400000000000000000000002731356323561300165670ustar00rootroot00000000000000__pycache__ *.pyc *.swp *.o *.so *.egg-info *.debhelper *.log *.substvars *-stamp debian/files debian/python-skytools debian/python3-skytools .tox .coverage .pybuild MANIFEST build tmp python-skytools-3.4/.pylintrc000066400000000000000000000371561356323561300164560ustar00rootroot00000000000000[MASTER] # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-whitelist= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS,tmp,dist # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # Specify a configuration file. #rcfile= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=bad-continuation, bad-whitespace, bare-except, broad-except, consider-using-in, consider-using-ternary, fixme, global-statement, invalid-name, missing-docstring, no-else-raise, no-else-return, no-self-use, trailing-newlines, unused-argument, unused-variable, using-constant-test, useless-object-inheritance, arguments-differ, multiple-statements, len-as-condition, chained-comparison, unnecessary-pass, cyclic-import, invalid-name, bad-continuation, too-many-ancestors, import-outside-toplevel, protected-access, try-except-raise, deprecated-module, no-else-break, no-else-continue # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=c-extension-no-member [REPORTS] # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. #msg-template= # Set the output format. Available formats are text, parseable, colorized, json # and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. output-format=text # Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. score=no [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=10 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [LOGGING] # Format style used to check logging format string. `old` means using % # formatting, while `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package.. #spelling-dict=en_US # List of comma separated words that should not be checked. spelling-ignore-words=usr,bin,env # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file=.local.dict # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. #argument-rgx= # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. #attr-rgx= # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. #class-attribute-rgx= # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. #class-rgx= # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. #const-rgx= # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. #function-rgx= # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _ # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. #inlinevar-rgx= # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. #method-rgx= # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. #module-rgx= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. #variable-rgx= [STRING] # This flag controls whether the implicit-str-concat-in-sequence should # generate a warning on implicit string concatenation in sequences defined over # several lines. check-str-concat-over-line-jumps=no [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=4 [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format=LF # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=160 # Maximum number of lines in a module. max-module-lines=10000 # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma, dict-separator # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=cls [DESIGN] # Maximum number of arguments for function / method. max-args=15 # Maximum number of attributes for a class (see R0902). max-attributes=37 # Maximum number of boolean expressions in an if statement. max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=50 # Maximum number of locals for function / method body. max-locals=45 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=420 # Maximum number of return / yield for function / method body. max-returns=16 # Maximum number of statements in function / method body. max-statements=150 # Minimum number of public methods for a class (see R0903). min-public-methods=0 [IMPORTS] # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". overgeneral-exceptions=BaseException, Exception python-skytools-3.4/AUTHORS000066400000000000000000000013451356323561300156500ustar00rootroot00000000000000 Maintainers ----------- Marko Kreen Petr Jelinek Sasha Aliashkevich Contributors ------------ Aleksei Plotnikov André Malo Andrew Dunstan Artyom Nosov Asko Oja Asko Tiidumaa Cédric Villemain Charles Duffy Devrim Gündüz Dimitri Fontaine Dmitriy V'jukov Doug Gorley Eero Oja Egon Valdmees Emiel van de Laar Erik Jones Glenn Davy Götz Lange Hannu Krosing Hans-Juergen Schoenig Jason Buberel Juta Vaks Kaarel Kitsemets Kristo Kaiv Luc Van Hoeylandt Lukáš Lalinský Marcin Stępnicki Mark Kirkwood Martin Otto Martin Pihlak Nico Mandery Petr Jelinek Pierre-Emmanuel André Priit Kustala Sasha Aliashkevich Sébastien Lardière Sergey Burladyan Sergey Konoplev Shoaib Mir Steve Singer Tarvi Pillessaar Tony Arkles Zoltán Böszörményi python-skytools-3.4/COPYRIGHT000066400000000000000000000013401356323561300160660ustar00rootroot00000000000000 Copyright (c) 2007-2017 Marko Kreen Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. python-skytools-3.4/MANIFEST.in000066400000000000000000000002071356323561300163320ustar00rootroot00000000000000include modules/*.h include tests/*.py tests/*.ini include tox.ini .prospector.yaml .coveragerc include MANIFEST.in include README.rst python-skytools-3.4/Makefile000066400000000000000000000004051356323561300162340ustar00rootroot00000000000000 all: clean: rm -rf build *.egg-info */__pycache__ tests/*.pyc rm -rf debian/python-* debian/files debian/*.log rm -rf debian/*.substvars debian/*.debhelper debian/*-stamp rm -rf .pybuild MANIFEST deb: debuild -us -uc -b xclean: clean rm -rf .tox dist python-skytools-3.4/NEWS.rst000066400000000000000000000007711356323561300161100ustar00rootroot00000000000000 Skytools 3.4 (2019-11-14) ------------------------- * Support Postgres 10 sequences * Make full_copy text-based * Allow None fields in magic_insert * Fix iterator use in magic insert * Fix Python3 bugs * Switch off Python2 tests, to avoid wasting time. Skytools 3.3 (2017-09-21) ------------------------- * Separate 'skytools' module out from big package * Python 3 support Skytools 3.2 and older ---------------------- See old changes here: https://github.com/pgq/skytools-legacy/blob/master/NEWS python-skytools-3.4/README.rst000066400000000000000000000012471356323561300162700ustar00rootroot00000000000000 Skytools - Utilities for writing Python scripts =============================================== This is the low-level utility module split out from old Skytools meta-package. It contains various utilities for writing database scripts. Database specific utilites are mainly meant for PostgreSQL. Features -------- * Support for background scripts - Daemonizing - logging - config parsing * Database tools - Tuned connection - DB structure examining - SQL parsing - COPY I/O * Time utilities - ISO timestamp parsing - datetime to timestamp * Text utilities - Natural sort - Fast urlencode I/O TODO ---- * Move from optparse to argparse * Doc cleanup python-skytools-3.4/debian/000077500000000000000000000000001356323561300160175ustar00rootroot00000000000000python-skytools-3.4/debian/changelog000066400000000000000000000004011356323561300176640ustar00rootroot00000000000000python-skytools (3.4-1) unstable; urgency=low * v3.4 -- Marko Kreen Thu, 14 Nov 2019 11:37:44 +0200 python-skytools (3.3.0-8) unstable; urgency=low * v3.3.0 -- Marko Kreen Fri, 04 Dec 2015 17:00:23 +0200 python-skytools-3.4/debian/compat000066400000000000000000000000021356323561300172150ustar00rootroot000000000000009 python-skytools-3.4/debian/control000066400000000000000000000011331356323561300174200ustar00rootroot00000000000000Source: python-skytools Section: python Priority: optional Maintainer: Marko Kreen Standards-Version: 3.9.2 Build-Depends: debhelper (>= 9), dh-python, python-all-dev, python3-all-dev, python-setuptools, python3-setuptools X-Python-Version: >= 2.7 X-Python3-Version: >= 3.5 Package: python-skytools Architecture: any Depends: ${shlibs:Depends}, ${misc:Depends}, ${python:Depends} Conflicts: python-skytools3 Description: Scripting tools . Package: python3-skytools Architecture: any Depends: ${shlibs:Depends}, ${misc:Depends}, ${python3:Depends} Description: Scripting tools . python-skytools-3.4/debian/copyright000066400000000000000000000002771356323561300177600ustar00rootroot00000000000000Format: http://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Source: https://github.com/pgq/skytools Files: * Copyright: Copyright (c) 2007-2016, Skytools Authors License: ISC python-skytools-3.4/debian/py3dist-overrides000066400000000000000000000000321356323561300213340ustar00rootroot00000000000000psycopg2 python3-psycopg2 python-skytools-3.4/debian/pydist-overrides000066400000000000000000000000311356323561300212500ustar00rootroot00000000000000psycopg2 python-psycopg2 python-skytools-3.4/debian/rules000077500000000000000000000002471356323561300171020ustar00rootroot00000000000000#! /usr/bin/make -f #export DH_VERBOSE = 1 export DEB_BUILD_OPTIONS = nocheck export PYBUILD_NAME = skytools %: dh $@ --with python2,python3 --buildsystem=pybuild python-skytools-3.4/debian/source/000077500000000000000000000000001356323561300173175ustar00rootroot00000000000000python-skytools-3.4/debian/source/format000066400000000000000000000000141356323561300205250ustar00rootroot000000000000003.0 (quilt) python-skytools-3.4/modules/000077500000000000000000000000001356323561300162455ustar00rootroot00000000000000python-skytools-3.4/modules/cquoting.c000066400000000000000000000412521356323561300202460ustar00rootroot00000000000000/* * Fast quoting functions for Python. */ #define PY_SSIZE_T_CLEAN #include #include #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) typedef int Py_ssize_t; #define PY_SSIZE_T_MAX INT_MAX #define PY_SSIZE_T_MIN INT_MIN #endif #ifdef _MSC_VER #define inline __inline #define strcasecmp stricmp #endif #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize(s,l) PyUnicode_FromStringAndSize(s,l) #define PyString_FromString(s) PyUnicode_FromString(s) #define PyString_InternInPlace(p) PyUnicode_InternInPlace(p) #endif #include "get_buffer.h" /* * Common buffer management. */ struct Buf { unsigned char *ptr; unsigned long pos; unsigned long alloc; }; static unsigned char *buf_init(struct Buf *buf, unsigned init_size) { if (init_size < 256) init_size = 256; buf->ptr = PyMem_Malloc(init_size); if (buf->ptr) { buf->pos = 0; buf->alloc = init_size; } return buf->ptr; } /* return new pos */ static unsigned char *buf_enlarge(struct Buf *buf, unsigned need_room) { unsigned alloc = buf->alloc; unsigned need_size = buf->pos + need_room; unsigned char *ptr; /* no alloc needed */ if (need_size < alloc) return buf->ptr + buf->pos; if (alloc <= need_size / 2) alloc = need_size; else alloc = alloc * 2; ptr = PyMem_Realloc(buf->ptr, alloc); if (!ptr) return NULL; buf->ptr = ptr; buf->alloc = alloc; return buf->ptr + buf->pos; } static void buf_free(struct Buf *buf) { PyMem_Free(buf->ptr); buf->ptr = NULL; buf->pos = buf->alloc = 0; } static inline unsigned char *buf_get_target_for(struct Buf *buf, unsigned len) { if (buf->pos + len <= buf->alloc) return buf->ptr + buf->pos; else return buf_enlarge(buf, len); } static inline void buf_set_target(struct Buf *buf, unsigned char *newpos) { assert(buf->ptr + buf->pos <= newpos); assert(buf->ptr + buf->alloc >= newpos); buf->pos = newpos - buf->ptr; } static inline int buf_put(struct Buf *buf, unsigned char c) { if (buf->pos < buf->alloc) { buf->ptr[buf->pos++] = c; return 1; } else if (buf_enlarge(buf, 1)) { buf->ptr[buf->pos++] = c; return 1; } return 0; } static PyObject *buf_pystr(struct Buf *buf, unsigned start_pos, unsigned char *newpos) { PyObject *res; if (newpos) buf_set_target(buf, newpos); res = PyString_FromStringAndSize((char *)buf->ptr + start_pos, buf->pos - start_pos); buf_free(buf); return res; } /* * Common argument parsing. */ typedef PyObject *(*quote_fn)(unsigned char *src, Py_ssize_t src_len); static PyObject *common_quote(PyObject *args, quote_fn qfunc) { unsigned char *src = NULL; Py_ssize_t src_len = 0; PyObject *arg, *res, *strtmp = NULL; if (!PyArg_ParseTuple(args, "O", &arg)) return NULL; if (arg != Py_None) { src_len = get_buffer(arg, &src, &strtmp); if (src_len < 0) return NULL; } res = qfunc(src, src_len); Py_CLEAR(strtmp); return res; } /* * Simple quoting functions. */ static const char doc_quote_literal[] = "Quote a literal value for SQL.\n" "\n" "If string contains '\\', it is quoted and result is prefixed with E.\n" "Input value of None results in string \"null\" without quotes.\n" "\n" "C implementation.\n"; static PyObject *quote_literal_body(unsigned char *src, Py_ssize_t src_len) { struct Buf buf; unsigned char *esc, *dst, *src_end = src + src_len; unsigned int start_ofs = 1; if (src == NULL) return PyString_FromString("null"); esc = dst = buf_init(&buf, src_len * 2 + 2 + 1); if (!dst) return NULL; *dst++ = ' '; *dst++ = '\''; while (src < src_end) { if (*src == '\\') { *dst++ = '\\'; start_ofs = 0; } else if (*src == '\'') { *dst++ = '\''; } *dst++ = *src++; } *dst++ = '\''; if (start_ofs == 0) *esc = 'E'; return buf_pystr(&buf, start_ofs, dst); } static PyObject *quote_literal(PyObject *self, PyObject *args) { return common_quote(args, quote_literal_body); } /* COPY field */ static const char doc_quote_copy[] = "Quoting for COPY data. None is converted to \\N.\n\n" "C implementation."; static PyObject *quote_copy_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) return PyString_FromString("\\N"); dst = buf_init(&buf, src_len * 2); if (!dst) return NULL; while (src < src_end) { switch (*src) { case '\t': *dst++ = '\\'; *dst++ = 't'; src++; break; case '\n': *dst++ = '\\'; *dst++ = 'n'; src++; break; case '\r': *dst++ = '\\'; *dst++ = 'r'; src++; break; case '\\': *dst++ = '\\'; *dst++ = '\\'; src++; break; default: *dst++ = *src++; break; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_copy(PyObject *self, PyObject *args) { return common_quote(args, quote_copy_body); } /* raw bytea for byteain() */ static const char doc_quote_bytea_raw[] = "Quoting for bytea parser. Returns None as None.\n" "\n" "C implementation."; static PyObject *quote_bytea_raw_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { Py_INCREF(Py_None); return Py_None; } dst = buf_init(&buf, src_len * 4); if (!dst) return NULL; while (src < src_end) { if (*src < 0x20 || *src >= 0x7F) { *dst++ = '\\'; *dst++ = '0' + (*src >> 6); *dst++ = '0' + ((*src >> 3) & 7); *dst++ = '0' + (*src & 7); src++; } else { if (*src == '\\') *dst++ = '\\'; *dst++ = *src++; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_bytea_raw(PyObject *self, PyObject *args) { return common_quote(args, quote_bytea_raw_body); } /* SQL unquote */ static const char doc_unquote_literal[] = "Unquote SQL value.\n\n" "E'..' -> extended quoting.\n" "'..' -> standard or extended quoting\n" "null -> None\n" "other -> returned as-is\n\n" "C implementation.\n"; static PyObject *do_sql_ext(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src == '\'') { src++; if (src < src_end && *src == '\'') { *dst++ = *src++; continue; } goto failed; } if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken exteded SQL string"); return NULL; } static PyObject *do_sql_std(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\'') { *dst++ = *src++; continue; } src++; if (src >= src_end || *src != '\'') goto failed; *dst++ = *src++; } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken standard SQL string"); return NULL; } static PyObject *do_dolq(unsigned char *src, Py_ssize_t src_len) { /* src_len >= 2, '$' in start and end */ unsigned char *src_end = src + src_len; unsigned char *p1 = src + 1, *p2 = src_end - 2; while (p1 < src_end && *p1 != '$') p1++; while (p2 > src && *p2 != '$') p2--; if (p2 <= p1) goto failed; p1++; /* position after '$' */ if ((p1 - src) != (src_end - p2)) goto failed; if (memcmp(src, p2, p1 - src) != 0) goto failed; return PyString_FromStringAndSize((char *)p1, p2 - p1); failed: PyErr_Format(PyExc_ValueError, "Broken dollar-quoted string"); return NULL; } static PyObject *unquote_literal(PyObject *self, PyObject *args) { unsigned char *src = NULL; Py_ssize_t src_len = 0; int stdstr = 0; PyObject *value = NULL; PyObject *tmp = NULL; PyObject *res = NULL; if (!PyArg_ParseTuple(args, "O|i", &value, &stdstr)) return NULL; src_len = get_buffer(value, &src, &tmp); if (src_len < 0) return NULL; if (src_len == 4 && strcasecmp((char *)src, "null") == 0) { Py_INCREF(Py_None); res = Py_None; } else if (src_len >= 2 && src[0] == '$' && src[src_len - 1] == '$') { res = do_dolq(src, src_len); } else if (src_len < 2 || src[src_len - 1] != '\'') { /* seems invalid, return as-is */ Py_INCREF(value); res = value; } else if (src[0] == '\'') { src++; src_len -= 2; res = stdstr ? do_sql_std(src, src_len) : do_sql_ext(src, src_len); } else if (src_len > 2 && (src[0] | 0x20) == 'e' && src[1] == '\'') { src += 2; src_len -= 3; res = do_sql_ext(src, src_len); } if (tmp) Py_CLEAR(tmp); return res; } /* C unescape */ static const char doc_unescape[] = "Unescape C-style escaped string.\n\n" "C implementation."; static PyObject *unescape_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { PyErr_Format(PyExc_TypeError, "None not allowed"); return NULL; } dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken string - \\ at the end"); return NULL; } static PyObject *unescape(PyObject *self, PyObject *args) { return common_quote(args, unescape_body); } /* * urlencode of dict */ static bool urlenc(struct Buf *buf, PyObject *obj) { Py_ssize_t len; unsigned char *src, *dst; PyObject *strtmp = NULL; static const unsigned char hextbl[] = "0123456789abcdef"; bool ok = false; len = get_buffer(obj, &src, &strtmp); if (len < 0) goto failed; dst = buf_get_target_for(buf, len * 3); if (!dst) goto failed; while (len--) { if ((*src >= 'a' && *src <= 'z') || (*src >= 'A' && *src <= 'Z') || (*src >= '0' && *src <= '9') || (*src == '.' || *src == '_' || *src == '-')) { *dst++ = *src++; } else if (*src == ' ') { *dst++ = '+'; src++; } else { *dst++ = '%'; *dst++ = hextbl[*src >> 4]; *dst++ = hextbl[*src & 0xF]; src++; } } buf_set_target(buf, dst); ok = true; failed: Py_CLEAR(strtmp); return ok; } /* urlencode key+val pair. val can be None */ static bool urlenc_keyval(struct Buf *buf, PyObject *key, PyObject *value, bool needAmp) { if (needAmp && !buf_put(buf, '&')) return false; if (!urlenc(buf, key)) return false; if (value != Py_None) { if (!buf_put(buf, '=')) return false; if (!urlenc(buf, value)) return false; } return true; } /* encode native dict using PyDict_Next */ static PyObject *encode_dict(PyObject *data) { PyObject *key, *value; Py_ssize_t pos = 0; bool needAmp = false; struct Buf buf; if (!buf_init(&buf, 1024)) return NULL; while (PyDict_Next(data, &pos, &key, &value)) { if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; } return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); return NULL; } /* encode custom object using .iteritems() */ static PyObject *encode_dictlike(PyObject *data) { PyObject *key = NULL, *value = NULL, *tup, *iter; struct Buf buf; bool needAmp = false; if (!buf_init(&buf, 1024)) return NULL; #if PY_MAJOR_VERSION >= 3 iter = PyObject_CallMethod(data, "items", NULL); #else iter = PyObject_CallMethod(data, "iteritems", NULL); #endif if (iter == NULL) { buf_free(&buf); return NULL; } while ((tup = PyIter_Next(iter))) { key = PySequence_GetItem(tup, 0); value = key ? PySequence_GetItem(tup, 1) : NULL; Py_CLEAR(tup); if (!key || !value) goto failed; if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; Py_CLEAR(key); Py_CLEAR(value); } /* allow error from iterator */ if (PyErr_Occurred()) goto failed; Py_CLEAR(iter); return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); Py_CLEAR(iter); Py_CLEAR(key); Py_CLEAR(value); return NULL; } static const char doc_db_urlencode[] = "Urlencode for database records.\n" "If a value is None the key is output without '='.\n" "\n" "C implementation."; static PyObject *db_urlencode(PyObject *self, PyObject *args) { PyObject *data; if (!PyArg_ParseTuple(args, "O", &data)) return NULL; if (PyDict_Check(data)) { return encode_dict(data); } else { return encode_dictlike(data); } } /* * urldecode to dict */ static inline int gethex(unsigned char c) { if (c >= '0' && c <= '9') return c - '0'; c |= 0x20; if (c >= 'a' && c <= 'f') return c - 'a' + 10; return -1; } static PyObject *get_elem(unsigned char *buf, unsigned char **src_p, unsigned char *src_end) { int c1, c2; unsigned char *src = *src_p; unsigned char *dst = buf; while (src < src_end) { switch (*src) { case '%': if (++src + 2 > src_end) goto hex_incomplete; if ((c1 = gethex(*src++)) < 0) goto hex_invalid; if ((c2 = gethex(*src++)) < 0) goto hex_invalid; *dst++ = (c1 << 4) | c2; break; case '+': *dst++ = ' '; src++; break; case '&': case '=': goto gotit; default: *dst++ = *src++; } } gotit: *src_p = src; return PyString_FromStringAndSize((char *)buf, dst - buf); hex_incomplete: PyErr_Format(PyExc_ValueError, "Incomplete hex code"); return NULL; hex_invalid: PyErr_Format(PyExc_ValueError, "Invalid hex code"); return NULL; } static const char doc_db_urldecode[] = "Urldecode from string to dict.\n" "NULL are detected by missing '='.\n" "Duplicate keys are ignored - only latest is kept.\n" "\n" "C implementation."; static PyObject *db_urldecode(PyObject *self, PyObject *args) { unsigned char *src, *src_end; Py_ssize_t src_len; PyObject *dict = NULL, *key = NULL, *value = NULL; struct Buf buf; #if PY_MAJOR_VERSION >= 3 if (!PyArg_ParseTuple(args, "s#", &src, &src_len)) return NULL; #else if (!PyArg_ParseTuple(args, "t#", &src, &src_len)) return NULL; #endif if (!buf_init(&buf, src_len)) return NULL; dict = PyDict_New(); if (!dict) { buf_free(&buf); return NULL; } src_end = src + src_len; while (src < src_end) { if (*src == '&') { src++; continue; } key = get_elem(buf.ptr, &src, src_end); if (!key) goto failed; if (src < src_end && *src == '=') { src++; value = get_elem(buf.ptr, &src, src_end); if (value == NULL) goto failed; } else { Py_INCREF(Py_None); value = Py_None; } /* lessen memory usage by intering */ PyString_InternInPlace(&key); if (PyDict_SetItem(dict, key, value) < 0) goto failed; Py_CLEAR(key); Py_CLEAR(value); } buf_free(&buf); return dict; failed: buf_free(&buf); Py_CLEAR(key); Py_CLEAR(value); Py_CLEAR(dict); return NULL; } /* * Module initialization */ static PyMethodDef cquoting_methods[] = { { "quote_literal", quote_literal, METH_VARARGS, doc_quote_literal }, { "quote_copy", quote_copy, METH_VARARGS, doc_quote_copy }, { "quote_bytea_raw", quote_bytea_raw, METH_VARARGS, doc_quote_bytea_raw }, { "unescape", unescape, METH_VARARGS, doc_unescape }, { "db_urlencode", db_urlencode, METH_VARARGS, doc_db_urlencode }, { "db_urldecode", db_urldecode, METH_VARARGS, doc_db_urldecode }, { "unquote_literal", unquote_literal, METH_VARARGS, doc_unquote_literal }, { NULL } }; #if PY_MAJOR_VERSION < 3 PyMODINIT_FUNC init_cquoting(void) { PyObject *module; module = Py_InitModule("_cquoting", cquoting_methods); PyModule_AddStringConstant(module, "__doc__", "fast quoting for skytools"); } #else static struct PyModuleDef modInfo = { PyModuleDef_HEAD_INIT, "_cquoting", NULL, -1, cquoting_methods, NULL, NULL, NULL, NULL }; PyMODINIT_FUNC PyInit__cquoting(void) { PyObject *module; module = PyModule_Create(&modInfo); PyModule_AddStringConstant(module, "__doc__", "fast quoting for skytools"); return module; } #endif python-skytools-3.4/modules/get_buffer.h000066400000000000000000000037731356323561300205400ustar00rootroot00000000000000 /* * Get string data from Python object. */ static Py_ssize_t get_buffer(PyObject *obj, unsigned char **buf_p, PyObject **tmp_obj_p) { PyObject *str = NULL; Py_ssize_t res; /* check for None */ if (obj == Py_None) { PyErr_Format(PyExc_TypeError, "None is not allowed"); return -1; } /* is string or unicode ? */ #if PY_MAJOR_VERSION < 3 if (PyString_Check(obj) || PyUnicode_Check(obj)) { if (PyString_AsStringAndSize(obj, (char**)buf_p, &res) < 0) return -1; return res; } #else /* python 3 */ if (PyUnicode_Check(obj)) { #if PY_VERSION_HEX >= 0x03030000 *buf_p = (unsigned char *)PyUnicode_AsUTF8AndSize(obj, &res); return res; #else /* convert to utf8 bytes */ *tmp_obj_p = PyUnicode_AsUTF8String(obj); if (*tmp_obj_p == NULL) return -1; /* obj is now bytes */ obj = *tmp_obj_p; if (PyBytes_AsStringAndSize(obj, (char**)buf_p, &res) < 0) return -1; return res; #endif } else if (PyBytes_Check(obj)) { if (PyBytes_AsStringAndSize(obj, (char**)buf_p, &res) < 0) return -1; return res; } #endif #if PY_MAJOR_VERSION < 3 { /* try to get buffer */ PyBufferProcs *bfp = obj->ob_type->tp_as_buffer; if (bfp && bfp->bf_getsegcount && bfp->bf_getreadbuffer) { if (bfp->bf_getsegcount(obj, NULL) == 1) return bfp->bf_getreadbuffer(obj, 0, (void**)buf_p); } } #endif /* * Not a string-like object, run str() or it. */ /* are we in recursion? */ if (tmp_obj_p == NULL) { PyErr_Format(PyExc_TypeError, "Cannot convert to string - get_buffer() recusively failed"); return -1; } /* do str() then */ str = PyObject_Str(obj); res = -1; #if PY_VERSION_HEX >= 0x03000000 && PY_VERSION_HEX < 0x03030000 if (str != NULL) { /* * Immediately convert to utf8 obj, * otherwise we dont have enough temp vars. */ obj = PyUnicode_AsUTF8String(str); Py_CLEAR(str); str = obj; obj = NULL; } #endif if (str != NULL) { res = get_buffer(str, buf_p, NULL); if (res >= 0) { *tmp_obj_p = str; } else { Py_CLEAR(str); } } return res; } python-skytools-3.4/modules/hashtext.c000066400000000000000000000210451356323561300202430ustar00rootroot00000000000000/* * Postgres hashes for Python. */ #define PY_SSIZE_T_CLEAN #include #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) typedef int Py_ssize_t; #define PY_SSIZE_T_MAX INT_MAX #define PY_SSIZE_T_MIN INT_MIN #endif #if PY_MAJOR_VERSION >= 3 #define PyInt_FromLong(v) PyLong_FromLong(v) #endif #include #include #include "get_buffer.h" typedef uint32_t (*hash_fn_t)(const void *src, unsigned src_len); typedef uint8_t uint8; typedef uint16_t uint16; typedef uint32_t uint32; #define rot(x, k) (((x)<<(k)) | ((x)>>(32-(k)))) /* * Old Postgres hashtext() */ #define mix_old(a,b,c) \ { \ a -= b; a -= c; a ^= ((c)>>13); \ b -= c; b -= a; b ^= ((a)<<8); \ c -= a; c -= b; c ^= ((b)>>13); \ a -= b; a -= c; a ^= ((c)>>12); \ b -= c; b -= a; b ^= ((a)<<16); \ c -= a; c -= b; c ^= ((b)>>5); \ a -= b; a -= c; a ^= ((c)>>3); \ b -= c; b -= a; b ^= ((a)<<10); \ c -= a; c -= b; c ^= ((b)>>15); \ } static uint32_t hash_old_hashtext(const void *_k, unsigned keylen) { const unsigned char *k = _k; register uint32 a, b, c, len; /* Set up the internal state */ len = keylen; a = b = 0x9e3779b9; /* the golden ratio; an arbitrary value */ c = 3923095; /* initialize with an arbitrary value */ /* handle most of the key */ while (len >= 12) { a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); mix_old(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ c += keylen; switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the first byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } mix_old(a, b, c); /* report the result */ return c; } /* * New Postgres hashtext() */ #define UINT32_ALIGN_MASK 3 #define mix_new(a,b,c) \ { \ a -= c; a ^= rot(c, 4); c += b; \ b -= a; b ^= rot(a, 6); a += c; \ c -= b; c ^= rot(b, 8); b += a; \ a -= c; a ^= rot(c,16); c += b; \ b -= a; b ^= rot(a,19); a += c; \ c -= b; c ^= rot(b, 4); b += a; \ } #define final_new(a,b,c) \ { \ c ^= b; c -= rot(b,14); \ a ^= c; a -= rot(c,11); \ b ^= a; b -= rot(a,25); \ c ^= b; c -= rot(b,16); \ a ^= c; a -= rot(c, 4); \ b ^= a; b -= rot(a,14); \ c ^= b; c -= rot(b,24); \ } static uint32_t hash_new_hashtext(const void *_k, unsigned keylen) { const unsigned char *k = _k; uint32_t a, b, c, len; /* Set up the internal state */ len = keylen; a = b = c = 0x9e3779b9 + len + 3923095; /* If the source pointer is word-aligned, we use word-wide fetches */ if (((long) k & UINT32_ALIGN_MASK) == 0) { /* Code path for aligned source data */ register const uint32_t *ka = (const uint32_t *) k; /* handle most of the key */ while (len >= 12) { a += ka[0]; b += ka[1]; c += ka[2]; mix_new(a, b, c); ka += 3; len -= 12; } /* handle the last 11 bytes */ k = (const unsigned char *) ka; #ifdef WORDS_BIGENDIAN switch (len) { case 11: c += ((uint32) k[10] << 8); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 8); /* fall through */ case 6: b += ((uint32) k[5] << 16); /* fall through */ case 5: b += ((uint32) k[4] << 24); /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 8); /* fall through */ case 2: a += ((uint32) k[1] << 16); /* fall through */ case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) { case 11: c += ((uint32) k[10] << 24); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 16); /* fall through */ case 6: b += ((uint32) k[5] << 8); /* fall through */ case 5: b += k[4]; /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 16); /* fall through */ case 2: a += ((uint32) k[1] << 8); /* fall through */ case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } else { /* Code path for non-aligned source data */ /* handle most of the key */ while (len >= 12) { #ifdef WORDS_BIGENDIAN a += (k[3] + ((uint32) k[2] << 8) + ((uint32) k[1] << 16) + ((uint32) k[0] << 24)); b += (k[7] + ((uint32) k[6] << 8) + ((uint32) k[5] << 16) + ((uint32) k[4] << 24)); c += (k[11] + ((uint32) k[10] << 8) + ((uint32) k[9] << 16) + ((uint32) k[8] << 24)); #else /* !WORDS_BIGENDIAN */ a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); #endif /* WORDS_BIGENDIAN */ mix_new(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ #ifdef WORDS_BIGENDIAN switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 8); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ case 8: b += k[7]; case 7: b += ((uint32) k[6] << 8); case 6: b += ((uint32) k[5] << 16); case 5: b += ((uint32) k[4] << 24); case 4: a += k[3]; case 3: a += ((uint32) k[2] << 8); case 2: a += ((uint32) k[1] << 16); case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } final_new(a, b, c); /* report the result */ return c; } /* * Common argument parsing. */ static PyObject *run_hash(PyObject *args, hash_fn_t real_hash) { unsigned char *src = NULL; Py_ssize_t src_len; PyObject *arg, *strtmp = NULL; int32_t hash; if (!PyArg_ParseTuple(args, "O", &arg)) return NULL; src_len = get_buffer(arg, &src, &strtmp); if (src_len < 0) return NULL; hash = real_hash(src, src_len); Py_CLEAR(strtmp); return PyInt_FromLong(hash); } /* * Python wrappers around actual hash functions. */ static PyObject *hashtext_old(PyObject *self, PyObject *args) { return run_hash(args, hash_old_hashtext); } static PyObject *hashtext_new(PyObject *self, PyObject *args) { return run_hash(args, hash_new_hashtext); } /* * Module initialization */ static PyMethodDef methods[] = { { "hashtext_old", hashtext_old, METH_VARARGS, "Old Postgres hashtext().\n" }, { "hashtext_new", hashtext_new, METH_VARARGS, "New Postgres hashtext().\n" }, { NULL } }; #if PY_MAJOR_VERSION < 3 PyMODINIT_FUNC init_chashtext(void) { PyObject *module; module = Py_InitModule("_chashtext", methods); PyModule_AddStringConstant(module, "__doc__", "String hash functions"); } #else static struct PyModuleDef modInfo = { PyModuleDef_HEAD_INIT, "_chashtext", NULL, -1, methods }; PyMODINIT_FUNC PyInit__chashtext(void) { PyObject *module; module = PyModule_Create(&modInfo); PyModule_AddStringConstant(module, "__doc__", "String hash functions"); return module; } #endif python-skytools-3.4/setup.py000066400000000000000000000025271356323561300163150ustar00rootroot00000000000000"""Setup for skytools module. """ from setuptools import setup, Extension import sys # don't build C module on win32 as it's unlikely to have dev env BUILD_C_MOD = 1 if sys.platform == 'win32': BUILD_C_MOD = 0 # check if building C is allowed c_modules = [] if BUILD_C_MOD: c_modules = [ Extension("skytools._cquoting", ['modules/cquoting.c']), Extension("skytools._chashtext", ['modules/hashtext.c']), ] # run actual setup setup( name = "skytools", license = "ISC", version = '3.4', url = "https://github.com/pgq/python-skytools", maintainer = "Marko Kreen", maintainer_email = "markokr@gmail.com", description = "Utilities for database scripts", packages = ['skytools'], ext_modules = c_modules, zip_safe = False, classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Console", "Intended Audience :: Developers", "License :: OSI Approved :: ISC License (ISCL)", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", "Programming Language :: Python :: 2", "Programming Language :: Python :: 3", "Topic :: Database", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Utilities", ], ) python-skytools-3.4/skytools/000077500000000000000000000000001356323561300164645ustar00rootroot00000000000000python-skytools-3.4/skytools/__init__.py000066400000000000000000000176061356323561300206070ustar00rootroot00000000000000 """Tools for Python database scripts.""" # pylint:disable=redefined-builtin,unused-wildcard-import,wildcard-import from __future__ import division, absolute_import, print_function try: import skytools.apipkg as _apipkg except ImportError: # make pylint think everything is imported immediately from skytools.quoting import * from skytools.sqltools import * from skytools.scripting import * from skytools.adminscript import * from skytools.config import * from skytools.dbservice import * from skytools.dbstruct import * from skytools.fileutil import * from skytools.gzlog import * from skytools.hashtext import * from skytools.natsort import * from skytools.parsing import * from skytools.psycopgwrapper import * from skytools.querybuilder import * from skytools.skylog import * from skytools.sockutil import * from skytools.timeutil import * from skytools.utf8 import * _symbols = { # skytools.adminscript 'AdminScript': 'skytools.adminscript:AdminScript', # skytools.config 'Config': 'skytools.config:Config', # skytools.dbservice 'DBService': 'skytools.dbservice:DBService', 'ServiceContext': 'skytools.dbservice:ServiceContext', 'TableAPI': 'skytools.dbservice:TableAPI', 'get_record': 'skytools.dbservice:get_record', 'get_record_list': 'skytools.dbservice:get_record_list', 'make_record': 'skytools.dbservice:make_record', 'make_record_array': 'skytools.dbservice:make_record_array', # skytools.dbstruct 'SeqStruct': 'skytools.dbstruct:SeqStruct', 'TableStruct': 'skytools.dbstruct:TableStruct', 'T_ALL': 'skytools.dbstruct:T_ALL', 'T_CONSTRAINT': 'skytools.dbstruct:T_CONSTRAINT', 'T_DEFAULT': 'skytools.dbstruct:T_DEFAULT', 'T_GRANT': 'skytools.dbstruct:T_GRANT', 'T_INDEX': 'skytools.dbstruct:T_INDEX', 'T_OWNER': 'skytools.dbstruct:T_OWNER', 'T_PARENT': 'skytools.dbstruct:T_PARENT', 'T_PKEY': 'skytools.dbstruct:T_PKEY', 'T_RULE': 'skytools.dbstruct:T_RULE', 'T_SEQUENCE': 'skytools.dbstruct:T_SEQUENCE', 'T_TABLE': 'skytools.dbstruct:T_TABLE', 'T_TRIGGER': 'skytools.dbstruct:T_TRIGGER', # skytools.fileutil 'signal_pidfile': 'skytools.fileutil:signal_pidfile', 'write_atomic': 'skytools.fileutil:write_atomic', # skytools.gzlog 'gzip_append': 'skytools.gzlog:gzip_append', # skytools.hashtext 'hashtext_old': 'skytools.hashtext:hashtext_old', 'hashtext_new': 'skytools.hashtext:hashtext_new', # skytools.natsort 'natsort': 'skytools.natsort:natsort', 'natsort_icase': 'skytools.natsort:natsort_icase', 'natsorted': 'skytools.natsort:natsorted', 'natsorted_icase': 'skytools.natsort:natsorted_icase', 'natsort_key': 'skytools.natsort:natsort_key', 'natsort_key_icase': 'skytools.natsort:natsort_key_icase', # skytools.parsing 'dedent': 'skytools.parsing:dedent', 'hsize_to_bytes': 'skytools.parsing:hsize_to_bytes', 'merge_connect_string': 'skytools.parsing:merge_connect_string', 'parse_acl': 'skytools.parsing:parse_acl', 'parse_connect_string': 'skytools.parsing:parse_connect_string', 'parse_logtriga_sql': 'skytools.parsing:parse_logtriga_sql', 'parse_pgarray': 'skytools.parsing:parse_pgarray', 'parse_sqltriga_sql': 'skytools.parsing:parse_sqltriga_sql', 'parse_statements': 'skytools.parsing:parse_statements', 'parse_tabbed_table': 'skytools.parsing:parse_tabbed_table', 'sql_tokenizer': 'skytools.parsing:sql_tokenizer', # skytools.psycopgwrapper 'connect_database': 'skytools.psycopgwrapper:connect_database', 'DBError': 'skytools.psycopgwrapper:DBError', 'I_AUTOCOMMIT': 'skytools.psycopgwrapper:I_AUTOCOMMIT', 'I_READ_COMMITTED': 'skytools.psycopgwrapper:I_READ_COMMITTED', 'I_REPEATABLE_READ': 'skytools.psycopgwrapper:I_REPEATABLE_READ', 'I_SERIALIZABLE': 'skytools.psycopgwrapper:I_SERIALIZABLE', # skytools.querybuilder 'PLPyQuery': 'skytools.querybuilder:PLPyQuery', 'PLPyQueryBuilder': 'skytools.querybuilder:PLPyQueryBuilder', 'QueryBuilder': 'skytools.querybuilder:QueryBuilder', 'plpy_exec': 'skytools.querybuilder:plpy_exec', 'run_exists': 'skytools.querybuilder:run_exists', 'run_lookup': 'skytools.querybuilder:run_lookup', 'run_query': 'skytools.querybuilder:run_query', 'run_query_row': 'skytools.querybuilder:run_query_row', # skytools.quoting 'db_urldecode': 'skytools.quoting:db_urldecode', 'db_urlencode': 'skytools.quoting:db_urlencode', 'json_decode': 'skytools.quoting:json_decode', 'json_encode': 'skytools.quoting:json_encode', 'make_pgarray': 'skytools.quoting:make_pgarray', 'quote_bytea_copy': 'skytools.quoting:quote_bytea_copy', 'quote_bytea_literal': 'skytools.quoting:quote_bytea_literal', 'quote_bytea_raw': 'skytools.quoting:quote_bytea_raw', 'quote_copy': 'skytools.quoting:quote_copy', 'quote_fqident': 'skytools.quoting:quote_fqident', 'quote_ident': 'skytools.quoting:quote_ident', 'quote_json': 'skytools.quoting:quote_json', 'quote_literal': 'skytools.quoting:quote_literal', 'quote_statement': 'skytools.quoting:quote_statement', 'unescape': 'skytools.quoting:unescape', 'unescape_copy': 'skytools.quoting:unescape_copy', 'unquote_fqident': 'skytools.quoting:unquote_fqident', 'unquote_ident': 'skytools.quoting:unquote_ident', 'unquote_literal': 'skytools.quoting:unquote_literal', # skytools.scripting 'BaseScript': 'skytools.scripting:BaseScript', 'daemonize': 'skytools.scripting:daemonize', 'DBScript': 'skytools.scripting:DBScript', 'UsageError': 'skytools.scripting:UsageError', # skytools.skylog 'getLogger': 'skytools.skylog:getLogger', # skytools.sockutil 'set_cloexec': 'skytools.sockutil:set_cloexec', 'set_nonblocking': 'skytools.sockutil:set_nonblocking', 'set_tcp_keepalive': 'skytools.sockutil:set_tcp_keepalive', # skytools.sqltools 'dbdict': 'skytools.sqltools:dbdict', 'CopyPipe': 'skytools.sqltools:CopyPipe', 'DBFunction': 'skytools.sqltools:DBFunction', 'DBLanguage': 'skytools.sqltools:DBLanguage', 'DBObject': 'skytools.sqltools:DBObject', 'DBSchema': 'skytools.sqltools:DBSchema', 'DBTable': 'skytools.sqltools:DBTable', 'Snapshot': 'skytools.sqltools:Snapshot', 'db_install': 'skytools.sqltools:db_install', 'exists_function': 'skytools.sqltools:exists_function', 'exists_language': 'skytools.sqltools:exists_language', 'exists_schema': 'skytools.sqltools:exists_schema', 'exists_sequence': 'skytools.sqltools:exists_sequence', 'exists_table': 'skytools.sqltools:exists_table', 'exists_temp_table': 'skytools.sqltools:exists_temp_table', 'exists_type': 'skytools.sqltools:exists_type', 'exists_view': 'skytools.sqltools:exists_view', 'fq_name': 'skytools.sqltools:fq_name', 'fq_name_parts': 'skytools.sqltools:fq_name_parts', 'full_copy': 'skytools.sqltools:full_copy', 'get_table_columns': 'skytools.sqltools:get_table_columns', 'get_table_oid': 'skytools.sqltools:get_table_oid', 'get_table_pkeys': 'skytools.sqltools:get_table_pkeys', 'installer_apply_file': 'skytools.sqltools:installer_apply_file', 'installer_find_file': 'skytools.sqltools:installer_find_file', 'magic_insert': 'skytools.sqltools:magic_insert', 'mk_delete_sql': 'skytools.sqltools:mk_delete_sql', 'mk_insert_sql': 'skytools.sqltools:mk_insert_sql', 'mk_update_sql': 'skytools.sqltools:mk_update_sql', # skytools.timeutil 'FixedOffsetTimezone': 'skytools.timeutil:FixedOffsetTimezone', 'datetime_to_timestamp': 'skytools.timeutil:datetime_to_timestamp', 'parse_iso_timestamp': 'skytools.timeutil:parse_iso_timestamp', # skytools.utf8 'safe_utf8_decode': 'skytools.utf8:safe_utf8_decode', } __all__ = _symbols.keys() _symbols['__version__'] = 'skytools.installer_config:package_version' # lazy-import exported vars _apipkg.initpkg(__name__, _symbols, {'apipkg': _apipkg}) python-skytools-3.4/skytools/_pyquoting.py000066400000000000000000000114411356323561300212350ustar00rootroot00000000000000# _pyquoting.py """Various helpers for string quoting/unquoting. Here is pure Python that should match C code in _cquoting. """ from __future__ import division, absolute_import, print_function import re try: from urllib.parse import quote_plus, unquote_plus # noqa def _bytes_val(c): return c except ImportError: from urllib import quote_plus, unquote_plus # noqa _bytes_val = chr __all__ = [ "quote_literal", "quote_copy", "quote_bytea_raw", "db_urlencode", "db_urldecode", "unescape", "unquote_literal", ] # # SQL quoting # def quote_literal(s): r"""Quote a literal value for SQL. If string contains '\\', extended E'' quoting is used, otherwise standard quoting. Input value of None results in string "null" without quotes. Python implementation. """ if s is None: return "null" s = str(s).replace("'", "''") s2 = s.replace("\\", "\\\\") if len(s) != len(s2): return "E'" + s2 + "'" return "'" + s2 + "'" def quote_copy(s): """Quoting for copy command. None is converted to \\N. Python implementation. """ if s is None: return "\\N" s = str(s) s = s.replace("\\", "\\\\") s = s.replace("\t", "\\t") s = s.replace("\n", "\\n") s = s.replace("\r", "\\r") return s _bytea_map = None def quote_bytea_raw(s): """Quoting for bytea parser. Returns None as None. Python implementation. """ global _bytea_map if s is None: return None if not isinstance(s, bytes): raise TypeError("Expect bytes") if 1 and _bytea_map is None: _bytea_map = {} for i in range(256): c = _bytes_val(i) if i < 0x20 or i >= 0x7F: _bytea_map[c] = "\\%03o" % i elif i == ord("\\"): _bytea_map[c] = "\\\\" else: _bytea_map[c] = '%c' % i return "".join([_bytea_map[b] for b in s]) # # Database specific urlencode and urldecode. # def db_urlencode(dict_val): """Database specific urlencode. Encode None as key without '='. That means that in "foo&bar=", foo is NULL and bar is empty string. Python implementation. """ elem_list = [] for k, v in dict_val.items(): if v is None: elem = quote_plus(str(k)) else: elem = quote_plus(str(k)) + '=' + quote_plus(str(v)) elem_list.append(elem) return '&'.join(elem_list) def db_urldecode(qs): """Database specific urldecode. Decode key without '=' as None. This also does not support one key several times. Python implementation. """ res = {} for elem in qs.split('&'): if not elem: continue pair = elem.split('=', 1) name = unquote_plus(pair[0]) if len(pair) == 1: res[name] = None else: res[name] = unquote_plus(pair[1]) return res # # Remove C-like backslash escapes # _esc_re = r"\\([0-7]{1,3}|.)" _esc_rc = re.compile(_esc_re) _esc_map = { 't': '\t', 'n': '\n', 'r': '\r', 'a': '\a', 'b': '\b', "'": "'", '"': '"', '\\': '\\', } def _sub_unescape_c(m): """unescape single escape seq.""" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v else: return chr(int(v, 8)) def unescape(val): """Removes C-style escapes from string. Python implementation. """ return _esc_rc.sub(_sub_unescape_c, val) _esql_re = r"''|\\([0-7]{1,3}|.)" _esql_rc = re.compile(_esql_re) def _sub_unescape_sqlext(m): """Unescape extended-quoted string.""" if m.group() == "''": return "'" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v return chr(int(v, 8)) def unquote_literal(val, stdstr=False): """Unquotes SQL string. E'..' -> extended quoting. '..' -> standard or extended quoting null -> None other -> returned as-is """ if val[0] == "'" and val[-1] == "'": if stdstr: return val[1:-1].replace("''", "'") else: return _esql_rc.sub(_sub_unescape_sqlext, val[1:-1]) elif len(val) > 2 and val[0] in ('E', 'e') and val[1] == "'" and val[-1] == "'": return _esql_rc.sub(_sub_unescape_sqlext, val[2:-1]) elif len(val) >= 2 and val[0] == '$' and val[-1] == '$': p1 = val.find('$', 1) p2 = val.rfind('$', 1, -1) if p1 > 0 and p2 > p1: t1 = val[:p1+1] t2 = val[p2:] if t1 == t2: return val[len(t1):-len(t1)] raise ValueError("Bad dollar-quoted string") elif val.lower() == "null": return None return val python-skytools-3.4/skytools/adminscript.py000066400000000000000000000100131356323561300213460ustar00rootroot00000000000000"""Admin scripting. """ # allow getargspec # pylint:disable=deprecated-method from __future__ import division, absolute_import, print_function import sys import inspect import skytools __all__ = ['AdminScript'] class AdminScript(skytools.DBScript): """Contains common admin script tools. Second argument (first is .ini file) is taken as command name. If class method 'cmd_' + arg exists, it is called, otherwise error is given. """ commands_without_pidfile = {} def __init__(self, service_name, args): """AdminScript init.""" super(AdminScript, self).__init__(service_name, args) if len(self.args) < 2: self.log.error("need command") sys.exit(1) cmd = self.args[1] if cmd in self.commands_without_pidfile: self.pidfile = None if self.pidfile: self.pidfile = self.pidfile + ".admin" def work(self): """Non-looping work function, calls command function.""" self.set_single_loop(1) cmd = self.args[1] cmdargs = self.args[2:] # find function fname = "cmd_" + cmd.replace('-', '_') if not hasattr(self, fname): self.log.error('bad subcommand, see --help for usage') sys.exit(1) fn = getattr(self, fname) # check if correct number of arguments (args, varargs, ___varkw, ___defaults) = inspect.getargspec(fn) n_args = len(args) - 1 # drop 'self' if varargs is None and n_args != len(cmdargs): helpstr = "" if n_args: helpstr = ": " + " ".join(args[1:]) self.log.error("command '%s' got %d args, but expects %d%s", cmd, len(cmdargs), n_args, helpstr) sys.exit(1) # run command fn(*cmdargs) def fetch_list(self, db, sql, args, keycol=None): """Fetch a resultset from db, optionally turning it into value list.""" curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if not keycol: res = rows else: res = [r[keycol] for r in rows] return res def display_table(self, db, desc, sql, args=(), fields=(), fieldfmt=None): """Display multirow query as a table.""" self.log.debug("display_table: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if len(rows) == 0: return 0 if not fieldfmt: fieldfmt = {} if not fields: fields = [f[0] for f in curs.description] widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = row[k] and len(str(row[k])) or 0 widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) if desc: print(desc) print(fmt % tuple(fields)) print(fmt % tuple(['-' * (w - 2) for w in widths])) #print(fmt % tuple(['-'*15] * len(fields))) for row in rows: vals = [] for field in fields: val = row[field] if field in fieldfmt: val = fieldfmt[field](val) vals.append(val) print(fmt % tuple(vals)) print('\n') return 1 def exec_stmt(self, db, sql, args): """Run regular non-query SQL on db.""" self.log.debug("exec_stmt: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) db.commit() def exec_query(self, db, sql, args): """Run regular query SQL on db.""" self.log.debug("exec_query: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) res = curs.fetchall() db.commit() return res python-skytools-3.4/skytools/apipkg.py000066400000000000000000000145521356323561300203200ustar00rootroot00000000000000""" apipkg: control the exported namespace of a python package. see http://pypi.python.org/pypi/apipkg (c) holger krekel, 2009 - MIT license """ #pylint: skip-file import os import sys from types import ModuleType __version__ = '1.4' def _py_abspath(path): """ special version of abspath that will leave paths from jython jars alone """ if path.startswith('__pyclasspath__'): return path else: return os.path.abspath(path) def distribution_version(name): """try to get the version of the named distribution, returs None on failure""" from pkg_resources import get_distribution, DistributionNotFound try: dist = get_distribution(name) except DistributionNotFound: pass else: return dist.version def initpkg(pkgname, exportdefs, attr=None, eager=False): """ initialize given package from the export definitions. """ oldmod = sys.modules.get(pkgname) d = {} f = getattr(oldmod, '__file__', None) if f: f = _py_abspath(f) d['__file__'] = f if hasattr(oldmod, '__version__'): d['__version__'] = oldmod.__version__ if hasattr(oldmod, '__loader__'): d['__loader__'] = oldmod.__loader__ if hasattr(oldmod, '__path__'): d['__path__'] = [_py_abspath(p) for p in oldmod.__path__] if '__doc__' not in exportdefs and getattr(oldmod, '__doc__', None): d['__doc__'] = oldmod.__doc__ if attr: d.update(attr) if hasattr(oldmod, "__dict__"): oldmod.__dict__.update(d) mod = ApiModule(pkgname, exportdefs, implprefix=pkgname, attr=d) sys.modules[pkgname] = mod # eagerload in bypthon to avoid their monkeypatching breaking packages if 'bpython' in sys.modules or eager: for module in sys.modules.values(): if isinstance(module, ApiModule): len(module.__dict__) def importobj(modpath, attrname): module = __import__(modpath, None, None, ['__doc__']) if not attrname: return module retval = module names = attrname.split(".") for x in names: retval = getattr(retval, x) return retval class ApiModule(ModuleType): __doc = None def __docget(self): try: return self.__doc except AttributeError: if '__doc__' in self.__map__: return self.__makeattr('__doc__') def __docset(self, value): self.__doc = value __doc__ = property(__docget, __docset) def __init__(self, name, importspec, implprefix=None, attr=None): super(ApiModule, self).__init__(name) self.__all__ = [x for x in importspec if x != '__onfirstaccess__'] self.__map__ = {} self.__implprefix__ = implprefix or name if attr: for name, val in attr.items(): # print "setting", self.__name__, name, val setattr(self, name, val) for name, importspec in importspec.items(): if isinstance(importspec, dict): subname = '%s.%s' % (self.__name__, name) apimod = ApiModule(subname, importspec, implprefix) sys.modules[subname] = apimod setattr(self, name, apimod) else: parts = importspec.split(':') modpath = parts.pop(0) attrname = parts and parts[0] or "" if modpath[0] == '.': modpath = implprefix + modpath if not attrname: subname = '%s.%s' % (self.__name__, name) apimod = makeAliasModule(subname, modpath) sys.modules[subname] = apimod if '.' not in name: setattr(self, name, apimod) else: self.__map__[name] = (modpath, attrname) def __repr__(self): l = [] if hasattr(self, '__version__'): l.append("version=" + repr(self.__version__)) if hasattr(self, '__file__'): l.append('from ' + repr(self.__file__)) if l: return '' % (self.__name__, " ".join(l)) return '' % (self.__name__,) def __makeattr(self, name): """lazily compute value for name or raise AttributeError if unknown.""" # print "makeattr", self.__name__, name target = None if '__onfirstaccess__' in self.__map__: target = self.__map__.pop('__onfirstaccess__') importobj(*target)() try: modpath, attrname = self.__map__[name] except KeyError: if target is not None and name != '__onfirstaccess__': # retry, onfirstaccess might have set attrs return getattr(self, name) raise AttributeError(name) else: result = importobj(modpath, attrname) setattr(self, name, result) try: del self.__map__[name] except KeyError: pass # in a recursive-import situation a double-del can happen return result __getattr__ = __makeattr @property def __dict__(self): # force all the content of the module # to be loaded when __dict__ is read dictdescr = ModuleType.__dict__['__dict__'] mdict = dictdescr.__get__(self) if mdict is not None: hasattr(self, 'some') for name in self.__all__: try: self.__makeattr(name) except AttributeError: pass return mdict def makeAliasModule(modname, modpath, attrname=None): mod = [] def getmod(): if not mod: x = importobj(modpath, None) if attrname is not None: x = getattr(x, attrname) mod.append(x) return mod[0] class AliasModule(ModuleType): def __repr__(self): x = modpath if attrname: x += "." + attrname return '' % (modname, x) def __getattribute__(self, name): try: return getattr(getmod(), name) except ImportError: return None def __setattr__(self, name, value): setattr(getmod(), name, value) def __delattr__(self, name): delattr(getmod(), name) return AliasModule(str(modname)) python-skytools-3.4/skytools/checker.py000066400000000000000000000503151356323561300204460ustar00rootroot00000000000000"""Catch moment when tables are in sync on master and slave. """ from __future__ import division, absolute_import, print_function import sys import time import os import subprocess import skytools class TableRepair(object): """Checks that tables in two databases are in sync.""" def __init__(self, table_name, log): self.table_name = table_name self.fq_table_name = skytools.quote_fqident(table_name) self.log = log self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None self.reset() def reset(self): self.cnt_insert = 0 self.cnt_update = 0 self.cnt_delete = 0 self.total_src = 0 self.total_dst = 0 self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None def do_repair(self, src_db, dst_db, where, pfx='repair', apply_fixes=False): """Actual comparison.""" self.reset() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.apply_fixes = apply_fixes if apply_fixes: self.apply_cursor = dst_curs self.log.info('Checking %s', self.table_name) copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where) dump_src = "%s.%s.src" % (pfx, self.table_name) dump_dst = "%s.%s.dst" % (pfx, self.table_name) fix = "%s.%s.fix" % (pfx, self.table_name) self.log.info("Dumping src table: %s", self.table_name) self.dump_table(copy_tbl, src_curs, dump_src) src_db.commit() self.log.info("Dumping dst table: %s", self.table_name) self.dump_table(copy_tbl, dst_curs, dump_dst) dst_db.commit() self.log.info("Sorting src table: %s", self.table_name) self.do_sort(dump_src, dump_src + '.sorted') self.log.info("Sorting dst table: %s", self.table_name) self.do_sort(dump_dst, dump_dst + '.sorted') self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix) os.unlink(dump_src) os.unlink(dump_dst) os.unlink(dump_src + ".sorted") os.unlink(dump_dst + ".sorted") if apply_fixes: dst_db.commit() def do_sort(self, src, dst): p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) s_ver = p.communicate()[0] del p xenv = os.environ.copy() xenv['LANG'] = 'C' xenv['LC_ALL'] = 'C' cmdline = ['sort', '-T', '.'] if s_ver.find("coreutils") > 0: cmdline.append('-S') cmdline.append('30%') cmdline.append('-o') cmdline.append(dst) cmdline.append(src) p = subprocess.Popen(cmdline, env=xenv) if p.wait() != 0: raise Exception('sort failed') def gen_copy_tbl(self, src_curs, dst_curs, where): """Create COPY expession from common fields.""" self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name) dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name) if dst_pkey != self.pkey_list: self.log.error('pkeys do not match') sys.exit(1) src_cols = skytools.get_table_columns(src_curs, self.table_name) dst_cols = skytools.get_table_columns(dst_curs, self.table_name) field_list = [] for f in self.pkey_list: field_list.append(f) for f in src_cols: if f in self.pkey_list: continue if f in dst_cols: field_list.append(f) self.common_fields = field_list fqlist = [skytools.quote_ident(col) for col in field_list] tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name) if where: tbl_expr += ' where ' + where tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr self.log.debug("using copy expr: %s", tbl_expr) return tbl_expr def dump_table(self, copy_cmd, curs, fn): """Dump table to disk.""" f = open(fn, "w", 64*1024) curs.copy_expert(copy_cmd, f) self.log.info('%s: Got %d bytes', self.table_name, f.tell()) f.close() def get_row(self, ln): """Parse a row into dict.""" if not ln: return None t = ln[:-1].split('\t') row = {} for i in range(len(self.common_fields)): row[self.common_fields[i]] = t[i] return row def dump_compare(self, src_fn, dst_fn, fix): """Dump + compare single table.""" self.log.info("Comparing dumps: %s", self.table_name) f1 = open(src_fn, "r", 64*1024) f2 = open(dst_fn, "r", 64*1024) src_ln = f1.readline() dst_ln = f2.readline() if src_ln: self.total_src += 1 if dst_ln: self.total_dst += 1 if os.path.isfile(fix): os.unlink(fix) while src_ln or dst_ln: keep_src = keep_dst = 0 if src_ln != dst_ln: src_row = self.get_row(src_ln) dst_row = self.get_row(dst_ln) diff = self.cmp_keys(src_row, dst_row) if diff > 0: # src > dst self.got_missed_delete(dst_row, fix) keep_src = 1 elif diff < 0: # src < dst self.got_missed_insert(src_row, fix) keep_dst = 1 else: if self.cmp_data(src_row, dst_row) != 0: self.got_missed_update(src_row, dst_row, fix) if not keep_src: src_ln = f1.readline() if src_ln: self.total_src += 1 if not keep_dst: dst_ln = f2.readline() if dst_ln: self.total_dst += 1 self.log.info("finished %s: src: %d rows, dst: %d rows,"\ " missed: %d inserts, %d updates, %d deletes", self.table_name, self.total_src, self.total_dst, self.cnt_insert, self.cnt_update, self.cnt_delete) def got_missed_insert(self, src_row, fn): """Create sql for missed insert.""" self.cnt_insert += 1 fld_list = self.common_fields fq_list = [] val_list = [] for f in fld_list: fq_list.append(skytools.quote_ident(f)) v = skytools.unescape_copy(src_row[f]) val_list.append(skytools.quote_literal(v)) q = "insert into %s (%s) values (%s);" % ( self.fq_table_name, ", ".join(fq_list), ", ".join(val_list)) self.show_fix(q, 'insert', fn) def got_missed_update(self, src_row, dst_row, fn): """Create sql for missed update.""" self.cnt_update += 1 fld_list = self.common_fields set_list = [] whe_list = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f])) for f in fld_list: v1 = src_row[f] v2 = dst_row[f] if self.cmp_value(v1, v2) == 0: continue self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1)) self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2)) q = "update only %s set %s where %s;" % ( self.fq_table_name, ", ".join(set_list), " and ".join(whe_list)) self.show_fix(q, 'update', fn) def got_missed_delete(self, dst_row, fn): """Create sql for missed delete.""" self.cnt_delete += 1 whe_list = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f])) q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list)) self.show_fix(q, 'delete', fn) def show_fix(self, q, desc, fn): """Print/write/apply repair sql.""" self.log.debug("missed %s: %s", desc, q) open(fn, "a").write("%s\n" % q) if self.apply_fixes: self.apply_cursor.execute(q) def addeq(self, dst_list, f, v): """Add quoted SET.""" vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def addcmp(self, dst_list, f, v): """Add quoted comparison.""" if v is None: s = "%s is null" % f else: vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def cmp_data(self, src_row, dst_row): """Compare data field-by-field.""" for k in self.common_fields: v1 = src_row[k] v2 = dst_row[k] if self.cmp_value(v1, v2) != 0: return -1 return 0 def cmp_value(self, v1, v2): """Compare single field, tolerates tz vs notz dates.""" if v1 == v2: return 0 # try to work around tz vs. notz z1 = len(v1) z2 = len(v2) if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+': v1 = v1[:-3] if v1 == v2: return 0 elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+': v2 = v2[:-3] if v1 == v2: return 0 return -1 def cmp_keys(self, src_row, dst_row): """Compare primary keys of the rows. Returns 1 if src > dst, -1 if src < dst and 0 if src == dst""" # None means table is done. tag it larger than any existing row. if src_row is None: if dst_row is None: return 0 return 1 elif dst_row is None: return -1 for k in self.pkey_list: v1 = src_row[k] v2 = dst_row[k] if v1 < v2: return -1 elif v1 > v2: return 1 return 0 class Syncer(skytools.DBScript): """Checks that tables in two databases are in sync.""" lock_timeout = 10 ticker_lag_limit = 20 consumer_lag_limit = 20 def sync_table(self, cstr1, cstr2, queue_name, consumer_name, table_name): """Syncer main function. Returns (src_db, dst_db) that are in transaction where table should be in sync. """ setup_db = self.get_database('setup_db', connstr=cstr1, autocommit=1) lock_db = self.get_database('lock_db', connstr=cstr1) src_db = self.get_database('src_db', connstr=cstr1, isolation_level=skytools.I_REPEATABLE_READ) dst_db = self.get_database('dst_db', connstr=cstr2, isolation_level=skytools.I_REPEATABLE_READ) lock_curs = lock_db.cursor() setup_curs = setup_db.cursor() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.check_consumer(setup_curs, queue_name, consumer_name) # lock table in separate connection self.log.info('Locking %s', table_name) self.set_lock_timeout(lock_curs) lock_time = time.time() lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(table_name)) # now wait until consumer has updated target table until locking self.log.info('Syncing %s', table_name) # consumer must get further than this tick self.force_tick(setup_curs, queue_name) # try to force second tick also self.force_tick(setup_curs, queue_name) # take server time setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')") tpos = setup_curs.fetchone()[0] # now wait while 1: time.sleep(0.5) q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)" setup_curs.execute(q, [tpos, queue_name, consumer_name]) res = setup_curs.fetchall() if len(res) == 0: raise Exception('No such consumer: %s/%s' % (queue_name, consumer_name)) row = res[0] self.log.debug("tpos=%s now=%s lag=%s ok=%s", tpos, row[1], row[2], row[0]) if row[0]: break # limit lock time if time.time() > lock_time + self.lock_timeout: self.log.error('Consumer lagging too much, exiting') lock_db.rollback() sys.exit(1) # take snapshot on provider side src_db.commit() src_curs.execute("SELECT 1") # take snapshot on subscriber side dst_db.commit() dst_curs.execute("SELECT 1") # release lock lock_db.commit() self.close_database('setup_db') self.close_database('lock_db') return (src_db, dst_db) def set_lock_timeout(self, curs): ms = int(1000 * self.lock_timeout) if ms > 0: q = "SET LOCAL statement_timeout = %d" % ms self.log.debug(q) curs.execute(q) def check_consumer(self, curs, queue_name, consumer_name): """ Before locking anything check if consumer is working ok. """ self.log.info("Queue: %s Consumer: %s", queue_name, consumer_name) curs.execute('select current_database()') self.log.info('Actual db: %s', curs.fetchone()[0]) # get ticker lag q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);" curs.execute(q, [queue_name]) ticker_lag = curs.fetchone()[0] self.log.info("Ticker lag: %s", ticker_lag) # get consumer lag q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);" curs.execute(q, [queue_name, consumer_name]) res = curs.fetchall() if len(res) == 0: self.log.error('check_consumer: No such consumer: %s/%s', queue_name, consumer_name) sys.exit(1) consumer_lag = res[0][0] # check that lag is acceptable self.log.info("Consumer lag: %s", consumer_lag) if consumer_lag > ticker_lag + 10: self.log.error('Consumer lagging too much, cannot proceed') sys.exit(1) def force_tick(self, curs, queue_name): """ Force tick into source queue so that consumer can move on faster """ q = "select pgq.force_tick(%s)" curs.execute(q, [queue_name]) res = curs.fetchone() cur_pos = res[0] start = time.time() while 1: time.sleep(0.5) curs.execute(q, [queue_name]) res = curs.fetchone() if res[0] != cur_pos: # new pos return res[0] # dont loop more than 10 secs dur = time.time() - start if dur > 10 and not self.options.force: raise Exception("Ticker seems dead") class Checker(Syncer): """Checks that tables in two databases are in sync. Config options:: ## data_checker ## confdb = dbname=confdb host=confdb.service extra_connstr = user=marko # one of: compare, repair, repair-apply, compare-repair-apply check_type = compare # random params used in queries cluster_name = instance_name = proxy_host = proxy_db = # list of tables to be compared table_list = foo, bar, baz where_expr = (hashtext(key_user_name) & %%(max_slot)s) in (%%(slots)s) # gets no args source_query = select h.hostname, d.db_name from dba.cluster c join dba.cluster_host ch on (ch.key_cluster = c.id_cluster) join conf.host h on (h.id_host = ch.key_host) join dba.database d on (d.key_host = ch.key_host) where c.db_name = '%(cluster_name)s' and c.instance_name = '%(instance_name)s' and d.mk_db_type = 'partition' and d.mk_db_status = 'active' order by d.db_name, h.hostname target_query = select db_name, hostname, slots, max_slot from dba.get_cross_targets(%%(hostname)s, %%(db_name)s, '%(proxy_host)s', '%(proxy_db)s') consumer_query = select q.queue_name, c.consumer_name from conf.host h join dba.database d on (d.key_host = h.id_host) join dba.pgq_queue q on (q.key_database = d.id_database) join dba.pgq_consumer c on (c.key_queue = q.id_queue) where h.hostname = %%(hostname)s and d.db_name = %%(db_name)s and q.queue_name like 'xm%%%%' """ def __init__(self, args): """Checker init.""" super(Checker, self).__init__('data_checker', args) self.set_single_loop(1) self.log.info('Checker starting %s', str(args)) self.lock_timeout = self.cf.getfloat('lock_timeout', 10) self.table_list = self.cf.getlist('table_list') def work(self): """Syncer main function.""" source_query = self.cf.get('source_query') target_query = self.cf.get('target_query') consumer_query = self.cf.get('consumer_query') where_expr = self.cf.get('where_expr') extra_connstr = self.cf.get('extra_connstr') check = self.cf.get('check_type', 'compare') confdb = self.get_database('confdb', autocommit=1) curs = confdb.cursor() curs.execute(source_query) for src_row in curs.fetchall(): s_host = src_row['hostname'] s_db = src_row['db_name'] curs.execute(consumer_query, src_row) r = curs.fetchone() consumer_name = r['consumer_name'] queue_name = r['queue_name'] curs.execute(target_query, src_row) for dst_row in curs.fetchall(): d_db = dst_row['db_name'] d_host = dst_row['hostname'] cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr) cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr) where = where_expr % dst_row self.log.info('Source: db=%s host=%s queue=%s consumer=%s', s_db, s_host, queue_name, consumer_name) self.log.info('Target: db=%s host=%s where=%s', d_db, d_host, where) for tbl in self.table_list: src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl) if check == 'compare': self.do_compare(tbl, src_db, dst_db, where) elif check == 'repair': r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, False) elif check == 'repair-apply': r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) elif check == 'compare-repair-apply': ok = self.do_compare(tbl, src_db, dst_db, where) if not ok: r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) else: raise Exception('unknown check type') self.reset() def do_compare(self, tbl, src_db, dst_db, where): """Actual comparison.""" src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.log.info('Counting %s', tbl) q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where q = self.cf.get('compare_sql', q) q = q.replace('_TABLE_', skytools.quote_fqident(tbl)) f = "%(cnt)d rows, checksum=%(chksum)s" f = self.cf.get('compare_fmt', f) self.log.debug("srcdb: %s", q) src_curs.execute(q) src_row = src_curs.fetchone() src_str = f % src_row self.log.info("srcdb: %s", src_str) self.log.debug("dstdb: %s", q) dst_curs.execute(q) dst_row = dst_curs.fetchone() dst_str = f % dst_row self.log.info("dstdb: %s", dst_str) src_db.commit() dst_db.commit() if src_str != dst_str: self.log.warning("%s: Results do not match!", tbl) return False else: self.log.info("%s: OK!", tbl) return True if __name__ == '__main__': script = Checker(sys.argv[1:]) script.start() python-skytools-3.4/skytools/config.py000066400000000000000000000273051356323561300203120ustar00rootroot00000000000000 """Nicer config class.""" from __future__ import division, absolute_import, print_function import os import os.path import re import socket import skytools try: from configparser import ( # noqa NoOptionError, NoSectionError, InterpolationError, InterpolationDepthError, Error as ConfigError, ConfigParser, MAX_INTERPOLATION_DEPTH, ExtendedInterpolation, Interpolation) except ImportError: from ConfigParser import ( # noqa NoOptionError, NoSectionError, InterpolationError, InterpolationDepthError, Error as ConfigError, SafeConfigParser, MAX_INTERPOLATION_DEPTH) class Interpolation(object): """Define Interpolation API from Python3.""" def before_get(self, parser, section, option, value, defaults): return value def before_set(self, parser, section, option, value): return value def before_read(self, parser, section, option, value): return value def before_write(self, parser, section, option, value): return value class ConfigParser(SafeConfigParser): """Default Python's ConfigParser that uses _DEFAULT_INTERPOLATION""" _DEFAULT_INTERPOLATION = None def _interpolate(self, section, option, rawval, defs): if self._DEFAULT_INTERPOLATION is None: return SafeConfigParser._interpolate(self, section, option, rawval, defs) return self._DEFAULT_INTERPOLATION.before_get(self, section, option, rawval, defs) __all__ = [ 'Config', 'NoOptionError', 'ConfigError', 'ConfigParser', 'ExtendedConfigParser', 'ExtendedCompatConfigParser' ] class Config(object): """Bit improved ConfigParser. Additional features: - Remembers section. - Accepts defaults in get() functions. - List value support. """ def __init__(self, main_section, filename, sane_config=None, user_defs=None, override=None, ignore_defs=False): """Initialize Config and read from file. """ # use config file name as default job_name if filename: job_name = os.path.splitext(os.path.basename(filename))[0] else: job_name = main_section # initialize defaults, make them usable in config file if ignore_defs: self.defs = {} else: self.defs = { 'job_name': job_name, 'service_name': main_section, 'host_name': socket.gethostname(), } if filename: self.defs['config_dir'] = os.path.dirname(filename) self.defs['config_file'] = filename if user_defs: self.defs.update(user_defs) self.main_section = main_section self.filename = filename self.override = override or {} self.cf = ConfigParser() if filename is None: self.cf.add_section(main_section) elif not os.path.isfile(filename): raise ConfigError('Config file not found: '+filename) self.reload() def reload(self): """Re-reads config file.""" if self.filename: self.cf.read(self.filename) if not self.cf.has_section(self.main_section): raise NoSectionError(self.main_section) # apply default if key not set for k, v in self.defs.items(): if not self.cf.has_option(self.main_section, k): self.cf.set(self.main_section, k, v) # apply overrides if self.override: for k, v in self.override.items(): self.cf.set(self.main_section, k, v) def get(self, key, default=None): """Reads string value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return str(self.cf.get(self.main_section, key)) def getint(self, key, default=None): """Reads int value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getint(self.main_section, key) def getboolean(self, key, default=None): """Reads boolean value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getboolean(self.main_section, key) def getfloat(self, key, default=None): """Reads float value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getfloat(self.main_section, key) def getlist(self, key, default=None): """Reads comma-separated list from key.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res = [] if not s: return res for v in s.split(","): res.append(v.strip()) return res def getdict(self, key, default=None): """Reads key-value dict from parameter. Key and value are separated with ':'. If missing, key itself is taken as value. """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res = {} if not s: return res for kv in s.split(","): tmp = kv.split(':', 1) if len(tmp) > 1: k = tmp[0].strip() v = tmp[1].strip() else: k = kv.strip() v = k res[k] = v return res def getfile(self, key, default=None): """Reads filename from config. In addition to reading string value, expands ~ to user directory. """ fn = self.get(key, default) if fn == "" or fn == "-": return fn # simulate that the cwd is script location #path = os.path.dirname(sys.argv[0]) # seems bad idea, cwd should be cwd fn = os.path.expanduser(fn) return fn def getbytes(self, key, default=None): """Reads a size value in human format, if not set then default. Examples: 1, 2 B, 3K, 4 MB """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) s = default else: s = self.cf.get(self.main_section, key) return skytools.hsize_to_bytes(s) def get_wildcard(self, key, values=(), default=None): """Reads a wildcard property from conf and returns its string value, if not set then default.""" orig_key = key keys = [key] for wild in values: key = key.replace('*', wild, 1) keys.append(key) keys.reverse() for k in keys: if self.cf.has_option(self.main_section, k): return self.cf.get(self.main_section, k) if default is None: raise NoOptionError(orig_key, self.main_section) return default def sections(self): """Returns list of sections in config file, excluding DEFAULT.""" return self.cf.sections() def has_section(self, section): """Checks if section is present in config file, excluding DEFAULT.""" return self.cf.has_section(section) def clone(self, main_section): """Return new Config() instance with new main section on same config file.""" return Config(main_section, self.filename) def options(self): """Return list of options in main section.""" return self.cf.options(self.main_section) def has_option(self, opt): """Checks if option exists in main section.""" return self.cf.has_option(self.main_section, opt) def items(self): """Returns list of (name, value) for each option in main section.""" return self.cf.items(self.main_section) # define some aliases (short-cuts / backward compatibility cruft) getbool = getboolean class ExtendedInterpolationCompat(Interpolation): _EXT_VAR_RX = r'\$\$|\$\{[^(){}]+\}' _OLD_VAR_RX = r'%%|%\([^(){}]+\)s' _var_rc = re.compile('(%s|%s)' % (_EXT_VAR_RX, _OLD_VAR_RX)) _bad_rc = re.compile('[%$]') def before_get(self, parser, section, option, rawval, defaults): dst = [] self._interpolate_ext(dst, parser, section, option, rawval, defaults, set()) return ''.join(dst) def before_set(self, parser, section, option, value): sub = self._var_rc.sub('', value) if self._bad_rc.search(sub): raise ValueError("invalid interpolation syntax in %r" % value) return value def _interpolate_ext(self, dst, parser, section, option, rawval, defaults, loop_detect): if not rawval: return if len(loop_detect) > MAX_INTERPOLATION_DEPTH: raise InterpolationDepthError(option, section, rawval) xloop = (section, option) if xloop in loop_detect: raise InterpolationError(section, option, 'Loop detected: %r in %r' % (xloop, loop_detect)) loop_detect.add(xloop) parts = self._var_rc.split(rawval) for i, frag in enumerate(parts): fullkey = None use_vars = defaults if i % 2 == 0: dst.append(frag) continue if frag in ('$$', '%%'): dst.append(frag[0]) continue if frag.startswith('${') and frag.endswith('}'): fullkey = frag[2:-1] # use section access only for new-style keys if ':' in fullkey: ksect, key = fullkey.split(':', 1) use_vars = None else: ksect, key = section, fullkey elif frag.startswith('%(') and frag.endswith(')s'): fullkey = frag[2:-2] ksect, key = section, fullkey else: raise InterpolationError(section, option, 'Internal parse error: %r' % frag) key = parser.optionxform(key) newpart = parser.get(ksect, key, raw=True, vars=use_vars) if newpart is None: raise InterpolationError(ksect, key, 'Key referenced is None') self._interpolate_ext(dst, parser, ksect, key, newpart, defaults, loop_detect) loop_detect.remove(xloop) try: ExtendedInterpolation except NameError: class ExtendedInterpolationPy2(ExtendedInterpolationCompat): _var_rc = re.compile('(%s)' % ExtendedInterpolationCompat._EXT_VAR_RX) _bad_rc = re.compile('[$]') ExtendedInterpolation = ExtendedInterpolationPy2 class ExtendedConfigParser(ConfigParser): """ConfigParser that uses Python3-style extended interpolation by default. Syntax: ${var} and ${section:var} """ _DEFAULT_INTERPOLATION = ExtendedInterpolation() class ExtendedCompatConfigParser(ExtendedConfigParser): r"""Support both extended "${}" syntax from python3 and old "%()s" too. New ${} syntax allows ${key} to refer key in same section, and ${sect:key} to refer key in other sections. """ _DEFAULT_INTERPOLATION = ExtendedInterpolationCompat() python-skytools-3.4/skytools/dbservice.py000066400000000000000000000601051356323561300210060ustar00rootroot00000000000000""" Class used to handle multiset receiving and returning PL/Python procedures """ from __future__ import division, absolute_import, print_function import skytools from skytools import dbdict try: import plpy except ImportError: pass try: basestring except NameError: basestring = str # noqa __all__ = ['DBService', 'ServiceContext', 'get_record', 'get_record_list', 'make_record', 'make_record_array', 'TableAPI', #'log_result', 'transform_fields' ] def transform_fields(rows, key_fields, name_field, data_field): """Convert multiple-rows per key input array to one-row, multiple-column output array. The input arrays must be sorted by the key fields. >>> from skytools.testing import ordered_dict >>> rows = [] >>> rows.append({'time': '22:00', 'metric': 'count', 'value': 100}) >>> rows.append({'time': '22:00', 'metric': 'dur', 'value': 7}) >>> rows.append({'time': '23:00', 'metric': 'count', 'value': 200}) >>> rows.append({'time': '23:00', 'metric': 'dur', 'value': 5}) >>> res = [ordered_dict(row) for row in transform_fields(rows, ['time'], 'metric', 'value')] >>> res[0] OrderedDict([('count', 100), ('dur', 7), ('time', '22:00')]) >>> res[1] OrderedDict([('count', 200), ('dur', 5), ('time', '23:00')]) """ cur_key = None cur_row = None res = [] for r in rows: k = [r[f] for f in key_fields] if k != cur_key: cur_key = k cur_row = {} for f in key_fields: cur_row[f] = r[f] res.append(cur_row) cur_row[r[name_field]] = r[data_field] return res # render_table def render_table(rows, fields): """ Render result rows as a table. Returns array of lines. """ widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = len(str(row.get(k))) widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) lines = [] lines.append(fmt % tuple(fields)) lines.append(fmt % tuple(['-'*15] * len(fields))) for row in rows: lines.append(fmt % tuple([str(row.get(k)) for k in fields])) return lines # data conversion to and from url def get_record(arg): """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ if not arg: return dbdict() # allow array of single record if arg[0] in ('{', '['): lst = skytools.parse_pgarray(arg) if len(lst) != 1: raise ValueError('get_record() expects exactly 1 row, got %d' % len(lst)) arg = lst[0] # parse record return dbdict(skytools.db_urldecode(arg)) def get_record_list(array): """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ if array is None: return [] if not isinstance(array, list): array = skytools.parse_pgarray(array) return [get_record(el) for el in array] def get_record_lists(tbl, field): """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ records = dbdict() for rec in tbl: master_id = str(rec[field]) records.setdefault(master_id, []).append(rec) return records def _make_record_convert(row): """Converts complex values.""" d = row.copy() for k, v in d.items(): if isinstance(v, list): d[k] = skytools.make_pgarray(v) return skytools.db_urlencode(d) def make_record(row): """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ for v in row.values(): if isinstance(v, list): return _make_record_convert(row) return skytools.db_urlencode(row) def make_record_array(rowlist): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return '{' + ','.join([make_record(row) for row in rowlist]) + '}' def get_result_items(rec_list, name): """ Get return values from result """ for r in rec_list: if r['res_code'] == name: return get_record_list(r['res_rows']) return None def log_result(log, rec_list): """ Sends dbservice execution logs to logfile """ msglist = get_result_items(rec_list, "_status") if msglist is None: if rec_list: log.warning('Unhandled output result: _status res_code not present.') else: for msg in msglist: log.debug(msg['_message']) class DBService(object): """ Wrap parameterized query handling and multiset stored procedure writing """ ROW = "_row" # name of the fake field where internal record id is stored FIELD = "_field" # parameter name for the field in record that is related to current message PARAM = "_param" # name of the parameter to which message relates SKIP = "skip" # used when record is needed for it's data but is not been updated INSERT = "insert" UPDATE = "update" DELETE = "delete" INFO = "info" # just informative message for the user NOTICE = "notice" # more than info less than warning WARNING = "warning" # warning message, something is out of ordinary ERROR = "error" # error found but execution continues until check then error is raised FATAL = "fatal" # execution is terminated at once and all found errors returned rows_found = 0 def __init__(self, context, global_dict=None): """ This object must be initiated in the beginning of each db service """ rec = skytools.db_urldecode(context) self._context = context # used to run dbservice in retval self.global_dict = global_dict # used for cacheing query plans self._retval = [] # used to collect return resultsets self._is_test = 'is_test' in rec # used to convert output into human readable form self.sqls = None # if sqls stays None then no recording of sqls is done if "show_sql" in rec: # api must add exected sql to resultset self.sqls = [] # sql's executed by dbservice, used for dubugging self.can_save = True # used to keep value most severe error found so far self.messages = [] # used to hold list of messages to be returned to the user # error and message handling def tell_user(self, severity, code, message, params=None, **kvargs): """ Adds another message to the set of messages to be sent back to user If error message then can_save is set false If fatal message then error or found errors are raised at once """ params = params or kvargs #plpy.notice("%s %s: %s %s" % (severity, code, message, str(params))) params["_severity"] = severity params["_code"] = code params["_message"] = message self.messages.append(params) if severity == self.ERROR: self.can_save = False if severity == self.FATAL: self.can_save = False self.raise_if_errors() def raise_if_errors(self): """ To be used in places where before continuing must be chcked if errors have been found Raises found errors packing them into error message as urlencoded string """ if not self.can_save: msgs = "Dbservice error(s): " + make_record_array(self.messages) plpy.error(msgs) # run sql meant mostly for select but not limited to def create_query(self, sql, params=None, **kvargs): """ Returns initialized querybuilder object for building complex dynamic queries """ params = params or kvargs return skytools.PLPyQueryBuilder(sql, params, self.global_dict, self.sqls) def run_query(self, sql, params=None, **kvargs): """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kvargs rows = skytools.plpy_exec(self.global_dict, sql, params) # convert result rows to dbdict if rows: rows = [dbdict(r) for r in rows] self.rows_found = len(rows) else: self.rows_found = 0 return rows def run_query_row(self, sql, params=None, **kvargs): """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None return rows[0] def run_exists(self, sql, params=None, **kvargs): """ Helper function to find out that record in given table exists using values in dict as criteria. Takes away all the hassle of preparing statements and processing returned result giving out just one boolean """ params = params or kvargs self.run_query(sql, params) return self.rows_found def run_lookup(self, sql, params=None, **kvargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. Uses plan cache if used inside db service """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None row = rows[0] return row.values()[0] # resultset handling def return_next(self, rows, res_name, severity=None): """ Adds given set of rows to resultset """ self._retval.append([res_name, rows]) if severity is not None and len(rows) == 0: self.tell_user(severity, "dbsXXXX", "No matching records found") return rows def return_next_sql(self, sql, params, res_name, severity=None): """ Exectes query and adds recors resultset """ rows = self.run_query(sql, params) return self.return_next(rows, res_name, severity) def retval(self, service_name=None, params=None, **kvargs): """ Return collected resultsets and append to the end messages to the users Method is called usually as last statement in dbservice to return the results Also converts results into desired format """ params = params or kvargs self.raise_if_errors() if len(self.messages): self.return_next(self.messages, "_status") if self.sqls is not None and len(self.sqls): self.return_next(self.sqls, "_sql") results = [] for r in self._retval: res_name = r[0] rows = r[1] res_count = str(len(rows)) if self._is_test and len(rows) > 0: results.append([res_name, res_count, res_name]) n = 1 for trow in render_table(rows, rows[0].keys()): results.append([res_name, n, trow]) n += 1 else: res_rows = make_record_array(rows) results.append([res_name, res_count, res_rows]) if service_name: sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) par = dbdict(i_context=self._context, i_params=make_record(params)) res = self.run_query(sql, par) for r in res: results.append((r.res_code, r.res_text, r.res_rows)) return results # miscellaneous def check_required(self, record_name, record, severity, *fields): """ Checks if all required fields are present in record Used to validate incoming data Returns list of field names that are missing or empty """ missing = [] params = {self.PARAM: record_name} if self.ROW in record: params[self.ROW] = record[self.ROW] for field in fields: params[self.FIELD] = field if field in record: if record[field] is None or (isinstance(record[field], basestring) and len(record[field]) == 0): self.tell_user(severity, "dbsXXXX", "Required value missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) else: self.tell_user(severity, "dbsXXXX", "Required field missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) return missing # TableAPI class TableAPI(object): """ Class for managing one record updates using primary key """ _table = None # schema name and table name _where = None # where condition used for update and delete _id = None # name of the primary key filed _id_type = None # column type of primary key _op = None # operation currently carried out _ctx = None # context object for username and version _logging = True # should tapi log data changed _row = None # row identifer from calling program def __init__(self, ctx, table, create_log=True, id_type='int8'): """ Table name is used to construct insert update and delete statements Table must have primary key field whose name is in format id_ Tablename should be in format schema.tablename """ self._ctx = ctx self._table = skytools.quote_fqident(table) self._id = "id_" + skytools.fq_name_parts(table)[1] self._id_type = id_type self._where = '%s = {%s:%s}' % (skytools.quote_ident(self._id), self._id, self._id_type) self._logging = create_log def _log(self, result, original=None): """ Log changei into table log.changelog """ if not self._logging: return changes = [] for key in result.keys(): if self._op == 'update': if key in original: if str(original[key]) != str(result[key]): changes.append(key + ": " + str(original[key]) + " -> " + str(result[key])) else: changes.append(key + ": " + str(result[key])) self._ctx.log(self._table, result[self._id], self._op, "\n".join(changes)) def _version_check(self, original, version): if original is None: self._ctx.tell_user(self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been deleted by other user "\ "while you were editing. Check version ({ver}) in changelog for details.", table=self._table, field=self._id, id=original[self._id], ver=original.version, _row=self._row) if version is not None and original.version is not None: if int(version) != int(original.version): self._ctx.tell_user(self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been changed by other user while you were editing. "\ "Version in db: ({db_ver}) and version sent by caller ({caller_ver}). "\ "See changelog for details.", table=self._table, field=self._id, id=original[self._id], db_ver=original.version, caller_ver=version, _row=self._row) def _insert(self, data): fields = [] values = [] for key in data.keys(): if data[key] is not None: # ignore empty fields.append(skytools.quote_ident(key)) values.append("{" + key + "}") sql = "insert into %s (%s) values (%s) returning *;" % (self._table, ",".join(fields), ",".join(values)) result = self._ctx.run_query_row(sql, data) self._log(result) return result def _update(self, data, version): sql = "select * from %s where %s" % (self._table, self._where) original = self._ctx.run_query_row(sql, data) self._version_check(original, version) pairs = [] for key in data.keys(): if data[key] is None: pairs.append(key + " = NULL") else: pairs.append(key + " = {" + key + "}") sql = "update %s set %s where %s returning *;" % (self._table, ", ".join(pairs), self._where) result = self._ctx.run_query_row(sql, data) self._log(result, original) return result def _delete(self, data, version): sql = "delete from %s where %s returning *;" % (self._table, self._where) result = self._ctx.run_query_row(sql, data) self._version_check(result, version) self._log(result) return result def do(self, data): """ Do dml according to special field _op that must be given together wit data """ result = data # so it is initialized for skip self._op = data.pop(self._ctx.OP) # determines operation done self._row = data.pop(self._ctx.ROW, None) # internal record id used for error reporting if self._row is None: # if no _row variable was provided self._row = data.get(self._id, None) # use id instead if self._id in data and data[self._id]: # if _id field is given if int(data[self._id]) < 0: # and it is fake key generated by ui data.pop(self._id) # remove fake key so real one can be assigned version = data.get('version', None) # version sent from caller data['version'] = self._ctx.version # current transaction id is stored in each record if self._op == self._ctx.INSERT: result = self._insert(data) elif self._op == self._ctx.UPDATE: result = self._update(data, version) elif self._op == self._ctx.DELETE: result = self._delete(data, version) elif self._op == self._ctx.SKIP: pass else: self._ctx.tell_user(self._ctx.ERROR, "dbsXXXX", "Unahndled _op='{op}' value in TableAPI (table={table}, id={id})", op=self._op, table=self._table, id=data[self._id]) result[self._ctx.OP] = self._op result[self._ctx.ROW] = self._row return result # ServiceContext class ServiceContext(DBService): OP = "_op" # name of the fake field where record modificaton operation is stored def __init__(self, context, global_dict=None): """ This object must be initiated in the beginning of each db service """ super(ServiceContext, self).__init__(context, global_dict) rec = skytools.db_urldecode(context) if "username" not in rec: plpy.error("Username must be provided in db service context parameter") self.username = rec['username'] # used for logging purposes res = plpy.execute("select txid_current() as txid;") row = res[0] self.version = row["txid"] self.rows_found = 0 # Flag set by run query to inicate number of rows got # logging def log(self, _object_type, _key_object, _change_op, _payload): """ Log stuff into the changelog whatever seems relevant to be logged """ self.run_query( "select log.log_change( {version}, {username}, {object_type}, {key_object}, {change_op}, {payload} );", version=self.version, username=self.username, object_type=_object_type, key_object=_key_object, change_op=_change_op, payload=_payload) # data conversion to and from url def get_record(self, arg): """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record(arg) def get_record_list(self, array): """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record_list(array) def get_list_groups(self, tbl, field): """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ return get_record_lists(tbl, field) def make_record(self, row): """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ return make_record(row) def make_record_array(self, rowlist): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return make_record_array(rowlist) # tapi based dml functions def _changelog(self, fields): log = True if fields: if '_log' in fields: if not fields.pop('_log'): log = False if '_log_id' in fields: fields.pop('_log_id') if '_log_field' in fields: fields.pop('_log_field') return log def tapi_do(self, tablename, row, **fields): """ Convenience function for just doing the change without creating tapi object first Fields object may contain aditional overriding values that are applied before do """ tapi = TableAPI(self, tablename, self._changelog(fields)) row = row or dbdict() if fields: row.update(fields) return tapi.do(row) def tapi_do_set(self, tablename, rows, **fields): """ Does changes to list of detail rows Used for normal foreign keys in master detail relationships Dows first deletes then updates and then inserts to avoid uniqueness problems """ tapi = TableAPI(self, tablename, self._changelog(fields)) results, updates, inserts = [], [], [] for row in rows: if fields: row.update(fields) if row[self.OP] == self.DELETE: results.append(tapi.do(row)) elif row[self.OP] == self.UPDATE: updates.append(row) else: inserts.append(row) for row in updates: results.append(tapi.do(row)) for row in inserts: results.append(tapi.do(row)) return results # resultset handling def retval_dbservice(self, service_name, ctx, **params): """ Runs service with standard interface. Convenient to use for calling select services from other services For example to return data after doing save """ self.raise_if_errors() service_sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) service_params = {"i_context": ctx, "i_params": self.make_record(params)} results = self.run_query(service_sql, service_params) retval = self.retval() for r in results: retval.append((r.res_code, r.res_text, r.res_rows)) return retval # miscellaneous def field_copy(self, rec, *keys): """ Used to copy subset of fields from one record into another example: dbs.copy(record, hosting) "start_date", "key_colo", "key_rack") """ retval = dbdict() for key in keys: if key in rec: retval[key] = rec[key] return retval def field_set(self, **fields): """ Fills dict with given values and returns resulting dict If dict was not provied with call it is created """ return fields python-skytools-3.4/skytools/dbstruct.py000066400000000000000000000570611356323561300207010ustar00rootroot00000000000000"""Find table structure and allow CREATE/DROP elements from it. """ from __future__ import division, absolute_import, print_function import re import skytools from skytools import quote_ident, quote_fqident __all__ = ['TableStruct', 'SeqStruct', 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER', 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL', 'T_SEQUENCE', 'T_PARENT', 'T_DEFAULT'] T_TABLE = 1 << 0 T_CONSTRAINT = 1 << 1 T_INDEX = 1 << 2 T_TRIGGER = 1 << 3 T_RULE = 1 << 4 T_GRANT = 1 << 5 T_OWNER = 1 << 6 T_SEQUENCE = 1 << 7 T_PARENT = 1 << 8 T_DEFAULT = 1 << 9 T_PKEY = 1 << 20 # special, one of constraints T_ALL = (T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE | T_TRIGGER | T_RULE | T_GRANT | T_OWNER | T_DEFAULT) # # Utility functions # def find_new_name(curs, name): """Create new object name for case the old exists. Needed when creating a new table besides old one. """ # cut off previous numbers m = re.search('_[0-9]+$', name) if m: name = name[:m.start()] # now loop for i in range(1, 1000): tname = "%s_%d" % (name, i) q = "select count(1) from pg_class where relname = %s" curs.execute(q, [tname]) if curs.fetchone()[0] == 0: return tname # failed raise Exception('find_new_name failed') def rx_replace(rx, sql, new_part): """Find a regex match and replace that part with new_part.""" m = re.search(rx, sql, re.I) if not m: raise Exception('rx_replace failed: rx=%r sql=%r new=%r' % (rx, sql, new_part)) p1 = sql[:m.start()] p2 = sql[m.end():] return p1 + new_part + p2 # # Schema objects # class TElem(object): """Keeps info about one metadata object.""" SQL = "" type = 0 def get_create_sql(self, curs, new_name=None): """Return SQL statement for creating or None if not supported.""" return None def get_drop_sql(self, curs): """Return SQL statement for dropping or None of not supported.""" return None @classmethod def get_load_sql(cls, pgver): """Return SQL statement for finding objects.""" return cls.SQL class TConstraint(TElem): """Info about constraint.""" type = T_CONSTRAINT SQL = """ SELECT c.conname as name, pg_get_constraintdef(c.oid) as def, c.contype, i.indisclustered as is_clustered FROM pg_constraint c LEFT JOIN pg_index i ON c.conrelid = i.indrelid AND c.conname = (SELECT r.relname FROM pg_class r WHERE r.oid = i.indexrelid) WHERE c.conrelid = %(oid)s AND c.contype != 'f' """ def __init__(self, table_name, row): """Init constraint.""" self.table_name = table_name self.name = row['name'] self.defn = row['def'] self.contype = row['contype'] self.is_clustered = row['is_clustered'] # tag pkeys if self.contype == 'p': self.type += T_PKEY def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" # no ONLY here as table with childs (only case that matters) # cannot have contraints that childs do not have fmt = "ALTER TABLE %s ADD CONSTRAINT %s\n %s;" if new_table_name: name = self.name if self.contype in ('p', 'u'): name = find_new_name(curs, self.name) qtbl = quote_fqident(new_table_name) qname = quote_ident(name) else: qtbl = quote_fqident(self.table_name) qname = quote_ident(self.name) sql = fmt % (qtbl, qname, self.defn) if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % (qtbl, qname) return sql def get_drop_sql(self, curs): """Generate removal sql.""" fmt = "ALTER TABLE ONLY %s\n DROP CONSTRAINT %s;" sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name)) return sql class TIndex(TElem): """Info about index.""" type = T_INDEX SQL = """ SELECT n.nspname || '.' || c.relname as name, pg_get_indexdef(i.indexrelid) as defn, c.relname as local_name, i.indisclustered as is_clustered FROM pg_index i, pg_class c, pg_namespace n WHERE c.oid = i.indexrelid AND i.indrelid = %(oid)s AND n.oid = c.relnamespace AND NOT EXISTS (select objid from pg_depend where classid = %(pg_class_oid)s and objid = c.oid and deptype = 'i') """ def __init__(self, table_name, row): self.name = row['name'] self.defn = row['defn'].replace(' USING ', '\n USING ', 1) + ';' self.is_clustered = row['is_clustered'] self.table_name = table_name self.local_name = row['local_name'] def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if new_table_name: # fixme: seems broken iname = find_new_name(curs, self.name) tname = new_table_name pnew = "INDEX %s ON %s " % (quote_ident(iname), quote_fqident(tname)) rx = r"\bINDEX[ ][a-z0-9._]+[ ]ON[ ][a-z0-9._]+[ ]" sql = rx_replace(rx, self.defn, pnew) else: sql = self.defn iname = self.local_name tname = self.table_name if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % ( quote_fqident(tname), quote_ident(iname)) return sql def get_drop_sql(self, curs): return 'DROP INDEX %s;' % quote_fqident(self.name) class TRule(TElem): """Info about rule.""" type = T_RULE SQL = """SELECT rw.*, pg_get_ruledef(rw.oid) as def FROM pg_rewrite rw WHERE rw.ev_class = %(oid)s AND rw.rulename <> '_RETURN'::name """ def __init__(self, table_name, row, new_name=None): self.table_name = table_name self.name = row['rulename'] self.defn = row['def'] self.enabled = row.get('ev_enabled', 'O') def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if not new_table_name: sql = self.defn table = self.table_name else: idrx = r'''([a-z0-9._]+|"([^"]+|"")+")+''' # fixme: broken / quoting rx = r"\bTO[ ]" + idrx rc = re.compile(rx, re.X) m = rc.search(self.defn) if not m: raise Exception('Cannot find table name in rule') old_tbl = m.group(1) new_tbl = quote_fqident(new_table_name) sql = self.defn.replace(old_tbl, new_tbl) table = new_table_name if self.enabled != 'O': # O - rule fires in origin and local modes # D - rule is disabled # R - rule fires in replica mode # A - rule fires always action = {'R': 'ENABLE REPLICA', 'A': 'ENABLE ALWAYS', 'D': 'DISABLE'}[self.enabled] sql += ('\nALTER TABLE %s %s RULE %s;' % (table, action, self.name)) return sql def get_drop_sql(self, curs): return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) class TTrigger(TElem): """Info about trigger.""" type = T_TRIGGER def __init__(self, table_name, row): self.table_name = table_name self.name = row['name'] self.defn = row['def'] + ';' self.defn = self.defn.replace('FOR EACH', '\n FOR EACH', 1) def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if not new_table_name: return self.defn # fixme: broken / quoting rx = r"\bON[ ][a-z0-9._]+[ ]" pnew = "ON %s " % new_table_name return rx_replace(rx, self.defn, pnew) def get_drop_sql(self, curs): return 'DROP TRIGGER %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) @classmethod def get_load_sql(cls, pg_vers): """Return SQL statement for finding objects.""" sql = "SELECT tgname as name, pg_get_triggerdef(oid) as def "\ " FROM pg_trigger "\ " WHERE tgrelid = %(oid)s AND " if pg_vers >= 90000: sql += "NOT tgisinternal" else: sql += "NOT tgisconstraint" return sql class TParent(TElem): """Info about trigger.""" type = T_PARENT SQL = """ SELECT n.nspname||'.'||c.relname AS name FROM pg_inherits i JOIN pg_class c ON i.inhparent = c.oid JOIN pg_namespace n ON c.relnamespace = n.oid WHERE i.inhrelid = %(oid)s """ def __init__(self, table_name, row): self.name = table_name self.parent_name = row['name'] def get_create_sql(self, curs, new_table_name=None): return 'ALTER TABLE ONLY %s\n INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) def get_drop_sql(self, curs): return 'ALTER TABLE ONLY %s\n NO INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) class TOwner(TElem): """Info about table owner.""" type = T_OWNER SQL = """ SELECT pg_get_userbyid(relowner) as owner FROM pg_class WHERE oid = %(oid)s """ def __init__(self, table_name, row, new_name=None): self.table_name = table_name self.name = 'Owner' self.owner = row['owner'] def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.table_name return 'ALTER TABLE %s\n OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner)) class TGrant(TElem): """Info about permissions.""" type = T_GRANT SQL = "SELECT relacl FROM pg_class where oid = %(oid)s" # Sync with: src/include/utils/acl.h acl_map = { 'a': 'INSERT', 'r': 'SELECT', 'w': 'UPDATE', 'd': 'DELETE', 'D': 'TRUNCATE', 'x': 'REFERENCES', 't': 'TRIGGER', 'X': 'EXECUTE', 'U': 'USAGE', 'C': 'CREATE', 'T': 'TEMPORARY', 'c': 'CONNECT', # old 'R': 'RULE', } def acl_to_grants(self, acl): if acl == "arwdRxt": # ALL for tables return "ALL" i = 0 lst1 = [] lst2 = [] while i < len(acl): a = self.acl_map[acl[i]] if i+1 < len(acl) and acl[i+1] == '*': lst2.append(a) i += 2 else: lst1.append(a) i += 1 return ", ".join(lst1), ", ".join(lst2) def parse_relacl(self, relacl): """Parse ACL to tuple of (user, acl, who)""" if relacl is None: return [] tup_list = [] for sacl in skytools.parse_pgarray(relacl): acl = skytools.parse_acl(sacl) if not acl: continue tup_list.append(acl) return tup_list def __init__(self, table_name, row, new_name=None): self.name = table_name self.acl_list = self.parse_relacl(row['relacl']) def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.name qtarget = quote_fqident(new_name) sql_list = [] for role, acl, ___who in self.acl_list: qrole = quote_ident(role) astr1, astr2 = self.acl_to_grants(acl) if astr1: sql = "GRANT %s ON %s\n TO %s;" % (astr1, qtarget, qrole) sql_list.append(sql) if astr2: sql = "GRANT %s ON %s\n TO %s WITH GRANT OPTION;" % (astr2, qtarget, qrole) sql_list.append(sql) return "\n".join(sql_list) def get_drop_sql(self, curs): sql_list = [] for user, ___acl, ___who in self.acl_list: sql = "REVOKE ALL FROM %s ON %s;" % (quote_ident(user), quote_fqident(self.name)) sql_list.append(sql) return "\n".join(sql_list) class TColumnDefault(TElem): """Info about table column default value.""" type = T_DEFAULT SQL = """ select a.attname as name, pg_get_expr(d.adbin, d.adrelid) as expr from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.atthasdef and a.attnum > 0 order by a.attnum; """ def __init__(self, table_name, row): self.table_name = table_name self.name = row['name'] self.expr = row['expr'] def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" tbl = new_name or self.table_name sql = "ALTER TABLE ONLY %s ALTER COLUMN %s\n SET DEFAULT %s;" % ( quote_fqident(tbl), quote_ident(self.name), self.expr) return sql def get_drop_sql(self, curs): return "ALTER TABLE %s ALTER COLUMN %s\n DROP DEFAULT;" % ( quote_fqident(self.table_name), quote_ident(self.name)) class TColumn(TElem): """Info about table column.""" SQL = """ select a.attname as name, quote_ident(a.attname) as qname, format_type(a.atttypid, a.atttypmod) as dtype, a.attnotnull, (select max(char_length(aa.attname)) from pg_attribute aa where aa.attrelid = %(oid)s) as maxcol, pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.attnum > 0 order by a.attnum; """ seqname = None def __init__(self, table_name, row): self.name = row['name'] fname = row['qname'].ljust(row['maxcol'] + 3) self.column_def = fname + ' ' + row['dtype'] if row['attnotnull']: self.column_def += ' not null' self.sequence = None if row['seqname']: self.seqname = skytools.unquote_fqident(row['seqname']) class TGPDistKey(TElem): """Info about GreenPlum table distribution keys""" SQL = """ select a.attname as name from pg_attribute a, gp_distribution_policy p where p.localoid = %(oid)s and a.attrelid = %(oid)s and a.attnum = any(p.attrnums) order by a.attnum; """ def __init__(self, table_name, row): self.name = row['name'] class TTable(TElem): """Info about table only (columns).""" type = T_TABLE def __init__(self, table_name, col_list, dist_key_list=None): self.name = table_name self.col_list = col_list self.dist_key_list = dist_key_list def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.name sql = "CREATE TABLE %s (" % quote_fqident(new_name) sep = "\n " for c in self.col_list: sql += sep + c.column_def sep = ",\n " sql += "\n)" if self.dist_key_list is not None: if self.dist_key_list != []: sql += "\ndistributed by(%s)" % ','.join(c.name for c in self.dist_key_list) else: sql += '\ndistributed randomly' sql += ";" return sql def get_drop_sql(self, curs): return "DROP TABLE %s;" % quote_fqident(self.name) class TSeq(TElem): """Info about sequence.""" type = T_SEQUENCE SQL_PG10 = """ SELECT %(fq2name)s::name AS sequence_name, s.last_value, p.seqstart AS start_value, p.seqincrement AS increment_by, p.seqmax AS max_value, p.seqmin AS min_value, p.seqcache AS cache_value, s.log_cnt, s.is_called, p.seqcycle AS is_cycled, %(owner)s as owner FROM pg_catalog.pg_sequence p, %(fqname)s s WHERE p.seqrelid = %(fq2name)s::regclass::oid """ SQL_PG9 = """ SELECT %(fq2name)s AS sequence_name, last_value, start_value, increment_by, max_value, min_value, cache_value, log_cnt, is_called, is_cycled, %(owner)s AS "owner" FROM %(fqname)s """ @classmethod def get_load_sql(cls, pg_vers): """Return SQL statement for finding objects.""" if pg_vers < 100000: return cls.SQL_PG9 return cls.SQL_PG10 def __init__(self, seq_name, row): self.name = seq_name defn = '' self.owner = row['owner'] if row.get('increment_by', 1) != 1: defn += ' INCREMENT BY %d' % row['increment_by'] if row.get('min_value', 1) != 1: defn += ' MINVALUE %d' % row['min_value'] if row.get('max_value', 9223372036854775807) != 9223372036854775807: defn += ' MAXVALUE %d' % row['max_value'] last_value = row['last_value'] if row['is_called']: last_value += row.get('increment_by', 1) if last_value >= row.get('max_value', 9223372036854775807): raise Exception('duh, seq passed max_value') if last_value != 1: defn += ' START %d' % last_value if row.get('cache_value', 1) != 1: defn += ' CACHE %d' % row['cache_value'] if row.get('is_cycled'): defn += ' CYCLE ' if self.owner: defn += ' OWNED BY %s' % self.owner self.defn = defn def get_create_sql(self, curs, new_seq_name=None): """Generate creation SQL.""" # we are in table def, forget full def if self.owner: sql = "ALTER SEQUENCE %s\n OWNED BY %s;" % ( quote_fqident(self.name), self.owner) return sql name = self.name if new_seq_name: name = new_seq_name sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn) return sql def get_drop_sql(self, curs): if self.owner: return '' return 'DROP SEQUENCE %s;' % quote_fqident(self.name) # # Main table object, loads all the others # class BaseStruct(object): """Collects and manages all info about a higher-level db object. Allow to issue CREATE/DROP statements about any group of elements. """ object_list = [] def __init__(self, curs, name): """Initializes class by loading info about table_name from database.""" self.name = name self.fqname = quote_fqident(name) def _load_elem(self, curs, name, args, eclass): """Fetch element(s) from db.""" elem_list = [] #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args)) sql = eclass.get_load_sql(curs.connection.server_version) curs.execute(sql % args) for row in curs.fetchall(): elem_list.append(eclass(name, row)) return elem_list def create(self, curs, objs, new_table_name=None, log=None): """Issues CREATE statements for requested set of objects. If new_table_name is giver, creates table under that name and also tries to rename all indexes/constraints that conflict with existing table. """ for o in self.object_list: if o.type & objs: sql = o.get_create_sql(curs, new_table_name) if not sql: continue if log: log.info('Creating %s' % o.name) log.debug(sql) curs.execute(sql) def drop(self, curs, objs, log=None): """Issues DROP statements for requested set of objects.""" # make sure the creating & dropping happen in reverse order olist = self.object_list[:] olist.reverse() for o in olist: if o.type & objs: sql = o.get_drop_sql(curs) if not sql: continue if log: log.info('Dropping %s' % o.name) log.debug(sql) curs.execute(sql) def get_create_sql(self, objs): res = [] for o in self.object_list: if o.type & objs: sql = o.get_create_sql(None, None) if sql: res.append(sql) return "".join(res) class TableStruct(BaseStruct): """Collects and manages all info about table. Allow to issue CREATE/DROP statements about any group of elements. """ def __init__(self, curs, table_name): """Initializes class by loading info about table_name from database.""" super(TableStruct, self).__init__(curs, table_name) self.table_name = table_name # fill args schema, name = skytools.fq_name_parts(table_name) args = { 'schema': schema, 'table': name, 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'oid': skytools.get_table_oid(curs, table_name), 'pg_class_oid': skytools.get_table_oid(curs, 'pg_catalog.pg_class'), } # load table struct self.col_list = self._load_elem(curs, self.name, args, TColumn) # if db is GP then read also table distribution keys if skytools.exists_table(curs, "pg_catalog.gp_distribution_policy"): self.dist_key_list = self._load_elem(curs, self.name, args, TGPDistKey) else: self.dist_key_list = None self.object_list = [TTable(table_name, self.col_list, self.dist_key_list)] self.seq_list = [] # load seqs for col in self.col_list: if col.seqname: fqname = quote_fqident(col.seqname) owner = self.fqname + '.' + quote_ident(col.name) seq_args = { 'fqname': fqname, 'fq2name': skytools.quote_literal(fqname), 'owner': skytools.quote_literal(owner), } self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq) self.object_list += self.seq_list # load additional objects to_load = [TColumnDefault, TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner, TParent] for eclass in to_load: self.object_list += self._load_elem(curs, self.name, args, eclass) def get_column_list(self): """Returns list of column names the table has.""" res = [] for c in self.col_list: res.append(c.name) return res class SeqStruct(BaseStruct): """Collects and manages all info about sequence. Allow to issue CREATE/DROP statements about any group of elements. """ def __init__(self, curs, seq_name): """Initializes class by loading info about table_name from database.""" super(SeqStruct, self).__init__(curs, seq_name) # fill args args = { 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'owner': 'null', } # load table struct self.object_list = self._load_elem(curs, seq_name, args, TSeq) def manual_check(): from skytools import connect_database db = connect_database("dbname=fooz") curs = db.cursor() s = TableStruct(curs, "public.data1") s.drop(curs, T_ALL) s.create(curs, T_ALL) s.create(curs, T_ALL, "data1_new") s.create(curs, T_PKEY) if __name__ == '__main__': manual_check() python-skytools-3.4/skytools/fileutil.py000066400000000000000000000103111356323561300206470ustar00rootroot00000000000000"""File utilities >>> import tempfile, os >>> pidfn = tempfile.mktemp('.pid') >>> write_atomic(pidfn, "1") >>> write_atomic(pidfn, "2") >>> os.remove(pidfn) >>> write_atomic(pidfn, "1", '.bak') >>> write_atomic(pidfn, "2", '.bak') >>> os.remove(pidfn) """ from __future__ import division, absolute_import, print_function import sys import os import errno __all__ = ['write_atomic', 'signal_pidfile'] try: unicode except NameError: unicode = str # noqa # non-win32 def write_atomic_unix(fn, data, bakext=None, mode='b'): """Write file with rename.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") if mode == 'b' and isinstance(data, unicode): data = data.encode('utf8') # write new data to tmp file fn2 = fn + '.new' f = open(fn2, 'w' + mode) f.write(data) f.close() # link old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.unlink(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.link(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise # win32 does not like replace if sys.platform == 'win32': try: os.remove(fn) except: pass # atomically replace file os.rename(fn2, fn) def signal_pidfile(pidfile, sig): """Send a signal to process whose ID is located in pidfile. Read only first line of pidfile to support multiline pidfiles like postmaster.pid. Returns True is successful, False if pidfile does not exist or process itself is dead. Any other errors will passed as exceptions.""" ln = '' try: f = open(pidfile, 'r') ln = f.readline().strip() f.close() pid = int(ln) if sig == 0 and sys.platform == 'win32': return win32_detect_pid(pid) os.kill(pid, sig) return True except (IOError, OSError) as ex: if ex.errno not in (errno.ESRCH, errno.ENOENT): raise except ValueError as ex: # this leaves slight race when someone is just creating the file, # but more common case is old empty file. if not ln: return False raise ValueError('Corrupt pidfile: %s' % pidfile) return False def win32_detect_pid(pid): """Process detection for win32.""" # avoid pywin32 dependecy, use ctypes instead import ctypes # win32 constants PROCESS_QUERY_INFORMATION = 1024 STILL_ACTIVE = 259 ERROR_INVALID_PARAMETER = 87 ERROR_ACCESS_DENIED = 5 # Load kernel32.dll k = ctypes.windll.kernel32 OpenProcess = k.OpenProcess OpenProcess.restype = ctypes.c_void_p # query pid exit code h = OpenProcess(PROCESS_QUERY_INFORMATION, 0, pid) if h is None: err = k.GetLastError() if err == ERROR_INVALID_PARAMETER: return False if err == ERROR_ACCESS_DENIED: return True raise OSError(errno.EFAULT, "Unknown win32error: " + str(err)) code = ctypes.c_int() k.GetExitCodeProcess(h, ctypes.byref(code)) k.CloseHandle(h) return code.value == STILL_ACTIVE def win32_write_atomic(fn, data, bakext=None, mode='b'): """Write file with rename for win32.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") # write new data to tmp file fn2 = fn + '.new' f = open(fn2, 'w' + mode) f.write(data) f.close() # move old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.remove(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.rename(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise else: try: os.remove(fn) except: pass # replace file os.rename(fn2, fn) if sys.platform == 'win32': write_atomic = win32_write_atomic else: write_atomic = write_atomic_unix python-skytools-3.4/skytools/gzlog.py000066400000000000000000000014771356323561300201710ustar00rootroot00000000000000 """Atomic append of gzipped data. The point is - if several gzip streams are concatenated, they are read back as one whole stream. """ from __future__ import division, absolute_import, print_function import gzip from io import BytesIO __all__ = ['gzip_append'] # # gzip storage # def gzip_append(filename, data, level=6): """Append a block of data to file with safety checks.""" # compress data buf = BytesIO() g = gzip.GzipFile(fileobj=buf, compresslevel=level, mode="w") g.write(data) g.close() zdata = buf.getvalue() # append, safely f = open(filename, "ab+", 0) f.seek(0, 2) pos = f.tell() try: f.write(zdata) f.close() except Exception as ex: # rollback on error f.seek(pos, 0) f.truncate() f.close() raise ex python-skytools-3.4/skytools/hashtext.py000066400000000000000000000103131356323561300206640ustar00rootroot00000000000000""" Implementation of Postgres hashing function. hashtext_old() - used up to PostgreSQL 8.3 hashtext_new() - used since PostgreSQL 8.4 >>> import skytools._chashtext >>> for i in range(3): ... print([hashtext_new_py(b'x' * (i*5 + j)) for j in range(5)]) [-1477818771, 1074944137, -1086392228, -1992236649, -1379736791] [-370454118, 1489915569, -66683019, -2126973000, 1651296771] [755764456, -1494243903, 631527812, 28686851, -9498641] >>> for i in range(3): ... print([hashtext_old_py(b'x' * (i*5 + j)) for j in range(5)]) [-863449762, 37835117, 294739542, -320432768, 1007638138] [1422906842, -261065348, 59863994, -162804943, 1736144510] [-682756517, 317827663, -495599455, -1411793989, 1739997714] >>> data = b'HypficUjFitraxlumCitcemkiOkIkthi' >>> p = [hashtext_old_py(data[:l]) for l in range(len(data)+1)] >>> c = [hashtext_old(data[:l]) for l in range(len(data)+1)] >>> assert p == c, '%s <> %s' % (p, c) >>> p == c True >>> p = [hashtext_new_py(data[:l]) for l in range(len(data)+1)] >>> c = [hashtext_new(data[:l]) for l in range(len(data)+1)] >>> assert p == c, '%s <> %s' % (p, c) >>> p == c True """ from __future__ import division, absolute_import, print_function import sys import struct try: from skytools._chashtext import hashtext_old, hashtext_new except ImportError: def hashtext_old(v): return hashtext_old_py(v) def hashtext_new(v): return hashtext_new_py(v) __all__ = ["hashtext_old", "hashtext_new"] # pad for last partial block PADDING = b'\0' * 12 def uint32(x): """python does not have 32 bit integer so we need this hack to produce uint32 after bit operations""" return x & 0xffffffff # # Old Postgres hashtext() - lookup2 with custom initval # FMT_OLD = struct.Struct(">13)) b -= c; b -= a; b = uint32(b ^ (a<<8)) c -= a; c -= b; c = uint32(c ^ (b>>13)) a -= b; a -= c; a = uint32(a ^ (c>>12)) b -= c; b -= a; b = uint32(b ^ (a<<16)) c -= a; c -= b; c = uint32(c ^ (b>>5)) a -= b; a -= c; a = uint32(a ^ (c>>3)) b -= c; b -= a; b = uint32(b ^ (a<<10)) c -= a; c -= b; c = uint32(c ^ (b>>15)) return a, b, c def hashtext_old_py(k): """Old Postgres hashtext()""" remain = len(k) pos = 0 a = b = 0x9e3779b9 c = 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_OLD.unpack_from(k, pos) a, b, c = mix_old(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_OLD.unpack_from(k[pos:] + PADDING, 0) # the lowest byte of c is reserved for the length c2 = (c2 << 8) + len(k) a, b, c = mix_old(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) # # New Postgres hashtext() - hacked lookup3: # - custom initval # - calls mix() when len=12 # - shifted c in last block on little-endian # FMT_NEW = struct.Struct("=LLL") def rol32(x, k): return ((x)<<(k)) | (uint32(x)>>(32-(k))) def mix_new(a, b, c): a -= c; a ^= rol32(c, 4); c += b b -= a; b ^= rol32(a, 6); a += c c -= b; c ^= rol32(b, 8); b += a a -= c; a ^= rol32(c, 16); c += b b -= a; b ^= rol32(a, 19); a += c c -= b; c ^= rol32(b, 4); b += a return uint32(a), uint32(b), uint32(c) def final_new(a, b, c): c ^= b; c -= rol32(b, 14) a ^= c; a -= rol32(c, 11) b ^= a; b -= rol32(a, 25) c ^= b; c -= rol32(b, 16) a ^= c; a -= rol32(c, 4) b ^= a; b -= rol32(a, 14) c ^= b; c -= rol32(b, 24) return uint32(a), uint32(b), uint32(c) def hashtext_new_py(k): """New Postgres hashtext()""" remain = len(k) pos = 0 a = b = c = 0x9e3779b9 + len(k) + 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_NEW.unpack_from(k, pos) a, b, c = mix_new(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_NEW.unpack_from(k[pos:] + PADDING, 0) if sys.byteorder == 'little': c2 = c2 << 8 a, b, c = final_new(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) python-skytools-3.4/skytools/installer_config.py000066400000000000000000000002211356323561300223530ustar00rootroot00000000000000 """SQL script locations.""" __all__ = ['sql_locations'] sql_locations = [ "/usr/share/skytools3", ] package_version = "3.4" skylog = 0 python-skytools-3.4/skytools/natsort.py000066400000000000000000000027041356323561300205330ustar00rootroot00000000000000"""Natural sort. Compares numeric parts numerically. """ from __future__ import division, absolute_import, print_function # Based on idea at http://code.activestate.com/recipes/285264/ # Works with both Python 2.x and 3.x # Ignores leading zeroes: 001 and 01 are considered equal import re as _re _rc = _re.compile(r'\d+|\D+') __all__ = ['natsort_key', 'natsort', 'natsorted', 'natsort_key_icase', 'natsort_icase', 'natsorted_icase'] def natsort_key(s): """Split string to numeric and non-numeric fragments.""" return [not f[0].isdigit() and f or int(f, 10) for f in _rc.findall(s)] def natsort(lst): """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key) def natsorted(lst): """Return copy of list, sorted in natural order, case-sensitive. >>> natsorted(['ver-1.1', 'ver-1.11', '', 'ver-1.0']) ['', 'ver-1.0', 'ver-1.1', 'ver-1.11'] """ lst = lst[:] natsort(lst) return lst # case-insensitive api def natsort_key_icase(s): """Split string to numeric and non-numeric fragments.""" return natsort_key(s.lower()) def natsort_icase(lst): """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key_icase) def natsorted_icase(lst): """Return copy of list, sorted in natural order, case-sensitive. >>> natsorted_icase(['Ver-1.1', 'vEr-1.11', '', 'veR-1.0']) ['', 'veR-1.0', 'Ver-1.1', 'vEr-1.11'] """ lst = lst[:] natsort_icase(lst) return lst python-skytools-3.4/skytools/parsing.py000066400000000000000000000436741356323561300205170ustar00rootroot00000000000000 """Various parsers for Postgres-specific data formats.""" from __future__ import division, absolute_import, print_function import re import skytools __all__ = [ "parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements", 'sql_tokenizer', 'parse_sqltriga_sql', "parse_acl", "dedent", "hsize_to_bytes", "parse_connect_string", "merge_connect_string"] _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) def parse_pgarray(array): r"""Parse Postgres array and return list of items inside it. Examples: >>> parse_pgarray('{}') [] >>> parse_pgarray('{a,b,null,"null"}') ['a', 'b', None, 'null'] >>> parse_pgarray(r'{"a,a","b\"b","c\\c"}') ['a,a', 'b"b', 'c\\c'] >>> parse_pgarray("[0,3]={1,2,3}") ['1', '2', '3'] >>> parse_pgarray(None) is None True >>> from nose.tools import * >>> assert_raises(ValueError, parse_pgarray, '}{') >>> assert_raises(ValueError, parse_pgarray, '[1]=}') >>> assert_raises(ValueError, parse_pgarray, '{"..." , }') """ if array is None: return None if not array or array[0] not in ("{", "[") or array[-1] != '}': raise ValueError("bad array format: must be surrounded with {}") res = [] pos = 1 # skip optional dimensions descriptor "[a,b]={...}" if array[0] == "[": pos = array.find('{') + 1 if pos < 1: raise ValueError("bad array format 2: must be surrounded with {}") while 1: m = _rc_listelem.search(array, pos) if not m: break pos2 = m.end() item = array[pos:pos2] if len(item) == 4 and item.upper() == "NULL": val = None else: if len(item) > 0 and item[0] == '"': if len(item) == 1 or item[-1] != '"': raise ValueError("bad array format: broken '\"'") item = item[1:-1] val = skytools.unescape(item) res.append(val) pos = pos2 + 1 if array[pos2] == "}": break elif array[pos2] != ",": raise ValueError("bad array format: expected ,} got " + repr(array[pos2])) if pos < len(array) - 1: raise ValueError("bad array format: failed to parse completely (pos=%d len=%d)" % (pos, len(array))) return res # # parse logtriga partial sql # class _logtriga_parser(object): """Parses logtriga/sqltriga partial SQL to values.""" pklist = None def tokenizer(self, sql): """Token generator.""" for ___typ, tok in sql_tokenizer(sql, ignore_whitespace=True): yield tok def parse_insert(self, tk, fields, values, key_fields, key_values): """Handler for inserts.""" # (col1, col2) values ('data', null) if next(tk) != "(": raise Exception("syntax error") while 1: fields.append(next(tk)) t = next(tk) if t == ")": break elif t != ",": raise Exception("syntax error") if next(tk).lower() != "values": raise Exception("syntax error, expected VALUES") if next(tk) != "(": raise Exception("syntax error, expected (") while 1: values.append(next(tk)) t = next(tk) if t == ")": break if t == ",": continue raise Exception("expected , or ) got "+t) t = next(tk) raise Exception("expected EOF, got " + repr(t)) def parse_update(self, tk, fields, values, key_fields, key_values): """Handler for updates.""" # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' while 1: fields.append(next(tk)) if next(tk) != "=": raise Exception("syntax error") values.append(next(tk)) t = next(tk) if t == ",": continue elif t.lower() == "where": break else: raise Exception("syntax error, expected WHERE or , got "+repr(t)) while 1: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise Exception("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise Exception("syntax error, expected AND got "+repr(t)) def parse_delete(self, tk, fields, values, key_fields, key_values): """Handler for deletes.""" # pk1 = 'pk1' and pk2 = 'pk2' while 1: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise Exception("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise Exception("syntax error, expected AND, got "+repr(t)) def _create_dbdict(self, fields, values): fields = [skytools.unquote_ident(f) for f in fields] values = [skytools.unquote_literal(f) for f in values] return skytools.dbdict(zip(fields, values)) def parse_sql(self, op, sql, pklist=None, splitkeys=False): """Main entry point.""" if pklist is None: self.pklist = [] else: self.pklist = pklist tk = self.tokenizer(sql) fields = [] values = [] key_fields = [] key_values = [] try: if op == "I": self.parse_insert(tk, fields, values, key_fields, key_values) elif op == "U": self.parse_update(tk, fields, values, key_fields, key_values) elif op == "D": self.parse_delete(tk, fields, values, key_fields, key_values) raise Exception("syntax error") except StopIteration: # last sanity check if (len(fields) + len(key_fields) == 0 or len(fields) != len(values) or len(key_fields) != len(key_values)): raise Exception("syntax error, fields do not match values") if splitkeys: return (self._create_dbdict(key_fields, key_values), self._create_dbdict(fields, values)) return self._create_dbdict(fields + key_fields, values + key_values) def parse_logtriga_sql(op, sql, splitkeys=False): return parse_sqltriga_sql(op, sql, splitkeys=splitkeys) def parse_sqltriga_sql(op, sql, pklist=None, splitkeys=False): """Parse partial SQL used by pgq.sqltriga() back to data values. Parser has following limitations: - Expects standard_quoted_strings = off - Does not support dollar quoting. - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) - WHERE expression must not contain IS (NOT) NULL - Does not support updating pk value, unless you use the splitkeys parameter. Returns dict of col->data pairs. >>> from skytools.testing import ordered_dict Insert event: >>> row = parse_logtriga_sql('I', '(id, data) values (1, null)') >>> ordered_dict(row) OrderedDict([('data', None), ('id', '1')]) Update event: >>> row = parse_logtriga_sql('U', "data='foo' where id = 1") >>> ordered_dict(row) OrderedDict([('data', 'foo'), ('id', '1')]) Delete event: >>> row = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'") >>> ordered_dict(row) OrderedDict([('id', '1'), ('id2', "str'val")]) If you set the splitkeys parameter, it will return two dicts, one for key fields and one for data fields. Insert event: >>> keys, row = parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True) >>> keys, ordered_dict(row) ({}, OrderedDict([('data', None), ('id', '1')])) Update event: >>> parse_logtriga_sql('U', "data='foo' where id = 1", splitkeys=True) ({'id': '1'}, {'data': 'foo'}) Delete event: >>> keys, row = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'", splitkeys=True) >>> (ordered_dict(keys), row) (OrderedDict([('id', '1'), ('id2', "str'val")]), {}) """ return _logtriga_parser().parse_sql(op, sql, pklist, splitkeys=splitkeys) def parse_tabbed_table(txt): r"""Parse a tab-separated table into list of dicts. Expect first row to be column names. Very primitive. Test: >>> from skytools.testing import ordered_dict >>> [ordered_dict(d) for d in parse_tabbed_table('col1\tcol2\nval1\tval2\n')] [OrderedDict([('col1', 'val1'), ('col2', 'val2')])] """ txt = txt.replace("\r\n", "\n") fields = None data = [] for ln in txt.split("\n"): if not ln: continue if not fields: fields = ln.split("\t") continue cols = ln.split("\t") if len(cols) != len(fields): continue row = dict(zip(fields, cols)) data.append(row) return data _extstr = r""" ['] (?: [^'\\]+ | \\. | [']['] )* ['] """ _stdstr = r""" ['] (?: [^']+ | [']['] )* ['] """ _name = r""" (?: [a-z_][a-z0-9_$]* | " (?: [^"]+ | "" )* " ) """ _ident = r""" (?P %s ) """ % _name _fqident = r""" (?P %s (?: \. %s )* ) """ % (_name, _name) _base_sql = r""" (?P (?P [$] (?: [_a-z][_a-z0-9]*)? [$] ) .*? (?P=dname) ) | (?P [0-9][0-9.e]* ) | (?P [$] [0-9]+ ) | (?P [%][(] [a-z_][a-z0-9_]* [)] [s] ) | (?P [{] [^{}]+ [}] ) | (?P (?: \s+ | [/][*] .*? [*][/] | [-][-][^\n]* )+ ) | (?P (?: [-+*~!@#^&|?/%<>=]+ | [,()\[\].:;] ) ) | (?P . )""" _base_sql_fq = r"%s | %s" % (_fqident, _base_sql) _base_sql = r"%s | %s" % (_ident, _base_sql) _std_sql = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql) _std_sql_fq = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql_fq) _ext_sql = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql) _ext_sql_fq = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql_fq) _std_sql_rc = _ext_sql_rc = None _std_sql_fq_rc = _ext_sql_fq_rc = None def sql_tokenizer(sql, standard_quoting=False, ignore_whitespace=False, fqident=False, show_location=False): r"""Parser SQL to tokens. Iterator, returns (toktype, tokstr) tuples. Example >>> [x for x in sql_tokenizer("select * from a.b", ignore_whitespace=True)] [('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b')] >>> [x for x in sql_tokenizer("\"c olumn\",'str''val'")] [('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'")] >>> list(sql_tokenizer('a.b a."b "" c" a.1', fqident=True, ignore_whitespace=True)) [('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1')] >>> list(sql_tokenizer(r"set 'a''\' + E'\''", standard_quoting=True, ignore_whitespace=True)) [('ident', 'set'), ('str', "'a''\\'"), ('sym', '+'), ('str', "E'\\''")] >>> list(sql_tokenizer('a.b a."b "" c" a.1', fqident=True, standard_quoting=True, ignore_whitespace=True)) [('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1')] >>> list(sql_tokenizer('a.b\nc;', show_location=True, ignore_whitespace=True)) [('ident', 'a', 1), ('sym', '.', 2), ('ident', 'b', 3), ('ident', 'c', 5), ('sym', ';', 6)] """ global _std_sql_rc, _ext_sql_rc, _std_sql_fq_rc, _ext_sql_fq_rc if not _std_sql_rc: _std_sql_rc = re.compile(_std_sql, re.X | re.I | re.S) _ext_sql_rc = re.compile(_ext_sql, re.X | re.I | re.S) _std_sql_fq_rc = re.compile(_std_sql_fq, re.X | re.I | re.S) _ext_sql_fq_rc = re.compile(_ext_sql_fq, re.X | re.I | re.S) if standard_quoting: if fqident: rc = _std_sql_fq_rc else: rc = _std_sql_rc else: if fqident: rc = _ext_sql_fq_rc else: rc = _ext_sql_rc pos = 0 while 1: m = rc.match(sql, pos) if not m: break pos = m.end() typ = m.lastgroup if ignore_whitespace and typ == "ws": continue tk = m.group() if show_location: yield (typ, tk, pos) else: yield (typ, tk) _copy_from_stdin_re = r"copy.*from\s+stdin" _copy_from_stdin_rc = None def parse_statements(sql, standard_quoting=False): """Parse multi-statement string into separate statements. Returns list of statements. >>> [sql for sql in parse_statements("begin; select 1; select 'foo'; end;")] ['begin;', 'select 1;', "select 'foo';", 'end;'] >>> [sql for sql in parse_statements("select (select 2+(select 3;);) ; select 4;")] ['select (select 2+(select 3;);) ;', 'select 4;'] >>> [sql for sql in parse_statements('select ());')] Traceback (most recent call last): ... ValueError: syntax error - unbalanced parenthesis >>> [sql for sql in parse_statements('copy from stdin;')] Traceback (most recent call last): ... ValueError: copy from stdin not supported """ global _copy_from_stdin_rc if not _copy_from_stdin_rc: _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I) tokens = [] pcount = 0 # '(' level for typ, t in sql_tokenizer(sql, standard_quoting=standard_quoting): # skip whitespace and comments before statement if len(tokens) == 0 and typ == "ws": continue # keep the rest tokens.append(t) if t == "(": pcount += 1 elif t == ")": pcount -= 1 elif t == ";" and pcount == 0: sql = "".join(tokens) if _copy_from_stdin_rc.match(sql): raise ValueError("copy from stdin not supported") yield "".join(tokens) tokens = [] if len(tokens) > 0: yield "".join(tokens) if pcount != 0: raise ValueError("syntax error - unbalanced parenthesis") _acl_name = r'(?: [0-9a-z_]+ | " (?: [^"]+ | "" )* " )' _acl_re = r''' \s* (?: group \s+ | user \s+ )? (?P %s )? (?P = [a-z*]* )? (?P / %s )? \s* $ ''' % (_acl_name, _acl_name) _acl_rc = None def parse_acl(acl): """Parse ACL entry. >>> parse_acl('user=rwx/owner') ('user', 'rwx', 'owner') >>> parse_acl('" ""user"=rwx/" ""owner"') (' "user', 'rwx', ' "owner') >>> parse_acl('user=rwx') ('user', 'rwx', None) >>> parse_acl('=/f') (None, '', 'f') On error (is this ok?): >>> parse_acl('?') is None True """ global _acl_rc if not _acl_rc: _acl_rc = re.compile(_acl_re, re.I | re.X) m = _acl_rc.match(acl) if not m: return None target = m.group('tgt') perm = m.group('perm') owner = m.group('owner') if target: target = skytools.unquote_ident(target) if perm: perm = perm[1:] if owner: owner = skytools.unquote_ident(owner[1:]) return (target, perm, owner) def dedent(doc): r"""Relaxed dedent. - takes whitespace to be removed from first indented line. - allows empty or non-indented lines at the start - allows first line to be unindented - skips empty lines at the start - ignores indent of empty lines - if line does not match common indent, is stays unchanged >>> dedent(' Line1:\n Line 2\n') 'Line1:\n Line 2\n' >>> dedent(' \nLine1:\n Line 2\n Line 3\n Line 4') 'Line1:\nLine 2\n Line 3\n Line 4\n' """ pfx = None res = [] for ln in doc.splitlines(): ln = ln.rstrip() if not pfx and len(res) < 2: if not ln: continue wslen = len(ln) - len(ln.lstrip()) pfx = ln[ : wslen] if pfx: if ln.startswith(pfx): ln = ln[len(pfx):] res.append(ln) res.append('') return '\n'.join(res) def hsize_to_bytes(input_str): """ Convert sizes from human format to bytes (string to integer) >>> hsize_to_bytes('10G'), hsize_to_bytes('12k') (10737418240, 12288) """ m = re.match(r"^([0-9]+) *([KMGTPEZY]?)B?$", input_str.strip(), re.IGNORECASE) if not m: raise ValueError("cannot parse: %s" % input_str) units = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] nbytes = int(m.group(1)) * 1024 ** units.index(m.group(2).upper()) return nbytes # # Connect string parsing # _cstr_rx = r""" \s* (\w+) \s* = \s* ( ' ( \\.| [^'\\] )* ' | \S+ ) \s* """ _cstr_unesc_rx = r"\\(.)" _cstr_badval_rx = r"[\s'\\]" _cstr_rc = None _cstr_unesc_rc = None _cstr_badval_rc = None def parse_connect_string(cstr): r"""Parse Postgres connect string. >>> parse_connect_string("host=foo") [('host', 'foo')] >>> parse_connect_string(r" host = foo password = ' f\\\o\'o ' ") [('host', 'foo'), ('password', "' f\\o'o '")] >>> parse_connect_string(r" host = ") Traceback (most recent call last): ... ValueError: Invalid connect string """ global _cstr_rc, _cstr_unesc_rc if not _cstr_rc: _cstr_rc = re.compile(_cstr_rx, re.X) _cstr_unesc_rc = re.compile(_cstr_unesc_rx) pos = 0 res = [] while pos < len(cstr): m = _cstr_rc.match(cstr, pos) if not m: raise ValueError('Invalid connect string') pos = m.end() k = m.group(1) v = m.group(2) if v[0] == "'": v = _cstr_unesc_rc.sub(r"\1", v) res.append((k, v)) return res def merge_connect_string(cstr_arg_list): """Put fragments back together. >>> merge_connect_string([('host', 'ip'), ('pass', ''), ('x', ' ')]) "host=ip pass='' x=' '" """ global _cstr_badval_rc if not _cstr_badval_rc: _cstr_badval_rc = re.compile(_cstr_badval_rx) buf = [] for k, v in cstr_arg_list: if not v or _cstr_badval_rc.search(v): v = v.replace('\\', r'\\') v = v.replace("'", r"\'") v = "'" + v + "'" buf.append("%s=%s" % (k, v)) return ' '.join(buf) python-skytools-3.4/skytools/plpy_applyrow.py000066400000000000000000000151101356323561300217550ustar00rootroot00000000000000 """ PLPY helper module for applying row events from pgq.logutriga(). """ from __future__ import division, absolute_import, print_function try: import plpy except ImportError: pass import skytools ## TODO: automatic fkey detection # find FK columns FK_SQL = """ SELECT (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = conrelid AND attnum = conkey[i])) FROM generate_series(1, array_upper(conkey, 1)) i) AS kcols, (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = confrelid AND attnum = confkey[i])) FROM generate_series(1, array_upper(confkey, 1)) i) AS fcols, confrelid::regclass::text AS ftable FROM pg_constraint WHERE conrelid = {tbl}::regclass AND contype='f' """ class DataError(Exception): "Invalid data" def colfilter_full(rnew, rold): return rnew def colfilter_changed(rnew, rold): res = {} for k, _ in rnew: if rnew[k] != rold[k]: res[k] = rnew[k] return res def canapply_dummy(rnew, rold): return True def canapply_tstamp_helper(rnew, rold, tscol): tnew = rnew[tscol] told = rold[tscol] if not tnew[0].isdigit(): raise DataError('invalid timestamp') if not told[0].isdigit(): raise DataError('invalid timestamp') return tnew > told def applyrow(tblname, ev_type, new_row, backup_row=None, alt_pkey_cols=None, fkey_cols=None, fkey_ref_table=None, fkey_ref_cols=None, fn_canapply=canapply_dummy, fn_colfilter=colfilter_full): """Core logic. Actual decisions will be done in callback functions. - [IUD]: If row referenced by fkey does not exist, event is not applied - If pkey does not exist but alt_pkey does, row is not applied. @param tblname: table name, schema-qualified @param ev_type: [IUD]:pkey1,pkey2 @param alt_pkey_cols: list of alternatice columns to consuder @param fkey_cols: columns in this table that refer to other table @param fkey_ref_table: other table referenced here @param fkey_ref_cols: column in other table that must match @param fn_canapply: callback function, gets new and old row, returns whether the row should be applied @param fn_colfilter: callback function, gets new and old row, returns dict of final columns to be applied """ gd = None # parse ev_type tmp = ev_type.split(':', 1) if len(tmp) != 2 or tmp[0] not in ('I', 'U', 'D'): raise DataError('Unsupported ev_type: '+repr(ev_type)) if not tmp[1]: raise DataError('No pkey in event') cmd = tmp[0] pkey_cols = tmp[1].split(',') qtblname = skytools.quote_fqident(tblname) # parse ev_data fields = skytools.db_urldecode(new_row) if ev_type.find('}') >= 0: raise DataError('Really suspicious activity') if ",".join(fields.keys()).find('}') >= 0: raise DataError('Really suspicious activity 2') # generate pkey expressions tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in pkey_cols] pkey_expr = " and ".join(tmp) alt_pkey_expr = None if alt_pkey_cols: tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in alt_pkey_cols] alt_pkey_expr = " and ".join(tmp) log = "data ok" # # Row data seems fine, now apply it # if fkey_ref_table: tmp = [] for k, rk in zip(fkey_cols, fkey_ref_cols): tmp.append("%s = {%s}" % (skytools.quote_ident(rk), k)) fkey_expr = " and ".join(tmp) q = "select 1 from only %s where %s" % ( skytools.quote_fqident(fkey_ref_table), fkey_expr) res = skytools.plpy_exec(gd, q, fields) if not res: return "IGN: parent row does not exist" log += ", fkey ok" # fetch old row if alt_pkey_expr: q = "select * from only %s where %s for update" % (qtblname, alt_pkey_expr) res = skytools.plpy_exec(gd, q, fields) if res: oldrow = res[0] # if altpk matches, but pk not, then delete need_del = 0 for k in pkey_cols: # fixme: proper type cmp? if fields[k] != str(oldrow[k]): need_del = 1 break if need_del: log += ", altpk del" q = "delete from only %s where %s" % (qtblname, alt_pkey_expr) skytools.plpy_exec(gd, q, fields) res = None else: log += ", altpk ok" else: # no altpk q = "select * from only %s where %s for update" % (qtblname, pkey_expr) res = skytools.plpy_exec(None, q, fields) # got old row, with same pk and altpk if res: oldrow = res[0] log += ", old row" ok = fn_canapply(fields, oldrow) if ok: log += ", new row better" if not ok: # ignore the update return "IGN:" + log + ", current row more up-to-date" else: log += ", no old row" oldrow = None if res: if cmd == 'I': cmd = 'U' else: if cmd == 'U': cmd = 'I' # allow column changes if oldrow: fields2 = fn_colfilter(fields, oldrow) for k in pkey_cols: if k not in fields2: fields2[k] = fields[k] fields = fields2 # apply change if cmd == 'I': q = skytools.mk_insert_sql(fields, tblname, pkey_cols) elif cmd == 'U': q = skytools.mk_update_sql(fields, tblname, pkey_cols) elif cmd == 'D': q = skytools.mk_delete_sql(fields, tblname, pkey_cols) else: plpy.error('Huh') plpy.execute(q) return log def ts_conflict_handler(gd, args): """Conflict handling based on timestamp column.""" conf = skytools.db_urldecode(args[0]) timefield = conf['timefield'] ev_type = args[1] ev_data = args[2] ev_extra1 = args[3] ev_extra2 = args[4] #ev_extra3 = args[5] #ev_extra4 = args[6] altpk = None if 'altpk' in conf: altpk = conf['altpk'].split(',') def ts_canapply(rnew, rold): return canapply_tstamp_helper(rnew, rold, timefield) return applyrow(ev_extra1, ev_type, ev_data, backup_row=ev_extra2, alt_pkey_cols=altpk, fkey_ref_table=conf.get('fkey_ref_table'), fkey_ref_cols=conf.get('fkey_ref_cols'), fkey_cols=conf.get('fkey_cols'), fn_canapply=ts_canapply) python-skytools-3.4/skytools/psycopgwrapper.py000066400000000000000000000077611356323561300221360ustar00rootroot00000000000000 """Wrapper around psycopg2. Database connection provides regular DB-API 2.0 interface. Connection object methods:: .cursor() .commit() .rollback() .close() Cursor methods:: .execute(query[, args]) .fetchone() .fetchall() Sample usage:: db = self.get_database('somedb') curs = db.cursor() # query arguments as array q = "select * from table where id = %s and name = %s" curs.execute(q, [1, 'somename']) # query arguments as dict q = "select id, name from table where id = %(id)s and name = %(name)s" curs.execute(q, {'id': 1, 'name': 'somename'}) # loop over resultset for row in curs.fetchall(): # columns can be asked by index: id = row[0] name = row[1] # and by name: id = row['id'] name = row['name'] # now commit the transaction db.commit() Deprecated interface: .dictfetchall/.dictfetchone functions on cursor. Plain .fetchall() / .fetchone() give exact same result. """ from __future__ import division, absolute_import, print_function import skytools from skytools.sockutil import set_tcp_keepalive import psycopg2.extensions import psycopg2.extras from psycopg2 import Error as DBError __all__ = ['connect_database', 'DBError', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_REPEATABLE_READ', 'I_SERIALIZABLE'] I_AUTOCOMMIT = psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT I_READ_COMMITTED = psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED I_REPEATABLE_READ = psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ I_SERIALIZABLE = psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE class _CompatRow(psycopg2.extras.DictRow): """Make DictRow more dict-like.""" __slots__ = ('_index',) def __contains__(self, k): """Returns if such row has such column.""" return k in self._index def copy(self): """Return regular dict.""" return skytools.dbdict(self.items()) def iterkeys(self): return self._index.iterkeys() def itervalues(self): return list.__iter__(self) # obj.foo access def __getattr__(self, k): return self[k] class _CompatCursor(psycopg2.extras.DictCursor): """Regular psycopg2 DictCursor with dict* methods.""" def __init__(self, *args, **kwargs): super(_CompatCursor, self).__init__(*args, **kwargs) self.row_factory = _CompatRow dictfetchone = psycopg2.extras.DictCursor.fetchone dictfetchall = psycopg2.extras.DictCursor.fetchall dictfetchmany = psycopg2.extras.DictCursor.fetchmany class _CompatConnection(psycopg2.extensions.connection): """Connection object that uses _CompatCursor.""" my_name = '?' server_version = None def cursor(self, name=None): if name: return super(_CompatConnection, self).cursor(cursor_factory=_CompatCursor, name=name) else: return super(_CompatConnection, self).cursor(cursor_factory=_CompatCursor) def connect_database(connstr, keepalive=True, tcp_keepidle=4*60, # 7200 tcp_keepcnt=4, # 9 tcp_keepintvl=15): # 75 """Create a db connection with connect_timeout and TCP keepalive. Default connect_timeout is 15, to change put it directly into dsn. The extra tcp_* options are Linux-specific, see `man 7 tcp` for details. """ # allow override if connstr.find("connect_timeout") < 0: connstr += " connect_timeout=15" # create connection db = _CompatConnection(connstr) curs = db.cursor() # tune keepalive fd = hasattr(db, 'fileno') and db.fileno() or curs.fileno() set_tcp_keepalive(fd, keepalive, tcp_keepidle, tcp_keepcnt, tcp_keepintvl) # fill .server_version on older psycopg if not getattr(db, 'server_version'): iso = db.isolation_level db.set_isolation_level(0) curs.execute('show server_version_num') db.server_version = int(curs.fetchone()[0]) db.set_isolation_level(iso) return db python-skytools-3.4/skytools/querybuilder.py000066400000000000000000000327751356323561300215700ustar00rootroot00000000000000#! /usr/bin/env python """Helper classes for complex query generation. Main target is code execution under PL/Python. Query parameters are referenced as C{{key}} or C{{key:type}}. Type will be given to C{plpy.prepare}. If C{type} is missing, C{text} is assumed. See L{plpy_exec} for examples. """ from __future__ import division, absolute_import, print_function import skytools try: import plpy except ImportError: plpy = None __all__ = [ 'QueryBuilder', 'PLPyQueryBuilder', 'PLPyQuery', 'plpy_exec', "run_query", "run_query_row", "run_lookup", "run_exists", ] PARAM_INLINE = 0 # quote_literal() PARAM_DBAPI = 1 # %()s PARAM_PLPY = 2 # $n class QArgConf(object): """Per-query arg-type config object.""" param_type = None class QArg(object): """Place-holder for a query parameter.""" def __init__(self, name, value, pos, conf): self.name = name self.value = value self.pos = pos self.conf = conf def __str__(self): if self.conf.param_type == PARAM_INLINE: return skytools.quote_literal(self.value) elif self.conf.param_type == PARAM_DBAPI: return "%s" elif self.conf.param_type == PARAM_PLPY: return "$%d" % self.pos else: raise Exception("bad QArgConf.param_type") # need an structure with fast remove-from-middle # and append operations. class DList(object): """Simple double-linked list.""" __slots__ = ('next', 'prev') def __init__(self): self.next = self self.prev = self def append(self, obj): obj.next = self obj.prev = self.prev self.prev.next = obj self.prev = obj def remove(self, obj): obj.next.prev = obj.prev obj.prev.next = obj.next obj.next = obj.prev = None def empty(self): return self.next is self def pop(self): """Remove and return first element.""" obj = None if not self.empty(): obj = self.next self.remove(obj) return obj class CachedPlan(DList): """Wrapper around prepared plan.""" __slots__ = ('key', 'plan') def __init__(self, key, plan): super(CachedPlan, self).__init__() self.key = key # (sql, (types)) self.plan = plan class PlanCache(object): """Cache for limited amount of plans.""" def __init__(self, maxplans=100): self.maxplans = maxplans self.plan_map = {} self.plan_list = DList() def get_plan(self, sql, types): """Prepare the plan and cache it.""" t = (sql, tuple(types)) if t in self.plan_map: pc = self.plan_map[t] # put to the end self.plan_list.remove(pc) self.plan_list.append(pc) return pc.plan # prepare new plan plan = plpy.prepare(sql, types) # add to cache pc = CachedPlan(t, plan) self.plan_list.append(pc) self.plan_map[t] = pc # remove plans if too much while len(self.plan_map) > self.maxplans: # this is ugly workaround for pylint drop = self.plan_list.pop() del self.plan_map[getattr(drop, 'key')] return plan class QueryBuilderCore(object): """Helper for query building. >>> args = {'success': 't', 'total': 45, 'ccy': 'EEK', 'id': 556} >>> q = QueryBuilder("update orders set total = {total} where id = {id}", args) >>> q.add(" and optional = {non_exist}") >>> q.add(" and final = {success}") >>> print(q.get_sql(PARAM_INLINE)) update orders set total = '45' where id = '556' and final = 't' >>> print(q.get_sql(PARAM_DBAPI)) update orders set total = %s where id = %s and final = %s >>> print(q.get_sql(PARAM_PLPY)) update orders set total = $1 where id = $2 and final = $3 """ def __init__(self, sqlexpr, params): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. """ self._params = params self._arg_type_list = [] self._arg_value_list = [] self._sql_parts = [] self._arg_conf = QArgConf() self._nargs = 0 if sqlexpr: self.add(sqlexpr, required=True) def add(self, expr, sql_type="text", required=False): """Add SQL fragment to query. """ self._add_expr('', expr, self._params, sql_type, required) def get_sql(self, param_type=PARAM_INLINE): """Return generated SQL (thus far) as string. Possible values for param_type: - 0: Insert values quoted with quote_literal() - 1: Insert %()s in place of parameters. - 2: Insert $n in place of parameters. """ self._arg_conf.param_type = param_type tmp = [str(part) for part in self._sql_parts] return "".join(tmp) def _add_expr(self, pfx, expr, params, sql_type, required): parts = [] types = [] values = [] nargs = self._nargs if pfx: parts.append(pfx) pos = 0 while 1: # find start of next argument a1 = expr.find('{', pos) if a1 < 0: parts.append(expr[pos:]) break # find end end of argument name a2 = expr.find('}', a1) if a2 < 0: raise Exception("missing argument terminator: "+expr) # add plain sql if a1 > pos: parts.append(expr[pos:a1]) pos = a2 + 1 # get arg name, check if exists k = expr[a1 + 1 : a2] # split name from type tpos = k.rfind(':') if tpos > 0: kparam = k[:tpos] ktype = k[tpos+1 : ] else: kparam = k ktype = sql_type # params==None means params are checked later if params is not None and kparam not in params: if required: raise Exception("required parameter missing: "+kparam) # optional fragment, param missing, skip it return # got arg nargs += 1 if params is not None: val = params[kparam] else: val = kparam values.append(val) types.append(ktype) arg = QArg(kparam, val, nargs, self._arg_conf) parts.append(arg) # add interesting parts to the main sql self._sql_parts.extend(parts) if types: self._arg_type_list.extend(types) if values: self._arg_value_list.extend(values) self._nargs = nargs class QueryBuilder(QueryBuilderCore): def execute(self, curs): """Client-side query execution on DB-API 2.0 cursor. Calls C{curs.execute()} with proper arguments. Returns result of curs.execute(), although that does not return anything interesting. Later curs.fetch* methods must be called to get result. """ q = self.get_sql(PARAM_DBAPI) args = self._params return curs.execute(q, args) class PLPyQueryBuilder(QueryBuilderCore): def __init__(self, sqlexpr, params, plan_cache=None, sqls=None): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. @param plan_cache: (PL/Python) A dict object where to store the plan cache, under the key C{"plan_cache"}. If not given, plan will not be cached and values will be inserted directly to query. Usually either C{GD} or C{SD} should be given here. @param sqls: list object where to append executed sqls (used for debugging) """ super(PLPyQueryBuilder, self).__init__(sqlexpr, params) self._sqls = sqls if plan_cache is not None: if 'plan_cache' not in plan_cache: plan_cache['plan_cache'] = PlanCache() self._plan_cache = plan_cache['plan_cache'] else: self._plan_cache = None def execute(self): """Server-side query execution via plpy. Query can be run either cached or uncached, depending on C{plan_cache} setting given to L{__init__}. Returns result of plpy.execute(). """ args = self._arg_value_list types = self._arg_type_list if self._sqls is not None: self._sqls.append({"sql": self.get_sql(PARAM_INLINE)}) if self._plan_cache is not None: sql = self.get_sql(PARAM_PLPY) plan = self._plan_cache.get_plan(sql, types) res = plpy.execute(plan, args) else: sql = self.get_sql(PARAM_INLINE) res = plpy.execute(sql) if res: res = [skytools.dbdict(r) for r in res] return res class PLPyQuery(object): """Static, cached PL/Python query that uses QueryBuilder formatting. See L{plpy_exec} for simple usage. """ def __init__(self, sql): qb = QueryBuilder(sql, None) p_sql = qb.get_sql(PARAM_PLPY) p_types = qb._arg_type_list self.plan = plpy.prepare(p_sql, p_types) self.arg_map = qb._arg_value_list self.sql = sql def execute(self, arg_dict, all_keys_required=True): try: if all_keys_required: arg_list = [arg_dict[k] for k in self.arg_map] else: arg_list = [arg_dict.get(k) for k in self.arg_map] return plpy.execute(self.plan, arg_list) except KeyError: need = set(self.arg_map) got = set(arg_dict.keys()) missing = list(need.difference(got)) plpy.error("Missing arguments: [%s] QUERY: %s" % ( ','.join(missing), repr(self.sql))) def __repr__(self): return 'PLPyQuery<%s>' % self.sql def plpy_exec(gd, sql, args, all_keys_required=True): """Cached plan execution for PL/Python. @param gd: dict to store cached plans under. If None, caching is disabled. @param sql: SQL statement to execute. @param args: dict of arguments to query. @param all_keys_required: if False, missing key is taken as NULL, instead of throwing error. >>> res = plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) DBG: plpy.prepare('select $1, $2, $3', ['text', 'int4', 'text']) DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['1', '2', '1']) >>> res = plpy_exec(None, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) DBG: plpy.execute("select '1', '2', '1'", ()) >>> res = plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3', 'arg2': '4'}) DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', '4', '3']) >>> res = plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}) DBG: plpy.error("Missing arguments: [arg2] QUERY: 'select {arg1}, {arg2:int4}, {arg1}'") >>> res = plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}, False) DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', None, '3']) """ if gd is None: return PLPyQueryBuilder(sql, args).execute() try: sq = gd['plq_cache'][sql] except KeyError: if 'plq_cache' not in gd: gd['plq_cache'] = {} sq = PLPyQuery(sql) gd['plq_cache'][sql] = sq return sq.execute(args, all_keys_required) # some helper functions for convenient sql execution def run_query(cur, sql, params=None, **kwargs): """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) rows = cur.fetchall() # convert result rows to dbdict if rows: rows = [skytools.dbdict(r) for r in rows] return rows def run_query_row(cur, sql, params=None, **kwargs): """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kwargs rows = run_query(cur, sql, params) if len(rows) == 0: return None return rows[0] def run_lookup(cur, sql, params=None, **kwargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) row = cur.fetchone() if row is None: return None return row[0] def run_exists(cur, sql, params=None, **kwargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs val = run_lookup(cur, sql, params) return val is not None # fake plpy for testing class fake_plpy(object): def prepare(self, sql, types): print("DBG: plpy.prepare(%s, %s)" % (repr(sql), repr(types))) return ('PLAN', sql, types) def execute(self, plan, args=()): print("DBG: plpy.execute(%s, %s)" % (repr(plan), repr(args))) def error(self, msg): print("DBG: plpy.error(%s)" % repr(msg)) # make plpy available if not plpy: plpy = fake_plpy() GD = {} python-skytools-3.4/skytools/quoting.py000066400000000000000000000144401356323561300205270ustar00rootroot00000000000000# quoting.py """Various helpers for string quoting/unquoting.""" from __future__ import division, absolute_import, print_function import re import json try: from skytools._cquoting import (db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal) except ImportError: from skytools._pyquoting import (db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal) __all__ = [ # _pyqoting / _cquoting "db_urldecode", "db_urlencode", "quote_bytea_raw", "quote_copy", "quote_literal", "unescape", "unquote_literal", # local "quote_bytea_literal", "quote_bytea_copy", "quote_statement", "quote_ident", "quote_fqident", "quote_json", "unescape_copy", "unquote_ident", "unquote_fqident", "json_encode", "json_decode", "make_pgarray", ] # # SQL quoting # def quote_bytea_literal(s): """Quote bytea for regular SQL.""" return quote_literal(quote_bytea_raw(s)) def quote_bytea_copy(s): """Quote bytea for COPY.""" return quote_copy(quote_bytea_raw(s)) def quote_statement(sql, dict_or_list): """Quote whole statement. Data values are taken from dict or list or tuple. """ if hasattr(dict_or_list, 'items'): qdict = {} for k, v in dict_or_list.items(): qdict[k] = quote_literal(v) return sql % qdict else: qvals = [quote_literal(v) for v in dict_or_list] return sql % tuple(qvals) # reserved keywords (RESERVED_KEYWORD + TYPE_FUNC_NAME_KEYWORD) _ident_kwmap = { "all":1, "analyse":1, "analyze":1, "and":1, "any":1, "array":1, "as":1, "asc":1, "asymmetric":1, "authorization":1, "between":1, "binary":1, "both":1, "case":1, "cast":1, "check":1, "collate":1, "collation":1, "column":1, "concurrently":1, "constraint":1, "create":1, "cross":1, "current_catalog":1, "current_date":1, "current_role":1, "current_schema":1, "current_time":1, "current_timestamp":1, "current_user":1, "default":1, "deferrable":1, "desc":1, "distinct":1, "do":1, "else":1, "end":1, "errors":1, "except":1, "false":1, "fetch":1, "for":1, "foreign":1, "freeze":1, "from":1, "full":1, "grant":1, "group":1, "having":1, "ilike":1, "in":1, "initially":1, "inner":1, "intersect":1, "into":1, "is":1, "isnull":1, "join":1, "lateral":1, "leading":1, "left":1, "like":1, "limit":1, "localtime":1, "localtimestamp":1, "natural":1, "new":1, "not":1, "notnull":1, "null":1, "off":1, "offset":1, "old":1, "on":1, "only":1, "or":1, "order":1, "outer":1, "over":1, "overlaps":1, "placing":1, "primary":1, "references":1, "returning":1, "right":1, "select":1, "session_user":1, "similar":1, "some":1, "symmetric":1, "table":1, "tablesample":1, "then":1, "to":1, "trailing":1, "true":1, "union":1, "unique":1, "user":1, "using":1, "variadic":1, "verbose":1, "when":1, "where":1, "window":1, "with":1, } _ident_bad = re.compile(r"[^a-z0-9_]|^[0-9]") def quote_ident(s): """Quote SQL identifier. If is checked against weird symbols and keywords. """ if _ident_bad.search(s) or s in _ident_kwmap: s = '"%s"' % s.replace('"', '""') elif not s: return '""' return s def quote_fqident(s): """Quote fully qualified SQL identifier. The '.' is taken as namespace separator and all parts are quoted separately Example: >>> quote_fqident('tbl') 'public.tbl' >>> quote_fqident('Baz.Foo.Bar') '"Baz"."Foo.Bar"' """ tmp = s.split('.', 1) if len(tmp) == 1: return 'public.' + quote_ident(s) return '.'.join([quote_ident(name) for name in tmp]) # # quoting for JSON strings # _jsre = re.compile(r'[\x00-\x1F\\/"]') _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", "\\": "\\\\", '"': '\\"', "/": "\\/", # to avoid html attacks } def _json_quote_char(m): """Quote single char.""" c = m.group(0) try: return _jsmap[c] except KeyError: return r"\u%04x" % ord(c) def quote_json(s): """JSON style quoting.""" if s is None: return "null" return '"%s"' % _jsre.sub(_json_quote_char, s) def unescape_copy(val): r"""Removes C-style escapes, also converts "\N" to None. Example: >>> unescape_copy(r'baz\tfo\'o') "baz\tfo'o" >>> unescape_copy(r'\N') is None True """ if val == r"\N": return None return unescape(val) def unquote_ident(val): """Unquotes possibly quoted SQL identifier. >>> unquote_ident('Foo') 'foo' >>> unquote_ident('"Wei "" rd"') 'Wei " rd' """ if len(val) > 1 and val[0] == '"' and val[-1] == '"': return val[1:-1].replace('""', '"') if val.find('"') > 0: raise Exception('unsupported syntax') return val.lower() def unquote_fqident(val): """Unquotes fully-qualified possibly quoted SQL identifier. >>> unquote_fqident('foo') 'foo' >>> unquote_fqident('"Foo"."Bar "" z"') 'Foo.Bar " z' """ tmp = val.split('.', 1) return '.'.join([unquote_ident(i) for i in tmp]) def json_encode(val=None, **kwargs): """Creates JSON string from Python object. >>> json_encode({'a': 1}) '{"a": 1}' >>> json_encode('a') '"a"' >>> json_encode(['a']) '["a"]' >>> json_encode(a=1) '{"a": 1}' """ return json.dumps(val or kwargs) def json_decode(s): """Parses JSON string into Python object. >>> json_decode('[1]') [1] """ return json.loads(s) # # Create Postgres array # # any chars not in "good" set? main bad ones: [ ,{}\"] _pgarray_bad_rx = r"[^0-9a-z_.%&=()<>*/+-]" _pgarray_bad_rc = None def _quote_pgarray_elem(s): if s is None: return 'NULL' s = str(s) if _pgarray_bad_rc.search(s): s = s.replace('\\', '\\\\') return '"' + s.replace('"', r'\"') + '"' elif not s: return '""' return s def make_pgarray(lst): r"""Formats Python list as Postgres array. Reverse of parse_pgarray(). >>> make_pgarray([]) '{}' >>> make_pgarray(['foo_3',1,'',None]) '{foo_3,1,"",NULL}' >>> make_pgarray([None,',','\\',"'",'"',"{","}",'_']) '{NULL,",","\\\\","\'","\\"","{","}",_}' """ global _pgarray_bad_rc if _pgarray_bad_rc is None: _pgarray_bad_rc = re.compile(_pgarray_bad_rx) items = [_quote_pgarray_elem(v) for v in lst] return '{' + ','.join(items) + '}' python-skytools-3.4/skytools/scripting.py000066400000000000000000001101701356323561300210400ustar00rootroot00000000000000 """Useful functions and classes for database scripts. """ from __future__ import division, absolute_import, print_function import errno import logging import logging.config import logging.handlers import optparse import os import select import signal import sys import time import skytools import skytools.skylog try: import skytools.installer_config default_skylog = skytools.installer_config.skylog except ImportError: default_skylog = 0 __pychecker__ = 'no-badexcept' __all__ = ['BaseScript', 'UsageError', 'daemonize', 'DBScript'] class UsageError(Exception): """User induced error.""" # # daemon mode # def daemonize(): """Turn the process into daemon. Goes background and disables all i/o. """ # launch new process, kill parent pid = os.fork() if pid != 0: os._exit(0) # start new session os.setsid() # stop i/o fd = os.open("/dev/null", os.O_RDWR) os.dup2(fd, 0) os.dup2(fd, 1) os.dup2(fd, 2) if fd > 2: os.close(fd) # # Pidfile locking+cleanup & daemonization combined # def run_single_process(runnable, daemon, pidfile): """Run runnable class, possibly daemonized, locked on pidfile.""" # check if another process is running if pidfile and os.path.isfile(pidfile): if skytools.signal_pidfile(pidfile, 0): print("Pidfile exists, another process running?") sys.exit(1) else: print("Ignoring stale pidfile") # daemonize if needed if daemon: daemonize() # clean only own pidfile own_pidfile = False try: if pidfile: data = str(os.getpid()) skytools.write_atomic(pidfile, data) own_pidfile = True runnable.run() finally: if own_pidfile: try: os.remove(pidfile) except: pass # # logging setup # _log_config_done = 0 _log_init_done = {} def _load_log_config(fn, defs): """Fixed fileConfig.""" # Work around fileConfig default behaviour to disable # not only old handlers on load (which slightly makes sense) # but also old logger objects (which does not make sense). if sys.hexversion >= 0x2060000: logging.config.fileConfig(fn, defs, False) else: logging.config.fileConfig(fn, defs) root = logging.getLogger() for lg in root.manager.loggerDict.values(): lg.disabled = 0 def _init_log(job_name, service_name, cf, log_level, is_daemon): """Logging setup happens here.""" global _log_config_done got_skylog = 0 use_skylog = cf.getint("use_skylog", default_skylog) # if non-daemon, avoid skylog if script is running on console. # set use_skylog=2 to disable. if not is_daemon and use_skylog == 1: # pylint gets spooked by it's own stdout wrapper and refuses to shut down # about it. 'noqa' tells prospector to ignore all warnings here. if sys.stdout.isatty(): # noqa use_skylog = 0 # load logging config if needed if use_skylog and not _log_config_done: # python logging.config braindamage: # cannot specify external classess without such hack logging.skylog = skytools.skylog skytools.skylog.set_service_name(service_name, job_name) # load general config flist = cf.getlist('skylog_locations', ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']) for fn in flist: fn = os.path.expanduser(fn) if os.path.isfile(fn): defs = {'job_name': job_name, 'service_name': service_name} _load_log_config(fn, defs) got_skylog = 1 break _log_config_done = 1 if not got_skylog: sys.stderr.write("skylog.ini not found!\n") sys.exit(1) # avoid duplicate logging init for job_name log = logging.getLogger(job_name) if job_name in _log_init_done: return log _log_init_done[job_name] = 1 # tune level on root logger root = logging.getLogger() root.setLevel(log_level) # compatibility: specify ini file in script config def_fmt = '%(asctime)s %(process)s %(levelname)s %(message)s' def_datefmt = '' # None logfile = cf.getfile("logfile", "") if logfile: fstr = cf.get('logfmt_file', def_fmt) fstr_date = cf.get('logdatefmt_file', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_file_verbose', fstr) fstr_date = cf.get('logdatefmt_file_verbose', fstr_date) fmt = logging.Formatter(fstr, fstr_date) size = cf.getint('log_size', 10*1024*1024) num = cf.getint('log_count', 3) file_hdlr = logging.handlers.RotatingFileHandler( logfile, 'a', size, num) file_hdlr.setFormatter(fmt) root.addHandler(file_hdlr) # if skylog.ini is disabled or not available, log at least to stderr if not got_skylog: fstr = cf.get('logfmt_console', def_fmt) fstr_date = cf.get('logdatefmt_console', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_console_verbose', fstr) fstr_date = cf.get('logdatefmt_console_verbose', fstr_date) stream_hdlr = logging.StreamHandler() fmt = logging.Formatter(fstr, fstr_date) stream_hdlr.setFormatter(fmt) root.addHandler(stream_hdlr) return log class BaseScript(object): """Base class for service scripts. Handles logging, daemonizing, config, errors. Config template:: ## Parameters for skytools.BaseScript ## # how many seconds to sleep between work loops # if missing or 0, then instead sleeping, the script will exit loop_delay = 1.0 # where to log logfile = ~/log/%(job_name)s.log # where to write pidfile pidfile = ~/pid/%(job_name)s.pid # per-process name to use in logging #job_name = %(config_name)s # whether centralized logging should be used # search-path [ ./skylog.ini, ~/.skylog.ini, /etc/skylog.ini ] # 0 - disabled # 1 - enabled, unless non-daemon on console (os.isatty()) # 2 - always enabled #use_skylog = 0 # where to find skylog.ini #skylog_locations = skylog.ini, ~/.skylog.ini, /etc/skylog.ini # how many seconds to sleep after catching a exception #exception_sleep = 20 """ service_name = None job_name = None cf = None cf_defaults = {} pidfile = None # >0 - sleep time if work() requests sleep # 0 - exit if work requests sleep # <0 - run work() once [same as looping=0] loop_delay = 1.0 # 0 - run work() once # 1 - run work() repeatedly looping = 1 # result from last work() call: # 1 - there is probably more work, don't sleep # 0 - no work, sleep before calling again # -1 - exception was thrown work_state = 1 # setup logger here, this allows override by subclass log = logging.getLogger('skytools.BaseScript') # start time started = 0 def __init__(self, service_name, args): """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset(), shutdown() and init_optparse(). NB: In case of daemon, __init__() and startup()/work()/shutdown() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.service_name = service_name self.go_daemon = 0 self.need_reload = 0 self.exception_count = 0 self.stat_dict = {} self.log_level = logging.INFO # parse command line parser = self.init_optparse() self.options, self.args = parser.parse_args(args) # check args if self.options.version: self.print_version() sys.exit(0) if self.options.daemon: self.go_daemon = 1 if self.options.quiet: self.log_level = logging.WARNING if self.options.verbose: if self.options.verbose > 1: self.log_level = skytools.skylog.TRACE else: self.log_level = logging.DEBUG self.cf_override = {} if self.options.set: for a in self.options.set: k, v = a.split('=', 1) self.cf_override[k.strip()] = v.strip() if self.options.ini: self.print_ini() sys.exit(0) # read config file self.reload() # init logging _init_log(self.job_name, self.service_name, self.cf, self.log_level, self.go_daemon) # send signal, if needed if self.options.cmd == "kill": self.send_signal(signal.SIGTERM) elif self.options.cmd == "stop": self.send_signal(signal.SIGINT) elif self.options.cmd == "reload": self.send_signal(signal.SIGHUP) def print_version(self): service = self.service_name ver = getattr(self, '__version__', None) if ver: service += ' version %s' % ver print('%s, Skytools version %s' % (service, getattr(skytools, '__version__'))) def print_ini(self): """Prints out ini file from doc string of the script of default for dbscript Used by --ini option on command line. """ # current service name print("[%s]\n" % self.service_name) # walk class hierarchy bases = [self.__class__] while len(bases) > 0: parents = [] for c in bases: for p in c.__bases__: if p not in parents: parents.append(p) doc = c.__doc__ if doc: self._print_ini_frag(doc) bases = parents def _print_ini_frag(self, doc): # use last '::' block as config template pos = doc and doc.rfind('::\n') or -1 if pos < 0: return doc = doc[pos+2 : ].rstrip() doc = skytools.dedent(doc) # merge overrided options into output for ln in doc.splitlines(): vals = ln.split('=', 1) if len(vals) != 2: print(ln) continue k = vals[0].strip() v = vals[1].strip() if k and k[0] == '#': print(ln) k = k[1:] if k in self.cf_override: print('%s = %s' % (k, self.cf_override[k])) elif k in self.cf_override: if v: print('#' + ln) print('%s = %s' % (k, self.cf_override[k])) else: print(ln) print('') def load_config(self): """Loads and returns skytools.Config instance. By default it uses first command-line argument as config file name. Can be overridden. """ if len(self.args) < 1: print("need config file, use --help for help.") sys.exit(1) conf_file = self.args[0] return skytools.Config(self.service_name, conf_file, user_defs=self.cf_defaults, override=self.cf_override) def init_optparse(self, parser=None): """Initialize a OptionParser() instance that will be used to parse command line arguments. Note that it can be overridden both directions - either DBScript will initialize an instance and pass it to user code or user can initialize and then pass to DBScript.init_optparse(). @param parser: optional OptionParser() instance, where DBScript should attach its own arguments. @return: initialized OptionParser() instance. """ if parser: p = parser else: p = optparse.OptionParser() p.set_usage("%prog [options] INI") # generic options p.add_option("-q", "--quiet", action="store_true", help="log only errors and warnings") p.add_option("-v", "--verbose", action="count", help="log verbosely") p.add_option("-d", "--daemon", action="store_true", help="go background") p.add_option("-V", "--version", action="store_true", help="print version info and exit") p.add_option("", "--ini", action="store_true", help="display sample ini file") p.add_option("", "--set", action="append", help="override config setting (--set 'PARAM=VAL')") # control options g = optparse.OptionGroup(p, 'control running process') g.add_option("-r", "--reload", action="store_const", const="reload", dest="cmd", help="reload config (send SIGHUP)") g.add_option("-s", "--stop", action="store_const", const="stop", dest="cmd", help="stop program safely (send SIGINT)") g.add_option("-k", "--kill", action="store_const", const="kill", dest="cmd", help="kill program immediately (send SIGTERM)") p.add_option_group(g) return p def send_signal(self, sig): if not self.pidfile: self.log.warning("No pidfile in config, nothing to do") elif os.path.isfile(self.pidfile): alive = skytools.signal_pidfile(self.pidfile, sig) if not alive: self.log.warning("pidfile exists, but process not running") else: self.log.warning("No pidfile, process not running") sys.exit(0) def set_single_loop(self, do_single_loop): """Changes whether the script will loop or not.""" if do_single_loop: self.looping = 0 else: self.looping = 1 def _boot_daemon(self): run_single_process(self, self.go_daemon, self.pidfile) def start(self): """This will launch main processing thread.""" if self.go_daemon: if not self.pidfile: self.log.error("Daemon needs pidfile") sys.exit(1) self.run_func_safely(self._boot_daemon) def stop(self): """Safely stops processing loop.""" self.looping = 0 def reload(self): "Reload config." # avoid double loading on startup if not self.cf: self.cf = self.load_config() else: self.cf.reload() self.log.info("Config reloaded") self.job_name = self.cf.get("job_name") self.pidfile = self.cf.getfile("pidfile", '') self.loop_delay = self.cf.getfloat("loop_delay", self.loop_delay) self.exception_sleep = self.cf.getfloat("exception_sleep", 20) self.exception_quiet = self.cf.getlist("exception_quiet", []) self.exception_grace = self.cf.getfloat("exception_grace", 5*60) self.exception_reset = self.cf.getfloat("exception_reset", 15*60) def hook_sighup(self, sig, frame): "Internal SIGHUP handler. Minimal code here." self.need_reload = 1 last_sigint = 0 def hook_sigint(self, sig, frame): "Internal SIGINT handler. Minimal code here." self.stop() t = time.time() if t - self.last_sigint < 1: self.log.warning("Double ^C, fast exit") sys.exit(1) self.last_sigint = t def stat_get(self, key): """Reads a stat value.""" try: value = self.stat_dict[key] except KeyError: value = None return value def stat_put(self, key, value): """Sets a stat value.""" self.stat_dict[key] = value def stat_increase(self, key, increase=1): """Increases a stat value.""" try: self.stat_dict[key] += increase except KeyError: self.stat_dict[key] = increase def send_stats(self): "Send statistics to log." res = [] for k, v in self.stat_dict.items(): res.append("%s: %s" % (k, v)) if len(res) == 0: return logmsg = "{%s}" % ", ".join(res) self.log.info(logmsg) self.stat_dict = {} def reset(self): "Something bad happened, reset all state." pass def run(self): "Thread main loop." # run startup, safely self.run_func_safely(self.startup) while 1: # reload config, if needed if self.need_reload: self.reload() self.need_reload = 0 # do some work work = self.run_once() if not self.looping or self.loop_delay < 0: break # remember work state self.work_state = work # should sleep? if not work: if self.loop_delay > 0: self.sleep(self.loop_delay) if not self.looping: break else: break # run shutdown, safely? self.shutdown() def run_once(self): state = self.run_func_safely(self.work, True) # send stats that was added self.send_stats() return state last_func_fail = None def run_func_safely(self, func, prefer_looping=False): "Run users work function, safely." try: r = func() if self.last_func_fail and time.time() > self.last_func_fail + self.exception_reset: self.last_func_fail = None # set exception count to 0 after success self.exception_count = 0 return r except UsageError as d: self.log.error(str(d)) sys.exit(1) except MemoryError as d: try: # complex logging may not succeed self.log.exception("Job %s out of memory, exiting", self.job_name) except MemoryError: self.log.fatal("Out of memory") sys.exit(1) except SystemExit as d: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got SystemExit(%s), exiting", str(d)) self.reset() raise d except KeyboardInterrupt as d: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got KeyboardInterrupt, exiting") self.reset() sys.exit(1) except Exception as d: try: # this may fail too self.send_stats() except: pass if self.last_func_fail is None: self.last_func_fail = time.time() emsg = str(d).rstrip() self.reset() self.exception_hook(d, emsg) # reset and sleep self.reset() if prefer_looping and self.looping and self.loop_delay > 0: # increase exception count & sleep self.exception_count += 1 self.sleep_on_exception() return -1 sys.exit(1) def sleep(self, secs): """Make script sleep for some amount of time.""" try: time.sleep(secs) except IOError as ex: if ex.errno != errno.EINTR: raise def sleep_on_exception(self): """Make script sleep for some amount of time when an exception occurs. To implement more advance exception sleeping like exponential backoff you can override this method. Also note that you can use self.exception_count to track the number of consecutive exceptions. """ self.sleep(self.exception_sleep) def _is_quiet_exception(self, ex): return ((self.exception_quiet == ["ALL"] or ex.__class__.__name__ in self.exception_quiet) and self.last_func_fail and time.time() < self.last_func_fail + self.exception_grace) def exception_hook(self, det, emsg): """Called on after exception processing. Can do additional logging. @param det: exception details @param emsg: exception msg """ lm = "Job %s crashed: %s" % (self.job_name, emsg) if self._is_quiet_exception(det): self.log.warning(lm) else: self.log.exception(lm) def work(self): """Here should user's processing happen. Return value is taken as boolean - if true, the next loop starts immediately. If false, DBScript sleeps for a loop_delay. """ raise Exception("Nothing implemented?") def startup(self): """Will be called just before entering main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ self.started = time.time() # set signals if hasattr(signal, 'SIGHUP'): signal.signal(signal.SIGHUP, self.hook_sighup) if hasattr(signal, 'SIGINT'): signal.signal(signal.SIGINT, self.hook_sigint) def shutdown(self): """Will be called just after exiting main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ pass # define some aliases (short-cuts / backward compatibility cruft) stat_add = stat_put # Old, deprecated function. stat_inc = stat_increase ## ## DBScript ## #: how old connections need to be closed DEF_CONN_AGE = 20*60 # 20 min class DBScript(BaseScript): """Base class for database scripts. Handles database connection state. Config template:: ## Parameters for skytools.DBScript ## # default lifetime for database connections (in seconds) #connection_lifetime = 1200 """ def __init__(self, service_name, args): """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset() and init_optparse(). NB: in case of daemon, the __init__() and startup()/work() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.db_cache = {} self._db_defaults = {} self._listen_map = {} # dbname: channel_list super(DBScript, self).__init__(service_name, args) def connection_hook(self, dbname, conn): pass def set_database_defaults(self, dbname, **kwargs): self._db_defaults[dbname] = kwargs def add_connect_string_profile(self, connstr, profile): """Add extra profile info to connect string. """ if profile: extra = self.cf.get("%s_extra_connstr" % profile, '') if extra: connstr += ' ' + extra return connstr def get_database(self, dbname, autocommit=0, isolation_level=-1, cache=None, connstr=None, profile=None): """Load cached database connection. User must not store it permanently somewhere, as all connections will be invalidated on reset. """ max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE) if not cache: cache = dbname params = {} defs = self._db_defaults.get(cache, {}) params.update(defs) if isolation_level >= 0: params['isolation_level'] = isolation_level elif autocommit: params['isolation_level'] = 0 elif params.get('autocommit', 0): params['isolation_level'] = 0 elif 'isolation_level' not in params: params['isolation_level'] = skytools.I_READ_COMMITTED if 'max_age' not in params: params['max_age'] = max_age if cache in self.db_cache: dbc = self.db_cache[cache] if connstr is None: connstr = self.cf.get(dbname, '') if connstr: connstr = self.add_connect_string_profile(connstr, profile) dbc.check_connstr(connstr) else: if not connstr: connstr = self.cf.get(dbname) connstr = self.add_connect_string_profile(connstr, profile) # connstr might contain password, it is not a good idea to log it filtered_connstr = connstr pos = connstr.lower().find('password') if pos >= 0: filtered_connstr = connstr[:pos] + ' [...]' self.log.debug("Connect '%s' to '%s'", cache, filtered_connstr) dbc = DBCachedConn(cache, connstr, params['max_age'], setup_func=self.connection_hook) self.db_cache[cache] = dbc clist = [] if cache in self._listen_map: clist = self._listen_map[cache] return dbc.get_connection(params['isolation_level'], clist) def close_database(self, dbname): """Explicitly close a cached connection. Next call to get_database() will reconnect. """ if dbname in self.db_cache: dbc = self.db_cache[dbname] dbc.reset() del self.db_cache[dbname] def reset(self): "Something bad happened, reset all connections." for dbc in self.db_cache.values(): dbc.reset() self.db_cache = {} super(DBScript, self).reset() def run_once(self): state = super(DBScript, self).run_once() # reconnect if needed for dbc in self.db_cache.values(): dbc.refresh() return state def exception_hook(self, d, emsg): """Log database and query details from exception.""" curs = getattr(d, 'cursor', None) conn = getattr(curs, 'connection', None) cname = getattr(conn, 'my_name', None) if cname: # Properly named connection cname = d.cursor.connection.my_name sql = getattr(curs, 'query', None) or '?' if isinstance(sql, bytes): sql = sql.decode('utf8') if len(sql) > 200: # avoid logging londiste huge batched queries sql = sql[:60] + " ..." lm = "Job %s got error on connection '%s': %s. Query: %s" % ( self.job_name, cname, emsg, sql) if self._is_quiet_exception(d): self.log.warning(lm) else: self.log.exception(lm) else: super(DBScript, self).exception_hook(d, emsg) def sleep(self, secs): """Make script sleep for some amount of time.""" fdlist = [] for dbname in self._listen_map: if dbname not in self.db_cache: continue fd = self.db_cache[dbname].fileno() if fd is None: continue fdlist.append(fd) if not fdlist: return super(DBScript, self).sleep(secs) try: if hasattr(select, 'poll'): p = select.poll() for fd in fdlist: p.register(fd, select.POLLIN) p.poll(int(secs * 1000)) else: select.select(fdlist, [], [], secs) except select.error: self.log.info('wait canceled') return None def _exec_cmd(self, curs, sql, args, quiet=False, prefix=None): """Internal tool: Run SQL on cursor.""" if self.options.verbose: self.log.debug("exec_cmd: %s", skytools.quote_statement(sql, args)) _pfx = "" if prefix: _pfx = "[%s] " % prefix curs.execute(sql, args) ok = True rows = curs.fetchall() for row in rows: try: code = row['ret_code'] msg = row['ret_note'] except KeyError: self.log.error("Query does not conform to exec_cmd API:") self.log.error("SQL: %s", skytools.quote_statement(sql, args)) self.log.error("Row: %s", repr(row.copy())) sys.exit(1) level = code // 100 if level == 1: self.log.debug("%s%d %s", _pfx, code, msg) elif level == 2: if quiet: self.log.debug("%s%d %s", _pfx, code, msg) else: self.log.info("%s%s", _pfx, msg) elif level == 3: self.log.warning("%s%s", _pfx, msg) else: self.log.error("%s%s", _pfx, msg) self.log.debug("Query was: %s", skytools.quote_statement(sql, args)) ok = False return (ok, rows) def _exec_cmd_many(self, curs, sql, baseargs, extra_list, quiet=False, prefix=None): """Internal tool: Run SQL on cursor multiple times.""" ok = True rows = [] for a in extra_list: (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet, prefix) if not tmp_ok: ok = False rows += tmp_rows return (ok, rows) def exec_cmd(self, db_or_curs, q, args, commit=True, quiet=False, prefix=None): """Run SQL on db with code/value error handling.""" if hasattr(db_or_curs, 'cursor'): db = db_or_curs curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd(curs, q, args, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def exec_cmd_many(self, db_or_curs, sql, baseargs, extra_list, commit=True, quiet=False, prefix=None): """Run SQL on db multiple times.""" if hasattr(db_or_curs, 'cursor'): db = db_or_curs curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def execute_with_retry(self, dbname, stmt, args, exceptions=None): """ Execute SQL and retry if it fails. Return number of retries and current valid cursor, or raise an exception. """ sql_retry = self.cf.getbool("sql_retry", False) sql_retry_max_count = self.cf.getint("sql_retry_max_count", 10) sql_retry_max_time = self.cf.getint("sql_retry_max_time", 300) sql_retry_formula_a = self.cf.getint("sql_retry_formula_a", 1) sql_retry_formula_b = self.cf.getint("sql_retry_formula_b", 5) sql_retry_formula_cap = self.cf.getint("sql_retry_formula_cap", 60) elist = exceptions or tuple() stime = time.time() tried = 0 dbc = None while True: try: if dbc is None: if dbname not in self.db_cache: self.get_database(dbname, autocommit=1) dbc = self.db_cache[dbname] if dbc.isolation_level != skytools.I_AUTOCOMMIT: raise skytools.UsageError("execute_with_retry: autocommit required") else: dbc.reset() curs = dbc.get_connection(dbc.isolation_level).cursor() curs.execute(stmt, args) break except elist as e: if not sql_retry or tried >= sql_retry_max_count or time.time() - stime >= sql_retry_max_time: raise self.log.info("Job %s got error on connection %s: %s", self.job_name, dbname, e) except: raise # y = a + bx , apply cap y = sql_retry_formula_a + sql_retry_formula_b * tried if sql_retry_formula_cap is not None and y > sql_retry_formula_cap: y = sql_retry_formula_cap tried += 1 self.log.info("Retry #%i in %i seconds ...", tried, y) self.sleep(y) return tried, curs def listen(self, dbname, channel): """Make connection listen for specific event channel. Listening will be activated on next .get_database() call. Basically this means that DBScript.sleep() will poll for events on that db connection, so when event appears, script will be woken up. """ if dbname not in self._listen_map: self._listen_map[dbname] = [] clist = self._listen_map[dbname] if channel not in clist: clist.append(channel) def unlisten(self, dbname, channel='*'): """Stop connection for listening on specific event channel. Listening will stop on next .get_database() call. """ if dbname not in self._listen_map: return if channel == '*': del self._listen_map[dbname] return clist = self._listen_map[dbname] try: clist.remove(channel) except ValueError: pass class DBCachedConn(object): """Cache a db connection.""" def __init__(self, name, loc, max_age=DEF_CONN_AGE, verbose=False, setup_func=None, channels=()): self.name = name self.loc = loc self.conn = None self.conn_time = 0 self.max_age = max_age self.isolation_level = -1 self.verbose = verbose self.setup_func = setup_func self.listen_channel_list = [] def fileno(self): if not self.conn: return None return self.conn.cursor().fileno() def get_connection(self, isolation_level=-1, listen_channel_list=()): # default isolation_level is READ COMMITTED if isolation_level < 0: isolation_level = skytools.I_READ_COMMITTED # new conn? if not self.conn: self.isolation_level = isolation_level self.conn = skytools.connect_database(self.loc) self.conn.my_name = self.name self.conn.set_isolation_level(isolation_level) self.conn_time = time.time() if self.setup_func: self.setup_func(self.name, self.conn) else: if self.isolation_level != isolation_level: raise Exception("Conflict in isolation_level") self._sync_listen(listen_channel_list) # done return self.conn def _sync_listen(self, new_clist): if not new_clist and not self.listen_channel_list: return curs = self.conn.cursor() for ch in self.listen_channel_list: if ch not in new_clist: curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) for ch in new_clist: if ch not in self.listen_channel_list: curs.execute("LISTEN %s" % skytools.quote_ident(ch)) if self.isolation_level != skytools.I_AUTOCOMMIT: self.conn.commit() self.listen_channel_list = new_clist[:] def refresh(self): if not self.conn: return #for row in self.conn.notifies(): # if row[0].lower() == "reload": # self.reset() # return if not self.max_age: return if time.time() - self.conn_time >= self.max_age: self.reset() def reset(self): if not self.conn: return # drop reference conn = self.conn self.conn = None self.listen_channel_list = [] # close try: conn.close() except: pass def check_connstr(self, connstr): """Drop connection if connect string has changed. """ if self.loc != connstr: self.reset() python-skytools-3.4/skytools/skylog.py000066400000000000000000000273671356323561300203650ustar00rootroot00000000000000"""Our log handlers for Python's logging package. """ from __future__ import division, absolute_import, print_function import os import socket import time import logging import logging.handlers from logging import LoggerAdapter import skytools import skytools.tnetstrings try: unicode except NameError: unicode = str # noqa __all__ = ['getLogger'] # add TRACE level TRACE = 5 logging.TRACE = TRACE logging.addLevelName(TRACE, 'TRACE') # extra info to be added to each log record _service_name = 'unknown_svc' _job_name = 'unknown_job' _hostname = socket.gethostname() try: _hostaddr = socket.gethostbyname(_hostname) except: _hostaddr = "0.0.0.0" _log_extra = { 'job_name': _job_name, 'service_name': _service_name, 'hostname': _hostname, 'hostaddr': _hostaddr, } def set_service_name(service_name, job_name): """Set info about current script.""" global _service_name, _job_name _service_name = service_name _job_name = job_name _log_extra['job_name'] = _job_name _log_extra['service_name'] = _service_name # # How to make extra fields available to all log records: # 1. Use own getLogger() # - messages logged otherwise (eg. from some libs) # will crash the logging. # 2. Fix record in own handlers # - works only with custom handlers, standard handlers will # crash is used with custom fmt string. # 3. Change root logger # - can't do it after non-root loggers are initialized, # doing it before will depend on import order. # 4. Update LogRecord.__dict__ # - fails, as formatter uses obj.__dict__ directly. # 5. Change LogRecord class # - ugly but seems to work. # _OldLogRecord = logging.LogRecord class _NewLogRecord(_OldLogRecord): def __init__(self, *args): super(_NewLogRecord, self).__init__(*args) self.__dict__.update(_log_extra) logging.LogRecord = _NewLogRecord # configurable file logger class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler): """Easier setup for RotatingFileHandler.""" def __init__(self, filename, maxBytes=10*1024*1024, backupCount=3): """Args same as for RotatingFileHandler, but in filename '~' is expanded.""" fn = os.path.expanduser(filename) super(EasyRotatingFileHandler, self).__init__(fn, maxBytes=maxBytes, backupCount=backupCount) # send JSON message over UDP class UdpLogServerHandler(logging.handlers.DatagramHandler): """Sends log records over UDP to logserver in JSON format.""" # map logging levels to logserver levels _level_map = { logging.DEBUG : 'DEBUG', logging.INFO : 'INFO', logging.WARNING : 'WARN', logging.ERROR : 'ERROR', logging.CRITICAL: 'FATAL', } # JSON message template _log_template = '{\n\t'\ '"logger": "skytools.UdpLogServer",\n\t'\ '"timestamp": %.0f,\n\t'\ '"level": "%s",\n\t'\ '"thread": null,\n\t'\ '"message": %s,\n\t'\ '"properties": {"application":"%s", "apptype": "%s", "type": "sys", "hostname":"%s", "hostaddr": "%s"}\n'\ '}\n' # cut longer msgs MAXMSG = 1024 def makePickle(self, record): """Create message in JSON format.""" # get & cut msg msg = self.format(record) if len(msg) > self.MAXMSG: msg = msg[:self.MAXMSG] txt_level = self._level_map.get(record.levelno, "ERROR") hostname = _hostname hostaddr = _hostaddr jobname = _job_name svcname = _service_name pkt = self._log_template % (time.time()*1000, txt_level, skytools.quote_json(msg), jobname, svcname, hostname, hostaddr) return pkt def send(self, s): """Disable socket caching.""" sock = self.makeSocket() if not isinstance(s, bytes): s = s.encode('utf8') sock.sendto(s, (self.host, self.port)) sock.close() # send TNetStrings message over UDP class UdpTNetStringsHandler(logging.handlers.DatagramHandler): """ Sends log records in TNetStrings format over UDP. """ # LogRecord fields to send send_fields = [ 'created', 'exc_text', 'levelname', 'levelno', 'message', 'msecs', 'name', 'hostaddr', 'hostname', 'job_name', 'service_name'] _udp_reset = 0 def makePickle(self, record): """ Create message in TNetStrings format. """ msg = {} self.format(record) # render 'message' attribute and others for k in self.send_fields: msg[k] = record.__dict__[k] tnetstr = skytools.tnetstrings.dumps(msg) return tnetstr def send(self, s): """ Cache socket for a moment, then recreate it. """ now = time.time() if now - 1 > self._udp_reset: if self.sock: self.sock.close() self.sock = self.makeSocket() self._udp_reset = now self.sock.sendto(s, (self.host, self.port)) class LogDBHandler(logging.handlers.SocketHandler): """Sends log records into PostgreSQL server. Additionally, does some statistics aggregating, to avoid overloading log server. It subclasses SocketHandler to get throtthling for failed connections. """ # map codes to string _level_map = { logging.DEBUG : 'DEBUG', logging.INFO : 'INFO', logging.WARNING : 'WARNING', logging.ERROR : 'ERROR', logging.CRITICAL: 'FATAL', } def __init__(self, connect_string): """ Initializes the handler with a specific connection string. """ super(LogDBHandler, self).__init__(None, None) self.closeOnError = 1 self.connect_string = connect_string self.stat_cache = {} self.stat_flush_period = 60 # send first stat line immediately self.last_stat_flush = 0 def createSocket(self): try: super(LogDBHandler, self).createSocket() except: self.sock = self.makeSocket() def makeSocket(self, timeout=1): """Create server connection. In this case its not socket but database connection.""" db = skytools.connect_database(self.connect_string) db.set_isolation_level(0) # autocommit return db def emit(self, record): """Process log record.""" # we do not want log debug messages if record.levelno < logging.INFO: return try: self.process_rec(record) except (SystemExit, KeyboardInterrupt): raise except: self.handleError(record) def process_rec(self, record): """Aggregate stats if needed, and send to logdb.""" # render msg msg = self.format(record) # dont want to send stats too ofter if record.levelno == logging.INFO and msg and msg[0] == "{": self.aggregate_stats(msg) if time.time() - self.last_stat_flush >= self.stat_flush_period: self.flush_stats(_job_name) return if record.levelno < logging.INFO: self.flush_stats(_job_name) # dont send more than one line ln = msg.find('\n') if ln > 0: msg = msg[:ln] txt_level = self._level_map.get(record.levelno, "ERROR") self.send_to_logdb(_job_name, txt_level, msg) def aggregate_stats(self, msg): """Sum stats together, to lessen load on logdb.""" msg = msg[1:-1] for rec in msg.split(", "): k, v = rec.split(": ") agg = self.stat_cache.get(k, 0) if v.find('.') >= 0: agg += float(v) else: agg += int(v) self.stat_cache[k] = agg def flush_stats(self, service): """Send acquired stats to logdb.""" res = [] for k, v in self.stat_cache.items(): res.append("%s: %s" % (k, str(v))) if len(res) > 0: logmsg = "{%s}" % ", ".join(res) self.send_to_logdb(service, "INFO", logmsg) self.stat_cache = {} self.last_stat_flush = time.time() def send_to_logdb(self, service, level, msg): """Actual sending is done here.""" if self.sock is None: self.createSocket() if self.sock: logcur = self.sock.cursor() query = "select * from log.add(%s, %s, %s)" logcur.execute(query, [level, service, msg]) # fix unicode bug in SysLogHandler class SysLogHandler(logging.handlers.SysLogHandler): """Fixes unicode bug in logging.handlers.SysLogHandler.""" # be compatible with both 2.6 and 2.7 socktype = socket.SOCK_DGRAM _udp_reset = 0 def _custom_format(self, record): msg = self.format(record) + '\000' # We need to convert record level to lowercase, maybe this will # change in the future. prio = '<%d>' % self.encodePriority(self.facility, self.mapPriority(record.levelname)) msg = prio + msg return msg def emit(self, record): """ Emit a record. The record is formatted, and then sent to the syslog server. If exception information is present, it is NOT sent to the server. """ msg = self._custom_format(record) # Message is a string. Convert to bytes as required by RFC 5424 if isinstance(msg, unicode): msg = msg.encode('utf-8') ## this puts BOM in wrong place #if codecs: # msg = codecs.BOM_UTF8 + msg try: if self.unixsocket: try: self.socket.send(msg) except socket.error: self._connect_unixsocket(self.address) self.socket.send(msg) elif self.socktype == socket.SOCK_DGRAM: now = time.time() if now - 1 > self._udp_reset: self.socket.close() self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._udp_reset = now self.socket.sendto(msg, self.address) else: self.socket.sendall(msg) except (KeyboardInterrupt, SystemExit): raise except: self.handleError(record) class SysLogHostnameHandler(SysLogHandler): """Slightly modified standard SysLogHandler - sends also hostname and service type""" def _custom_format(self, record): msg = self.format(record) format_string = '<%d> %s %s %s\000' msg = format_string % (self.encodePriority(self.facility, self.mapPriority(record.levelname)), _hostname, _service_name, msg) return msg # add missing aliases (that are in Logger class) if not hasattr(LoggerAdapter, 'fatal'): LoggerAdapter.fatal = LoggerAdapter.critical if not hasattr(LoggerAdapter, 'warn'): LoggerAdapter.warn = LoggerAdapter.warning class SkyLogger(LoggerAdapter): def __init__(self, logger, extra): super(SkyLogger, self).__init__(logger, extra) self.name = logger.name def trace(self, msg, *args, **kwargs): """Log 'msg % args' with severity 'TRACE'.""" self.log(TRACE, msg, *args, **kwargs) def addHandler(self, hdlr): """Add the specified handler to this logger.""" self.logger.addHandler(hdlr) def isEnabledFor(self, level): """See if the underlying logger is enabled for the specified level.""" return self.logger.isEnabledFor(level) def getLogger(name=None, **kwargs_extra): """Get logger with extra functionality. Adds additional log levels, and extra fields to log record. name - name for logging.getLogger() kwargs_extra - extra fields to add to log record """ log = logging.getLogger(name) return SkyLogger(log, kwargs_extra) python-skytools-3.4/skytools/sockutil.py000066400000000000000000000077251356323561300207060ustar00rootroot00000000000000"""Various low-level utility functions for sockets.""" from __future__ import division, absolute_import, print_function import sys import os import socket try: import fcntl except ImportError: fcntl = None __all__ = ['set_tcp_keepalive', 'set_nonblocking', 'set_cloexec'] def set_tcp_keepalive(fd, keepalive=True, tcp_keepidle=4*60, tcp_keepcnt=4, tcp_keepintvl=15): """Turn on TCP keepalive. The fd can be either numeric or socket object with 'fileno' method. OS defaults for SO_KEEPALIVE=1: - Linux: (7200, 9, 75) - can configure all. - MacOS: (7200, 8, 75) - can configure only tcp_keepidle. - Win32: (7200, 5|10, 1) - can configure tcp_keepidle and tcp_keepintvl. Our defaults: (240, 4, 15). >>> import socket >>> s = socket.socket() >>> set_tcp_keepalive(s) """ # usable on this OS? if not hasattr(socket, 'SO_KEEPALIVE') or not hasattr(socket, 'fromfd'): return # need socket object if isinstance(fd, socket.SocketType): s = fd else: if hasattr(fd, 'fileno'): fd = fd.fileno() s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) # skip if unix socket if not isinstance(s.getsockname(), tuple): return # no keepalive? if not keepalive: s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0) return # basic keepalive s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # detect available options TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None) TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None) TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None) TCP_KEEPALIVE = getattr(socket, 'TCP_KEEPALIVE', None) SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None) if TCP_KEEPIDLE is None and TCP_KEEPALIVE is None and sys.platform == 'darwin': TCP_KEEPALIVE = 0x10 # configure if TCP_KEEPCNT is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPCNT, tcp_keepcnt) if TCP_KEEPINTVL is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPINTVL, tcp_keepintvl) if TCP_KEEPIDLE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPIDLE, tcp_keepidle) elif TCP_KEEPALIVE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, tcp_keepidle) elif SIO_KEEPALIVE_VALS is not None and fcntl: fcntl.ioctl(s.fileno(), SIO_KEEPALIVE_VALS, (1, tcp_keepidle*1000, tcp_keepintvl*1000)) def set_nonblocking(fd, onoff=True): """Toggle the O_NONBLOCK flag. If onoff==None then return current setting. Actual sockets from 'socket' module should use .setblocking() method, this is for situations where it is not available. Eg. pipes from 'subprocess' module. >>> import socket >>> s = socket.socket() >>> set_nonblocking(s, None) False >>> set_nonblocking(s, 1) 1 >>> set_nonblocking(s, None) True """ flags = fcntl.fcntl(fd, fcntl.F_GETFL) if onoff is None: return (flags & os.O_NONBLOCK) > 0 if onoff: flags |= os.O_NONBLOCK else: flags &= ~os.O_NONBLOCK fcntl.fcntl(fd, fcntl.F_SETFL, flags) return onoff def set_cloexec(fd, onoff=True): """Toggle the FD_CLOEXEC flag. If onoff==None then return current setting. Some libraries do it automatically (eg. libpq). Others do not (Python stdlib). >>> import os >>> f = open(os.devnull, 'rb') >>> set_cloexec(f, None) in (True, False) True >>> set_cloexec(f, True) True >>> set_cloexec(f, None) True >>> import socket >>> s = socket.socket() >>> set_cloexec(s, None) in (True, False) True >>> set_cloexec(s) True >>> set_cloexec(s, None) True """ flags = fcntl.fcntl(fd, fcntl.F_GETFD) if onoff is None: return (flags & fcntl.FD_CLOEXEC) > 0 if onoff: flags |= fcntl.FD_CLOEXEC else: flags &= ~fcntl.FD_CLOEXEC fcntl.fcntl(fd, fcntl.F_SETFD, flags) return onoff python-skytools-3.4/skytools/sqltools.py000066400000000000000000000524171356323561300207270ustar00rootroot00000000000000 """Database tools.""" from __future__ import division, absolute_import, print_function import os import io import skytools __all__ = [ "fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys", "get_table_columns", "exists_schema", "exists_table", "exists_type", "exists_sequence", "exists_temp_table", "exists_view", "exists_function", "exists_language", "Snapshot", "magic_insert", "CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction", "DBLanguage", "db_install", "installer_find_file", "installer_apply_file", "dbdict", "mk_insert_sql", "mk_update_sql", "mk_delete_sql", ] class dbdict(dict): """Wrapper on actual dict that allows accessing dict keys as attributes. >>> row = dbdict(a=1, b=2) >>> row.a, row.b, row['a'], row['b'] (1, 2, 1, 2) >>> row.c = 3; row['c'] 3 >>> del row.c; row.c Traceback (most recent call last): ... AttributeError: c >>> row['c'] Traceback (most recent call last): ... KeyError: 'c' >>> row.merge({'q': 4}); row.q 4 """ # obj.foo access def __getattr__(self, k): "Return attribute." try: return self[k] except KeyError: raise AttributeError(k) def __setattr__(self, k, v): "Set attribute." self[k] = v def __delattr__(self, k): "Remove attribute." del self[k] def merge(self, other): for key in other: if key not in self: self[key] = other[key] # # Fully qualified table name # def fq_name_parts(tbl): """Return fully qualified name parts. >>> fq_name_parts('tbl') ['public', 'tbl'] >>> fq_name_parts('foo.tbl') ['foo', 'tbl'] >>> fq_name_parts('foo.tbl.baz') ['foo', 'tbl.baz'] """ tmp = tbl.split('.', 1) if len(tmp) == 1: return ['public', tbl] return tmp def fq_name(tbl): """Return fully qualified name. >>> fq_name('tbl') 'public.tbl' >>> fq_name('foo.tbl') 'foo.tbl' >>> fq_name('foo.tbl.baz') 'foo.tbl.baz' """ return '.'.join(fq_name_parts(tbl)) # # info about table # def get_table_oid(curs, table_name): """Find Postgres OID for table.""" schema, name = fq_name_parts(table_name) q = """select c.oid from pg_namespace n, pg_class c where c.relnamespace = n.oid and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchall() if len(res) == 0: raise Exception('Table not found: '+table_name) return res[0][0] def get_table_pkeys(curs, tbl): """Return list of pkey column names.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] def get_table_columns(curs, tbl): """Return list of column names for table.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_attribute k"\ " WHERE k.attrelid = %s"\ " AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] # # exist checks # def exists_schema(curs, schema): """Does schema exists?""" q = "select count(1) from pg_namespace where nspname = %s" curs.execute(q, [schema]) res = curs.fetchone() return res[0] def exists_table(curs, table_name): """Does table exists?""" schema, name = fq_name_parts(table_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'r' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_sequence(curs, seq_name): """Does sequence exists?""" schema, name = fq_name_parts(seq_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'S' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_view(curs, view_name): """Does view exists?""" schema, name = fq_name_parts(view_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'v' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_type(curs, type_name): """Does type exists?""" schema, name = fq_name_parts(type_name) q = """select count(1) from pg_namespace n, pg_type t where t.typnamespace = n.oid and n.nspname = %s and t.typname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_function(curs, function_name, nargs): """Does function exists?""" # this does not check arg types, so may match several functions schema, name = fq_name_parts(function_name) q = """select count(1) from pg_namespace n, pg_proc p where p.pronamespace = n.oid and p.pronargs = %s and n.nspname = %s and p.proname = %s""" curs.execute(q, [nargs, schema, name]) res = curs.fetchone() # if unqualified function, check builtin functions too if not res[0] and function_name.find('.') < 0: name = "pg_catalog." + function_name return exists_function(curs, name, nargs) return res[0] def exists_language(curs, lang_name): """Does PL exists?""" q = """select count(1) from pg_language where lanname = %s""" curs.execute(q, [lang_name]) res = curs.fetchone() return res[0] def exists_temp_table(curs, tbl): """Does temp table exists?""" # correct way, works only on 8.2 q = "select 1 from pg_class where relname = %s and relnamespace = pg_my_temp_schema()" curs.execute(q, [tbl]) tmp = curs.fetchall() return len(tmp) > 0 # # Support for PostgreSQL snapshot # class Snapshot(object): """Represents a PostgreSQL snapshot. Example: >>> sn = Snapshot('11:20:11,12,15') >>> sn.contains(9) True >>> sn.contains(11) False >>> sn.contains(17) True >>> sn.contains(20) False >>> Snapshot(':') Traceback (most recent call last): ... ValueError: Unknown format for snapshot """ def __init__(self, str_val): "Create snapshot from string." self.sn_str = str_val tmp = str_val.split(':') if len(tmp) != 3: raise ValueError('Unknown format for snapshot') self.xmin = int(tmp[0]) self.xmax = int(tmp[1]) self.txid_list = [] if tmp[2] != "": for s in tmp[2].split(','): self.txid_list.append(int(s)) def contains(self, txid): "Is txid visible in snapshot." txid = int(txid) if txid < self.xmin: return True if txid >= self.xmax: return False if txid in self.txid_list: return False return True # # Copy helpers # def _gen_dict_copy(tbl, row, fields, qfields): tmp = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_dict_insert(tbl, row, fields, qfields): tmp = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def _gen_list_copy(tbl, row, fields, qfields): tmp = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_list_insert(tbl, row, fields, qfields): tmp = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def magic_insert(curs, tablename, data, fields=None, use_insert=False, quoted_table=False): r"""Copy/insert a list of dict/list data to database. If curs is None, then the copy or insert statements are returned as string. For list of dict the field list is optional, as its possible to guess them from dict keys. Example: >>> magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2']) 'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' >>> magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2'], use_insert=True) "insert into public.tbl (col1,col2) values ('1','1');\ninsert into public.tbl (col1,col2) values ('2','2');\n" >>> magic_insert(None, 'tbl', [], ['col1', 'col2']) >>> magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']], ['col1', 'col2'], quoted_table=True) 'COPY tbl."1" (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' >>> magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']]) Traceback (most recent call last): ... Exception: Non-dict data needs field list >>> magic_insert(None, 'a.tbl', [{'a':1}, {'a':2}]) 'COPY a.tbl (a) FROM STDIN;\n1\n2\n\\.\n' >>> magic_insert(None, 'a.tbl', [{'a':1}, {'a':2}], use_insert=True) "insert into a.tbl (a) values ('1');\ninsert into a.tbl (a) values ('2');\n" More fields than data: >>> magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3']) 'COPY public.tbl (col1,col2,col3) FROM STDIN;\n1\ta\t\\N\n\\.\n' >>> magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3'], use_insert=True) "insert into public.tbl (col1,col2,col3) values ('1','a',null);\n" >>> magic_insert(None, 'tbl', [{'a':1}, {'b':2}], ['a', 'b'], use_insert=False) 'COPY public.tbl (a,b) FROM STDIN;\n1\t\\N\n\\N\t2\n\\.\n' >>> magic_insert(None, 'tbl', [{'a':1}, {'b':2}], ['a', 'b'], use_insert=True) "insert into public.tbl (a,b) values ('1',null);\ninsert into public.tbl (a,b) values (null,'2');\n" """ if len(data) == 0: return None if fields is not None: fields = list(fields) # get rid of iterator # decide how to process if hasattr(data[0], 'keys'): if fields is None: fields = data[0].keys() if use_insert: row_func = _gen_dict_insert else: row_func = _gen_dict_copy else: if fields is None: raise Exception("Non-dict data needs field list") if use_insert: row_func = _gen_list_insert else: row_func = _gen_list_copy qfields = [skytools.quote_ident(f) for f in fields] if quoted_table: qtablename = tablename else: qtablename = skytools.quote_fqident(tablename) # init processing buf = io.StringIO() if curs is None and use_insert == 0: fmt = "COPY %s (%s) FROM STDIN;\n" buf.write(fmt % (qtablename, ",".join(qfields))) # process data for row in data: buf.write(row_func(qtablename, row, fields, qfields)) buf.write("\n") # if user needs only string, return it if curs is None: if use_insert == 0: buf.write("\\.\n") return buf.getvalue() # do the actual copy/inserts if use_insert: curs.execute(buf.getvalue()) else: buf.seek(0) hdr = "%s (%s)" % (qtablename, ",".join(qfields)) curs.copy_from(buf, hdr) return None # # Full COPY of table from one db to another # class CopyPipe(io.TextIOBase): """Splits one big COPY to chunks. """ def __init__(self, dstcurs, tablename=None, limit=512*1024, sql_from=None): super(CopyPipe, self).__init__() self.tablename = tablename self.sql_from = sql_from self.dstcurs = dstcurs self.buf = io.StringIO() self.limit = limit #hook for new data, hook func should return new data #def write_hook(obj, data): # return data self.write_hook = None #hook for flush, hook func result is discarded # def flush_hook(obj): # return None self.flush_hook = None self.total_rows = 0 self.total_bytes = 0 def write(self, data): """New row from psycopg """ if self.write_hook: data = self.write_hook(self, data) self.total_bytes += len(data) # it's chars now... self.total_rows += 1 self.buf.write(data) if self.buf.tell() >= self.limit: self.flush() def flush(self): """Send data out. """ if self.flush_hook: self.flush_hook(self) if self.buf.tell() <= 0: return self.buf.seek(0) if self.sql_from: self.dstcurs.copy_expert(self.sql_from, self.buf) else: self.dstcurs.copy_from(self.buf, self.tablename) self.buf.seek(0) self.buf.truncate() def full_copy(tablename, src_curs, dst_curs, column_list=(), condition=None, dst_tablename=None, dst_column_list=None, write_hook=None, flush_hook=None): """COPY table from one db to another.""" # default dst table and dst columns to source ones dst_tablename = dst_tablename or tablename dst_column_list = dst_column_list or column_list[:] if len(dst_column_list) != len(column_list): raise Exception('src and dst column lists must match in length') def build_qfields(cols): if cols: return ",".join([skytools.quote_ident(f) for f in cols]) else: return "*" def build_statement(table, cols): qtable = skytools.quote_fqident(table) if cols: qfields = build_qfields(cols) return "%s (%s)" % (qtable, qfields) else: return qtable dst = build_statement(dst_tablename, dst_column_list) if condition: src = "(SELECT %s FROM %s WHERE %s)" % (build_qfields(column_list), skytools.quote_fqident(tablename), condition) else: src = build_statement(tablename, column_list) sql_to = "COPY %s TO stdout" % src sql_from = "COPY %s FROM stdin" % dst buf = CopyPipe(dst_curs, sql_from=sql_from) buf.write_hook = write_hook buf.flush_hook = flush_hook src_curs.copy_expert(sql_to, buf) buf.flush() return (buf.total_bytes, buf.total_rows) # # SQL installer # class DBObject(object): """Base class for installable DB objects.""" name = None sql = None sql_file = None def __init__(self, name, sql=None, sql_file=None): """Generic dbobject init.""" self.name = name self.sql = sql self.sql_file = sql_file def create(self, curs, log=None): """Create a dbobject.""" if log: log.info('Installing %s' % self.name) if self.sql: sql = self.sql elif self.sql_file: fn = self.find_file() if log: log.info(" Reading from %s" % fn) sql = open(fn, "r").read() else: raise Exception('object not defined') for stmt in skytools.parse_statements(sql): #if log: log.debug(repr(stmt)) curs.execute(stmt) def find_file(self): """Find install script file.""" return installer_find_file(self.sql_file) class DBSchema(DBObject): """Handles db schema.""" def exists(self, curs): """Does schema exists.""" return exists_schema(curs, self.name) class DBTable(DBObject): """Handles db table.""" def exists(self, curs): """Does table exists.""" return exists_table(curs, self.name) class DBFunction(DBObject): """Handles db function.""" def __init__(self, name, nargs, sql=None, sql_file=None): """Function object - number of args is significant.""" super(DBFunction, self).__init__(name, sql, sql_file) self.nargs = nargs def exists(self, curs): """Does function exists.""" return exists_function(curs, self.name, self.nargs) class DBLanguage(DBObject): """Handles db language.""" def __init__(self, name): """PL object - creation happens with CREATE LANGUAGE.""" super(DBLanguage, self).__init__(name, sql="create language %s" % name) def exists(self, curs): """Does PL exists.""" return exists_language(curs, self.name) def db_install(curs, obj_list, log=None): """Installs list of objects into db.""" for obj in obj_list: if not obj.exists(curs): obj.create(curs, log) else: if log: log.info('%s is installed' % obj.name) def installer_find_file(filename): """Find SQL script from pre-defined paths.""" full_fn = None if filename[0] == "/": if os.path.isfile(filename): full_fn = filename else: from skytools.installer_config import sql_locations dir_list = sql_locations for fdir in dir_list: fn = os.path.join(fdir, filename) if os.path.isfile(fn): full_fn = fn break if not full_fn: raise Exception('File not found: '+filename) return full_fn def installer_apply_file(db, filename, log): """Find SQL file and apply it to db, statement-by-statement.""" fn = installer_find_file(filename) sql = open(fn, "r").read() if log: log.info("applying %s" % fn) curs = db.cursor() for stmt in skytools.parse_statements(sql): #log.debug(repr(stmt)) curs.execute(stmt) # # Generate INSERT/UPDATE/DELETE statement # def mk_insert_sql(row, tbl, pkey_list=None, field_map=None): """Generate INSERT statement from dict data. >>> from collections import OrderedDict >>> row = OrderedDict([('id',1), ('data', None)]) >>> mk_insert_sql(row, 'tbl') "insert into public.tbl (id, data) values ('1', null);" >>> mk_insert_sql(row, 'tbl', ['x'], OrderedDict([('id', 'id_'), ('data', 'data_')])) "insert into public.tbl (id_, data_) values ('1', null);" """ col_list = [] val_list = [] if field_map: for src, dst in field_map.items(): col_list.append(skytools.quote_ident(dst)) val_list.append(skytools.quote_literal(row[src])) else: for c, v in row.items(): col_list.append(skytools.quote_ident(c)) val_list.append(skytools.quote_literal(v)) col_str = ", ".join(col_list) val_str = ", ".join(val_list) return "insert into %s (%s) values (%s);" % ( skytools.quote_fqident(tbl), col_str, val_str) def mk_update_sql(row, tbl, pkey_list, field_map=None): r"""Generate UPDATE statement from dict data. >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2']) 'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';' >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2'], ... {'id': '_id', 'id2': '_id2', 'data': '_data'}) 'update only public."Table" set _data = E\'str\\\\\' where _id = \'0\' and _id2 = \'2\';' >>> mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', []) Traceback (most recent call last): ... Exception: update needs pkeys """ if len(pkey_list) < 1: raise Exception("update needs pkeys") set_list = [] whe_list = [] pkmap = {} for k in pkey_list: pkmap[k] = 1 new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) if field_map: for src, dst in field_map.items(): if src not in pkmap: col = skytools.quote_ident(dst) val = skytools.quote_literal(row[src]) set_list.append("%s = %s" % (col, val)) else: for col, val in row.items(): if col not in pkmap: col = skytools.quote_ident(col) val = skytools.quote_literal(val) set_list.append("%s = %s" % (col, val)) return "update only %s set %s where %s;" % (skytools.quote_fqident(tbl), ", ".join(set_list), " and ".join(whe_list)) def mk_delete_sql(row, tbl, pkey_list, field_map=None): """Generate DELETE statement from dict data. >>> mk_delete_sql({'a': 1, 'b':2, 'c':3}, 'tablename', ['a','b']) "delete from only public.tablename where a = '1' and b = '2';" >>> mk_delete_sql({'a': 1, 'b':2, 'c':3}, 'tablename', ['a','b'], {'a': 'aa', 'b':'bb'}) "delete from only public.tablename where aa = '1' and bb = '2';" >>> mk_delete_sql({'a': 1, 'b':2, 'c':3}, 'tablename', []) Traceback (most recent call last): ... Exception: delete needs pkeys """ if len(pkey_list) < 1: raise Exception("delete needs pkeys") whe_list = [] for k in pkey_list: new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) whe_str = " and ".join(whe_list) return "delete from only %s where %s;" % (skytools.quote_fqident(tbl), whe_str) python-skytools-3.4/skytools/testing.py000066400000000000000000000004461356323561300205170ustar00rootroot00000000000000"""Utilities for tests. """ from __future__ import division, absolute_import, print_function from collections import OrderedDict def ordered_dict(d): """Return OrderedDict with sorted keys. >>> OrderedDict(dict(a=1,b=2,c=3)) x """ return OrderedDict(sorted(d.items())) python-skytools-3.4/skytools/timeutil.py000066400000000000000000000121151356323561300206720ustar00rootroot00000000000000 """Fill gaps in Python time API-s. parse_iso_timestamp: Parse reasonable subset of ISO_8601 timestamp formats. [ http://en.wikipedia.org/wiki/ISO_8601 ] datetime_to_timestamp: Get POSIX timestamp from datetime() object. """ from __future__ import division, absolute_import, print_function import re import time from datetime import datetime, timedelta, tzinfo __all__ = ['parse_iso_timestamp', 'FixedOffsetTimezone', 'datetime_to_timestamp'] class FixedOffsetTimezone(tzinfo): """Fixed offset in minutes east from UTC.""" __slots__ = ('__offset', '__name') def __init__(self, offset): super(FixedOffsetTimezone, self).__init__() self.__offset = timedelta(minutes=offset) # numeric tz name h, m = divmod(abs(offset), 60) if offset < 0: h = -h if m: self.__name = "%+03d:%02d" % (h, m) else: self.__name = "%+03d" % h def utcoffset(self, dt): return self.__offset def tzname(self, dt): return self.__name def dst(self, dt): return ZERO ZERO = timedelta(0) # # Parse ISO_8601 timestamps. # """ TODO: - support more combinations from ISO 8601 (only reasonable ones) - cache TZ objects - make it faster? """ _iso_regex = r""" \s* (?P \d\d\d\d) [-] (?P \d\d) [-] (?P \d\d) [ T] (?P \d\d) [:] (?P \d\d) (?: [:] (?P \d\d ) (?: [.,] (?P \d+))? )? (?: \s* (?P [-+]) (?P \d\d) (?: [:]? (?P \d\d))? | (?P Z ) )? \s* $ """ _iso_rc = None def parse_iso_timestamp(s, default_tz=None): """Parse ISO timestamp to datetime object. YYYY-MM-DD[ T]HH:MM[:SS[.ss]][-+HH[:MM]] Assumes that second fractions are zero-trimmed from the end, so '.15' means 150000 microseconds. If the timezone offset is not present, use default_tz as tzinfo. By default its None, meaning the datetime object will be without tz. Only fixed offset timezones are supported. >>> str(parse_iso_timestamp('2005-06-01 15:00')) '2005-06-01 15:00:00' >>> str(parse_iso_timestamp(' 2005-06-01T15:00 +02 ')) '2005-06-01 15:00:00+02:00' >>> str(parse_iso_timestamp('2005-06-01 15:00:33+02:00')) '2005-06-01 15:00:33+02:00' >>> d = parse_iso_timestamp('2005-06-01 15:00:59.33 +02') >>> d.strftime("%z %Z") '+0200 +02' >>> str(parse_iso_timestamp(str(d))) '2005-06-01 15:00:59.330000+02:00' >>> parse_iso_timestamp('2005-06-01 15:00-0530').strftime('%Y-%m-%d %H:%M %z %Z') '2005-06-01 15:00 -0530 -05:30' >>> parse_iso_timestamp('2014-10-27T11:59:13Z').strftime('%Y-%m-%d %H:%M:%S %z %Z') '2014-10-27 11:59:13 +0000 +00' >>> parse_iso_timestamp('2014.10.27') Traceback (most recent call last): ... ValueError: Date not in ISO format: '2014.10.27' """ global _iso_rc if _iso_rc is None: _iso_rc = re.compile(_iso_regex, re.X) m = _iso_rc.match(s) if not m: raise ValueError('Date not in ISO format: %s' % repr(s)) tz = default_tz if m.group('tzsign'): tzofs = int(m.group('tzhr')) * 60 if m.group('tzmin'): tzofs += int(m.group('tzmin')) if m.group('tzsign') == '-': tzofs = -tzofs tz = FixedOffsetTimezone(tzofs) elif m.group('tzname'): tz = UTC return datetime(int(m.group('year')), int(m.group('month')), int(m.group('day')), int(m.group('hour')), int(m.group('min')), m.group('sec') and int(m.group('sec')) or 0, m.group('ss') and int(m.group('ss').ljust(6, '0')) or 0, tz) # # POSIX timestamp from datetime() # UTC = FixedOffsetTimezone(0) TZ_EPOCH = datetime.fromtimestamp(0, UTC) UTC_NOTZ_EPOCH = datetime.utcfromtimestamp(0) def datetime_to_timestamp(dt, local_time=True): """Get posix timestamp from datetime() object. if dt is without timezone, then local_time specifies whether it's UTC or local time. Returns seconds since epoch as float. >>> datetime_to_timestamp(parse_iso_timestamp("2005-06-01 15:00:59.5 +02")) 1117630859.5 >>> datetime_to_timestamp(datetime.fromtimestamp(1117630859.5, UTC)) 1117630859.5 >>> datetime_to_timestamp(datetime.fromtimestamp(1117630859.5)) 1117630859.5 >>> now = datetime.utcnow() >>> now2 = datetime.utcfromtimestamp(datetime_to_timestamp(now, False)) >>> abs(now2.microsecond - now.microsecond) < 100 True >>> now2 = now2.replace(microsecond = now.microsecond) >>> now == now2 True >>> now = datetime.now() >>> now2 = datetime.fromtimestamp(datetime_to_timestamp(now)) >>> abs(now2.microsecond - now.microsecond) < 100 True >>> now2 = now2.replace(microsecond = now.microsecond) >>> now == now2 True """ if dt.tzinfo: delta = dt - TZ_EPOCH return delta.total_seconds() elif local_time: s = time.mktime(dt.timetuple()) return s + (dt.microsecond / 1000000.0) else: delta = dt - UTC_NOTZ_EPOCH return delta.total_seconds() python-skytools-3.4/skytools/tnetstrings.py000066400000000000000000000116731356323561300214320ustar00rootroot00000000000000"""TNetStrings. >>> def ustr(v): return repr(v).replace("u'", "'") >>> def nstr(b): return b.decode('utf8') >>> if isinstance('qwe', bytes): nstr = str >>> vals = (None, False, True, 222, 333.0, "foo", u"bar") >>> tnvals = [nstr(dumps(v)) for v in vals] >>> tnvals ['0:~', '5:false!', '4:true!', '3:222#', '5:333.0^', '3:foo,', '3:bar,'] >>> vals2 = [[], (), {}, [1,2], {'a':'b'}] >>> tnvals2 = [nstr(dumps(v)) for v in vals2] >>> tnvals2 ['0:]', '0:]', '0:}', '8:1:1#1:2#]', '8:1:a,1:b,}'] >>> ustr([parse(dumps(v)) for v in vals]) "[None, False, True, 222, 333.0, 'foo', 'bar']" >>> ustr([parse(dumps(v)) for v in vals2]) "[[], [], {}, [1, 2], {'a': 'b'}]" >>> nstr(dumps([memoryview(b'zzz'),b'qqq'])) '12:3:zzz,3:qqq,]' Error handling: >>> def errtest(bstr, exc): ... try: ... loads(bstr) ... raise Exception('no exception') ... except exc: ... return None >>> errtest(b'4:qwez!', ValueError) >>> errtest(b'4:', ValueError) >>> errtest(b'4:qwez', ValueError) >>> errtest(b'4', ValueError) >>> errtest(b'', ValueError) >>> errtest(b'999999999999999999:z,', ValueError) >>> errtest(u'qweqwe', TypeError) >>> errtest(b'4:true!0:~', ValueError) >>> errtest(b'1:X~', ValueError) >>> errtest(b'8:1:1#1:2#}', ValueError) >>> errtest(b'1:Xz', ValueError) >>> dumps(divmod) Traceback (most recent call last): ... TypeError: Object type not supported: """ from __future__ import division, absolute_import, print_function import codecs __all__ = ['loads', 'dumps'] try: unicode except NameError: unicode = str # noqa long = int # noqa _memstr_types = (unicode, bytes, memoryview) _struct_types = (list, tuple, dict) _inttypes = (int, long) def _dumps(dst, val): if isinstance(val, _struct_types): tlenpos = len(dst) tlen = 0 dst.append(None) if isinstance(val, dict): for k in val: tlen += _dumps(dst, k) tlen += _dumps(dst, val[k]) dst.append(b'}') else: for v in val: tlen += _dumps(dst, v) dst.append(b']') dst[tlenpos] = b'%d:' % tlen return len(dst[tlenpos]) + tlen + 1 elif isinstance(val, _memstr_types): if isinstance(val, unicode): bval = val.encode('utf8') elif isinstance(val, memoryview): bval = val.tobytes() else: bval = val tval = b'%d:%s,' % (len(bval), bval) elif isinstance(val, bool): tval = val and b'4:true!' or b'5:false!' elif isinstance(val, _inttypes): bval = b'%d' % val tval = b'%d:%s#' % (len(bval), bval) elif isinstance(val, float): bval = b'%r' % val tval = b'%d:%s^' % (len(bval), bval) elif val is None: tval = b'0:~' else: raise TypeError("Object type not supported: %r" % val) dst.append(tval) return len(tval) _decode_utf8 = codecs.getdecoder('utf8') def _loads(buf): pos = 0 maxlen = min(len(buf), 9) while buf[pos:pos+1] != b':': pos += 1 if pos > maxlen: raise ValueError("Too large length") lenbytes = buf[ : pos].tobytes() tlen = int(lenbytes) ofs = len(lenbytes) + 1 endofs = ofs + tlen val = buf[ofs : endofs] code = buf[endofs : endofs + 1] rest = buf[endofs + 1:] if len(val) + 1 != tlen + len(code): raise ValueError("failed to load value, invalid length") if code == b',': return _decode_utf8(val)[0], rest elif code == b'#': return int(val.tobytes(), 10), rest elif code == b'^': return float(val.tobytes()), rest elif code == b']': listobj = [] while val: elem, val = _loads(val) listobj.append(elem) return listobj, rest elif code == b'}': dictobj = {} while val: k, val = _loads(val) if not isinstance(k, unicode): raise ValueError("failed to load value, invalid key type") dictobj[k], val = _loads(val) return dictobj, rest elif code == b'!': if val == b'true': return True, rest if val == b'false': return False, rest raise ValueError("failed to load value, invalid boolean value") elif code == b'~': if val == b'': return None, rest raise ValueError("failed to load value, invalid null value") else: raise ValueError("failed to load value, invalid value code") # # Public API # def dumps(val): """Dump object tree as TNetString value. """ dst = [] _dumps(dst, val) return b''.join(dst) def loads(binval): """Parse TNetstring from byte string. """ if not isinstance(binval, (bytes, memoryview)): raise TypeError("Bytes or memoryview required") obj, rest = _loads(memoryview(binval)) if rest: raise ValueError("Not all data processed") return obj # old compat? parse = loads dump = dumps python-skytools-3.4/skytools/utf8.py000066400000000000000000000064021356323561300177260ustar00rootroot00000000000000r"""UTF-8 sanitizer. Python's UTF-8 parser is quite relaxed, this creates problems when talking with other software that uses stricter parsers. >>> _norm(safe_utf8_decode(b"foobar")) (True, ['f', 'o', 'o', 'b', 'a', 'r']) >>> _norm(safe_utf8_decode(b'X\0Z')) (False, ['X', 65533, 'Z']) >>> _norm(safe_utf8_decode(b'OK')) (True, ['O', 'K']) >>> _norm(safe_utf8_decode(b'X\xF1Y')) (False, ['X', 65533, 'Y']) >>> _norm_str(sanitize_unicode(u'\uD801\uDC01')) [66561] >>> sanitize_unicode(b'qwe') Traceback (most recent call last): ... TypeError: Need unicode string """ ## these give different results in py27 and py35 # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80Y\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 'Y', 65533, 65533, 65533, 'Z']) # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 65533, 65533, 65533, 'Z']) # from __future__ import division, absolute_import, print_function import re import codecs try: unichr except NameError: unichr = chr # noqa unicode = str # noqa def _norm_char(uchr): code = ord(uchr) if code >= 0x20 and code < 0x7f: return chr(code) return code def _norm_str(ustr): return [_norm_char(c) for c in ustr] def _norm(tup): flg, ustr = tup return (flg, _norm_str(ustr)) __all__ = ['safe_utf8_decode'] # by default, use same symbol as 'replace' REPLACEMENT_SYMBOL = unichr(0xFFFD) # 65533 def _fix_utf8(m): """Merge UTF16 surrogates, replace others""" u = m.group() if len(u) == 2: # merge into single symbol c1 = ord(u[0]) c2 = ord(u[1]) c = 0x10000 + ((c1 & 0x3FF) << 10) + (c2 & 0x3FF) return unichr(c) else: # use replacement symbol return REPLACEMENT_SYMBOL _urc = None def sanitize_unicode(u): """Fix invalid symbols in unicode string.""" global _urc if not isinstance(u, unicode): raise TypeError('Need unicode string') # regex for finding invalid chars, works on unicode string if not _urc: rx = u"[\uD800-\uDBFF] [\uDC00-\uDFFF]? | [\0\uDC00-\uDFFF]" _urc = re.compile(rx, re.X) # now find and fix UTF16 surrogates m = _urc.search(u) if m: u = _urc.sub(_fix_utf8, u) return u def safe_replace(exc): """Replace only one symbol at a time. Builtin .decode('xxx', 'replace') replaces several symbols together, which is unsafe. """ c2 = REPLACEMENT_SYMBOL # we could assume latin1 #if 0: # c1 = exc.object[exc.start] # c2 = unichr(ord(c1)) return c2, exc.start + 1 # register, it will be globally available codecs.register_error("safe_replace", safe_replace) def safe_utf8_decode(s): """Decode UTF-8 safely. Acts like str.decode('utf8', 'replace') but also fixes UTF16 surrogates and NUL bytes, which Python's default decoder does not do. @param s: utf8-encoded byte string @return: tuple of (was_valid_utf8, unicode_string) """ # decode with error detection ok = True try: # expect no errors by default u = s.decode('utf8') except UnicodeDecodeError: u = s.decode('utf8', 'safe_replace') ok = False u2 = sanitize_unicode(u) if u is not u2: ok = False return (ok, u2) python-skytools-3.4/tests/000077500000000000000000000000001356323561300157375ustar00rootroot00000000000000python-skytools-3.4/tests/config.ini000066400000000000000000000007171356323561300177120ustar00rootroot00000000000000[base] foo = 1 bar = %(foo)s bool-true1 = 1 bool-true2 = true bool-false1 = 0 bool-false2 = false float-val = 2.0 list-val1 = list-val2 = a, 1, asd, ppp dict-val1 = dict-val2 = a : 1, b : 2, z file-val1 = - file-val2 = ~/foo bytes-val1 = 4 bytes-val2 = 2k wild-*-* = w2 wild-a-* = w.a wild-a-b = w.a.b vars1 = V2=%(vars2)s vars2 = V3=%(vars3)s vars3 = Q3 bad1 = B2=%(bad2)s bad2 = %(missing1)s %(missing2)s [DEFAULT] all = yes [other] test = try python-skytools-3.4/tests/test_api.py000066400000000000000000000002341356323561300201200ustar00rootroot00000000000000 import skytools from nose.tools import * def test_version(): assert_true(skytools.natsort_key(skytools.__version__) >= skytools.natsort_key('3.3')) python-skytools-3.4/tests/test_config.py000066400000000000000000000115701356323561300206210ustar00rootroot00000000000000 import os.path import io from nose.tools import * from skytools.config import (Config, NoOptionError, NoSectionError, ConfigError, InterpolationError, ExtendedConfigParser, ExtendedCompatConfigParser) TOP = os.path.dirname(__file__) CONFIG = os.path.join(TOP, 'config.ini') def test_config_str(): cf = Config('base', CONFIG) eq_(cf.get('foo'), '1') eq_(cf.get('missing', 'q'), 'q') assert_raises(NoOptionError, cf.get, 'missing') def test_config_int(): cf = Config('base', CONFIG) eq_(cf.getint('foo'), 1) eq_(cf.getint('missing', 2), 2) assert_raises(NoOptionError, cf.getint, 'missing') def test_config_float(): cf = Config('base', CONFIG) eq_(cf.getfloat('float-val'), 2.0) eq_(cf.getfloat('missing', 3.0), 3.0) assert_raises(NoOptionError, cf.getfloat, 'missing') def test_config_bool(): cf = Config('base', CONFIG) eq_(cf.getboolean('bool-true1'), True) eq_(cf.getboolean('bool-true2'), True) eq_(cf.getboolean('missing', True), True) assert_raises(NoOptionError, cf.getboolean, 'missing') eq_(cf.getboolean('bool-false1'), False) eq_(cf.getboolean('bool-false2'), False) eq_(cf.getboolean('missing', False), False) assert_raises(NoOptionError, cf.getbool, 'missing') def test_config_list(): cf = Config('base', CONFIG) eq_(cf.getlist('list-val1'), []) eq_(cf.getlist('list-val2'), ['a', '1', 'asd', 'ppp']) eq_(cf.getlist('missing', [1]), [1]) assert_raises(NoOptionError, cf.getlist, 'missing') def test_config_dict(): cf = Config('base', CONFIG) eq_(cf.getdict('dict-val1'), {}) eq_(cf.getdict('dict-val2'), {'a': '1', 'b': '2', 'z': 'z'}) eq_(cf.getdict('missing', {'a':1}), {'a':1}) assert_raises(NoOptionError, cf.getdict, 'missing') def test_config_file(): cf = Config('base', CONFIG) eq_(cf.getfile('file-val1'), '-') eq_(cf.getfile('file-val2')[0], '/') eq_(cf.getfile('missing', 'qwe'), 'qwe') assert_raises(NoOptionError, cf.getfile, 'missing') def test_config_bytes(): cf = Config('base', CONFIG) eq_(cf.getbytes('bytes-val1'), 4) eq_(cf.getbytes('bytes-val2'), 2048) eq_(cf.getbytes('missing', '3k'), 3072) assert_raises(NoOptionError, cf.getbytes, 'missing') def test_config_wildcard(): cf = Config('base', CONFIG) eq_(cf.get_wildcard('wild-*-*', ['a', 'b']), 'w.a.b') eq_(cf.get_wildcard('wild-*-*', ['a', 'x']), 'w.a') eq_(cf.get_wildcard('wild-*-*', ['q', 'b']), 'w2') eq_(cf.get_wildcard('missing-*-*', ['1', '2'], 'def'), 'def') assert_raises(NoOptionError, cf.get_wildcard, 'missing-*-*', ['1', '2']) def test_config_default(): cf = Config('base', CONFIG) eq_(cf.get('all'), 'yes') def test_config_other(): cf = Config('base', CONFIG) eq_(sorted(cf.sections()), ['base', 'other']) assert_true(cf.has_section('base')) assert_true(cf.has_section('other')) assert_false(cf.has_section('missing')) assert_false(cf.has_section('DEFAULT')) assert_false(cf.has_option('missing')) assert_true(cf.has_option('all')) assert_true(cf.has_option('foo')) cf2 = cf.clone('other') eq_(sorted(cf2.options()), ['all', 'config_dir', 'config_file', 'host_name', 'job_name', 'service_name', 'test']) eq_(len(cf2.items()), len(cf2.options())) def test_loading(): assert_raises(NoSectionError, Config, 'random', CONFIG) assert_raises(ConfigError, Config, 'random', 'random.ini') def test_nofile(): cf = Config('base', None, user_defs = {'a': '1'}) eq_(cf.sections(), ['base']) eq_(cf.get('a'), '1') cf = Config('base', None, user_defs = {'a': '1'}, ignore_defs=True) eq_(cf.get('a', '2'), '2') def test_override(): cf = Config('base', CONFIG, override = {'foo': 'overrided'}) eq_(cf.get('foo'), 'overrided') def test_vars(): cf = Config('base', CONFIG) eq_(cf.get('vars1'), 'V2=V3=Q3') assert_raises(InterpolationError, cf.get, 'bad1') def test_extended_compat(): config = u'[foo]\nkey = ${sub} $${nosub}\nsub = 2\n[bar]\nkey = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.readfp(io.StringIO(config), 'conf.ini') eq_(cf.get('bar', 'key'), '2 ${nosub}') config = u'[foo]\nloop1= ${loop1}\nloop2 = ${loop3}\nloop3 = ${loop2}\n' cf = ExtendedCompatConfigParser() cf.readfp(io.StringIO(config), 'conf.ini') assert_raises(InterpolationError, cf.get, 'foo', 'loop1') assert_raises(InterpolationError, cf.get, 'foo', 'loop2') config = u'[foo]\nkey = %(sub)s ${sub}\nsub = 2\n[bar]\nkey = %(foo:key)s\nkey2 = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.readfp(io.StringIO(config), 'conf.ini') eq_(cf.get('bar', 'key2'), '2 2') assert_raises(NoOptionError, cf.get, 'bar', 'key') config = u'[foo]\nkey = ${bad:xxx}\n[bad]\nsub = 1\n' cf = ExtendedCompatConfigParser(); cf.readfp(io.StringIO(config), 'conf.ini') assert_raises(NoOptionError, cf.get, 'foo', 'key') python-skytools-3.4/tests/test_gzlog.py000066400000000000000000000012041356323561300204670ustar00rootroot00000000000000 import os import tempfile import gzip from skytools.gzlog import gzip_append from nose.tools import * def test_gzlog(): fd, tmpname = tempfile.mkstemp(suffix='.gz') os.close(fd) try: blk = b'1234567890'*100 write_total = 0 for i in range(5): gzip_append(tmpname, blk) write_total += len(blk) read_total = 0 with gzip.open(tmpname) as rfd: while 1: blk = rfd.read(512) if not blk: break read_total += len(blk) finally: os.remove(tmpname) eq_(read_total, write_total) python-skytools-3.4/tests/test_kwcheck.py000066400000000000000000000045771356323561300210040ustar00rootroot00000000000000"""Check if SQL keywords are up-to-date. """ from __future__ import division, absolute_import, print_function import sys import re import os.path import skytools.quoting from nose.tools import * locations = [ "/opt/src/pgsql/postgresql/src/include/parser/kwlist.h", "~/src/pgsql/postgres/src/include/parser/kwlist.h", "~/src/pgsql/pg95/src/include/parser/kwlist.h", "~/src/pgsql/pg94/src/include/parser/kwlist.h", "~/src/pgsql/pg93/src/include/parser/kwlist.h", "~/src/pgsql/pg92/src/include/parser/kwlist.h", "~/src/pgsql/pg91/src/include/parser/kwlist.h", "~/src/pgsql/pg90/src/include/parser/kwlist.h", "~/src/pgsql/pg84/src/include/parser/kwlist.h", "~/src/pgsql/pg83/src/include/parser/kwlist.h", "/usr/include/postgresql/9.5/server/parser/kwlist.h", "/usr/include/postgresql/9.4/server/parser/kwlist.h", "/usr/include/postgresql/9.3/server/parser/kwlist.h", "/usr/include/postgresql/9.2/server/parser/kwlist.h", "/usr/include/postgresql/9.1/server/parser/kwlist.h", ] def _load_kwlist(fn, full_map, cur_map): fn = os.path.expanduser(fn) if not os.path.isfile(fn): return data = open(fn, 'rt').read() rc = re.compile(r'PG_KEYWORD[(]"(.*)" , \s* \w+ , \s* (\w+) [)]', re.X) for kw, cat in rc.findall(data): full_map[kw] = cat if cat == 'UNRESERVED_KEYWORD': continue if cat == 'COL_NAME_KEYWORD': continue cur_map[kw] = cat def test_kwcheck(): """Compare keyword list in quoting.py to the one in postgres sources """ kwset = set(skytools.quoting._ident_kwmap) full_map = {} # all types from kwlist.h cur_map = {} # only kwlist.h new_list = [] # missing from kwset obsolete_list = [] # in kwset, but not in cur_map for fn in locations: _load_kwlist(fn, full_map, cur_map) if not full_map: return for kw in sorted(cur_map): if kw not in kwset: new_list.append((kw, cur_map[kw])) kwset.add(kw) for k in sorted(kwset): if k not in full_map: # especially obsolete obsolete_list.append( (k, '!FULL') ) elif k not in cur_map: # slightly obsolete obsolete_list.append( (k, '!CUR') ) eq_(new_list, []) # here we need to keep older keywords around longer #eq_(obsolete_list, []) python-skytools-3.4/tests/test_querybuilder.py000066400000000000000000000022161356323561300220650ustar00rootroot00000000000000 from nose.tools import * from skytools.querybuilder import DList, CachedPlan, PlanCache def test_dlist(): root = DList() assert_true(root.empty()) elem1 = DList() elem2 = DList() elem3 = DList() root.append(elem1) root.append(elem2) root.append(elem3) assert_false(root.empty()) assert_false(elem1.empty()) root.remove(elem2) root.remove(elem3) root.remove(elem1) assert_true(root.empty()) assert_is_none(elem1.next) assert_is_none(elem2.next) assert_is_none(elem3.next) assert_is_none(elem1.prev) assert_is_none(elem2.prev) assert_is_none(elem3.prev) def test_cached_plan(): cache = PlanCache(3) p1 = cache.get_plan('sql1', ['text']) assert_is(p1, cache.get_plan('sql1', ['text'])) p2 = cache.get_plan('sql1', ['int']) assert_is(p2, cache.get_plan('sql1', ['int'])) assert_is_not(p1, p2) p3 = cache.get_plan('sql3', ['text']) assert_is(p3, cache.get_plan('sql3', ['text'])) p4 = cache.get_plan('sql4', ['text']) assert_is(p4, cache.get_plan('sql4', ['text'])) p1x = cache.get_plan('sql1', ['text']) assert_is_not(p1, p1x) python-skytools-3.4/tests/test_quoting.py000066400000000000000000000154241356323561300210440ustar00rootroot00000000000000"""Extra tests for quoting module. """ from __future__ import division, absolute_import, print_function import sys, time import skytools.psycopgwrapper import skytools._cquoting import skytools._pyquoting from decimal import Decimal from skytools.testing import ordered_dict from nose.tools import * # create a DictCursor row class fake_cursor: index = ordered_dict({'id': 0, 'data': 1}) description = ['x', 'x'] dbrow = skytools.psycopgwrapper._CompatRow(fake_cursor()) dbrow[0] = '123' dbrow[1] = 'value' def try_func(qfunc, data_list): for val, exp in data_list: got = qfunc(val) eq_(got, exp) def try_catch(qfunc, data_list, exc): for d in data_list: assert_raises(exc, qfunc, d) def test_quote_literal(): sql_literal = [ [None, "null"], ["", "''"], ["a'b", "'a''b'"], [r"a\'b", r"E'a\\''b'"], [1, "'1'"], [True, "'True'"], [Decimal(1), "'1'"], [u'qwe', "'qwe'"] ] try_func(skytools._cquoting.quote_literal, sql_literal) try_func(skytools._pyquoting.quote_literal, sql_literal) try_func(skytools.quote_literal, sql_literal) qliterals_common = [ (r"""null""", None), (r"""NULL""", None), (r"""123""", "123"), (r"""''""", r""""""), (r"""'a''b''c'""", r"""a'b'c"""), (r"""'foo'""", r"""foo"""), (r"""E'foo'""", r"""foo"""), (r"""E'a\n\t\a\b\0\z\'b'""", "a\n\t\x07\x08\x00z'b"), (r"""$$$$""", r""), (r"""$$qw$e$z$$""", r"qw$e$z"), (r"""$qq$$aa$$$'"\\$qq$""", '$aa$$$\'"\\\\'), (u"'qwe'", 'qwe'), ] bad_dol_literals = [ ('$$', '$$'), #('$$q', '$$q'), ('$$q$', '$$q$'), ('$q$q$', '$q$q$'), ('$q$q$x$', '$q$q$x$'), ] def test_unquote_literal(): qliterals_nonstd = qliterals_common + [ (r"""'a\\b\\c'""", r"""a\b\c"""), (r"""e'a\\b\\c'""", r"""a\b\c"""), ] try_func(skytools._cquoting.unquote_literal, qliterals_nonstd) try_func(skytools._pyquoting.unquote_literal, qliterals_nonstd) try_func(skytools.unquote_literal, qliterals_nonstd) for v1, v2 in bad_dol_literals: assert_raises(ValueError, skytools._pyquoting.unquote_literal, v1) assert_raises(ValueError, skytools._cquoting.unquote_literal, v1) assert_raises(ValueError, skytools.unquote_literal, v1) def test_unquote_literal_std(): qliterals_std = qliterals_common + [ (r"''", r""), (r"'foo'", r"foo"), (r"E'foo'", r"foo"), (r"'\\''z'", r"\\'z"), ] for val, exp in qliterals_std: eq_(skytools._cquoting.unquote_literal(val, True), exp) eq_(skytools._pyquoting.unquote_literal(val, True), exp) eq_(skytools.unquote_literal(val, True), exp) def test_quote_copy(): sql_copy = [ [None, "\\N"], ["", ""], ["a'\tb", "a'\\tb"], [r"a\'b", r"a\\'b"], [1, "1"], [True, "True"], [u"qwe", "qwe"], [Decimal(1), "1"], ] try_func(skytools._cquoting.quote_copy, sql_copy) try_func(skytools._pyquoting.quote_copy, sql_copy) try_func(skytools.quote_copy, sql_copy) def test_quote_bytea_raw(): sql_bytea_raw = [ [None, None], [b"", ""], [b"a'\tb", "a'\\011b"], [b"a\\'b", r"a\\'b"], [b"\t\344", r"\011\344"], ] try_func(skytools._cquoting.quote_bytea_raw, sql_bytea_raw) try_func(skytools._pyquoting.quote_bytea_raw, sql_bytea_raw) try_func(skytools.quote_bytea_raw, sql_bytea_raw) def test_quote_bytea_raw_fail(): assert_raises(TypeError, skytools._pyquoting.quote_bytea_raw, u'qwe') #assert_raises(TypeError, skytools._cquoting.quote_bytea_raw, u'qwe') #assert_raises(TypeError, skytools.quote_bytea_raw, 'qwe') def test_quote_ident(): sql_ident = [ ['', '""'], ["a'\t\\\"b", '"a\'\t\\""b"'], ['abc_19', 'abc_19'], ['from', '"from"'], ['0foo', '"0foo"'], ['mixCase', '"mixCase"'], [u'utf', 'utf'], ] try_func(skytools.quote_ident, sql_ident) def _sort_urlenc(func): def wrapper(data): res = func(data) return '&'.join(sorted(res.split('&'))) return wrapper def test_db_urlencode(): t_urlenc = [ [{}, ""], [{'a': 1}, "a=1"], [{'a': None}, "a"], [{'qwe': 1, u'zz': u"qwe"}, 'qwe=1&zz=qwe'], [ordered_dict({'qwe': 1, u'zz': u"qwe"}), 'qwe=1&zz=qwe'], [{'a': '\000%&'}, "a=%00%25%26"], [dbrow, 'data=value&id=123'], [{'a': Decimal("1")}, "a=1"], ] try_func(_sort_urlenc(skytools._cquoting.db_urlencode), t_urlenc) try_func(_sort_urlenc(skytools._pyquoting.db_urlencode), t_urlenc) try_func(_sort_urlenc(skytools.db_urlencode), t_urlenc) def test_db_urldecode(): t_urldec = [ ["", {}], ["a=b&c", {'a': 'b', 'c': None}], ["&&b=f&&", {'b': 'f'}], [u"abc=qwe", {'abc': 'qwe'}], ["b=", {'b': ''}], ["b=%00%45", {'b': '\x00E'}], ] try_func(skytools._cquoting.db_urldecode, t_urldec) try_func(skytools._pyquoting.db_urldecode, t_urldec) try_func(skytools.db_urldecode, t_urldec) def test_unescape(): t_unesc = [ ["", ""], ["\\N", "N"], ["abc", "abc"], [u"abc", "abc"], [r"\0\000\001\01\1", "\0\000\001\001\001"], [r"a\001b\tc\r\n", "a\001b\tc\r\n"], ] try_func(skytools._cquoting.unescape, t_unesc) try_func(skytools._pyquoting.unescape, t_unesc) try_func(skytools.unescape, t_unesc) def test_quote_bytea_literal(): bytea_raw = [ [None, "null"], [b"", "''"], [b"a'\tb", "E'a''\\\\011b'"], [b"a\\'b", r"E'a\\\\''b'"], [b"\t\344", r"E'\\011\\344'"], ] try_func(skytools.quote_bytea_literal, bytea_raw) def test_quote_bytea_copy(): bytea_raw = [ [None, "\\N"], [b"", ""], [b"a'\tb", "a'\\\\011b"], [b"a\\'b", r"a\\\\'b"], [b"\t\344", r"\\011\\344"], ] try_func(skytools.quote_bytea_copy, bytea_raw) def test_quote_statement(): sql = "set a=%s, b=%s, c=%s" args = [None, u"qwe'qwe", 6.6] eq_(skytools.quote_statement(sql, args), "set a=null, b='qwe''qwe', c='6.6'") sql = "set a=%(a)s, b=%(b)s, c=%(c)s" args = dict(a=None, b="qwe'qwe", c=6.6) eq_(skytools.quote_statement(sql, args), "set a=null, b='qwe''qwe', c='6.6'") def test_quote_json(): json_string_vals = [ [None, "null"], ['', '""'], [u'xx', '"xx"'], ['qwe"qwe\t', '"qwe\\"qwe\\t"'], ['\x01', '"\\u0001"'], ] try_func(skytools.quote_json, json_string_vals) def test_unquote_ident(): idents = [ ['qwe', 'qwe'], [u'qwe', 'qwe'], ['"qwe"', 'qwe'], ['"q""w\\\\e"', 'q"w\\\\e'], ] try_func(skytools.unquote_ident, idents) @raises(Exception) def test_unquote_ident_fail(): skytools.unquote_ident('asd"asd') python-skytools-3.4/tox.ini000066400000000000000000000014041356323561300161070ustar00rootroot00000000000000 [tox] envlist = lint3,py36 [package] name = skytools deps = psycopg2 test_deps = nose==1.3.7 coverage==4.5.4 lint_deps = pylint==2.4.4 [testenv] changedir = {envsitepackagesdir} deps = {[package]deps} {[package]test_deps} commands = coverage erase coverage run --rcfile "{toxinidir}/.coveragerc" --include "{[package]name}/*" \ -m nose -P --with-doctest --all-modules {[package]name} "{toxinidir}/tests" coverage html -d "{toxinidir}/tmp/cover-{envname}" \ --title "Coverage for {envname}" \ --rcfile "{toxinidir}/.coveragerc" coverage report --rcfile "{toxinidir}/.coveragerc" [testenv:lint3] basepython = python3 deps = {[package]deps} {[package]lint_deps} commands = pylint --rcfile={toxinidir}/.pylintrc {[package]name}