pax_global_header00006660000000000000000000000064137346033330014517gustar00rootroot0000000000000052 comment=b4633810efe9f7640fbc2028005be548fbad7ccb python-skytools-3.6.1/000077500000000000000000000000001373460333300147345ustar00rootroot00000000000000python-skytools-3.6.1/.coveragerc000066400000000000000000000003031373460333300170510ustar00rootroot00000000000000[report] exclude_lines = pragma: no cover if __name__ except ImportError: raise NotImplementedError omit = .tox/* tests/* [paths] source = skytools/ */skytools/ python-skytools-3.6.1/.github/000077500000000000000000000000001373460333300162745ustar00rootroot00000000000000python-skytools-3.6.1/.github/workflows/000077500000000000000000000000001373460333300203315ustar00rootroot00000000000000python-skytools-3.6.1/.github/workflows/ci.yml000066400000000000000000000143041373460333300214510ustar00rootroot00000000000000# # https://docs.github.com/en/actions/reference # https://github.com/actions # name: CI on: pull_request: {} push: {} jobs: pylint: name: "PyLint" runs-on: ubuntu-latest strategy: matrix: test: - {PY: "3.8", TOXENV: "lint"} steps: - name: "Checkout" uses: actions/checkout@v2 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v2 with: python-version: ${{matrix.test.PY}} - name: "Install tox" run: python -m pip -q install tox - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} run: python -m tox -r no_database: name: "${{matrix.test.osname}} + Python ${{matrix.test.PY}} ${{matrix.test.arch}}" runs-on: ${{matrix.test.os}} strategy: matrix: test: - {os: "ubuntu-16.04", osname: "Ubuntu 16.04", PY: "3.6", TOXENV: "py36", arch: "x64"} - {os: "ubuntu-18.04", osname: "Ubuntu 18.04", PY: "3.7", TOXENV: "py37", arch: "x64"} - {os: "ubuntu-20.04", osname: "Ubuntu 20.04", PY: "3.8", TOXENV: "py38", arch: "x64"} - {os: "ubuntu-20.04", osname: "Ubuntu 20.04", PY: "3.9.0-rc.2", TOXENV: "py39", arch: "x64"} - {os: "macos-10.15", osname: "MacOS 10.15", PY: "3.6", TOXENV: "py36", arch: "x64"} - {os: "macos-10.15", osname: "MacOS 10.15", PY: "3.8", TOXENV: "py38", arch: "x64"} - {os: "windows-2016", osname: "Windows 2016", PY: "3.6", TOXENV: "py36", arch: "x86"} - {os: "windows-2016", osname: "Windows 2016", PY: "3.7", TOXENV: "py37", arch: "x64"} - {os: "windows-2019", osname: "Windows 2019", PY: "3.8", TOXENV: "py38", arch: "x86"} - {os: "windows-2019", osname: "Windows 2019", PY: "3.8", TOXENV: "py38", arch: "x64"} - {os: "ubuntu-20.04", osname: "Ubuntu 20.04", PY: "pypy3", TOXENV: "pypy3", arch: "x64"} steps: - name: "Checkout" uses: actions/checkout@v2 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v2 with: python-version: ${{matrix.test.PY}} architecture: ${{matrix.test.arch}} - name: "Install tox" run: python -m pip -q install tox - name: "Build" run: python setup.py build - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} run: python -m tox -r -- --color=yes database: #if: false #name: "database test (disabled)" name: "PostgreSQL ${{matrix.test.PG}} + Python ${{matrix.test.PY}}" runs-on: ubuntu-20.04 strategy: matrix: test: - {PY: "3.8", PG: "12", TOXENV: "py38"} steps: - name: "Checkout" uses: actions/checkout@v2 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v2 with: python-version: ${{matrix.test.PY}} - name: "Install tox" run: | python -m pip -q install tox - name: "InstallDB" run: | echo "::group::apt-get-update" echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main ${{matrix.test.PG}}" \ | sudo tee /etc/apt/sources.list.d/pgdg.list sudo -nH apt-get -q update echo "::endgroup::" echo "::group::apt-get-install" # disable new cluster creation sudo -nH mkdir -p /etc/postgresql-common/createcluster.d echo "create_main_cluster = false" | sudo -nH tee /etc/postgresql-common/createcluster.d/no-main.conf sudo -nH apt-get -qyu install postgresql-${{matrix.test.PG}} echo "::endgroup::" # tune environment echo "::add-path::/usr/lib/postgresql/${{matrix.test.PG}}/bin" echo "::set-env name=PGHOST::/tmp" - name: "StartDB" run: | rm -rf data log mkdir -p log LANG=C initdb data sed -ri -e "s,^[# ]*(unix_socket_directories).*,\\1='/tmp'," data/postgresql.conf pg_ctl -D data -l log/pg.log start || { cat log/pg.log ; exit 1; } sleep 1 createdb testdb - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} TEST_DB: dbname=testdb host=/tmp run: | python -m tox -r -- --color=yes - name: "StopDB" run: | pg_ctl -D data stop rm -rf data log /tmp/.s.PGSQL* test_linux_wheels: name: "Wheel: manylinux1" runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: python-version: "3.8" - run: | python setup.py sdist mv dist sdist - uses: "docker://quay.io/pypa/manylinux1_x86_64:latest" with: {entrypoint: "./.github/workflows/manylinux.sh"} - uses: "docker://quay.io/pypa/manylinux1_i686:latest" with: {entrypoint: "./.github/workflows/manylinux.sh"} - uses: actions/upload-artifact@v2 with: {name: "dist", path: "dist"} test_other_wheels: name: "Wheel: ${{matrix.sys.name}}-${{matrix.sys.pyarch}}-${{matrix.pyver}}" runs-on: ${{matrix.sys.os}} strategy: matrix: sys: - {os: "macos-10.15", name: "MacOS 10.15", pyarch: "x64", opts: "--build-option --py-limited-api=cp36", repack: false} - {os: "windows-2019", name: "Windows 2019", pyarch: "x86", opts: "", repack: true} - {os: "windows-2019", name: "Windows 2019", pyarch: "x64", opts: "", repack: true} pyver: ["3.6"] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: {architecture: "${{matrix.sys.pyarch}}", python-version: "${{matrix.pyver}}"} - run: | python setup.py sdist mv dist sdist - name: "Build" shell: bash run: | pip install --disable-pip-version-check -U wheel pip wheel -v --disable-pip-version-check ${{matrix.sys.opts}} -w dist sdist/*.tar.gz ls -l dist - name: "Repack" if: ${{matrix.sys.repack}} shell: bash run: | mv dist dist2 python .github/workflows/win32abi3.py -d dist dist2/*.whl ls -l dist - uses: actions/upload-artifact@v2 with: {name: "dist", path: "dist"} python-skytools-3.6.1/.github/workflows/manylinux.sh000077500000000000000000000016271373460333300227220ustar00rootroot00000000000000#! /bin/sh # will be run inside manylinux docker # # https://github.com/pypa/manylinux # https://www.python.org/dev/peps/pep-0513/ - manylinux1 # https://www.python.org/dev/peps/pep-0571/ - manylinux2010 # https://www.python.org/dev/peps/pep-0599/ - manylinux2014 # set -e set -x PYLIST="cp36-cp36m" PYDEPS="" DSTDIR="dist" BLDDIR="build/${AUDITWHEEL_PLAT}" WHEELOPTS="--build-option --py-limited-api=cp36" PIPOPTS="--no-cache-dir --disable-pip-version-check" # build initial wheel build_wheel() { if test -n "${PYDEPS}"; then pip install ${PIPOPTS} -U ${PYDEPS} fi pip wheel ${PIPOPTS} -w "${BLDDIR}" $WHEELOPTS . } # build wheels for requested python versions for tag in ${PYLIST}; do PATH="/opt/python/${tag}/bin:${PATH}" \ build_wheel done # use auditwheel to rebuild wheels in BLDDIR to DSTDIR for whl in "${BLDDIR}"/*.whl; do auditwheel repair -w "${DSTDIR}" "${whl}" done python-skytools-3.6.1/.github/workflows/release.yml000066400000000000000000000117561373460333300225060ustar00rootroot00000000000000# # https://docs.github.com/en/actions/reference # https://github.com/actions # name: REL on: push: tags: ["v[0-9]*"] jobs: sdist: name: "Build source package" runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: {python-version: "3.8"} - run: python setup.py sdist - uses: actions/upload-artifact@v2 with: {name: "dist", path: "dist"} linux_wheels: name: "Wheel: manylinux1" needs: [sdist] runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v2 with: {name: "dist", path: "sdist"} - uses: "docker://quay.io/pypa/manylinux1_x86_64:latest" with: {entrypoint: "./.github/workflows/manylinux.sh"} - uses: "docker://quay.io/pypa/manylinux1_i686:latest" with: {entrypoint: "./.github/workflows/manylinux.sh"} - uses: actions/upload-artifact@v2 with: {name: "dist", path: "dist"} other_wheels: name: "Wheel: ${{matrix.sys.name}}-${{matrix.sys.pyarch}}-${{matrix.pyver}}" needs: [sdist] runs-on: ${{matrix.sys.os}} strategy: matrix: sys: - {os: "macos-10.15", name: "MacOS 10.15", pyarch: "x64", opts: "--build-option --py-limited-api=cp36", repack: false} - {os: "windows-2019", name: "Windows 2019", pyarch: "x86", opts: "", repack: true} - {os: "windows-2019", name: "Windows 2019", pyarch: "x64", opts: "", repack: true} pyver: ["3.6"] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: {architecture: "${{matrix.sys.pyarch}}", python-version: "${{matrix.pyver}}"} - uses: actions/download-artifact@v2 with: {name: "dist", path: "sdist"} - name: "Build" shell: bash run: | pip install --disable-pip-version-check -U wheel pip wheel --disable-pip-version-check ${{matrix.sys.opts}} -w dist sdist/*.tar.gz ls -l dist - name: "Repack" if: ${{matrix.sys.repack}} shell: bash run: | mv dist dist2 python .github/workflows/win32abi3.py -d dist dist2/*.whl ls -l dist - uses: actions/upload-artifact@v2 with: {name: "dist", path: "dist"} publish: name: "Publish" runs-on: ubuntu-20.04 needs: [linux_wheels, other_wheels] steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: {python-version: "3.8"} - uses: actions/download-artifact@v2 with: {name: "dist", path: "dist"} - name: "Prepare" run: | PACKAGE=$(python setup.py --name) VERSION=$(python setup.py --version) TGZ="${PACKAGE}-${VERSION}.tar.gz" # default - gh:release, pypi # PRERELEASE - gh:prerelease, pypi # DRAFT - gh:draft,prerelease, testpypi PRERELEASE="false" DRAFT="false" if echo "${VERSION}" | grep -qE '(a|b|rc)'; then PRERELEASE="true"; fi if echo "${VERSION}" | grep -qE '(dev)'; then DRAFT="true"; PRERELEASE="true"; fi test "${{github.ref}}" = "refs/tags/v${VERSION}" || { echo "ERR: tag mismatch"; exit 1; } test -f "dist/${TGZ}" || { echo "ERR: sdist failed"; exit 1; } echo "::set-env name=PACKAGE::${PACKAGE}" echo "::set-env name=VERSION::${VERSION}" echo "::set-env name=TGZ::${TGZ}" echo "::set-env name=PRERELEASE::${PRERELEASE}" echo "::set-env name=DRAFT::${DRAFT}" sudo -nH apt-get -u -y install pandoc pandoc --version mkdir -p tmp make -s shownote > tmp/note.md cat tmp/note.md - name: "Create release" id: github_release uses: actions/create-release@v1 env: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} with: tag_name: ${{github.ref}} release_name: ${{env.PACKAGE}} v${{env.VERSION}} body_path: tmp/note.md prerelease: ${{env.PRERELEASE}} draft: ${{env.DRAFT}} - name: "Upload to Github" id: github_upload uses: actions/upload-release-asset@v1 env: GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} with: upload_url: ${{steps.github_release.outputs.upload_url}} asset_path: dist/${{env.TGZ}} asset_name: ${{env.TGZ}} asset_content_type: application/x-gzip - name: "Upload to PYPI" id: pypi_upload env: PYPI_TOKEN: ${{secrets.PYPI_TOKEN}} PYPI_TEST_TOKEN: ${{secrets.PYPI_TEST_TOKEN}} run: | pip install --disable-pip-version-check -U twine ls -l dist if test "${DRAFT}" = "false"; then python -m twine upload -u __token__ -p ${PYPI_TOKEN} \ --repository pypi --disable-progress-bar dist/* else python -m twine upload -u __token__ -p ${PYPI_TEST_TOKEN} \ --repository testpypi --disable-progress-bar dist/* fi python-skytools-3.6.1/.github/workflows/win32abi3.py000077500000000000000000000113071373460333300224110ustar00rootroot00000000000000#! /usr/bin/env python3 """ Convert windows wheel to abi3-like multi-python wheel. This is a workaround for missing win32-abi3 support in pip/wheel/bdist_wheel. Requirements: - extensions must be compiled as abi3 compatible. - python 3.5+, older don't have stable libc abi. """ import argparse import base64 import hashlib import os import os.path import re import sys import zipfile #from wheel.wheelfile import WheelFile, WHEEL_INFO_RE RC_FN = re.compile(r""" ^ (?P (?P .+? ) - (?P \d [^-]* ) ) (?P - \d [^-]* )? - (?P[a-z][^-]*) - (?P[a-z][a-z0-9]+) - (?P[a-z][a-z0-9_]+) [.]whl $ """, re.X) _quiet = False _verbose = False def writemsg(fd, msg, args): if args: msg = msg % args fd.write(msg + "\n") fd.flush() def printf(msg, *args): if not _quiet: writemsg(sys.stdout, msg, args) def dprintf(msg, *args): if _verbose: writemsg(sys.stdout, msg, args) def eprintf(msg, *args): writemsg(sys.stderr, msg, args) def die(msg, *args): eprintf(msg, *args) sys.exit(1) def convert_filename(fn, pyvers): m = RC_FN.match(fn) if not m: die("Unsupported wheel name: %s", fn) namever = m.group("namever") build = m.group("build") or "" abi = m.group("abi") arch = m.group("arch") if arch.startswith("win"): abi = "none" # should be "abi3" newtag = "%s-%s-%s" % (pyvers, abi, arch) fn2 = "%s%s-%s.whl" % (namever, build, newtag) return fn2, namever, newtag def convert_tags(wheeldata, newtag): res = [] for ln in wheeldata.decode().split("\n"): if ln.startswith("Tag:"): res.append("Tag: %s" % newtag) else: res.append(ln) return "\n".join(res).encode() def digest(data): md = hashlib.sha256(data).digest() b64 = base64.urlsafe_b64encode(md).decode() return "sha256=" + b64.strip("=") def convert_record(data, wheeldata): res = [] for ln in data.decode().split("\n"): parts = ln.split(",") if len(parts) != 3: res.append(ln) continue elif parts[0].endswith("WHEEL"): ln2 = "%s,%s,%s" % (parts[0], digest(wheeldata), len(wheeldata)) res.append(ln2) else: res.append(ln) return "\n".join(res).encode() def convert_wheel(srcwheel, dstwheel, namever, newtag): recordfn = "%s.dist-info/RECORD" % namever wheelfn = "%s.dist-info/WHEEL" % namever printf("Creating %s", dstwheel) wheel = None record = None with zipfile.ZipFile(srcwheel, "r") as zsrc: wheel = convert_tags(zsrc.read(wheelfn), newtag) if not wheel: die("WHEEL entry not found") with zipfile.ZipFile(dstwheel, "w") as zdst: for info in zsrc.infolist(): if info.is_dir(): continue filename = info.filename.replace("\\", "/") info2 = zipfile.ZipInfo(filename=filename, date_time=info.date_time) info2.compress_type = zipfile.ZIP_DEFLATED if filename == wheelfn: dprintf(" Converting %s", filename) zdst.writestr(info2, wheel) elif filename == recordfn: dprintf(" Converting %s", filename) record = convert_record(zsrc.read(info), wheel) zdst.writestr(info2, record) else: dprintf(" Copying %s", filename) zdst.writestr(info2, zsrc.read(info)) if not record: die("RECORD entry not found") def main(argv=None): """Convert win32/64 wheels to be abi3-like. """ global _verbose, _quiet p = argparse.ArgumentParser( description=main.__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("wheel", metavar="WHEEL", nargs="+", help="Wheels to convert") p.add_argument("-d", dest="dest", help="target dir", default=".") p.add_argument("-p", dest="pyvers", help="python versions", default="cp36.cp37.cp38") p.add_argument("-q", dest="quiet", help="no info messages", action="store_true") p.add_argument("-v", dest="verbose", help="debug messages", action="store_true") args = p.parse_args(sys.argv[1:] if argv is None else argv) if args.quiet: _quiet = True elif args.verbose: _verbose = True os.makedirs(args.dest, exist_ok=True) for fn in args.wheel: srcfn = os.path.basename(fn) dstfn, namever, newtag = convert_filename(srcfn, args.pyvers) fn2 = os.path.join(args.dest, dstfn) convert_wheel(fn, fn2, namever, newtag) if __name__ == "__main__": main() python-skytools-3.6.1/.gitignore000066400000000000000000000003571373460333300167310ustar00rootroot00000000000000__pycache__ *.pyc *.swp *.o *.so *.egg-info *.debhelper *.log *.substvars *-stamp debian/files debian/python-skytools debian/python3-skytools .pytype .pytest_cache .mypy_cache .tox .coverage .pybuild cover *.xml MANIFEST build tmp dist python-skytools-3.6.1/.pylintrc000066400000000000000000000372041373460333300166070ustar00rootroot00000000000000[MASTER] # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. extension-pkg-whitelist= # Add files or directories to the blacklist. They should be base names, not # paths. ignore=CVS,tmp,dist # Add files or directories matching the regex patterns to the blacklist. The # regex matches against base names, not paths. ignore-patterns= # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). #init-hook= # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use. jobs=1 # Control the amount of potential inferred values when inferring a single # object. This can help the performance when dealing with large functions or # complex, nested conditions. limit-inference-results=100 # List of plugins (as comma separated values of python modules names) to load, # usually to register additional checkers. load-plugins= # Pickle collected data for later comparisons. persistent=yes # Specify a configuration file. #rcfile= # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode=yes # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. unsafe-load-any-extension=no [MESSAGES CONTROL] # Only show warnings with the listed confidence levels. Leave empty to show # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence= # Disable the message, report, category or checker with the given id(s). You # can either give multiple identifiers separated by comma (,) or put this # option multiple times (only on the command line, not in the configuration # file where it should appear only once). You can also use "--disable=all" to # disable everything first and then reenable specific checks. For example, if # you want to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable=bad-continuation, bad-whitespace, bare-except, broad-except, consider-using-in, consider-using-ternary, fixme, global-statement, invalid-name, missing-docstring, no-else-raise, no-else-return, no-self-use, trailing-newlines, unused-argument, unused-variable, using-constant-test, useless-object-inheritance, duplicate-code, arguments-differ, multiple-statements, len-as-condition, chained-comparison, unnecessary-pass, cyclic-import, invalid-name, bad-continuation, too-many-ancestors, import-outside-toplevel, protected-access, try-except-raise, deprecated-module, no-else-break, no-else-continue # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where # it should appear only once). See also the "--disable" option for examples. enable=c-extension-no-member [REPORTS] # Python expression which should return a note less than 10 (10 is the highest # note). You have access to the variables errors warning, statement which # respectively contain the number of errors / warnings messages and the total # number of statements analyzed. This is used by the global evaluation report # (RP0004). evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. #msg-template= # Set the output format. Available formats are text, parseable, colorized, json # and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. output-format=colorized # Tells whether to display a full report or only the messages. reports=no # Activate the evaluation score. score=no [REFACTORING] # Maximum number of nested blocks for function / method body max-nested-blocks=10 # Complete name of functions that never returns. When checking for # inconsistent-return-statements if a never returning function is called then # it will be considered as an explicit return statement and no message will be # printed. never-returning-functions=sys.exit [LOGGING] # Format style used to check logging format string. `old` means using % # formatting, while `new` is for `{}` formatting. logging-format-style=old # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules=logging [MISCELLANEOUS] # List of note tags to take in consideration, separated by a comma. notes=FIXME, XXX, TODO [SPELLING] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions=4 # Spelling dictionary name. Available dictionaries: none. To make it working # install python-enchant package.. #spelling-dict=en_US # List of comma separated words that should not be checked. spelling-ignore-words=usr,bin,env # A path to a file that contains private dictionary; one word per line. spelling-private-dict-file=.local.dict # Tells whether to store unknown words to indicated private dictionary in # --spelling-private-dict-file option instead of raising a message. spelling-store-unknown-words=no [BASIC] # Naming style matching correct argument names. argument-naming-style=snake_case # Regular expression matching correct argument names. Overrides argument- # naming-style. #argument-rgx= # Naming style matching correct attribute names. attr-naming-style=snake_case # Regular expression matching correct attribute names. Overrides attr-naming- # style. #attr-rgx= # Bad variable names which should always be refused, separated by a comma. bad-names=foo, bar, baz, toto, tutu, tata # Naming style matching correct class attribute names. class-attribute-naming-style=any # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. #class-attribute-rgx= # Naming style matching correct class names. class-naming-style=PascalCase # Regular expression matching correct class names. Overrides class-naming- # style. #class-rgx= # Naming style matching correct constant names. const-naming-style=UPPER_CASE # Regular expression matching correct constant names. Overrides const-naming- # style. #const-rgx= # Minimum line length for functions/classes that require docstrings, shorter # ones are exempt. docstring-min-length=-1 # Naming style matching correct function names. function-naming-style=snake_case # Regular expression matching correct function names. Overrides function- # naming-style. #function-rgx= # Good variable names which should always be accepted, separated by a comma. good-names=i, j, k, ex, Run, _ # Include a hint for the correct naming format with invalid-name. include-naming-hint=no # Naming style matching correct inline iteration names. inlinevar-naming-style=any # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. #inlinevar-rgx= # Naming style matching correct method names. method-naming-style=snake_case # Regular expression matching correct method names. Overrides method-naming- # style. #method-rgx= # Naming style matching correct module names. module-naming-style=snake_case # Regular expression matching correct module names. Overrides module-naming- # style. #module-rgx= # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. name-group= # Regular expression which should only match function or class names that do # not require a docstring. no-docstring-rgx=^_ # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. # These decorators are taken in consideration only for invalid-name. property-classes=abc.abstractproperty # Naming style matching correct variable names. variable-naming-style=snake_case # Regular expression matching correct variable names. Overrides variable- # naming-style. #variable-rgx= [STRING] # This flag controls whether the implicit-str-concat-in-sequence should # generate a warning on implicit string concatenation in sequences defined over # several lines. check-str-concat-over-line-jumps=no [SIMILARITIES] # Ignore comments when computing similarities. ignore-comments=yes # Ignore docstrings when computing similarities. ignore-docstrings=yes # Ignore imports when computing similarities. ignore-imports=no # Minimum lines number of a similarity. min-similarity-lines=4 [VARIABLES] # List of additional names supposed to be defined in builtins. Remember that # you should avoid defining new builtins when possible. additional-builtins= # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables=yes # List of strings which can identify a callback function by name. A callback # name must start or end with one of those strings. callbacks=cb_, _cb # A regular expression matching the name of dummy variables (i.e. expected to # not be used). dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ # Argument names that match this expression will be ignored. Default to name # with leading underscore. ignored-argument-names=_.*|^ignored_|^unused_ # Tells whether we should check for unused import in __init__ files. init-import=no # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io [TYPECHECK] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. generated-members= # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). ignore-mixin-members=yes # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. ignore-none=yes # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference # can return multiple potential results while evaluating a Python object, but # some branches might not be evaluated, which results in partial inference. In # that case, it might be useful to still emit no-member and other checks for # the rest of the inferred objects. ignore-on-opaque-inference=yes # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes=optparse.Values,thread._local,_thread._local # List of module names for which member attributes should not be checked # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis. It # supports qualified module names, as well as Unix pattern matching. ignored-modules= # Show a hint with possible names when a member name was not found. The aspect # of finding the hint is based on edit distance. missing-member-hint=yes # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance=1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices=1 [FORMAT] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format=LF # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines=^\s*(# )??$ # Number of spaces of indent required inside a hanging or continued line. indent-after-paren=4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string=' ' # Maximum number of characters on a single line. max-line-length=160 # Maximum number of lines in a module. max-module-lines=10000 # List of optional constructs for which whitespace checking is disabled. `dict- # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. # `trailing-comma` allows a space between comma and closing bracket: (a, ). # `empty-line` allows space-only lines. no-space-check=trailing-comma, dict-separator # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no # Allow the body of an if to be on the same line as the test if there is no # else. single-line-if-stmt=no [CLASSES] # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, setUp # List of member names, which should be excluded from the protected access # warning. exclude-protected=_asdict, _fields, _replace, _source, _make # List of valid names for the first argument in a class method. valid-classmethod-first-arg=cls # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg=cls [DESIGN] # Maximum number of arguments for function / method. max-args=15 # Maximum number of attributes for a class (see R0902). max-attributes=37 # Maximum number of boolean expressions in an if statement. max-bool-expr=5 # Maximum number of branch for function / method body. max-branches=50 # Maximum number of locals for function / method body. max-locals=45 # Maximum number of parents for a class (see R0901). max-parents=7 # Maximum number of public methods for a class (see R0904). max-public-methods=420 # Maximum number of return / yield for function / method body. max-returns=16 # Maximum number of statements in function / method body. max-statements=150 # Minimum number of public methods for a class (see R0903). min-public-methods=0 [IMPORTS] # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no # Analyse import fallback blocks. This can be used to support both Python 2 and # 3 compatible code, which means that the block might have code that exists # only in one or another interpreter, leading to false positives when analysed. analyse-fallback-blocks=no # Deprecated modules which should not be used, separated by a comma. deprecated-modules=optparse,tkinter.tix # Create a graph of external dependencies in the given file (report RP0402 must # not be disabled). ext-import-graph= # Create a graph of every (i.e. internal and external) dependencies in the # given file (report RP0402 must not be disabled). import-graph= # Create a graph of internal dependencies in the given file (report RP0402 must # not be disabled). int-import-graph= # Force import order to recognize a module as part of the standard # compatibility libraries. known-standard-library= # Force import order to recognize a module as part of a third party library. known-third-party=enchant [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". overgeneral-exceptions=BaseException, Exception python-skytools-3.6.1/AUTHORS000066400000000000000000000013451373460333300160070ustar00rootroot00000000000000 Maintainers ----------- Marko Kreen Petr Jelinek Sasha Aliashkevich Contributors ------------ Aleksei Plotnikov André Malo Andrew Dunstan Artyom Nosov Asko Oja Asko Tiidumaa Cédric Villemain Charles Duffy Devrim Gündüz Dimitri Fontaine Dmitriy V'jukov Doug Gorley Eero Oja Egon Valdmees Emiel van de Laar Erik Jones Glenn Davy Götz Lange Hannu Krosing Hans-Juergen Schoenig Jason Buberel Juta Vaks Kaarel Kitsemets Kristo Kaiv Luc Van Hoeylandt Lukáš Lalinský Marcin Stępnicki Mark Kirkwood Martin Otto Martin Pihlak Nico Mandery Petr Jelinek Pierre-Emmanuel André Priit Kustala Sasha Aliashkevich Sébastien Lardière Sergey Burladyan Sergey Konoplev Shoaib Mir Steve Singer Tarvi Pillessaar Tony Arkles Zoltán Böszörményi python-skytools-3.6.1/COPYRIGHT000066400000000000000000000013451373460333300162320ustar00rootroot00000000000000 Copyright (c) 2007-2020 Skytools Authors Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. python-skytools-3.6.1/MANIFEST.in000066400000000000000000000002141373460333300164670ustar00rootroot00000000000000include modules/*.[ch] include tests/*.py tests/*.ini include tox.ini .coveragerc .pylintrc include MANIFEST.in include README.rst NEWS.rst python-skytools-3.6.1/Makefile000066400000000000000000000013411373460333300163730ustar00rootroot00000000000000 VERSION = $(shell python3 setup.py --version) RXVERSION = $(shell python3 setup.py --version | sed 's/\./[.]/g') TAG = v$(VERSION) NEWS = NEWS.rst all: clean: rm -rf build *.egg-info */__pycache__ tests/*.pyc rm -rf .pybuild MANIFEST xclean: clean rm -rf .tox dist checkver: @echo "Checking version" @grep -q '^Skytools $(RXVERSION)\b' $(NEWS) \ || { echo "Version '$(VERSION)' not in $(NEWS)"; exit 1; } @echo "Checking git repo" @git diff --stat --exit-code || { echo "ERROR: Unclean repo"; exit 1; } release: checkver git tag $(TAG) git push github $(TAG):$(TAG) unrelease: git push github :$(TAG) git tag -d $(TAG) shownote: awk -v VER="$(VERSION)" -f etc/note.awk $(NEWS) \ | pandoc -f rst -t gfm --wrap=none python-skytools-3.6.1/NEWS.rst000066400000000000000000000030611373460333300162420ustar00rootroot00000000000000 NEWS ==== Skytools 3.6.1 (2020-09-29) --------------------------- Fixes: * scripting: Do not set .my_name on connection, does not work on plain Psycopg connection. * cquoting: Work around pypy3 PyBytes_Check bug. * modules: Use multiphase init. Skytools 3.6 (2020-08-11) ------------------------- Feature removal: * Remove ancient compat code from psycopgwrapper: - dict* and iter* methods - getattr access to fields. - Keepalive tuning from connect_database(). That is built-in to libpq since 9.0. - Require psycpopg 2.5+ Cleanups: * Switch C modules to use stable ABI only (abi3). * Remove Debian packaging. * Upgrade apipkg to 1.5. * Remove Py2 compat. Skytools 3.5 (2020-07-18) ------------------------- Fixes: * dbservice: py3 fix for row.values() * skylog: Use logging.setLogRecordFactory for adding extra fields * fileutil,sockutil: fixes for win32. * natsort: py3 fix, improve rules. Cleanups: * Set up Github Actions for CI and release. * Use "with" for opening files. * Drop py2 syntax. * Code reformat. * Convert nose+doctests to pytest. Skytools 3.4 (2019-11-14) ------------------------- * Support Postgres 10 sequences * Make full_copy text-based * Allow None fields in magic_insert * Fix iterator use in magic insert * Fix Python3 bugs * Switch off Python2 tests, to avoid wasting time. Skytools 3.3 (2017-09-21) ------------------------- * Separate 'skytools' module out from big package * Python 3 support Skytools 3.2 and older ---------------------- See old changes here: https://github.com/pgq/skytools-legacy/blob/master/NEWS python-skytools-3.6.1/README.rst000066400000000000000000000012471373460333300164270ustar00rootroot00000000000000 Skytools - Utilities for writing Python scripts =============================================== This is the low-level utility module split out from old Skytools meta-package. It contains various utilities for writing database scripts. Database specific utilites are mainly meant for PostgreSQL. Features -------- * Support for background scripts - Daemonizing - logging - config parsing * Database tools - Tuned connection - DB structure examining - SQL parsing - COPY I/O * Time utilities - ISO timestamp parsing - datetime to timestamp * Text utilities - Natural sort - Fast urlencode I/O TODO ---- * Move from optparse to argparse * Doc cleanup python-skytools-3.6.1/etc/000077500000000000000000000000001373460333300155075ustar00rootroot00000000000000python-skytools-3.6.1/etc/note.awk000066400000000000000000000004601373460333300171600ustar00rootroot00000000000000# extract version notes for version VER /^[-_0-9a-zA-Z]+ v?[0-9]/ { if ($2 == VER) { good = 1 next } else { good = 0 } } /^(===|---)/ { next } { if (good) { # also remove sphinx syntax print gensub(/:(\w+):`~?([^`]+)`/, "``\\2``", "g") } } python-skytools-3.6.1/modules/000077500000000000000000000000001373460333300164045ustar00rootroot00000000000000python-skytools-3.6.1/modules/cquoting.c000066400000000000000000000400421373460333300204010ustar00rootroot00000000000000/* * Fast quoting functions for Python. */ #define PY_SSIZE_T_CLEAN #include #include #ifdef _MSC_VER #define inline __inline #define strcasecmp stricmp #endif /* inheritance check is broken in pypy3 */ #ifdef PYPY_VERSION_NUM #undef PyBytes_Check #define PyBytes_Check PyBytes_CheckExact #undef PyDict_Check #define PyDict_Check PyDict_CheckExact #endif #include "get_buffer.h" /* * Common buffer management. */ struct Buf { unsigned char *ptr; Py_ssize_t pos; Py_ssize_t alloc; }; static unsigned char *buf_init(struct Buf *buf, Py_ssize_t init_size) { if (init_size < 256) init_size = 256; buf->ptr = PyMem_Malloc(init_size); if (buf->ptr) { buf->pos = 0; buf->alloc = init_size; } return buf->ptr; } /* return new pos */ static unsigned char *buf_enlarge(struct Buf *buf, Py_ssize_t need_room) { Py_ssize_t alloc = buf->alloc; Py_ssize_t need_size = buf->pos + need_room; unsigned char *ptr; /* no alloc needed */ if (need_size < alloc) return buf->ptr + buf->pos; if (alloc <= need_size / 2) alloc = need_size; else alloc = alloc * 2; ptr = PyMem_Realloc(buf->ptr, alloc); if (!ptr) return NULL; buf->ptr = ptr; buf->alloc = alloc; return buf->ptr + buf->pos; } static void buf_free(struct Buf *buf) { PyMem_Free(buf->ptr); buf->ptr = NULL; buf->pos = buf->alloc = 0; } static inline unsigned char *buf_get_target_for(struct Buf *buf, Py_ssize_t len) { if (buf->pos + len <= buf->alloc) return buf->ptr + buf->pos; else return buf_enlarge(buf, len); } static inline void buf_set_target(struct Buf *buf, unsigned char *newpos) { assert(buf->ptr + buf->pos <= newpos); assert(buf->ptr + buf->alloc >= newpos); buf->pos = newpos - buf->ptr; } static inline int buf_put(struct Buf *buf, unsigned char c) { if (buf->pos < buf->alloc) { buf->ptr[buf->pos++] = c; return 1; } else if (buf_enlarge(buf, 1)) { buf->ptr[buf->pos++] = c; return 1; } return 0; } static PyObject *buf_pystr(struct Buf *buf, Py_ssize_t start_pos, unsigned char *newpos) { PyObject *res; if (newpos) buf_set_target(buf, newpos); res = PyUnicode_FromStringAndSize((char *)buf->ptr + start_pos, buf->pos - start_pos); buf_free(buf); return res; } /* * Common argument parsing. */ typedef PyObject *(*quote_fn)(unsigned char *src, Py_ssize_t src_len); static PyObject *common_quote(PyObject *args, quote_fn qfunc) { unsigned char *src = NULL; Py_ssize_t src_len = 0; PyObject *arg, *res, *strtmp = NULL; if (!PyArg_ParseTuple(args, "O", &arg)) return NULL; if (arg != Py_None) { src_len = get_buffer(arg, &src, &strtmp); if (src_len < 0) return NULL; } res = qfunc(src, src_len); Py_CLEAR(strtmp); return res; } /* * Simple quoting functions. */ static const char doc_quote_literal[] = "Quote a literal value for SQL.\n" "\n" "If string contains '\\', it is quoted and result is prefixed with E.\n" "Input value of None results in string \"null\" without quotes.\n" "\n" "C implementation.\n"; static PyObject *quote_literal_body(unsigned char *src, Py_ssize_t src_len) { struct Buf buf; unsigned char *esc, *dst, *src_end = src + src_len; unsigned int start_ofs = 1; if (src == NULL) return PyUnicode_FromString("null"); esc = dst = buf_init(&buf, src_len * 2 + 2 + 1); if (!dst) return NULL; *dst++ = ' '; *dst++ = '\''; while (src < src_end) { if (*src == '\\') { *dst++ = '\\'; start_ofs = 0; } else if (*src == '\'') { *dst++ = '\''; } *dst++ = *src++; } *dst++ = '\''; if (start_ofs == 0) *esc = 'E'; return buf_pystr(&buf, start_ofs, dst); } static PyObject *quote_literal(PyObject *self, PyObject *args) { return common_quote(args, quote_literal_body); } /* COPY field */ static const char doc_quote_copy[] = "Quoting for COPY data. None is converted to \\N.\n\n" "C implementation."; static PyObject *quote_copy_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) return PyUnicode_FromString("\\N"); dst = buf_init(&buf, src_len * 2); if (!dst) return NULL; while (src < src_end) { switch (*src) { case '\t': *dst++ = '\\'; *dst++ = 't'; src++; break; case '\n': *dst++ = '\\'; *dst++ = 'n'; src++; break; case '\r': *dst++ = '\\'; *dst++ = 'r'; src++; break; case '\\': *dst++ = '\\'; *dst++ = '\\'; src++; break; default: *dst++ = *src++; break; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_copy(PyObject *self, PyObject *args) { return common_quote(args, quote_copy_body); } /* raw bytea for byteain() */ static const char doc_quote_bytea_raw[] = "Quoting for bytea parser. Returns None as None.\n" "\n" "C implementation."; static PyObject *quote_bytea_raw_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { Py_INCREF(Py_None); return Py_None; } dst = buf_init(&buf, src_len * 4); if (!dst) return NULL; while (src < src_end) { if (*src < 0x20 || *src >= 0x7F) { *dst++ = '\\'; *dst++ = '0' + (*src >> 6); *dst++ = '0' + ((*src >> 3) & 7); *dst++ = '0' + (*src & 7); src++; } else { if (*src == '\\') *dst++ = '\\'; *dst++ = *src++; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_bytea_raw(PyObject *self, PyObject *args) { return common_quote(args, quote_bytea_raw_body); } /* SQL unquote */ static const char doc_unquote_literal[] = "Unquote SQL value.\n\n" "E'..' -> extended quoting.\n" "'..' -> standard or extended quoting\n" "null -> None\n" "other -> returned as-is\n\n" "C implementation.\n"; static PyObject *do_sql_ext(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src == '\'') { src++; if (src < src_end && *src == '\'') { *dst++ = *src++; continue; } goto failed; } if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken exteded SQL string"); return NULL; } static PyObject *do_sql_std(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\'') { *dst++ = *src++; continue; } src++; if (src >= src_end || *src != '\'') goto failed; *dst++ = *src++; } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken standard SQL string"); return NULL; } static PyObject *do_dolq(unsigned char *src, Py_ssize_t src_len) { /* src_len >= 2, '$' in start and end */ unsigned char *src_end = src + src_len; unsigned char *p1 = src + 1, *p2 = src_end - 2; while (p1 < src_end && *p1 != '$') p1++; while (p2 > src && *p2 != '$') p2--; if (p2 <= p1) goto failed; p1++; /* position after '$' */ if ((p1 - src) != (src_end - p2)) goto failed; if (memcmp(src, p2, p1 - src) != 0) goto failed; return PyUnicode_FromStringAndSize((char *)p1, p2 - p1); failed: PyErr_Format(PyExc_ValueError, "Broken dollar-quoted string"); return NULL; } static PyObject *unquote_literal(PyObject *self, PyObject *args) { unsigned char *src = NULL; Py_ssize_t src_len = 0; int stdstr = 0; PyObject *value = NULL; PyObject *tmp = NULL; PyObject *res = NULL; if (!PyArg_ParseTuple(args, "O|i", &value, &stdstr)) return NULL; src_len = get_buffer(value, &src, &tmp); if (src_len < 0) return NULL; if (src_len == 4 && strcasecmp((char *)src, "null") == 0) { Py_INCREF(Py_None); res = Py_None; } else if (src_len >= 2 && src[0] == '$' && src[src_len - 1] == '$') { res = do_dolq(src, src_len); } else if (src_len < 2 || src[src_len - 1] != '\'') { /* seems invalid, return as-is */ Py_INCREF(value); res = value; } else if (src[0] == '\'') { src++; src_len -= 2; res = stdstr ? do_sql_std(src, src_len) : do_sql_ext(src, src_len); } else if (src_len > 2 && (src[0] | 0x20) == 'e' && src[1] == '\'') { src += 2; src_len -= 3; res = do_sql_ext(src, src_len); } if (tmp) Py_CLEAR(tmp); return res; } /* C unescape */ static const char doc_unescape[] = "Unescape C-style escaped string.\n\n" "C implementation."; static PyObject *unescape_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { PyErr_Format(PyExc_TypeError, "None not allowed"); return NULL; } dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken string - \\ at the end"); return NULL; } static PyObject *unescape(PyObject *self, PyObject *args) { return common_quote(args, unescape_body); } /* * urlencode of dict */ static bool urlenc(struct Buf *buf, PyObject *obj) { Py_ssize_t len; unsigned char *src, *dst; PyObject *strtmp = NULL; static const unsigned char hextbl[] = "0123456789abcdef"; bool ok = false; len = get_buffer(obj, &src, &strtmp); if (len < 0) goto failed; dst = buf_get_target_for(buf, len * 3); if (!dst) goto failed; while (len--) { if ((*src >= 'a' && *src <= 'z') || (*src >= 'A' && *src <= 'Z') || (*src >= '0' && *src <= '9') || (*src == '.' || *src == '_' || *src == '-')) { *dst++ = *src++; } else if (*src == ' ') { *dst++ = '+'; src++; } else { *dst++ = '%'; *dst++ = hextbl[*src >> 4]; *dst++ = hextbl[*src & 0xF]; src++; } } buf_set_target(buf, dst); ok = true; failed: Py_CLEAR(strtmp); return ok; } /* urlencode key+val pair. val can be None */ static bool urlenc_keyval(struct Buf *buf, PyObject *key, PyObject *value, bool needAmp) { if (needAmp && !buf_put(buf, '&')) return false; if (!urlenc(buf, key)) return false; if (value != Py_None) { if (!buf_put(buf, '=')) return false; if (!urlenc(buf, value)) return false; } return true; } /* encode native dict using PyDict_Next */ static PyObject *encode_dict(PyObject *data) { PyObject *key, *value; Py_ssize_t pos = 0; bool needAmp = false; struct Buf buf; if (!buf_init(&buf, 1024)) return NULL; while (PyDict_Next(data, &pos, &key, &value)) { if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; } return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); return NULL; } /* encode custom object using .iteritems() */ static PyObject *encode_dictlike(PyObject *data) { PyObject *key = NULL, *value = NULL, *tup, *iter; struct Buf buf; bool needAmp = false; if (!buf_init(&buf, 1024)) return NULL; iter = PyObject_CallMethod(data, "items", NULL); if (iter == NULL) { buf_free(&buf); return NULL; } while ((tup = PyIter_Next(iter))) { key = PySequence_GetItem(tup, 0); value = key ? PySequence_GetItem(tup, 1) : NULL; Py_CLEAR(tup); if (!key || !value) goto failed; if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; Py_CLEAR(key); Py_CLEAR(value); } /* allow error from iterator */ if (PyErr_Occurred()) goto failed; Py_CLEAR(iter); return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); Py_CLEAR(iter); Py_CLEAR(key); Py_CLEAR(value); return NULL; } static const char doc_db_urlencode[] = "Urlencode for database records.\n" "If a value is None the key is output without '='.\n" "\n" "C implementation."; static PyObject *db_urlencode(PyObject *self, PyObject *args) { PyObject *data; if (!PyArg_ParseTuple(args, "O", &data)) return NULL; if (PyDict_Check(data)) { return encode_dict(data); } else { return encode_dictlike(data); } } /* * urldecode to dict */ static inline int gethex(unsigned char c) { if (c >= '0' && c <= '9') return c - '0'; c |= 0x20; if (c >= 'a' && c <= 'f') return c - 'a' + 10; return -1; } static PyObject *get_elem(unsigned char *buf, unsigned char **src_p, unsigned char *src_end) { int c1, c2; unsigned char *src = *src_p; unsigned char *dst = buf; while (src < src_end) { switch (*src) { case '%': if (++src + 2 > src_end) goto hex_incomplete; if ((c1 = gethex(*src++)) < 0) goto hex_invalid; if ((c2 = gethex(*src++)) < 0) goto hex_invalid; *dst++ = (c1 << 4) | c2; break; case '+': *dst++ = ' '; src++; break; case '&': case '=': goto gotit; default: *dst++ = *src++; } } gotit: *src_p = src; return PyUnicode_FromStringAndSize((char *)buf, dst - buf); hex_incomplete: PyErr_Format(PyExc_ValueError, "Incomplete hex code"); return NULL; hex_invalid: PyErr_Format(PyExc_ValueError, "Invalid hex code"); return NULL; } static const char doc_db_urldecode[] = "Urldecode from string to dict.\n" "NULL are detected by missing '='.\n" "Duplicate keys are ignored - only latest is kept.\n" "\n" "C implementation."; static PyObject *db_urldecode(PyObject *self, PyObject *args) { unsigned char *src, *src_end; Py_ssize_t src_len; PyObject *dict = NULL, *key = NULL, *value = NULL; struct Buf buf; if (!PyArg_ParseTuple(args, "s#", &src, &src_len)) return NULL; if (!buf_init(&buf, src_len)) return NULL; dict = PyDict_New(); if (!dict) { buf_free(&buf); return NULL; } src_end = src + src_len; while (src < src_end) { if (*src == '&') { src++; continue; } key = get_elem(buf.ptr, &src, src_end); if (!key) goto failed; if (src < src_end && *src == '=') { src++; value = get_elem(buf.ptr, &src, src_end); if (value == NULL) goto failed; } else { Py_INCREF(Py_None); value = Py_None; } /* lessen memory usage by intering */ PyUnicode_InternInPlace(&key); if (PyDict_SetItem(dict, key, value) < 0) goto failed; Py_CLEAR(key); Py_CLEAR(value); } buf_free(&buf); return dict; failed: buf_free(&buf); Py_CLEAR(key); Py_CLEAR(value); Py_CLEAR(dict); return NULL; } /* * Module initialization */ static PyMethodDef methods[] = { { "quote_literal", quote_literal, METH_VARARGS, doc_quote_literal }, { "quote_copy", quote_copy, METH_VARARGS, doc_quote_copy }, { "quote_bytea_raw", quote_bytea_raw, METH_VARARGS, doc_quote_bytea_raw }, { "unescape", unescape, METH_VARARGS, doc_unescape }, { "db_urlencode", db_urlencode, METH_VARARGS, doc_db_urlencode }, { "db_urldecode", db_urldecode, METH_VARARGS, doc_db_urldecode }, { "unquote_literal", unquote_literal, METH_VARARGS, doc_unquote_literal }, { NULL } }; static PyModuleDef_Slot slots[] = {{0, NULL}}; static struct PyModuleDef module = { PyModuleDef_HEAD_INIT, .m_name = "_cquoting", .m_doc = "fast quoting for skytools", .m_size = 0, .m_methods = methods, .m_slots = slots }; PyMODINIT_FUNC PyInit__cquoting(void) { return PyModuleDef_Init(&module); } python-skytools-3.6.1/modules/get_buffer.h000066400000000000000000000020271373460333300206660ustar00rootroot00000000000000 /* * Get string data from Python object. */ static Py_ssize_t get_buffer(PyObject *obj, unsigned char **buf_p, PyObject **tmp_obj_p) { PyObject *str = NULL; Py_ssize_t res; /* check for None */ if (obj == Py_None) { PyErr_Format(PyExc_TypeError, "None is not allowed"); return -1; } /* quick path for bytes */ if (PyBytes_Check(obj)) { if (PyBytes_AsStringAndSize(obj, (char**)buf_p, &res) < 0) return -1; return res; } /* convert to bytes */ if (PyUnicode_Check(obj)) { /* no direct string access in abi3 */ *tmp_obj_p = PyUnicode_AsUTF8String(obj); } else if (PyMemoryView_Check(obj) || PyByteArray_Check(obj)) { /* no direct buffer access in abi3 */ *tmp_obj_p = PyBytes_FromObject(obj); } else { /* Not a string-like object, run str() or it. */ str = PyObject_Str(obj); if (str == NULL) return -1; *tmp_obj_p = PyUnicode_AsUTF8String(str); Py_CLEAR(str); } if (*tmp_obj_p == NULL) return -1; if (PyBytes_AsStringAndSize(*tmp_obj_p, (char**)buf_p, &res) < 0) return -1; return res; } python-skytools-3.6.1/modules/hashtext.c000066400000000000000000000200061373460333300203760ustar00rootroot00000000000000/* * Postgres hashes for Python. */ #define PY_SSIZE_T_CLEAN #include #include #include typedef uint32_t (*hash_fn_t)(const void *src, Py_ssize_t src_len); typedef uint8_t uint8; typedef uint16_t uint16; typedef uint32_t uint32; #define rot(x, k) (((x)<<(k)) | ((x)>>(32-(k)))) /* * Old Postgres hashtext() */ #define mix_old(a,b,c) \ { \ a -= b; a -= c; a ^= ((c)>>13); \ b -= c; b -= a; b ^= ((a)<<8); \ c -= a; c -= b; c ^= ((b)>>13); \ a -= b; a -= c; a ^= ((c)>>12); \ b -= c; b -= a; b ^= ((a)<<16); \ c -= a; c -= b; c ^= ((b)>>5); \ a -= b; a -= c; a ^= ((c)>>3); \ b -= c; b -= a; b ^= ((a)<<10); \ c -= a; c -= b; c ^= ((b)>>15); \ } static uint32_t hash_old_hashtext(const void *_k, Py_ssize_t keylen) { const unsigned char *k = _k; uint32 a, b, c; Py_ssize_t len; /* Set up the internal state */ len = keylen; a = b = 0x9e3779b9; /* the golden ratio; an arbitrary value */ c = 3923095; /* initialize with an arbitrary value */ /* handle most of the key */ while (len >= 12) { a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); mix_old(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ c += (uint32)keylen; switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the first byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } mix_old(a, b, c); /* report the result */ return c; } /* * New Postgres hashtext() */ #define UINT32_ALIGN_MASK 3 #define mix_new(a,b,c) \ { \ a -= c; a ^= rot(c, 4); c += b; \ b -= a; b ^= rot(a, 6); a += c; \ c -= b; c ^= rot(b, 8); b += a; \ a -= c; a ^= rot(c,16); c += b; \ b -= a; b ^= rot(a,19); a += c; \ c -= b; c ^= rot(b, 4); b += a; \ } #define final_new(a,b,c) \ { \ c ^= b; c -= rot(b,14); \ a ^= c; a -= rot(c,11); \ b ^= a; b -= rot(a,25); \ c ^= b; c -= rot(b,16); \ a ^= c; a -= rot(c, 4); \ b ^= a; b -= rot(a,14); \ c ^= b; c -= rot(b,24); \ } static uint32_t hash_new_hashtext(const void *_k, Py_ssize_t keylen) { const unsigned char *k = _k; uint32_t a, b, c; Py_ssize_t len = keylen; /* Set up the internal state */ a = b = c = 0x9e3779b9 + (uint32)len + 3923095; /* If the source pointer is word-aligned, we use word-wide fetches */ if (((uintptr_t) k & UINT32_ALIGN_MASK) == 0) { /* Code path for aligned source data */ const uint32_t *ka = (const uint32_t *) k; /* handle most of the key */ while (len >= 12) { a += ka[0]; b += ka[1]; c += ka[2]; mix_new(a, b, c); ka += 3; len -= 12; } /* handle the last 11 bytes */ k = (const unsigned char *) ka; #ifdef WORDS_BIGENDIAN switch (len) { case 11: c += ((uint32) k[10] << 8); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 8); /* fall through */ case 6: b += ((uint32) k[5] << 16); /* fall through */ case 5: b += ((uint32) k[4] << 24); /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 8); /* fall through */ case 2: a += ((uint32) k[1] << 16); /* fall through */ case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) { case 11: c += ((uint32) k[10] << 24); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 16); /* fall through */ case 6: b += ((uint32) k[5] << 8); /* fall through */ case 5: b += k[4]; /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 16); /* fall through */ case 2: a += ((uint32) k[1] << 8); /* fall through */ case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } else { /* Code path for non-aligned source data */ /* handle most of the key */ while (len >= 12) { #ifdef WORDS_BIGENDIAN a += (k[3] + ((uint32) k[2] << 8) + ((uint32) k[1] << 16) + ((uint32) k[0] << 24)); b += (k[7] + ((uint32) k[6] << 8) + ((uint32) k[5] << 16) + ((uint32) k[4] << 24)); c += (k[11] + ((uint32) k[10] << 8) + ((uint32) k[9] << 16) + ((uint32) k[8] << 24)); #else /* !WORDS_BIGENDIAN */ a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); #endif /* WORDS_BIGENDIAN */ mix_new(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ #ifdef WORDS_BIGENDIAN switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 8); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ case 8: b += k[7]; case 7: b += ((uint32) k[6] << 8); case 6: b += ((uint32) k[5] << 16); case 5: b += ((uint32) k[4] << 24); case 4: a += k[3]; case 3: a += ((uint32) k[2] << 8); case 2: a += ((uint32) k[1] << 16); case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } final_new(a, b, c); /* report the result */ return c; } /* * Common argument parsing. */ static PyObject *run_hash(PyObject *args, hash_fn_t real_hash) { unsigned char *src = NULL; Py_ssize_t src_len = 0; int32_t hash; if (!PyArg_ParseTuple(args, "s#", &src, &src_len)) return NULL; hash = real_hash(src, src_len); return PyLong_FromLong(hash); } /* * Python wrappers around actual hash functions. */ static PyObject *hashtext_old(PyObject *self, PyObject *args) { return run_hash(args, hash_old_hashtext); } static PyObject *hashtext_new(PyObject *self, PyObject *args) { return run_hash(args, hash_new_hashtext); } /* * Module initialization */ static PyMethodDef methods[] = { { "hashtext_old", hashtext_old, METH_VARARGS, "Old Postgres hashtext().\n" }, { "hashtext_new", hashtext_new, METH_VARARGS, "New Postgres hashtext().\n" }, { NULL } }; static PyModuleDef_Slot slots[] = {{0, NULL}}; static struct PyModuleDef module = { PyModuleDef_HEAD_INIT, .m_name = "_chashtext", .m_doc = "String hash functions", .m_size = 0, .m_methods = methods, .m_slots = slots }; PyMODINIT_FUNC PyInit__chashtext(void) { return PyModuleDef_Init(&module); } python-skytools-3.6.1/setup.cfg000066400000000000000000000011121373460333300165500ustar00rootroot00000000000000[tool:pytest] testpaths = tests [flake8] max-line-length = 120 ignore = W391, # blank line at end of file E265, # comment start - "# " E712, # no "==" in comparison to bool E301, # expected 1 blank line, found 0 W504, # line break after binary operator W503, # line break before binary operator E741, # ambiguous variable name E266, # too many leading '#' for block comment [pytype] exclude = skytools/apipkg.py skytools/fileutil.py inputs = skytools/*.py keep_going = True disable = import-error [mypy] ignore_missing_imports = True python-skytools-3.6.1/setup.py000066400000000000000000000031071373460333300164470ustar00rootroot00000000000000"""Setup for skytools module. """ import sys from setuptools import Extension, setup # load version _version = None with open("skytools/installer_config.py") as f: for ln in f: if ln.startswith("package_version"): _version = ln.split()[2].strip("\"'") len(_version) # load info with open("README.rst") as f: ldesc = f.read().strip() sdesc = ldesc.split("\n")[0].split("-", 1)[-1].strip() # use only stable abi abi3_options = dict( define_macros=[('Py_LIMITED_API', '0x03050000')], py_limited_api=True, ) # run actual setup setup( name="skytools", description="Utilities for database scripts", long_description=ldesc, version=_version, license="ISC", url="https://github.com/pgq/python-skytools", maintainer="Marko Kreen", maintainer_email="markokr@gmail.com", packages=["skytools"], ext_modules = [ Extension("skytools._cquoting", ["modules/cquoting.c"], **abi3_options), Extension("skytools._chashtext", ["modules/hashtext.c"], **abi3_options), ], zip_safe=False, classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Console", "Intended Audience :: Developers", "License :: OSI Approved :: ISC License (ISCL)", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", "Programming Language :: Python :: 3", "Topic :: Database", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Utilities", ], ) python-skytools-3.6.1/skytools/000077500000000000000000000000001373460333300166235ustar00rootroot00000000000000python-skytools-3.6.1/skytools/__init__.py000066400000000000000000000175111373460333300207410ustar00rootroot00000000000000 """Tools for Python database scripts.""" # pylint:disable=redefined-builtin,unused-wildcard-import,wildcard-import try: from skytools import apipkg as _apipkg except ImportError: # make pylint think everything is imported immediately from skytools.adminscript import * from skytools.config import * from skytools.dbservice import * from skytools.dbstruct import * from skytools.fileutil import * from skytools.gzlog import * from skytools.hashtext import * from skytools.natsort import * from skytools.parsing import * from skytools.psycopgwrapper import * from skytools.querybuilder import * from skytools.quoting import * from skytools.scripting import * from skytools.skylog import * from skytools.sockutil import * from skytools.sqltools import * from skytools.timeutil import * from skytools.utf8 import * _symbols = { # skytools.adminscript 'AdminScript': 'skytools.adminscript:AdminScript', # skytools.config 'Config': 'skytools.config:Config', # skytools.dbservice 'DBService': 'skytools.dbservice:DBService', 'ServiceContext': 'skytools.dbservice:ServiceContext', 'TableAPI': 'skytools.dbservice:TableAPI', 'get_record': 'skytools.dbservice:get_record', 'get_record_list': 'skytools.dbservice:get_record_list', 'make_record': 'skytools.dbservice:make_record', 'make_record_array': 'skytools.dbservice:make_record_array', # skytools.dbstruct 'SeqStruct': 'skytools.dbstruct:SeqStruct', 'TableStruct': 'skytools.dbstruct:TableStruct', 'T_ALL': 'skytools.dbstruct:T_ALL', 'T_CONSTRAINT': 'skytools.dbstruct:T_CONSTRAINT', 'T_DEFAULT': 'skytools.dbstruct:T_DEFAULT', 'T_GRANT': 'skytools.dbstruct:T_GRANT', 'T_INDEX': 'skytools.dbstruct:T_INDEX', 'T_OWNER': 'skytools.dbstruct:T_OWNER', 'T_PARENT': 'skytools.dbstruct:T_PARENT', 'T_PKEY': 'skytools.dbstruct:T_PKEY', 'T_RULE': 'skytools.dbstruct:T_RULE', 'T_SEQUENCE': 'skytools.dbstruct:T_SEQUENCE', 'T_TABLE': 'skytools.dbstruct:T_TABLE', 'T_TRIGGER': 'skytools.dbstruct:T_TRIGGER', # skytools.fileutil 'signal_pidfile': 'skytools.fileutil:signal_pidfile', 'write_atomic': 'skytools.fileutil:write_atomic', # skytools.gzlog 'gzip_append': 'skytools.gzlog:gzip_append', # skytools.hashtext 'hashtext_old': 'skytools.hashtext:hashtext_old', 'hashtext_new': 'skytools.hashtext:hashtext_new', # skytools.natsort 'natsort': 'skytools.natsort:natsort', 'natsort_icase': 'skytools.natsort:natsort_icase', 'natsorted': 'skytools.natsort:natsorted', 'natsorted_icase': 'skytools.natsort:natsorted_icase', 'natsort_key': 'skytools.natsort:natsort_key', 'natsort_key_icase': 'skytools.natsort:natsort_key_icase', # skytools.parsing 'dedent': 'skytools.parsing:dedent', 'hsize_to_bytes': 'skytools.parsing:hsize_to_bytes', 'merge_connect_string': 'skytools.parsing:merge_connect_string', 'parse_acl': 'skytools.parsing:parse_acl', 'parse_connect_string': 'skytools.parsing:parse_connect_string', 'parse_logtriga_sql': 'skytools.parsing:parse_logtriga_sql', 'parse_pgarray': 'skytools.parsing:parse_pgarray', 'parse_sqltriga_sql': 'skytools.parsing:parse_sqltriga_sql', 'parse_statements': 'skytools.parsing:parse_statements', 'parse_tabbed_table': 'skytools.parsing:parse_tabbed_table', 'sql_tokenizer': 'skytools.parsing:sql_tokenizer', # skytools.psycopgwrapper 'connect_database': 'skytools.psycopgwrapper:connect_database', 'DBError': 'skytools.psycopgwrapper:DBError', 'I_AUTOCOMMIT': 'skytools.psycopgwrapper:I_AUTOCOMMIT', 'I_READ_COMMITTED': 'skytools.psycopgwrapper:I_READ_COMMITTED', 'I_REPEATABLE_READ': 'skytools.psycopgwrapper:I_REPEATABLE_READ', 'I_SERIALIZABLE': 'skytools.psycopgwrapper:I_SERIALIZABLE', # skytools.querybuilder 'PLPyQuery': 'skytools.querybuilder:PLPyQuery', 'PLPyQueryBuilder': 'skytools.querybuilder:PLPyQueryBuilder', 'QueryBuilder': 'skytools.querybuilder:QueryBuilder', 'plpy_exec': 'skytools.querybuilder:plpy_exec', 'run_exists': 'skytools.querybuilder:run_exists', 'run_lookup': 'skytools.querybuilder:run_lookup', 'run_query': 'skytools.querybuilder:run_query', 'run_query_row': 'skytools.querybuilder:run_query_row', # skytools.quoting 'db_urldecode': 'skytools.quoting:db_urldecode', 'db_urlencode': 'skytools.quoting:db_urlencode', 'json_decode': 'skytools.quoting:json_decode', 'json_encode': 'skytools.quoting:json_encode', 'make_pgarray': 'skytools.quoting:make_pgarray', 'quote_bytea_copy': 'skytools.quoting:quote_bytea_copy', 'quote_bytea_literal': 'skytools.quoting:quote_bytea_literal', 'quote_bytea_raw': 'skytools.quoting:quote_bytea_raw', 'quote_copy': 'skytools.quoting:quote_copy', 'quote_fqident': 'skytools.quoting:quote_fqident', 'quote_ident': 'skytools.quoting:quote_ident', 'quote_json': 'skytools.quoting:quote_json', 'quote_literal': 'skytools.quoting:quote_literal', 'quote_statement': 'skytools.quoting:quote_statement', 'unescape': 'skytools.quoting:unescape', 'unescape_copy': 'skytools.quoting:unescape_copy', 'unquote_fqident': 'skytools.quoting:unquote_fqident', 'unquote_ident': 'skytools.quoting:unquote_ident', 'unquote_literal': 'skytools.quoting:unquote_literal', # skytools.scripting 'BaseScript': 'skytools.scripting:BaseScript', 'daemonize': 'skytools.scripting:daemonize', 'DBScript': 'skytools.scripting:DBScript', 'UsageError': 'skytools.scripting:UsageError', # skytools.skylog 'getLogger': 'skytools.skylog:getLogger', # skytools.sockutil 'set_cloexec': 'skytools.sockutil:set_cloexec', 'set_nonblocking': 'skytools.sockutil:set_nonblocking', 'set_tcp_keepalive': 'skytools.sockutil:set_tcp_keepalive', # skytools.sqltools 'dbdict': 'skytools.sqltools:dbdict', 'CopyPipe': 'skytools.sqltools:CopyPipe', 'DBFunction': 'skytools.sqltools:DBFunction', 'DBLanguage': 'skytools.sqltools:DBLanguage', 'DBObject': 'skytools.sqltools:DBObject', 'DBSchema': 'skytools.sqltools:DBSchema', 'DBTable': 'skytools.sqltools:DBTable', 'Snapshot': 'skytools.sqltools:Snapshot', 'db_install': 'skytools.sqltools:db_install', 'exists_function': 'skytools.sqltools:exists_function', 'exists_language': 'skytools.sqltools:exists_language', 'exists_schema': 'skytools.sqltools:exists_schema', 'exists_sequence': 'skytools.sqltools:exists_sequence', 'exists_table': 'skytools.sqltools:exists_table', 'exists_temp_table': 'skytools.sqltools:exists_temp_table', 'exists_type': 'skytools.sqltools:exists_type', 'exists_view': 'skytools.sqltools:exists_view', 'fq_name': 'skytools.sqltools:fq_name', 'fq_name_parts': 'skytools.sqltools:fq_name_parts', 'full_copy': 'skytools.sqltools:full_copy', 'get_table_columns': 'skytools.sqltools:get_table_columns', 'get_table_oid': 'skytools.sqltools:get_table_oid', 'get_table_pkeys': 'skytools.sqltools:get_table_pkeys', 'installer_apply_file': 'skytools.sqltools:installer_apply_file', 'installer_find_file': 'skytools.sqltools:installer_find_file', 'magic_insert': 'skytools.sqltools:magic_insert', 'mk_delete_sql': 'skytools.sqltools:mk_delete_sql', 'mk_insert_sql': 'skytools.sqltools:mk_insert_sql', 'mk_update_sql': 'skytools.sqltools:mk_update_sql', # skytools.timeutil 'FixedOffsetTimezone': 'skytools.timeutil:FixedOffsetTimezone', 'datetime_to_timestamp': 'skytools.timeutil:datetime_to_timestamp', 'parse_iso_timestamp': 'skytools.timeutil:parse_iso_timestamp', # skytools.utf8 'safe_utf8_decode': 'skytools.utf8:safe_utf8_decode', } __all__ = tuple(_symbols) _symbols['__version__'] = 'skytools.installer_config:package_version' # lazy-import exported vars _apipkg.initpkg(__name__, _symbols, {'apipkg': _apipkg}) python-skytools-3.6.1/skytools/_pyquoting.py000066400000000000000000000110761373460333300214000ustar00rootroot00000000000000"""Various helpers for string quoting/unquoting. Here is pure Python that should match C code in _cquoting. """ import re from urllib.parse import quote_plus, unquote_plus # noqa __all__ = ( "quote_literal", "quote_copy", "quote_bytea_raw", "db_urlencode", "db_urldecode", "unescape", "unquote_literal", ) # # SQL quoting # def quote_literal(s): r"""Quote a literal value for SQL. If string contains '\\', extended E'' quoting is used, otherwise standard quoting. Input value of None results in string "null" without quotes. Python implementation. """ if s is None: return "null" s = str(s).replace("'", "''") s2 = s.replace("\\", "\\\\") if len(s) != len(s2): return "E'" + s2 + "'" return "'" + s2 + "'" def quote_copy(s): """Quoting for copy command. None is converted to \\N. Python implementation. """ if s is None: return "\\N" s = str(s) s = s.replace("\\", "\\\\") s = s.replace("\t", "\\t") s = s.replace("\n", "\\n") s = s.replace("\r", "\\r") return s _bytea_map = None def quote_bytea_raw(s): """Quoting for bytea parser. Returns None as None. Python implementation. """ global _bytea_map if s is None: return None if not isinstance(s, bytes): raise TypeError("Expect bytes") if 1 and _bytea_map is None: _bytea_map = {} for i in range(256): c = i if i < 0x20 or i >= 0x7F: _bytea_map[c] = "\\%03o" % i elif i == ord("\\"): _bytea_map[c] = "\\\\" else: _bytea_map[c] = '%c' % i return "".join([_bytea_map[b] for b in s]) # # Database specific urlencode and urldecode. # def db_urlencode(dict_val): """Database specific urlencode. Encode None as key without '='. That means that in "foo&bar=", foo is NULL and bar is empty string. Python implementation. """ elem_list = [] for k, v in dict_val.items(): if v is None: elem = quote_plus(str(k)) else: elem = quote_plus(str(k)) + '=' + quote_plus(str(v)) elem_list.append(elem) return '&'.join(elem_list) def db_urldecode(qs): """Database specific urldecode. Decode key without '=' as None. This also does not support one key several times. Python implementation. """ res = {} for elem in qs.split('&'): if not elem: continue pair = elem.split('=', 1) name = unquote_plus(pair[0]) if len(pair) == 1: res[name] = None else: res[name] = unquote_plus(pair[1]) return res # # Remove C-like backslash escapes # _esc_re = r"\\([0-7]{1,3}|.)" _esc_rc = re.compile(_esc_re) _esc_map = { 't': '\t', 'n': '\n', 'r': '\r', 'a': '\a', 'b': '\b', "'": "'", '"': '"', '\\': '\\', } def _sub_unescape_c(m): """unescape single escape seq.""" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v else: return chr(int(v, 8)) def unescape(val): """Removes C-style escapes from string. Python implementation. """ return _esc_rc.sub(_sub_unescape_c, val) _esql_re = r"''|\\([0-7]{1,3}|.)" _esql_rc = re.compile(_esql_re) def _sub_unescape_sqlext(m): """Unescape extended-quoted string.""" if m.group() == "''": return "'" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v return chr(int(v, 8)) def unquote_literal(val, stdstr=False): """Unquotes SQL string. E'..' -> extended quoting. '..' -> standard or extended quoting null -> None other -> returned as-is """ if val[0] == "'" and val[-1] == "'": if stdstr: return val[1:-1].replace("''", "'") else: return _esql_rc.sub(_sub_unescape_sqlext, val[1:-1]) elif len(val) > 2 and val[0] in ('E', 'e') and val[1] == "'" and val[-1] == "'": return _esql_rc.sub(_sub_unescape_sqlext, val[2:-1]) elif len(val) >= 2 and val[0] == '$' and val[-1] == '$': p1 = val.find('$', 1) p2 = val.rfind('$', 1, -1) if p1 > 0 and p2 > p1: t1 = val[:p1 + 1] t2 = val[p2:] if t1 == t2: return val[len(t1):-len(t1)] raise ValueError("Bad dollar-quoted string") elif val.lower() == "null": return None return val python-skytools-3.6.1/skytools/adminscript.py000066400000000000000000000077161373460333300215250ustar00rootroot00000000000000"""Admin scripting. """ # allow getargspec # pylint:disable=deprecated-method import inspect import sys import skytools __all__ = ['AdminScript'] class AdminScript(skytools.DBScript): """Contains common admin script tools. Second argument (first is .ini file) is taken as command name. If class method 'cmd_' + arg exists, it is called, otherwise error is given. """ commands_without_pidfile = {} def __init__(self, service_name, args): """AdminScript init.""" super(AdminScript, self).__init__(service_name, args) if len(self.args) < 2: self.log.error("need command") sys.exit(1) cmd = self.args[1] if cmd in self.commands_without_pidfile: self.pidfile = None if self.pidfile: self.pidfile = self.pidfile + ".admin" def work(self): """Non-looping work function, calls command function.""" self.set_single_loop(1) cmd = self.args[1] cmdargs = self.args[2:] # find function fname = "cmd_" + cmd.replace('-', '_') if not hasattr(self, fname): self.log.error('bad subcommand, see --help for usage') sys.exit(1) fn = getattr(self, fname) # check if correct number of arguments (args, varargs, ___varkw, ___defaults) = inspect.getargspec(fn) n_args = len(args) - 1 # drop 'self' if varargs is None and n_args != len(cmdargs): helpstr = "" if n_args: helpstr = ": " + " ".join(args[1:]) self.log.error("command '%s' got %d args, but expects %d%s", cmd, len(cmdargs), n_args, helpstr) sys.exit(1) # run command fn(*cmdargs) def fetch_list(self, db, sql, args, keycol=None): """Fetch a resultset from db, optionally turning it into value list.""" curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if not keycol: res = rows else: res = [r[keycol] for r in rows] return res def display_table(self, db, desc, sql, args=(), fields=(), fieldfmt=None): """Display multirow query as a table.""" self.log.debug("display_table: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if len(rows) == 0: return 0 if not fieldfmt: fieldfmt = {} if not fields: fields = [f[0] for f in curs.description] widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = row[k] and len(str(row[k])) or 0 widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) if desc: print(desc) print(fmt % tuple(fields)) print(fmt % tuple(['-' * (w - 2) for w in widths])) #print(fmt % tuple(['-'*15] * len(fields))) for row in rows: vals = [] for field in fields: val = row[field] if field in fieldfmt: val = fieldfmt[field](val) vals.append(val) print(fmt % tuple(vals)) print('\n') return 1 def exec_stmt(self, db, sql, args): """Run regular non-query SQL on db.""" self.log.debug("exec_stmt: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) db.commit() def exec_query(self, db, sql, args): """Run regular query SQL on db.""" self.log.debug("exec_query: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) res = curs.fetchall() db.commit() return res python-skytools-3.6.1/skytools/apipkg.py000066400000000000000000000150631373460333300204550ustar00rootroot00000000000000""" apipkg: control the exported namespace of a Python package. see https://pypi.python.org/pypi/apipkg (c) holger krekel, 2009 - MIT license """ #pylint: skip-file import os import sys from types import ModuleType __version__ = "1.5" def _py_abspath(path): """ special version of abspath that will leave paths from jython jars alone """ if path.startswith('__pyclasspath__'): return path else: return os.path.abspath(path) def distribution_version(name): """try to get the version of the named distribution, returs None on failure""" from pkg_resources import get_distribution, DistributionNotFound try: dist = get_distribution(name) except DistributionNotFound: pass else: return dist.version def initpkg(pkgname, exportdefs, attr=None, eager=False): """ initialize given package from the export definitions. """ attr = attr or {} oldmod = sys.modules.get(pkgname) d = {} f = getattr(oldmod, '__file__', None) if f: f = _py_abspath(f) d['__file__'] = f if hasattr(oldmod, '__version__'): d['__version__'] = oldmod.__version__ if hasattr(oldmod, '__loader__'): d['__loader__'] = oldmod.__loader__ if hasattr(oldmod, '__path__'): d['__path__'] = [_py_abspath(p) for p in oldmod.__path__] if hasattr(oldmod, '__package__'): d['__package__'] = oldmod.__package__ if '__doc__' not in exportdefs and getattr(oldmod, '__doc__', None): d['__doc__'] = oldmod.__doc__ d.update(attr) if hasattr(oldmod, "__dict__"): oldmod.__dict__.update(d) mod = ApiModule(pkgname, exportdefs, implprefix=pkgname, attr=d) sys.modules[pkgname] = mod # eagerload in bypthon to avoid their monkeypatching breaking packages if 'bpython' in sys.modules or eager: for module in list(sys.modules.values()): if isinstance(module, ApiModule): module.__dict__ def importobj(modpath, attrname): """imports a module, then resolves the attrname on it""" module = __import__(modpath, None, None, ['__doc__']) if not attrname: return module retval = module names = attrname.split(".") for x in names: retval = getattr(retval, x) return retval class ApiModule(ModuleType): """the magical lazy-loading module standing""" def __docget(self): try: return self.__doc except AttributeError: if '__doc__' in self.__map__: return self.__makeattr('__doc__') def __docset(self, value): self.__doc = value __doc__ = property(__docget, __docset) def __init__(self, name, importspec, implprefix=None, attr=None): super().__init__(name) self.__all__ = [x for x in importspec if x != '__onfirstaccess__'] self.__map__ = {} self.__implprefix__ = implprefix or name if attr: for name, val in attr.items(): # print "setting", self.__name__, name, val setattr(self, name, val) for name, importspec in importspec.items(): if isinstance(importspec, dict): subname = '%s.%s' % (self.__name__, name) apimod = ApiModule(subname, importspec, implprefix) sys.modules[subname] = apimod setattr(self, name, apimod) else: parts = importspec.split(':') modpath = parts.pop(0) attrname = parts and parts[0] or "" if modpath[0] == '.': modpath = implprefix + modpath if not attrname: subname = '%s.%s' % (self.__name__, name) apimod = AliasModule(subname, modpath) sys.modules[subname] = apimod if '.' not in name: setattr(self, name, apimod) else: self.__map__[name] = (modpath, attrname) def __repr__(self): repr_list = [] if hasattr(self, '__version__'): repr_list.append("version=" + repr(self.__version__)) if hasattr(self, '__file__'): repr_list.append('from ' + repr(self.__file__)) if repr_list: return '' % (self.__name__, " ".join(repr_list)) return '' % (self.__name__,) def __makeattr(self, name): """lazily compute value for name or raise AttributeError if unknown.""" # print "makeattr", self.__name__, name target = None if '__onfirstaccess__' in self.__map__: target = self.__map__.pop('__onfirstaccess__') importobj(*target)() try: modpath, attrname = self.__map__[name] except KeyError: if target is not None and name != '__onfirstaccess__': # retry, onfirstaccess might have set attrs return getattr(self, name) raise AttributeError(name) else: result = importobj(modpath, attrname) setattr(self, name, result) try: del self.__map__[name] except KeyError: pass # in a recursive-import situation a double-del can happen return result __getattr__ = __makeattr @property def __dict__(self): # force all the content of the module # to be loaded when __dict__ is read dictdescr = ModuleType.__dict__['__dict__'] dict = dictdescr.__get__(self) if dict is not None: hasattr(self, 'some') for name in self.__all__: try: self.__makeattr(name) except AttributeError: pass return dict def AliasModule(modname, modpath, attrname=None): mod = [] def getmod(): if not mod: x = importobj(modpath, None) if attrname is not None: x = getattr(x, attrname) mod.append(x) return mod[0] class AliasModule(ModuleType): def __repr__(self): x = modpath if attrname: x += "." + attrname return '' % (modname, x) def __getattribute__(self, name): try: return getattr(getmod(), name) except ImportError: return None def __setattr__(self, name, value): setattr(getmod(), name, value) def __delattr__(self, name): delattr(getmod(), name) return AliasModule(str(modname)) python-skytools-3.6.1/skytools/checker.py000066400000000000000000000504151373460333300206060ustar00rootroot00000000000000"""Catch moment when tables are in sync on master and slave. """ import os import subprocess import sys import time import skytools class TableRepair: """Checks that tables in two databases are in sync.""" def __init__(self, table_name, log): self.table_name = table_name self.fq_table_name = skytools.quote_fqident(table_name) self.log = log self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None self.reset() def reset(self): self.cnt_insert = 0 self.cnt_update = 0 self.cnt_delete = 0 self.total_src = 0 self.total_dst = 0 self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None def do_repair(self, src_db, dst_db, where, pfx='repair', apply_fixes=False): """Actual comparison.""" self.reset() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.apply_fixes = apply_fixes if apply_fixes: self.apply_cursor = dst_curs self.log.info('Checking %s', self.table_name) copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where) dump_src = "%s.%s.src" % (pfx, self.table_name) dump_dst = "%s.%s.dst" % (pfx, self.table_name) fix = "%s.%s.fix" % (pfx, self.table_name) self.log.info("Dumping src table: %s", self.table_name) self.dump_table(copy_tbl, src_curs, dump_src) src_db.commit() self.log.info("Dumping dst table: %s", self.table_name) self.dump_table(copy_tbl, dst_curs, dump_dst) dst_db.commit() self.log.info("Sorting src table: %s", self.table_name) self.do_sort(dump_src, dump_src + '.sorted') self.log.info("Sorting dst table: %s", self.table_name) self.do_sort(dump_dst, dump_dst + '.sorted') self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix) os.unlink(dump_src) os.unlink(dump_dst) os.unlink(dump_src + ".sorted") os.unlink(dump_dst + ".sorted") if apply_fixes: dst_db.commit() def do_sort(self, src, dst): p = subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) s_ver = p.communicate()[0] del p xenv = os.environ.copy() xenv['LANG'] = 'C' xenv['LC_ALL'] = 'C' cmdline = ['sort', '-T', '.'] if s_ver.find("coreutils") > 0: cmdline.append('-S') cmdline.append('30%') cmdline.append('-o') cmdline.append(dst) cmdline.append(src) p = subprocess.Popen(cmdline, env=xenv) if p.wait() != 0: raise Exception('sort failed') def gen_copy_tbl(self, src_curs, dst_curs, where): """Create COPY expession from common fields.""" self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name) dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name) if dst_pkey != self.pkey_list: self.log.error('pkeys do not match') sys.exit(1) src_cols = skytools.get_table_columns(src_curs, self.table_name) dst_cols = skytools.get_table_columns(dst_curs, self.table_name) field_list = [] for f in self.pkey_list: field_list.append(f) for f in src_cols: if f in self.pkey_list: continue if f in dst_cols: field_list.append(f) self.common_fields = field_list fqlist = [skytools.quote_ident(col) for col in field_list] tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name) if where: tbl_expr += ' where ' + where tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr self.log.debug("using copy expr: %s", tbl_expr) return tbl_expr def dump_table(self, copy_cmd, curs, fn): """Dump table to disk.""" with open(fn, "w", 64 * 1024) as f: curs.copy_expert(copy_cmd, f) self.log.info('%s: Got %d bytes', self.table_name, f.tell()) def get_row(self, ln): """Parse a row into dict.""" if not ln: return None t = ln[:-1].split('\t') row = {} for i in range(len(self.common_fields)): row[self.common_fields[i]] = t[i] return row def dump_compare(self, src_fn, dst_fn, fix): """Dump + compare single table.""" self.log.info("Comparing dumps: %s", self.table_name) f1 = open(src_fn, "r", 64 * 1024) f2 = open(dst_fn, "r", 64 * 1024) src_ln = f1.readline() dst_ln = f2.readline() if src_ln: self.total_src += 1 if dst_ln: self.total_dst += 1 if os.path.isfile(fix): os.unlink(fix) while src_ln or dst_ln: keep_src = keep_dst = 0 if src_ln != dst_ln: src_row = self.get_row(src_ln) dst_row = self.get_row(dst_ln) diff = self.cmp_keys(src_row, dst_row) if diff > 0: # src > dst self.got_missed_delete(dst_row, fix) keep_src = 1 elif diff < 0: # src < dst self.got_missed_insert(src_row, fix) keep_dst = 1 else: if self.cmp_data(src_row, dst_row) != 0: self.got_missed_update(src_row, dst_row, fix) if not keep_src: src_ln = f1.readline() if src_ln: self.total_src += 1 if not keep_dst: dst_ln = f2.readline() if dst_ln: self.total_dst += 1 self.log.info("finished %s: src: %d rows, dst: %d rows," " missed: %d inserts, %d updates, %d deletes", self.table_name, self.total_src, self.total_dst, self.cnt_insert, self.cnt_update, self.cnt_delete) f1.close() f2.close() def got_missed_insert(self, src_row, fn): """Create sql for missed insert.""" self.cnt_insert += 1 fld_list = self.common_fields fq_list = [] val_list = [] for f in fld_list: fq_list.append(skytools.quote_ident(f)) v = skytools.unescape_copy(src_row[f]) val_list.append(skytools.quote_literal(v)) q = "insert into %s (%s) values (%s);" % ( self.fq_table_name, ", ".join(fq_list), ", ".join(val_list)) self.show_fix(q, 'insert', fn) def got_missed_update(self, src_row, dst_row, fn): """Create sql for missed update.""" self.cnt_update += 1 fld_list = self.common_fields set_list = [] whe_list = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f])) for f in fld_list: v1 = src_row[f] v2 = dst_row[f] if self.cmp_value(v1, v2) == 0: continue self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1)) self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2)) q = "update only %s set %s where %s;" % ( self.fq_table_name, ", ".join(set_list), " and ".join(whe_list)) self.show_fix(q, 'update', fn) def got_missed_delete(self, dst_row, fn): """Create sql for missed delete.""" self.cnt_delete += 1 whe_list = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f])) q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list)) self.show_fix(q, 'delete', fn) def show_fix(self, q, desc, fn): """Print/write/apply repair sql.""" self.log.debug("missed %s: %s", desc, q) with open(fn, "a") as f: f.write("%s\n" % q) if self.apply_fixes: self.apply_cursor.execute(q) def addeq(self, dst_list, f, v): """Add quoted SET.""" vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def addcmp(self, dst_list, f, v): """Add quoted comparison.""" if v is None: s = "%s is null" % f else: vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def cmp_data(self, src_row, dst_row): """Compare data field-by-field.""" for k in self.common_fields: v1 = src_row[k] v2 = dst_row[k] if self.cmp_value(v1, v2) != 0: return -1 return 0 def cmp_value(self, v1, v2): """Compare single field, tolerates tz vs notz dates.""" if v1 == v2: return 0 # try to work around tz vs. notz z1 = len(v1) z2 = len(v2) if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+': v1 = v1[:-3] if v1 == v2: return 0 elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+': v2 = v2[:-3] if v1 == v2: return 0 return -1 def cmp_keys(self, src_row, dst_row): """Compare primary keys of the rows. Returns 1 if src > dst, -1 if src < dst and 0 if src == dst""" # None means table is done. tag it larger than any existing row. if src_row is None: if dst_row is None: return 0 return 1 elif dst_row is None: return -1 for k in self.pkey_list: v1 = src_row[k] v2 = dst_row[k] if v1 < v2: return -1 elif v1 > v2: return 1 return 0 class Syncer(skytools.DBScript): """Checks that tables in two databases are in sync.""" lock_timeout = 10 ticker_lag_limit = 20 consumer_lag_limit = 20 def sync_table(self, cstr1, cstr2, queue_name, consumer_name, table_name): """Syncer main function. Returns (src_db, dst_db) that are in transaction where table should be in sync. """ setup_db = self.get_database('setup_db', connstr=cstr1, autocommit=1) lock_db = self.get_database('lock_db', connstr=cstr1) src_db = self.get_database('src_db', connstr=cstr1, isolation_level=skytools.I_REPEATABLE_READ) dst_db = self.get_database('dst_db', connstr=cstr2, isolation_level=skytools.I_REPEATABLE_READ) lock_curs = lock_db.cursor() setup_curs = setup_db.cursor() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.check_consumer(setup_curs, queue_name, consumer_name) # lock table in separate connection self.log.info('Locking %s', table_name) self.set_lock_timeout(lock_curs) lock_time = time.time() lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(table_name)) # now wait until consumer has updated target table until locking self.log.info('Syncing %s', table_name) # consumer must get further than this tick self.force_tick(setup_curs, queue_name) # try to force second tick also self.force_tick(setup_curs, queue_name) # take server time setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')") tpos = setup_curs.fetchone()[0] # now wait while True: time.sleep(0.5) q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)" setup_curs.execute(q, [tpos, queue_name, consumer_name]) res = setup_curs.fetchall() if len(res) == 0: raise Exception('No such consumer: %s/%s' % (queue_name, consumer_name)) row = res[0] self.log.debug("tpos=%s now=%s lag=%s ok=%s", tpos, row[1], row[2], row[0]) if row[0]: break # limit lock time if time.time() > lock_time + self.lock_timeout: self.log.error('Consumer lagging too much, exiting') lock_db.rollback() sys.exit(1) # take snapshot on provider side src_db.commit() src_curs.execute("SELECT 1") # take snapshot on subscriber side dst_db.commit() dst_curs.execute("SELECT 1") # release lock lock_db.commit() self.close_database('setup_db') self.close_database('lock_db') return (src_db, dst_db) def set_lock_timeout(self, curs): ms = int(1000 * self.lock_timeout) if ms > 0: q = "SET LOCAL statement_timeout = %d" % ms self.log.debug(q) curs.execute(q) def check_consumer(self, curs, queue_name, consumer_name): """ Before locking anything check if consumer is working ok. """ self.log.info("Queue: %s Consumer: %s", queue_name, consumer_name) curs.execute('select current_database()') self.log.info('Actual db: %s', curs.fetchone()[0]) # get ticker lag q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);" curs.execute(q, [queue_name]) ticker_lag = curs.fetchone()[0] self.log.info("Ticker lag: %s", ticker_lag) # get consumer lag q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);" curs.execute(q, [queue_name, consumer_name]) res = curs.fetchall() if len(res) == 0: self.log.error('check_consumer: No such consumer: %s/%s', queue_name, consumer_name) sys.exit(1) consumer_lag = res[0][0] # check that lag is acceptable self.log.info("Consumer lag: %s", consumer_lag) if consumer_lag > ticker_lag + 10: self.log.error('Consumer lagging too much, cannot proceed') sys.exit(1) def force_tick(self, curs, queue_name): """ Force tick into source queue so that consumer can move on faster """ q = "select pgq.force_tick(%s)" curs.execute(q, [queue_name]) res = curs.fetchone() cur_pos = res[0] start = time.time() while True: time.sleep(0.5) curs.execute(q, [queue_name]) res = curs.fetchone() if res[0] != cur_pos: # new pos return res[0] # dont loop more than 10 secs dur = time.time() - start if dur > 10 and not self.options.force: raise Exception("Ticker seems dead") class Checker(Syncer): """Checks that tables in two databases are in sync. Config options:: ## data_checker ## confdb = dbname=confdb host=confdb.service extra_connstr = user=marko # one of: compare, repair, repair-apply, compare-repair-apply check_type = compare # random params used in queries cluster_name = instance_name = proxy_host = proxy_db = # list of tables to be compared table_list = foo, bar, baz where_expr = (hashtext(key_user_name) & %%(max_slot)s) in (%%(slots)s) # gets no args source_query = select h.hostname, d.db_name from dba.cluster c join dba.cluster_host ch on (ch.key_cluster = c.id_cluster) join conf.host h on (h.id_host = ch.key_host) join dba.database d on (d.key_host = ch.key_host) where c.db_name = '%(cluster_name)s' and c.instance_name = '%(instance_name)s' and d.mk_db_type = 'partition' and d.mk_db_status = 'active' order by d.db_name, h.hostname target_query = select db_name, hostname, slots, max_slot from dba.get_cross_targets(%%(hostname)s, %%(db_name)s, '%(proxy_host)s', '%(proxy_db)s') consumer_query = select q.queue_name, c.consumer_name from conf.host h join dba.database d on (d.key_host = h.id_host) join dba.pgq_queue q on (q.key_database = d.id_database) join dba.pgq_consumer c on (c.key_queue = q.id_queue) where h.hostname = %%(hostname)s and d.db_name = %%(db_name)s and q.queue_name like 'xm%%%%' """ def __init__(self, args): """Checker init.""" super(Checker, self).__init__('data_checker', args) self.set_single_loop(1) self.log.info('Checker starting %s', str(args)) self.lock_timeout = self.cf.getfloat('lock_timeout', 10) self.table_list = self.cf.getlist('table_list') def work(self): """Syncer main function.""" source_query = self.cf.get('source_query') target_query = self.cf.get('target_query') consumer_query = self.cf.get('consumer_query') where_expr = self.cf.get('where_expr') extra_connstr = self.cf.get('extra_connstr') check = self.cf.get('check_type', 'compare') confdb = self.get_database('confdb', autocommit=1) curs = confdb.cursor() curs.execute(source_query) for src_row in curs.fetchall(): s_host = src_row['hostname'] s_db = src_row['db_name'] curs.execute(consumer_query, src_row) r = curs.fetchone() consumer_name = r['consumer_name'] queue_name = r['queue_name'] curs.execute(target_query, src_row) for dst_row in curs.fetchall(): d_db = dst_row['db_name'] d_host = dst_row['hostname'] cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr) cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr) where = where_expr % dst_row self.log.info('Source: db=%s host=%s queue=%s consumer=%s', s_db, s_host, queue_name, consumer_name) self.log.info('Target: db=%s host=%s where=%s', d_db, d_host, where) for tbl in self.table_list: src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl) if check == 'compare': self.do_compare(tbl, src_db, dst_db, where) elif check == 'repair': r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, False) elif check == 'repair-apply': r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) elif check == 'compare-repair-apply': ok = self.do_compare(tbl, src_db, dst_db, where) if not ok: r = TableRepair(tbl, self.log) r.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) else: raise Exception('unknown check type') self.reset() def do_compare(self, tbl, src_db, dst_db, where): """Actual comparison.""" src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.log.info('Counting %s', tbl) q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where q = self.cf.get('compare_sql', q) q = q.replace('_TABLE_', skytools.quote_fqident(tbl)) f = "%(cnt)d rows, checksum=%(chksum)s" f = self.cf.get('compare_fmt', f) self.log.debug("srcdb: %s", q) src_curs.execute(q) src_row = src_curs.fetchone() src_str = f % src_row self.log.info("srcdb: %s", src_str) self.log.debug("dstdb: %s", q) dst_curs.execute(q) dst_row = dst_curs.fetchone() dst_str = f % dst_row self.log.info("dstdb: %s", dst_str) src_db.commit() dst_db.commit() if src_str != dst_str: self.log.warning("%s: Results do not match!", tbl) return False else: self.log.info("%s: OK!", tbl) return True if __name__ == '__main__': script = Checker(sys.argv[1:]) script.start() python-skytools-3.6.1/skytools/config.py000066400000000000000000000243751373460333300204550ustar00rootroot00000000000000"""Nicer config class. """ import os import os.path import re import socket from configparser import MAX_INTERPOLATION_DEPTH, ConfigParser from configparser import Error as ConfigError from configparser import ( ExtendedInterpolation, Interpolation, InterpolationDepthError, InterpolationError, NoOptionError, NoSectionError, ) import skytools __all__ = ( 'Config', 'NoOptionError', 'ConfigError', 'ConfigParser', 'ExtendedConfigParser', 'ExtendedCompatConfigParser' ) class Config: """Bit improved ConfigParser. Additional features: - Remembers section. - Accepts defaults in get() functions. - List value support. """ def __init__(self, main_section, filename, sane_config=None, user_defs=None, override=None, ignore_defs=False): """Initialize Config and read from file. """ # use config file name as default job_name if filename: job_name = os.path.splitext(os.path.basename(filename))[0] else: job_name = main_section # initialize defaults, make them usable in config file if ignore_defs: self.defs = {} else: self.defs = { 'job_name': job_name, 'service_name': main_section, 'host_name': socket.gethostname(), } if filename: self.defs['config_dir'] = os.path.dirname(filename) self.defs['config_file'] = filename if user_defs: self.defs.update(user_defs) self.main_section = main_section self.filename = filename self.override = override or {} self.cf = ConfigParser() if filename is None: self.cf.add_section(main_section) elif not os.path.isfile(filename): raise ConfigError('Config file not found: ' + filename) self.reload() def reload(self): """Re-reads config file.""" if self.filename: self.cf.read(self.filename) if not self.cf.has_section(self.main_section): raise NoSectionError(self.main_section) # apply default if key not set for k, v in self.defs.items(): if not self.cf.has_option(self.main_section, k): self.cf.set(self.main_section, k, v) # apply overrides if self.override: for k, v in self.override.items(): self.cf.set(self.main_section, k, v) def get(self, key, default=None): """Reads string value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return str(self.cf.get(self.main_section, key)) def getint(self, key, default=None): """Reads int value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getint(self.main_section, key) def getboolean(self, key, default=None): """Reads boolean value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getboolean(self.main_section, key) def getfloat(self, key, default=None): """Reads float value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getfloat(self.main_section, key) def getlist(self, key, default=None): """Reads comma-separated list from key.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res = [] if not s: return res for v in s.split(","): res.append(v.strip()) return res def getdict(self, key, default=None): """Reads key-value dict from parameter. Key and value are separated with ':'. If missing, key itself is taken as value. """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res = {} if not s: return res for kv in s.split(","): tmp = kv.split(':', 1) if len(tmp) > 1: k = tmp[0].strip() v = tmp[1].strip() else: k = kv.strip() v = k res[k] = v return res def getfile(self, key, default=None): """Reads filename from config. In addition to reading string value, expands ~ to user directory. """ fn = self.get(key, default) if fn == "" or fn == "-": return fn # simulate that the cwd is script location #path = os.path.dirname(sys.argv[0]) # seems bad idea, cwd should be cwd fn = os.path.expanduser(fn) return fn def getbytes(self, key, default=None): """Reads a size value in human format, if not set then default. Examples: 1, 2 B, 3K, 4 MB """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) s = default else: s = self.cf.get(self.main_section, key) return skytools.hsize_to_bytes(s) def get_wildcard(self, key, values=(), default=None): """Reads a wildcard property from conf and returns its string value, if not set then default.""" orig_key = key keys = [key] for wild in values: key = key.replace('*', wild, 1) keys.append(key) keys.reverse() for k in keys: if self.cf.has_option(self.main_section, k): return self.cf.get(self.main_section, k) if default is None: raise NoOptionError(orig_key, self.main_section) return default def sections(self): """Returns list of sections in config file, excluding DEFAULT.""" return self.cf.sections() def has_section(self, section): """Checks if section is present in config file, excluding DEFAULT.""" return self.cf.has_section(section) def clone(self, main_section): """Return new Config() instance with new main section on same config file.""" return Config(main_section, self.filename) def options(self): """Return list of options in main section.""" return self.cf.options(self.main_section) def has_option(self, opt): """Checks if option exists in main section.""" return self.cf.has_option(self.main_section, opt) def items(self): """Returns list of (name, value) for each option in main section.""" return self.cf.items(self.main_section) # define some aliases (short-cuts / backward compatibility cruft) getbool = getboolean class ExtendedInterpolationCompat(Interpolation): _EXT_VAR_RX = r'\$\$|\$\{[^(){}]+\}' _OLD_VAR_RX = r'%%|%\([^(){}]+\)s' _var_rc = re.compile('(%s|%s)' % (_EXT_VAR_RX, _OLD_VAR_RX)) _bad_rc = re.compile('[%$]') def before_get(self, parser, section, option, rawval, defaults): dst = [] self._interpolate_ext(dst, parser, section, option, rawval, defaults, set()) return ''.join(dst) def before_set(self, parser, section, option, value): sub = self._var_rc.sub('', value) if self._bad_rc.search(sub): raise ValueError("invalid interpolation syntax in %r" % value) return value def _interpolate_ext(self, dst, parser, section, option, rawval, defaults, loop_detect): if not rawval: return if len(loop_detect) > MAX_INTERPOLATION_DEPTH: raise InterpolationDepthError(option, section, rawval) xloop = (section, option) if xloop in loop_detect: raise InterpolationError(section, option, 'Loop detected: %r in %r' % (xloop, loop_detect)) loop_detect.add(xloop) parts = self._var_rc.split(rawval) for i, frag in enumerate(parts): fullkey = None use_vars = defaults if i % 2 == 0: dst.append(frag) continue if frag in ('$$', '%%'): dst.append(frag[0]) continue if frag.startswith('${') and frag.endswith('}'): fullkey = frag[2:-1] # use section access only for new-style keys if ':' in fullkey: ksect, key = fullkey.split(':', 1) use_vars = None else: ksect, key = section, fullkey elif frag.startswith('%(') and frag.endswith(')s'): fullkey = frag[2:-2] ksect, key = section, fullkey else: raise InterpolationError(section, option, 'Internal parse error: %r' % frag) key = parser.optionxform(key) newpart = parser.get(ksect, key, raw=True, vars=use_vars) if newpart is None: raise InterpolationError(ksect, key, 'Key referenced is None') self._interpolate_ext(dst, parser, ksect, key, newpart, defaults, loop_detect) loop_detect.remove(xloop) class ExtendedConfigParser(ConfigParser): """ConfigParser that uses Python3-style extended interpolation by default. Syntax: ${var} and ${section:var} """ _DEFAULT_INTERPOLATION = ExtendedInterpolation() class ExtendedCompatConfigParser(ExtendedConfigParser): r"""Support both extended "${}" syntax from python3 and old "%()s" too. New ${} syntax allows ${key} to refer key in same section, and ${sect:key} to refer key in other sections. """ _DEFAULT_INTERPOLATION = ExtendedInterpolationCompat() python-skytools-3.6.1/skytools/dbservice.py000066400000000000000000000566551373460333300211640ustar00rootroot00000000000000""" Class used to handle multiset receiving and returning PL/Python procedures """ import skytools from skytools import dbdict try: import plpy except ImportError: pass __all__ = ( 'DBService', 'ServiceContext', 'get_record', 'get_record_list', 'make_record', 'make_record_array', 'TableAPI', #'log_result', 'transform_fields' ) def transform_fields(rows, key_fields, name_field, data_field): """Convert multiple-rows per key input array to one-row, multiple-column output array. The input arrays must be sorted by the key fields. """ cur_key = None cur_row = None res = [] for r in rows: k = [r[f] for f in key_fields] if k != cur_key: cur_key = k cur_row = {} for f in key_fields: cur_row[f] = r[f] res.append(cur_row) cur_row[r[name_field]] = r[data_field] return res # render_table def render_table(rows, fields): """ Render result rows as a table. Returns array of lines. """ widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = len(str(row.get(k))) widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) lines = [] lines.append(fmt % tuple(fields)) lines.append(fmt % tuple(['-' * 15] * len(fields))) for row in rows: lines.append(fmt % tuple([str(row.get(k)) for k in fields])) return lines # data conversion to and from url def get_record(arg): """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ if not arg: return dbdict() # allow array of single record if arg[0] in ('{', '['): lst = skytools.parse_pgarray(arg) if len(lst) != 1: raise ValueError('get_record() expects exactly 1 row, got %d' % len(lst)) arg = lst[0] # parse record return dbdict(skytools.db_urldecode(arg)) def get_record_list(array): """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ if array is None: return [] if not isinstance(array, list): array = skytools.parse_pgarray(array) return [get_record(el) for el in array] def get_record_lists(tbl, field): """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ records = dbdict() for rec in tbl: master_id = str(rec[field]) records.setdefault(master_id, []).append(rec) return records def _make_record_convert(row): """Converts complex values.""" d = row.copy() for k, v in d.items(): if isinstance(v, list): d[k] = skytools.make_pgarray(v) return skytools.db_urlencode(d) def make_record(row): """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ for v in row.values(): if isinstance(v, list): return _make_record_convert(row) return skytools.db_urlencode(row) def make_record_array(rowlist): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return '{' + ','.join([make_record(row) for row in rowlist]) + '}' def get_result_items(rec_list, name): """ Get return values from result """ for r in rec_list: if r['res_code'] == name: return get_record_list(r['res_rows']) return None def log_result(log, rec_list): """ Sends dbservice execution logs to logfile """ msglist = get_result_items(rec_list, "_status") if msglist is None: if rec_list: log.warning('Unhandled output result: _status res_code not present.') else: for msg in msglist: log.debug(msg['_message']) class DBService: """ Wrap parameterized query handling and multiset stored procedure writing """ ROW = "_row" # name of the fake field where internal record id is stored FIELD = "_field" # parameter name for the field in record that is related to current message PARAM = "_param" # name of the parameter to which message relates SKIP = "skip" # used when record is needed for it's data but is not been updated INSERT = "insert" UPDATE = "update" DELETE = "delete" INFO = "info" # just informative message for the user NOTICE = "notice" # more than info less than warning WARNING = "warning" # warning message, something is out of ordinary ERROR = "error" # error found but execution continues until check then error is raised FATAL = "fatal" # execution is terminated at once and all found errors returned rows_found = 0 def __init__(self, context, global_dict=None): """ This object must be initiated in the beginning of each db service """ rec = skytools.db_urldecode(context) self._context = context # used to run dbservice in retval self.global_dict = global_dict # used for cacheing query plans self._retval = [] # used to collect return resultsets self._is_test = 'is_test' in rec # used to convert output into human readable form self.sqls = None # if sqls stays None then no recording of sqls is done if "show_sql" in rec: # api must add exected sql to resultset self.sqls = [] # sql's executed by dbservice, used for dubugging self.can_save = True # used to keep value most severe error found so far self.messages = [] # used to hold list of messages to be returned to the user # error and message handling def tell_user(self, severity, code, message, params=None, **kvargs): """ Adds another message to the set of messages to be sent back to user If error message then can_save is set false If fatal message then error or found errors are raised at once """ params = params or kvargs #plpy.notice("%s %s: %s %s" % (severity, code, message, str(params))) params["_severity"] = severity params["_code"] = code params["_message"] = message self.messages.append(params) if severity == self.ERROR: self.can_save = False if severity == self.FATAL: self.can_save = False self.raise_if_errors() def raise_if_errors(self): """ To be used in places where before continuing must be chcked if errors have been found Raises found errors packing them into error message as urlencoded string """ if not self.can_save: msgs = "Dbservice error(s): " + make_record_array(self.messages) plpy.error(msgs) # run sql meant mostly for select but not limited to def create_query(self, sql, params=None, **kvargs): """ Returns initialized querybuilder object for building complex dynamic queries """ params = params or kvargs return skytools.PLPyQueryBuilder(sql, params, self.global_dict, self.sqls) def run_query(self, sql, params=None, **kvargs): """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kvargs rows = skytools.plpy_exec(self.global_dict, sql, params) # convert result rows to dbdict if rows: rows = [dbdict(r) for r in rows] self.rows_found = len(rows) else: self.rows_found = 0 return rows def run_query_row(self, sql, params=None, **kvargs): """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None return rows[0] def run_exists(self, sql, params=None, **kvargs): """ Helper function to find out that record in given table exists using values in dict as criteria. Takes away all the hassle of preparing statements and processing returned result giving out just one boolean """ params = params or kvargs self.run_query(sql, params) return self.rows_found def run_lookup(self, sql, params=None, **kvargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. Uses plan cache if used inside db service """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None row = rows[0] return list(row.values())[0] # resultset handling def return_next(self, rows, res_name, severity=None): """ Adds given set of rows to resultset """ self._retval.append([res_name, rows]) if severity is not None and len(rows) == 0: self.tell_user(severity, "dbsXXXX", "No matching records found") return rows def return_next_sql(self, sql, params, res_name, severity=None): """ Exectes query and adds recors resultset """ rows = self.run_query(sql, params) return self.return_next(rows, res_name, severity) def retval(self, service_name=None, params=None, **kvargs): """ Return collected resultsets and append to the end messages to the users Method is called usually as last statement in dbservice to return the results Also converts results into desired format """ params = params or kvargs self.raise_if_errors() if len(self.messages): self.return_next(self.messages, "_status") if self.sqls is not None and len(self.sqls): self.return_next(self.sqls, "_sql") results = [] for r in self._retval: res_name = r[0] rows = r[1] res_count = str(len(rows)) if self._is_test and len(rows) > 0: results.append([res_name, res_count, res_name]) n = 1 for trow in render_table(rows, rows[0].keys()): results.append([res_name, n, trow]) n += 1 else: res_rows = make_record_array(rows) results.append([res_name, res_count, res_rows]) if service_name: sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) par = dbdict(i_context=self._context, i_params=make_record(params)) res = self.run_query(sql, par) for r in res: results.append((r.res_code, r.res_text, r.res_rows)) return results # miscellaneous def check_required(self, record_name, record, severity, *fields): """ Checks if all required fields are present in record Used to validate incoming data Returns list of field names that are missing or empty """ missing = [] params = {self.PARAM: record_name} if self.ROW in record: params[self.ROW] = record[self.ROW] for field in fields: params[self.FIELD] = field if field in record: if record[field] is None or (isinstance(record[field], str) and len(record[field]) == 0): self.tell_user(severity, "dbsXXXX", "Required value missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) else: self.tell_user(severity, "dbsXXXX", "Required field missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) return missing # TableAPI class TableAPI: """ Class for managing one record updates using primary key """ _table = None # schema name and table name _where = None # where condition used for update and delete _id = None # name of the primary key filed _id_type = None # column type of primary key _op = None # operation currently carried out _ctx = None # context object for username and version _logging = True # should tapi log data changed _row = None # row identifer from calling program def __init__(self, ctx, table, create_log=True, id_type='int8'): """ Table name is used to construct insert update and delete statements Table must have primary key field whose name is in format id_ Tablename should be in format schema.tablename """ self._ctx = ctx self._table = skytools.quote_fqident(table) self._id = "id_" + skytools.fq_name_parts(table)[1] self._id_type = id_type self._where = '%s = {%s:%s}' % (skytools.quote_ident(self._id), self._id, self._id_type) self._logging = create_log def _log(self, result, original=None): """ Log changei into table log.changelog """ if not self._logging: return changes = [] for key in result.keys(): if self._op == 'update': if key in original: if str(original[key]) != str(result[key]): changes.append(key + ": " + str(original[key]) + " -> " + str(result[key])) else: changes.append(key + ": " + str(result[key])) self._ctx.log(self._table, result[self._id], self._op, "\n".join(changes)) def _version_check(self, original, version): if original is None: self._ctx.tell_user( self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been deleted by other user " "while you were editing. Check version ({ver}) in changelog for details.", table=self._table, field=self._id, id=original[self._id], ver=original.version, _row=self._row ) if version is not None and original.version is not None: if int(version) != int(original.version): self._ctx.tell_user( self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been changed by other user while you were editing. " "Version in db: ({db_ver}) and version sent by caller ({caller_ver}). " "See changelog for details.", table=self._table, field=self._id, id=original[self._id], db_ver=original.version, caller_ver=version, _row=self._row ) def _insert(self, data): fields = [] values = [] for key in data.keys(): if data[key] is not None: # ignore empty fields.append(skytools.quote_ident(key)) values.append("{" + key + "}") sql = "insert into %s (%s) values (%s) returning *;" % ( self._table, ",".join(fields), ",".join(values) ) result = self._ctx.run_query_row(sql, data) self._log(result) return result def _update(self, data, version): sql = "select * from %s where %s" % (self._table, self._where) original = self._ctx.run_query_row(sql, data) self._version_check(original, version) pairs = [] for key in data.keys(): if data[key] is None: pairs.append(key + " = NULL") else: pairs.append(key + " = {" + key + "}") sql = "update %s set %s where %s returning *;" % (self._table, ", ".join(pairs), self._where) result = self._ctx.run_query_row(sql, data) self._log(result, original) return result def _delete(self, data, version): sql = "delete from %s where %s returning *;" % (self._table, self._where) result = self._ctx.run_query_row(sql, data) self._version_check(result, version) self._log(result) return result def do(self, data): """ Do dml according to special field _op that must be given together wit data """ result = data # so it is initialized for skip self._op = data.pop(self._ctx.OP) # determines operation done self._row = data.pop(self._ctx.ROW, None) # internal record id used for error reporting if self._row is None: # if no _row variable was provided self._row = data.get(self._id, None) # use id instead if self._id in data and data[self._id]: # if _id field is given if int(data[self._id]) < 0: # and it is fake key generated by ui data.pop(self._id) # remove fake key so real one can be assigned version = data.get('version', None) # version sent from caller data['version'] = self._ctx.version # current transaction id is stored in each record if self._op == self._ctx.INSERT: result = self._insert(data) elif self._op == self._ctx.UPDATE: result = self._update(data, version) elif self._op == self._ctx.DELETE: result = self._delete(data, version) elif self._op == self._ctx.SKIP: pass else: self._ctx.tell_user(self._ctx.ERROR, "dbsXXXX", "Unahndled _op='{op}' value in TableAPI (table={table}, id={id})", op=self._op, table=self._table, id=data[self._id]) result[self._ctx.OP] = self._op result[self._ctx.ROW] = self._row return result # ServiceContext class ServiceContext(DBService): OP = "_op" # name of the fake field where record modificaton operation is stored def __init__(self, context, global_dict=None): """ This object must be initiated in the beginning of each db service """ super(ServiceContext, self).__init__(context, global_dict) rec = skytools.db_urldecode(context) if "username" not in rec: plpy.error("Username must be provided in db service context parameter") self.username = rec['username'] # used for logging purposes res = plpy.execute("select txid_current() as txid;") row = res[0] self.version = row["txid"] self.rows_found = 0 # Flag set by run query to inicate number of rows got # logging def log(self, _object_type, _key_object, _change_op, _payload): """ Log stuff into the changelog whatever seems relevant to be logged """ self.run_query( "select log.log_change( {version}, {username}, {object_type}, {key_object}, {change_op}, {payload} );", version=self.version, username=self.username, object_type=_object_type, key_object=_key_object, change_op=_change_op, payload=_payload) # data conversion to and from url def get_record(self, arg): """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record(arg) def get_record_list(self, array): """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record_list(array) def get_list_groups(self, tbl, field): """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ return get_record_lists(tbl, field) def make_record(self, row): """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ return make_record(row) def make_record_array(self, rowlist): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return make_record_array(rowlist) # tapi based dml functions def _changelog(self, fields): log = True if fields: if '_log' in fields: if not fields.pop('_log'): log = False if '_log_id' in fields: fields.pop('_log_id') if '_log_field' in fields: fields.pop('_log_field') return log def tapi_do(self, tablename, row, **fields): """ Convenience function for just doing the change without creating tapi object first Fields object may contain aditional overriding values that are applied before do """ tapi = TableAPI(self, tablename, self._changelog(fields)) row = row or dbdict() if fields: row.update(fields) return tapi.do(row) def tapi_do_set(self, tablename, rows, **fields): """ Does changes to list of detail rows Used for normal foreign keys in master detail relationships Dows first deletes then updates and then inserts to avoid uniqueness problems """ tapi = TableAPI(self, tablename, self._changelog(fields)) results, updates, inserts = [], [], [] for row in rows: if fields: row.update(fields) if row[self.OP] == self.DELETE: results.append(tapi.do(row)) elif row[self.OP] == self.UPDATE: updates.append(row) else: inserts.append(row) for row in updates: results.append(tapi.do(row)) for row in inserts: results.append(tapi.do(row)) return results # resultset handling def retval_dbservice(self, service_name, ctx, **params): """ Runs service with standard interface. Convenient to use for calling select services from other services For example to return data after doing save """ self.raise_if_errors() service_sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) service_params = {"i_context": ctx, "i_params": self.make_record(params)} results = self.run_query(service_sql, service_params) retval = self.retval() for r in results: retval.append((r.res_code, r.res_text, r.res_rows)) return retval # miscellaneous def field_copy(self, rec, *keys): """ Used to copy subset of fields from one record into another example: dbs.copy(record, hosting) "start_date", "key_colo", "key_rack") """ retval = dbdict() for key in keys: if key in rec: retval[key] = rec[key] return retval def field_set(self, **fields): """ Fills dict with given values and returns resulting dict If dict was not provied with call it is created """ return fields python-skytools-3.6.1/skytools/dbstruct.py000066400000000000000000000567541373460333300210500ustar00rootroot00000000000000"""Find table structure and allow CREATE/DROP elements from it. """ import re import skytools from skytools import quote_fqident, quote_ident __all__ = ( 'TableStruct', 'SeqStruct', 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER', 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL', 'T_SEQUENCE', 'T_PARENT', 'T_DEFAULT', ) T_TABLE = 1 << 0 T_CONSTRAINT = 1 << 1 T_INDEX = 1 << 2 T_TRIGGER = 1 << 3 T_RULE = 1 << 4 T_GRANT = 1 << 5 T_OWNER = 1 << 6 T_SEQUENCE = 1 << 7 T_PARENT = 1 << 8 T_DEFAULT = 1 << 9 T_PKEY = 1 << 20 # special, one of constraints T_ALL = (T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE | T_TRIGGER | T_RULE | T_GRANT | T_OWNER | T_DEFAULT) # # Utility functions # def find_new_name(curs, name): """Create new object name for case the old exists. Needed when creating a new table besides old one. """ # cut off previous numbers m = re.search('_[0-9]+$', name) if m: name = name[:m.start()] # now loop for i in range(1, 1000): tname = "%s_%d" % (name, i) q = "select count(1) from pg_class where relname = %s" curs.execute(q, [tname]) if curs.fetchone()[0] == 0: return tname # failed raise Exception('find_new_name failed') def rx_replace(rx, sql, new_part): """Find a regex match and replace that part with new_part.""" m = re.search(rx, sql, re.I) if not m: raise Exception('rx_replace failed: rx=%r sql=%r new=%r' % (rx, sql, new_part)) p1 = sql[:m.start()] p2 = sql[m.end():] return p1 + new_part + p2 # # Schema objects # class TElem: """Keeps info about one metadata object.""" SQL = "" type = 0 def get_create_sql(self, curs, new_name=None): """Return SQL statement for creating or None if not supported.""" return None def get_drop_sql(self, curs): """Return SQL statement for dropping or None of not supported.""" return None @classmethod def get_load_sql(cls, pgver): """Return SQL statement for finding objects.""" return cls.SQL class TConstraint(TElem): """Info about constraint.""" type = T_CONSTRAINT SQL = """ SELECT c.conname as name, pg_get_constraintdef(c.oid) as def, c.contype, i.indisclustered as is_clustered FROM pg_constraint c LEFT JOIN pg_index i ON c.conrelid = i.indrelid AND c.conname = (SELECT r.relname FROM pg_class r WHERE r.oid = i.indexrelid) WHERE c.conrelid = %(oid)s AND c.contype != 'f' """ def __init__(self, table_name, row): """Init constraint.""" self.table_name = table_name self.name = row['name'] self.defn = row['def'] self.contype = row['contype'] self.is_clustered = row['is_clustered'] # tag pkeys if self.contype == 'p': self.type += T_PKEY def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" # no ONLY here as table with childs (only case that matters) # cannot have contraints that childs do not have fmt = "ALTER TABLE %s ADD CONSTRAINT %s\n %s;" if new_table_name: name = self.name if self.contype in ('p', 'u'): name = find_new_name(curs, self.name) qtbl = quote_fqident(new_table_name) qname = quote_ident(name) else: qtbl = quote_fqident(self.table_name) qname = quote_ident(self.name) sql = fmt % (qtbl, qname, self.defn) if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % (qtbl, qname) return sql def get_drop_sql(self, curs): """Generate removal sql.""" fmt = "ALTER TABLE ONLY %s\n DROP CONSTRAINT %s;" sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name)) return sql class TIndex(TElem): """Info about index.""" type = T_INDEX SQL = """ SELECT n.nspname || '.' || c.relname as name, pg_get_indexdef(i.indexrelid) as defn, c.relname as local_name, i.indisclustered as is_clustered FROM pg_index i, pg_class c, pg_namespace n WHERE c.oid = i.indexrelid AND i.indrelid = %(oid)s AND n.oid = c.relnamespace AND NOT EXISTS (select objid from pg_depend where classid = %(pg_class_oid)s and objid = c.oid and deptype = 'i') """ def __init__(self, table_name, row): self.name = row['name'] self.defn = row['defn'].replace(' USING ', '\n USING ', 1) + ';' self.is_clustered = row['is_clustered'] self.table_name = table_name self.local_name = row['local_name'] def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if new_table_name: # fixme: seems broken iname = find_new_name(curs, self.name) tname = new_table_name pnew = "INDEX %s ON %s " % (quote_ident(iname), quote_fqident(tname)) rx = r"\bINDEX[ ][a-z0-9._]+[ ]ON[ ][a-z0-9._]+[ ]" sql = rx_replace(rx, self.defn, pnew) else: sql = self.defn iname = self.local_name tname = self.table_name if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % ( quote_fqident(tname), quote_ident(iname)) return sql def get_drop_sql(self, curs): return 'DROP INDEX %s;' % quote_fqident(self.name) class TRule(TElem): """Info about rule.""" type = T_RULE SQL = """SELECT rw.*, pg_get_ruledef(rw.oid) as def FROM pg_rewrite rw WHERE rw.ev_class = %(oid)s AND rw.rulename <> '_RETURN'::name """ def __init__(self, table_name, row, new_name=None): self.table_name = table_name self.name = row['rulename'] self.defn = row['def'] self.enabled = row.get('ev_enabled', 'O') def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if not new_table_name: sql = self.defn table = self.table_name else: idrx = r'''([a-z0-9._]+|"([^"]+|"")+")+''' # fixme: broken / quoting rx = r"\bTO[ ]" + idrx rc = re.compile(rx, re.X) m = rc.search(self.defn) if not m: raise Exception('Cannot find table name in rule') old_tbl = m.group(1) new_tbl = quote_fqident(new_table_name) sql = self.defn.replace(old_tbl, new_tbl) table = new_table_name if self.enabled != 'O': # O - rule fires in origin and local modes # D - rule is disabled # R - rule fires in replica mode # A - rule fires always action = {'R': 'ENABLE REPLICA', 'A': 'ENABLE ALWAYS', 'D': 'DISABLE'}[self.enabled] sql += ('\nALTER TABLE %s %s RULE %s;' % (table, action, self.name)) return sql def get_drop_sql(self, curs): return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) class TTrigger(TElem): """Info about trigger.""" type = T_TRIGGER def __init__(self, table_name, row): self.table_name = table_name self.name = row['name'] self.defn = row['def'] + ';' self.defn = self.defn.replace('FOR EACH', '\n FOR EACH', 1) def get_create_sql(self, curs, new_table_name=None): """Generate creation SQL.""" if not new_table_name: return self.defn # fixme: broken / quoting rx = r"\bON[ ][a-z0-9._]+[ ]" pnew = "ON %s " % new_table_name return rx_replace(rx, self.defn, pnew) def get_drop_sql(self, curs): return 'DROP TRIGGER %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) @classmethod def get_load_sql(cls, pg_vers): """Return SQL statement for finding objects.""" sql = "SELECT tgname as name, pg_get_triggerdef(oid) as def "\ " FROM pg_trigger "\ " WHERE tgrelid = %(oid)s AND " if pg_vers >= 90000: sql += "NOT tgisinternal" else: sql += "NOT tgisconstraint" return sql class TParent(TElem): """Info about trigger.""" type = T_PARENT SQL = """ SELECT n.nspname||'.'||c.relname AS name FROM pg_inherits i JOIN pg_class c ON i.inhparent = c.oid JOIN pg_namespace n ON c.relnamespace = n.oid WHERE i.inhrelid = %(oid)s """ def __init__(self, table_name, row): self.name = table_name self.parent_name = row['name'] def get_create_sql(self, curs, new_table_name=None): return 'ALTER TABLE ONLY %s\n INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) def get_drop_sql(self, curs): return 'ALTER TABLE ONLY %s\n NO INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) class TOwner(TElem): """Info about table owner.""" type = T_OWNER SQL = """ SELECT pg_get_userbyid(relowner) as owner FROM pg_class WHERE oid = %(oid)s """ def __init__(self, table_name, row, new_name=None): self.table_name = table_name self.name = 'Owner' self.owner = row['owner'] def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.table_name return 'ALTER TABLE %s\n OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner)) class TGrant(TElem): """Info about permissions.""" type = T_GRANT SQL = "SELECT relacl FROM pg_class where oid = %(oid)s" # Sync with: src/include/utils/acl.h acl_map = { 'a': 'INSERT', 'r': 'SELECT', 'w': 'UPDATE', 'd': 'DELETE', 'D': 'TRUNCATE', 'x': 'REFERENCES', 't': 'TRIGGER', 'X': 'EXECUTE', 'U': 'USAGE', 'C': 'CREATE', 'T': 'TEMPORARY', 'c': 'CONNECT', # old 'R': 'RULE', } def acl_to_grants(self, acl): if acl == "arwdRxt": # ALL for tables return "ALL" i = 0 lst1 = [] lst2 = [] while i < len(acl): a = self.acl_map[acl[i]] if i + 1 < len(acl) and acl[i + 1] == '*': lst2.append(a) i += 2 else: lst1.append(a) i += 1 return ", ".join(lst1), ", ".join(lst2) def parse_relacl(self, relacl): """Parse ACL to tuple of (user, acl, who)""" if relacl is None: return [] tup_list = [] for sacl in skytools.parse_pgarray(relacl): acl = skytools.parse_acl(sacl) if not acl: continue tup_list.append(acl) return tup_list def __init__(self, table_name, row, new_name=None): self.name = table_name self.acl_list = self.parse_relacl(row['relacl']) def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.name qtarget = quote_fqident(new_name) sql_list = [] for role, acl, ___who in self.acl_list: qrole = quote_ident(role) astr1, astr2 = self.acl_to_grants(acl) if astr1: sql = "GRANT %s ON %s\n TO %s;" % (astr1, qtarget, qrole) sql_list.append(sql) if astr2: sql = "GRANT %s ON %s\n TO %s WITH GRANT OPTION;" % (astr2, qtarget, qrole) sql_list.append(sql) return "\n".join(sql_list) def get_drop_sql(self, curs): sql_list = [] for user, ___acl, ___who in self.acl_list: sql = "REVOKE ALL FROM %s ON %s;" % (quote_ident(user), quote_fqident(self.name)) sql_list.append(sql) return "\n".join(sql_list) class TColumnDefault(TElem): """Info about table column default value.""" type = T_DEFAULT SQL = """ select a.attname as name, pg_get_expr(d.adbin, d.adrelid) as expr from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.atthasdef and a.attnum > 0 order by a.attnum; """ def __init__(self, table_name, row): self.table_name = table_name self.name = row['name'] self.expr = row['expr'] def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" tbl = new_name or self.table_name sql = "ALTER TABLE ONLY %s ALTER COLUMN %s\n SET DEFAULT %s;" % ( quote_fqident(tbl), quote_ident(self.name), self.expr) return sql def get_drop_sql(self, curs): return "ALTER TABLE %s ALTER COLUMN %s\n DROP DEFAULT;" % ( quote_fqident(self.table_name), quote_ident(self.name)) class TColumn(TElem): """Info about table column.""" SQL = """ select a.attname as name, quote_ident(a.attname) as qname, format_type(a.atttypid, a.atttypmod) as dtype, a.attnotnull, (select max(char_length(aa.attname)) from pg_attribute aa where aa.attrelid = %(oid)s) as maxcol, pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.attnum > 0 order by a.attnum; """ seqname = None def __init__(self, table_name, row): self.name = row['name'] fname = row['qname'].ljust(row['maxcol'] + 3) self.column_def = fname + ' ' + row['dtype'] if row['attnotnull']: self.column_def += ' not null' self.sequence = None if row['seqname']: self.seqname = skytools.unquote_fqident(row['seqname']) class TGPDistKey(TElem): """Info about GreenPlum table distribution keys""" SQL = """ select a.attname as name from pg_attribute a, gp_distribution_policy p where p.localoid = %(oid)s and a.attrelid = %(oid)s and a.attnum = any(p.attrnums) order by a.attnum; """ def __init__(self, table_name, row): self.name = row['name'] class TTable(TElem): """Info about table only (columns).""" type = T_TABLE def __init__(self, table_name, col_list, dist_key_list=None): self.name = table_name self.col_list = col_list self.dist_key_list = dist_key_list def get_create_sql(self, curs, new_name=None): """Generate creation SQL.""" if not new_name: new_name = self.name sql = "CREATE TABLE %s (" % quote_fqident(new_name) sep = "\n " for c in self.col_list: sql += sep + c.column_def sep = ",\n " sql += "\n)" if self.dist_key_list is not None: if self.dist_key_list != []: sql += "\ndistributed by(%s)" % ','.join(c.name for c in self.dist_key_list) else: sql += '\ndistributed randomly' sql += ";" return sql def get_drop_sql(self, curs): return "DROP TABLE %s;" % quote_fqident(self.name) class TSeq(TElem): """Info about sequence.""" type = T_SEQUENCE SQL_PG10 = """ SELECT %(fq2name)s::name AS sequence_name, s.last_value, p.seqstart AS start_value, p.seqincrement AS increment_by, p.seqmax AS max_value, p.seqmin AS min_value, p.seqcache AS cache_value, s.log_cnt, s.is_called, p.seqcycle AS is_cycled, %(owner)s as owner FROM pg_catalog.pg_sequence p, %(fqname)s s WHERE p.seqrelid = %(fq2name)s::regclass::oid """ SQL_PG9 = """ SELECT %(fq2name)s AS sequence_name, last_value, start_value, increment_by, max_value, min_value, cache_value, log_cnt, is_called, is_cycled, %(owner)s AS "owner" FROM %(fqname)s """ @classmethod def get_load_sql(cls, pg_vers): """Return SQL statement for finding objects.""" if pg_vers < 100000: return cls.SQL_PG9 return cls.SQL_PG10 def __init__(self, seq_name, row): self.name = seq_name defn = '' self.owner = row['owner'] if row.get('increment_by', 1) != 1: defn += ' INCREMENT BY %d' % row['increment_by'] if row.get('min_value', 1) != 1: defn += ' MINVALUE %d' % row['min_value'] if row.get('max_value', 9223372036854775807) != 9223372036854775807: defn += ' MAXVALUE %d' % row['max_value'] last_value = row['last_value'] if row['is_called']: last_value += row.get('increment_by', 1) if last_value >= row.get('max_value', 9223372036854775807): raise Exception('duh, seq passed max_value') if last_value != 1: defn += ' START %d' % last_value if row.get('cache_value', 1) != 1: defn += ' CACHE %d' % row['cache_value'] if row.get('is_cycled'): defn += ' CYCLE ' if self.owner: defn += ' OWNED BY %s' % self.owner self.defn = defn def get_create_sql(self, curs, new_seq_name=None): """Generate creation SQL.""" # we are in table def, forget full def if self.owner: sql = "ALTER SEQUENCE %s\n OWNED BY %s;" % ( quote_fqident(self.name), self.owner) return sql name = self.name if new_seq_name: name = new_seq_name sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn) return sql def get_drop_sql(self, curs): if self.owner: return '' return 'DROP SEQUENCE %s;' % quote_fqident(self.name) # # Main table object, loads all the others # class BaseStruct: """Collects and manages all info about a higher-level db object. Allow to issue CREATE/DROP statements about any group of elements. """ object_list = [] def __init__(self, curs, name): """Initializes class by loading info about table_name from database.""" self.name = name self.fqname = quote_fqident(name) def _load_elem(self, curs, name, args, eclass): """Fetch element(s) from db.""" elem_list = [] #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args)) sql = eclass.get_load_sql(curs.connection.server_version) curs.execute(sql % args) for row in curs.fetchall(): elem_list.append(eclass(name, row)) return elem_list def create(self, curs, objs, new_table_name=None, log=None): """Issues CREATE statements for requested set of objects. If new_table_name is giver, creates table under that name and also tries to rename all indexes/constraints that conflict with existing table. """ for o in self.object_list: if o.type & objs: sql = o.get_create_sql(curs, new_table_name) if not sql: continue if log: log.info('Creating %s' % o.name) log.debug(sql) curs.execute(sql) def drop(self, curs, objs, log=None): """Issues DROP statements for requested set of objects.""" # make sure the creating & dropping happen in reverse order olist = self.object_list[:] olist.reverse() for o in olist: if o.type & objs: sql = o.get_drop_sql(curs) if not sql: continue if log: log.info('Dropping %s' % o.name) log.debug(sql) curs.execute(sql) def get_create_sql(self, objs): res = [] for o in self.object_list: if o.type & objs: sql = o.get_create_sql(None, None) if sql: res.append(sql) return "".join(res) class TableStruct(BaseStruct): """Collects and manages all info about table. Allow to issue CREATE/DROP statements about any group of elements. """ def __init__(self, curs, table_name): """Initializes class by loading info about table_name from database.""" super(TableStruct, self).__init__(curs, table_name) self.table_name = table_name # fill args schema, name = skytools.fq_name_parts(table_name) args = { 'schema': schema, 'table': name, 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'oid': skytools.get_table_oid(curs, table_name), 'pg_class_oid': skytools.get_table_oid(curs, 'pg_catalog.pg_class'), } # load table struct self.col_list = self._load_elem(curs, self.name, args, TColumn) # if db is GP then read also table distribution keys if skytools.exists_table(curs, "pg_catalog.gp_distribution_policy"): self.dist_key_list = self._load_elem(curs, self.name, args, TGPDistKey) else: self.dist_key_list = None self.object_list = [TTable(table_name, self.col_list, self.dist_key_list)] self.seq_list = [] # load seqs for col in self.col_list: if col.seqname: fqname = quote_fqident(col.seqname) owner = self.fqname + '.' + quote_ident(col.name) seq_args = { 'fqname': fqname, 'fq2name': skytools.quote_literal(fqname), 'owner': skytools.quote_literal(owner), } self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq) self.object_list += self.seq_list # load additional objects to_load = [TColumnDefault, TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner, TParent] for eclass in to_load: self.object_list += self._load_elem(curs, self.name, args, eclass) def get_column_list(self): """Returns list of column names the table has.""" res = [] for c in self.col_list: res.append(c.name) return res class SeqStruct(BaseStruct): """Collects and manages all info about sequence. Allow to issue CREATE/DROP statements about any group of elements. """ def __init__(self, curs, seq_name): """Initializes class by loading info about table_name from database.""" super(SeqStruct, self).__init__(curs, seq_name) # fill args args = { 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'owner': 'null', } # load table struct self.object_list = self._load_elem(curs, seq_name, args, TSeq) def manual_check(): from skytools import connect_database db = connect_database("dbname=fooz") curs = db.cursor() s = TableStruct(curs, "public.data1") s.drop(curs, T_ALL) s.create(curs, T_ALL) s.create(curs, T_ALL, "data1_new") s.create(curs, T_PKEY) if __name__ == '__main__': manual_check() python-skytools-3.6.1/skytools/fileutil.py000066400000000000000000000076711373460333300210250ustar00rootroot00000000000000"""File utilities """ import errno import os import sys __all__ = ['write_atomic', 'signal_pidfile'] def write_atomic_unix(fn, data, bakext=None, mode='b'): """Write file with rename.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") if mode == 'b' and not isinstance(data, bytes): data = data.encode('utf8') # write new data to tmp file fn2 = fn + '.new' with open(fn2, 'w' + mode) as f: f.write(data) # link old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.unlink(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.link(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise # win32 does not like replace if sys.platform == 'win32': try: os.remove(fn) except BaseException: pass # atomically replace file os.rename(fn2, fn) def signal_pidfile(pidfile, sig): """Send a signal to process whose ID is located in pidfile. Read only first line of pidfile to support multiline pidfiles like postmaster.pid. Returns True is successful, False if pidfile does not exist or process itself is dead. Any other errors will passed as exceptions.""" ln = '' try: with open(pidfile, 'r') as f: ln = f.readline().strip() pid = int(ln) if sig == 0 and sys.platform == 'win32': return win32_detect_pid(pid) os.kill(pid, sig) return True except (IOError, OSError) as ex: if ex.errno not in (errno.ESRCH, errno.ENOENT): raise except ValueError: # this leaves slight race when someone is just creating the file, # but more common case is old empty file. if not ln: return False raise ValueError('Corrupt pidfile: %s' % pidfile) return False def win32_detect_pid(pid): """Process detection for win32.""" # avoid pywin32 dependecy, use ctypes instead import ctypes # win32 constants PROCESS_QUERY_INFORMATION = 1024 STILL_ACTIVE = 259 ERROR_INVALID_PARAMETER = 87 ERROR_ACCESS_DENIED = 5 # Load kernel32.dll k = ctypes.windll.kernel32 OpenProcess = k.OpenProcess OpenProcess.restype = ctypes.c_void_p # query pid exit code h = OpenProcess(PROCESS_QUERY_INFORMATION, 0, pid) if h is None: err = k.GetLastError() if err == ERROR_INVALID_PARAMETER: return False if err == ERROR_ACCESS_DENIED: return True raise OSError(errno.EFAULT, "Unknown win32error: " + str(err)) code = ctypes.c_int() k.GetExitCodeProcess(h, ctypes.byref(code)) k.CloseHandle(h) return code.value == STILL_ACTIVE def win32_write_atomic(fn, data, bakext=None, mode='b'): """Write file with rename for win32.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") if mode == 'b' and not isinstance(data, bytes): data = data.encode('utf8') # write new data to tmp file fn2 = fn + '.new' with open(fn2, 'w' + mode) as f: f.write(data) # move old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.remove(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.rename(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise else: try: os.remove(fn) except BaseException: pass # replace file os.rename(fn2, fn) if sys.platform == 'win32': write_atomic = win32_write_atomic else: write_atomic = write_atomic_unix python-skytools-3.6.1/skytools/gzlog.py000066400000000000000000000013461373460333300203230ustar00rootroot00000000000000"""Atomic append of gzipped data. The point is - if several gzip streams are concatenated, they are read back as one whole stream. """ import gzip import io __all__ = ('gzip_append',) def gzip_append(filename, data, level=6): """Append a block of data to file with safety checks.""" # compress data buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, compresslevel=level, mode="w") as g: g.write(data) zdata = buf.getvalue() # append, safely with open(filename, "ab+", 0) as f: f.seek(0, 2) pos = f.tell() try: f.write(zdata) except Exception as ex: # rollback on error f.seek(pos, 0) f.truncate() raise ex python-skytools-3.6.1/skytools/hashtext.py000066400000000000000000000064301373460333300210300ustar00rootroot00000000000000""" Implementation of Postgres hashing function. hashtext_old() - used up to PostgreSQL 8.3 hashtext_new() - used since PostgreSQL 8.4 """ import struct import sys try: from skytools._chashtext import hashtext_new, hashtext_old except ImportError: def hashtext_old(v): return hashtext_old_py(v) def hashtext_new(v): return hashtext_new_py(v) __all__ = ("hashtext_old", "hashtext_new") # pad for last partial block PADDING = b'\0' * 12 def uint32(x): """python does not have 32 bit integer so we need this hack to produce uint32 after bit operations""" return x & 0xffffffff # # Old Postgres hashtext() - lookup2 with custom initval # FMT_OLD = struct.Struct("> 13)) b = uint32((b - c - a) ^ (a << 8)) c = uint32((c - a - b) ^ (b >> 13)) a = uint32((a - b - c) ^ (c >> 12)) b = uint32((b - c - a) ^ (a << 16)) c = uint32((c - a - b) ^ (b >> 5)) a = uint32((a - b - c) ^ (c >> 3)) b = uint32((b - c - a) ^ (a << 10)) c = uint32((c - a - b) ^ (b >> 15)) return a, b, c def hashtext_old_py(k): """Old Postgres hashtext()""" if isinstance(k, str): k = k.encode() remain = len(k) pos = 0 a = b = 0x9e3779b9 c = 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_OLD.unpack_from(k, pos) a, b, c = mix_old(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_OLD.unpack_from(k[pos:] + PADDING, 0) # the lowest byte of c is reserved for the length c2 = (c2 << 8) + len(k) a, b, c = mix_old(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) # # New Postgres hashtext() - hacked lookup3: # - custom initval # - calls mix() when len=12 # - shifted c in last block on little-endian # FMT_NEW = struct.Struct("=LLL") def rol32(x, k): return ((x) << (k)) | (uint32(x) >> (32 - (k))) def mix_new(a, b, c): a = (a - c) ^ rol32(c, 4) c += b b = (b - a) ^ rol32(a, 6) a += c c = (c - b) ^ rol32(b, 8) b += a a = (a - c) ^ rol32(c, 16) c += b b = (b - a) ^ rol32(a, 19) a += c c = (c - b) ^ rol32(b, 4) b += a return uint32(a), uint32(b), uint32(c) def final_new(a, b, c): c = (c ^ b) - rol32(b, 14) a = (a ^ c) - rol32(c, 11) b = (b ^ a) - rol32(a, 25) c = (c ^ b) - rol32(b, 16) a = (a ^ c) - rol32(c, 4) b = (b ^ a) - rol32(a, 14) c = (c ^ b) - rol32(b, 24) return uint32(a), uint32(b), uint32(c) def hashtext_new_py(k): """New Postgres hashtext()""" if isinstance(k, str): k = k.encode() remain = len(k) pos = 0 a = b = c = 0x9e3779b9 + len(k) + 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_NEW.unpack_from(k, pos) a, b, c = mix_new(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_NEW.unpack_from(k[pos:] + PADDING, 0) if sys.byteorder == 'little': c2 = c2 << 8 a, b, c = final_new(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) python-skytools-3.6.1/skytools/installer_config.py000066400000000000000000000003141373460333300225150ustar00rootroot00000000000000 """SQL script locations.""" __all__ = ['sql_locations'] sql_locations = [ "/usr/share/skytools3", ] # PEP 440 version: [N!]N(.N)*[{a|b|rc}N][.postN][.devN] package_version = "3.6.1" skylog = 0 python-skytools-3.6.1/skytools/natsort.py000066400000000000000000000036641373460333300207000ustar00rootroot00000000000000"""Natural sort. Compares numeric parts numerically. Rules: * String consists of numeric and non-numeric parts. * Parts are compared pairwise, if both are numeric then numerically, otherwise textually. (Only first fragment can have different type) Extra rules for version numbers: * In textual comparision, '~' is less than anything else, including end-of-string. * In numeric comparision, numbers that start with '0' are compared as decimal fractions, thus number that starts with more zeroes is smaller. """ import re as _re _rc = _re.compile(r'\d+|\D+', _re.A) __all__ = ( 'natsort_key', 'natsort', 'natsorted', 'natsort_key_icase', 'natsort_icase', 'natsorted_icase' ) def natsort_key(s): """Returns tuple that sorts according to natsort rules. """ # key consists of triplets (type:int, magnitude:int, value:str) key = [] if '~' in s: s = s.replace('~', '\0') for frag in _rc.findall(s): if frag < '0': key.extend((1, 0, frag + '\1')) elif frag < '1': key.extend((2, len(frag.lstrip('0')) - len(frag), frag)) elif frag < ':': key.extend((2, len(frag), frag)) else: key.extend((3, 0, frag + '\1')) if not key or key[-3] == 2: key.extend((1, 0, '\1')) return tuple(key) def natsort(lst): """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key) def natsorted(lst): """Return copy of list, sorted in natural order, case-sensitive. """ return sorted(lst, key=natsort_key) # case-insensitive api def natsort_key_icase(s): """Split string to numeric and non-numeric fragments.""" return natsort_key(s.lower()) def natsort_icase(lst): """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key_icase) def natsorted_icase(lst): """Return copy of list, sorted in natural order, case-sensitive. """ return sorted(lst, key=natsort_key_icase) python-skytools-3.6.1/skytools/parsing.py000066400000000000000000000325661373460333300206540ustar00rootroot00000000000000"""Various parsers for Postgres-specific data formats. """ import re import skytools __all__ = ( "parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements", 'sql_tokenizer', 'parse_sqltriga_sql', "parse_acl", "dedent", "hsize_to_bytes", "parse_connect_string", "merge_connect_string", ) _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) def parse_pgarray(array): r"""Parse Postgres array and return list of items inside it. """ if array is None: return None if not array or array[0] not in ("{", "[") or array[-1] != '}': raise ValueError("bad array format: must be surrounded with {}") res = [] pos = 1 # skip optional dimensions descriptor "[a,b]={...}" if array[0] == "[": pos = array.find('{') + 1 if pos < 1: raise ValueError("bad array format 2: must be surrounded with {}") while True: m = _rc_listelem.search(array, pos) if not m: break pos2 = m.end() item = array[pos:pos2] if len(item) == 4 and item.upper() == "NULL": val = None else: if item and item[0] == '"': item = item[1:-1] val = skytools.unescape(item) res.append(val) pos = pos2 + 1 if array[pos2] == "}": break elif array[pos2] != ",": raise ValueError("bad array format: expected ,} got " + repr(array[pos2])) if pos < len(array) - 1: raise ValueError("bad array format: failed to parse completely (pos=%d len=%d)" % (pos, len(array))) return res # # parse logtriga partial sql # class _logtriga_parser: """Parses logtriga/sqltriga partial SQL to values.""" pklist = None def tokenizer(self, sql): """Token generator.""" for ___typ, tok in sql_tokenizer(sql, ignore_whitespace=True): yield tok def parse_insert(self, tk, fields, values, key_fields, key_values): """Handler for inserts.""" # (col1, col2) values ('data', null) if next(tk) != "(": raise ValueError("syntax error") while True: fields.append(next(tk)) t = next(tk) if t == ")": break elif t != ",": raise ValueError("syntax error") if next(tk).lower() != "values": raise ValueError("syntax error, expected VALUES") if next(tk) != "(": raise ValueError("syntax error, expected (") while True: values.append(next(tk)) t = next(tk) if t == ")": break if t == ",": continue raise ValueError("expected , or ) got " + t) t = next(tk) raise ValueError("expected EOF, got " + repr(t)) def parse_update(self, tk, fields, values, key_fields, key_values): """Handler for updates.""" # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' while True: fields.append(next(tk)) if next(tk) != "=": raise ValueError("syntax error") values.append(next(tk)) t = next(tk) if t == ",": continue elif t.lower() == "where": break else: raise ValueError("syntax error, expected WHERE or , got " + repr(t)) while True: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise ValueError("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise ValueError("syntax error, expected AND got " + repr(t)) def parse_delete(self, tk, fields, values, key_fields, key_values): """Handler for deletes.""" # pk1 = 'pk1' and pk2 = 'pk2' while True: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise ValueError("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise ValueError("syntax error, expected AND, got " + repr(t)) def _create_dbdict(self, fields, values): fields = [skytools.unquote_ident(f) for f in fields] values = [skytools.unquote_literal(f) for f in values] return skytools.dbdict(zip(fields, values)) def parse_sql(self, op, sql, pklist=None, splitkeys=False): """Main entry point.""" if pklist is None: self.pklist = [] else: self.pklist = pklist tk = self.tokenizer(sql) fields = [] values = [] key_fields = [] key_values = [] try: if op == "I": self.parse_insert(tk, fields, values, key_fields, key_values) elif op == "U": self.parse_update(tk, fields, values, key_fields, key_values) elif op == "D": self.parse_delete(tk, fields, values, key_fields, key_values) raise ValueError("syntax error") except StopIteration: pass # last sanity check if (len(fields) + len(key_fields) == 0 or len(fields) != len(values) or len(key_fields) != len(key_values)): raise ValueError("syntax error, fields do not match values") if splitkeys: return (self._create_dbdict(key_fields, key_values), self._create_dbdict(fields, values)) return self._create_dbdict(fields + key_fields, values + key_values) def parse_logtriga_sql(op, sql, splitkeys=False): return parse_sqltriga_sql(op, sql, splitkeys=splitkeys) def parse_sqltriga_sql(op, sql, pklist=None, splitkeys=False): """Parse partial SQL used by pgq.sqltriga() back to data values. Parser has following limitations: - Expects standard_quoted_strings = off - Does not support dollar quoting. - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) - WHERE expression must not contain IS (NOT) NULL - Does not support updating pk value, unless you use the splitkeys parameter. Returns dict of col->data pairs. """ return _logtriga_parser().parse_sql(op, sql, pklist, splitkeys=splitkeys) def parse_tabbed_table(txt): r"""Parse a tab-separated table into list of dicts. Expect first row to be column names. Very primitive. """ txt = txt.replace("\r\n", "\n") fields = None data = [] for ln in txt.split("\n"): if not ln: continue if not fields: fields = ln.split("\t") continue cols = ln.split("\t") if len(cols) != len(fields): continue row = dict(zip(fields, cols)) data.append(row) return data _extstr = r""" ['] (?: [^'\\]+ | \\. | [']['] )* ['] """ _stdstr = r""" ['] (?: [^']+ | [']['] )* ['] """ _name = r""" (?: [a-z_][a-z0-9_$]* | " (?: [^"]+ | "" )* " ) """ _ident = r""" (?P %s ) """ % _name _fqident = r""" (?P %s (?: \. %s )* ) """ % (_name, _name) _base_sql = r""" (?P (?P [$] (?: [_a-z][_a-z0-9]*)? [$] ) .*? (?P=dname) ) | (?P [0-9][0-9.e]* ) | (?P [$] [0-9]+ ) | (?P [%][(] [a-z_][a-z0-9_]* [)] [s] ) | (?P [{] [^{}]+ [}] ) | (?P (?: \s+ | [/][*] .*? [*][/] | [-][-][^\n]* )+ ) | (?P (?: [-+*~!@#^&|?/%<>=]+ | [,()\[\].:;] ) ) | (?P . )""" _base_sql_fq = r"%s | %s" % (_fqident, _base_sql) _base_sql = r"%s | %s" % (_ident, _base_sql) _std_sql = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql) _std_sql_fq = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql_fq) _ext_sql = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql) _ext_sql_fq = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql_fq) _std_sql_rc = _ext_sql_rc = None _std_sql_fq_rc = _ext_sql_fq_rc = None def sql_tokenizer(sql, standard_quoting=False, ignore_whitespace=False, fqident=False, show_location=False): r"""Parser SQL to tokens. Iterator, returns (toktype, tokstr) tuples. """ global _std_sql_rc, _ext_sql_rc, _std_sql_fq_rc, _ext_sql_fq_rc if not _std_sql_rc: _std_sql_rc = re.compile(_std_sql, re.X | re.I | re.S) _ext_sql_rc = re.compile(_ext_sql, re.X | re.I | re.S) _std_sql_fq_rc = re.compile(_std_sql_fq, re.X | re.I | re.S) _ext_sql_fq_rc = re.compile(_ext_sql_fq, re.X | re.I | re.S) if standard_quoting: if fqident: rc = _std_sql_fq_rc else: rc = _std_sql_rc else: if fqident: rc = _ext_sql_fq_rc else: rc = _ext_sql_rc pos = 0 while True: m = rc.match(sql, pos) if not m: break pos = m.end() typ = m.lastgroup if ignore_whitespace and typ == "ws": continue tk = m.group() if show_location: yield (typ, tk, pos) else: yield (typ, tk) _copy_from_stdin_re = r"copy.*from\s+stdin" _copy_from_stdin_rc = None def parse_statements(sql, standard_quoting=False): """Parse multi-statement string into separate statements. Returns list of statements. """ global _copy_from_stdin_rc if not _copy_from_stdin_rc: _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I) tokens = [] pcount = 0 # '(' level for typ, t in sql_tokenizer(sql, standard_quoting=standard_quoting): # skip whitespace and comments before statement if len(tokens) == 0 and typ == "ws": continue # keep the rest tokens.append(t) if t == "(": pcount += 1 elif t == ")": pcount -= 1 elif t == ";" and pcount == 0: sql = "".join(tokens) if _copy_from_stdin_rc.match(sql): raise ValueError("copy from stdin not supported") yield "".join(tokens) tokens = [] if len(tokens) > 0: yield "".join(tokens) if pcount != 0: raise ValueError("syntax error - unbalanced parenthesis") _acl_name = r'(?: [0-9a-z_]+ | " (?: [^"]+ | "" )* " )' _acl_re = r''' \s* (?: group \s+ | user \s+ )? (?P %s )? (?P = [a-z*]* )? (?P / %s )? \s* $ ''' % (_acl_name, _acl_name) _acl_rc = None def parse_acl(acl): """Parse ACL entry. """ global _acl_rc if not _acl_rc: _acl_rc = re.compile(_acl_re, re.I | re.X) m = _acl_rc.match(acl) if not m: return None target = m.group('tgt') perm = m.group('perm') owner = m.group('owner') if target: target = skytools.unquote_ident(target) if perm: perm = perm[1:] if owner: owner = skytools.unquote_ident(owner[1:]) return (target, perm, owner) def dedent(doc): r"""Relaxed dedent. - takes whitespace to be removed from first indented line. - allows empty or non-indented lines at the start - allows first line to be unindented - skips empty lines at the start - ignores indent of empty lines - if line does not match common indent, is stays unchanged """ pfx = None res = [] for ln in doc.splitlines(): ln = ln.rstrip() if not pfx and len(res) < 2: if not ln: continue wslen = len(ln) - len(ln.lstrip()) pfx = ln[: wslen] if pfx: if ln.startswith(pfx): ln = ln[len(pfx):] res.append(ln) res.append('') return '\n'.join(res) def hsize_to_bytes(input_str): """ Convert sizes from human format to bytes (string to integer) """ m = re.match(r"^([0-9]+) *([KMGTPEZY]?)B?$", input_str.strip(), re.IGNORECASE) if not m: raise ValueError("cannot parse: %s" % input_str) units = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] nbytes = int(m.group(1)) * 1024 ** units.index(m.group(2).upper()) return nbytes # # Connect string parsing # _cstr_rx = r""" \s* (\w+) \s* = \s* ( ' ( \\.| [^'\\] )* ' | \S+ ) \s* """ _cstr_unesc_rx = r"\\(.)" _cstr_badval_rx = r"[\s'\\]" _cstr_rc = None _cstr_unesc_rc = None _cstr_badval_rc = None def parse_connect_string(cstr): r"""Parse Postgres connect string. """ global _cstr_rc, _cstr_unesc_rc if not _cstr_rc: _cstr_rc = re.compile(_cstr_rx, re.X) _cstr_unesc_rc = re.compile(_cstr_unesc_rx) pos = 0 res = [] while pos < len(cstr): m = _cstr_rc.match(cstr, pos) if not m: raise ValueError('Invalid connect string') pos = m.end() k = m.group(1) v = m.group(2) if v[0] == "'": v = _cstr_unesc_rc.sub(r"\1", v) res.append((k, v)) return res def merge_connect_string(cstr_arg_list): """Put fragments back together. """ global _cstr_badval_rc if not _cstr_badval_rc: _cstr_badval_rc = re.compile(_cstr_badval_rx) buf = [] for k, v in cstr_arg_list: if not v or _cstr_badval_rc.search(v): v = v.replace('\\', r'\\') v = v.replace("'", r"\'") v = "'" + v + "'" buf.append("%s=%s" % (k, v)) return ' '.join(buf) python-skytools-3.6.1/skytools/plpy_applyrow.py000066400000000000000000000147131373460333300221240ustar00rootroot00000000000000"""PLPY helper module for applying row events from pgq.logutriga(). """ import skytools try: import plpy except ImportError: pass ## TODO: automatic fkey detection # find FK columns FK_SQL = """ SELECT (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = conrelid AND attnum = conkey[i])) FROM generate_series(1, array_upper(conkey, 1)) i) AS kcols, (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = confrelid AND attnum = confkey[i])) FROM generate_series(1, array_upper(confkey, 1)) i) AS fcols, confrelid::regclass::text AS ftable FROM pg_constraint WHERE conrelid = {tbl}::regclass AND contype='f' """ class DataError(Exception): "Invalid data" def colfilter_full(rnew, rold): return rnew def colfilter_changed(rnew, rold): res = {} for k, _ in rnew: if rnew[k] != rold[k]: res[k] = rnew[k] return res def canapply_dummy(rnew, rold): return True def canapply_tstamp_helper(rnew, rold, tscol): tnew = rnew[tscol] told = rold[tscol] if not tnew[0].isdigit(): raise DataError('invalid timestamp') if not told[0].isdigit(): raise DataError('invalid timestamp') return tnew > told def applyrow(tblname, ev_type, new_row, backup_row=None, alt_pkey_cols=None, fkey_cols=None, fkey_ref_table=None, fkey_ref_cols=None, fn_canapply=canapply_dummy, fn_colfilter=colfilter_full): """Core logic. Actual decisions will be done in callback functions. - [IUD]: If row referenced by fkey does not exist, event is not applied - If pkey does not exist but alt_pkey does, row is not applied. @param tblname: table name, schema-qualified @param ev_type: [IUD]:pkey1,pkey2 @param alt_pkey_cols: list of alternatice columns to consuder @param fkey_cols: columns in this table that refer to other table @param fkey_ref_table: other table referenced here @param fkey_ref_cols: column in other table that must match @param fn_canapply: callback function, gets new and old row, returns whether the row should be applied @param fn_colfilter: callback function, gets new and old row, returns dict of final columns to be applied """ gd = None # parse ev_type tmp = ev_type.split(':', 1) if len(tmp) != 2 or tmp[0] not in ('I', 'U', 'D'): raise DataError('Unsupported ev_type: ' + repr(ev_type)) if not tmp[1]: raise DataError('No pkey in event') cmd = tmp[0] pkey_cols = tmp[1].split(',') qtblname = skytools.quote_fqident(tblname) # parse ev_data fields = skytools.db_urldecode(new_row) if ev_type.find('}') >= 0: raise DataError('Really suspicious activity') if ",".join(fields.keys()).find('}') >= 0: raise DataError('Really suspicious activity 2') # generate pkey expressions tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in pkey_cols] pkey_expr = " and ".join(tmp) alt_pkey_expr = None if alt_pkey_cols: tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in alt_pkey_cols] alt_pkey_expr = " and ".join(tmp) log = "data ok" # # Row data seems fine, now apply it # if fkey_ref_table: tmp = [] for k, rk in zip(fkey_cols, fkey_ref_cols): tmp.append("%s = {%s}" % (skytools.quote_ident(rk), k)) fkey_expr = " and ".join(tmp) q = "select 1 from only %s where %s" % ( skytools.quote_fqident(fkey_ref_table), fkey_expr) res = skytools.plpy_exec(gd, q, fields) if not res: return "IGN: parent row does not exist" log += ", fkey ok" # fetch old row if alt_pkey_expr: q = "select * from only %s where %s for update" % (qtblname, alt_pkey_expr) res = skytools.plpy_exec(gd, q, fields) if res: oldrow = res[0] # if altpk matches, but pk not, then delete need_del = 0 for k in pkey_cols: # fixme: proper type cmp? if fields[k] != str(oldrow[k]): need_del = 1 break if need_del: log += ", altpk del" q = "delete from only %s where %s" % (qtblname, alt_pkey_expr) skytools.plpy_exec(gd, q, fields) res = None else: log += ", altpk ok" else: # no altpk q = "select * from only %s where %s for update" % (qtblname, pkey_expr) res = skytools.plpy_exec(None, q, fields) # got old row, with same pk and altpk if res: oldrow = res[0] log += ", old row" ok = fn_canapply(fields, oldrow) if ok: log += ", new row better" if not ok: # ignore the update return "IGN:" + log + ", current row more up-to-date" else: log += ", no old row" oldrow = None if res: if cmd == 'I': cmd = 'U' else: if cmd == 'U': cmd = 'I' # allow column changes if oldrow: fields2 = fn_colfilter(fields, oldrow) for k in pkey_cols: if k not in fields2: fields2[k] = fields[k] fields = fields2 # apply change if cmd == 'I': q = skytools.mk_insert_sql(fields, tblname, pkey_cols) elif cmd == 'U': q = skytools.mk_update_sql(fields, tblname, pkey_cols) elif cmd == 'D': q = skytools.mk_delete_sql(fields, tblname, pkey_cols) else: plpy.error('Huh') plpy.execute(q) return log def ts_conflict_handler(gd, args): """Conflict handling based on timestamp column.""" conf = skytools.db_urldecode(args[0]) timefield = conf['timefield'] ev_type = args[1] ev_data = args[2] ev_extra1 = args[3] ev_extra2 = args[4] #ev_extra3 = args[5] #ev_extra4 = args[6] altpk = None if 'altpk' in conf: altpk = conf['altpk'].split(',') def ts_canapply(rnew, rold): return canapply_tstamp_helper(rnew, rold, timefield) return applyrow( ev_extra1, ev_type, ev_data, backup_row=ev_extra2, alt_pkey_cols=altpk, fkey_ref_table=conf.get('fkey_ref_table'), fkey_ref_cols=conf.get('fkey_ref_cols'), fkey_cols=conf.get('fkey_cols'), fn_canapply=ts_canapply ) python-skytools-3.6.1/skytools/psycopgwrapper.py000066400000000000000000000033551373460333300222700ustar00rootroot00000000000000"""Wrapper around psycopg2. Database connection provides regular DB-API 2.0 interface. Connection object methods:: .cursor() .commit() .rollback() .close() Cursor methods:: .execute(query[, args]) .fetchone() .fetchall() Sample usage:: db = self.get_database('somedb') curs = db.cursor() # query arguments as array q = "select * from table where id = %s and name = %s" curs.execute(q, [1, 'somename']) # query arguments as dict q = "select id, name from table where id = %(id)s and name = %(name)s" curs.execute(q, {'id': 1, 'name': 'somename'}) # loop over resultset for row in curs.fetchall(): # columns can be asked by index: id = row[0] name = row[1] # and by name: id = row['id'] name = row['name'] # now commit the transaction db.commit() """ import psycopg2 import psycopg2.extensions import psycopg2.extras from psycopg2 import Error as DBError __all__ = ( 'connect_database', 'DBError', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_REPEATABLE_READ', 'I_SERIALIZABLE', ) #: Isolation level for db.set_isolation_level() I_AUTOCOMMIT = psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT #: Isolation level for db.set_isolation_level() I_READ_COMMITTED = psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED #: Isolation level for db.set_isolation_level() I_REPEATABLE_READ = psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ #: Isolation level for db.set_isolation_level() I_SERIALIZABLE = psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE def connect_database(connstr): """Create a db connection with DictCursor. """ db = psycopg2.connect(connstr, cursor_factory=psycopg2.extras.DictCursor) return db python-skytools-3.6.1/skytools/querybuilder.py000066400000000000000000000277121373460333300217220ustar00rootroot00000000000000"""Helper classes for complex query generation. Main target is code execution under PL/Python. Query parameters are referenced as C{{key}} or C{{key:type}}. Type will be given to C{plpy.prepare}. If C{type} is missing, C{text} is assumed. See L{plpy_exec} for examples. """ import skytools try: import plpy except ImportError: plpy = None __all__ = ( 'QueryBuilder', 'PLPyQueryBuilder', 'PLPyQuery', 'plpy_exec', "run_query", "run_query_row", "run_lookup", "run_exists", ) PARAM_INLINE = 0 # quote_literal() PARAM_DBAPI = 1 # %()s PARAM_PLPY = 2 # $n class QArgConf: """Per-query arg-type config object.""" param_type = None class QArg: """Place-holder for a query parameter.""" def __init__(self, name, value, pos, conf): self.name = name self.value = value self.pos = pos self.conf = conf def __str__(self): if self.conf.param_type == PARAM_INLINE: return skytools.quote_literal(self.value) elif self.conf.param_type == PARAM_DBAPI: return "%s" elif self.conf.param_type == PARAM_PLPY: return "$%d" % self.pos else: raise Exception("bad QArgConf.param_type") # need an structure with fast remove-from-middle # and append operations. class DList: """Simple double-linked list.""" __slots__ = ('next', 'prev') def __init__(self): self.next = self self.prev = self def append(self, obj): obj.next = self obj.prev = self.prev self.prev.next = obj self.prev = obj def remove(self, obj): obj.next.prev = obj.prev obj.prev.next = obj.next obj.next = obj.prev = None def empty(self): return self.next is self def pop(self): """Remove and return first element.""" obj = None if not self.empty(): obj = self.next self.remove(obj) return obj class CachedPlan(DList): """Wrapper around prepared plan.""" __slots__ = ('key', 'plan') def __init__(self, key, plan): super(CachedPlan, self).__init__() self.key = key # (sql, (types)) self.plan = plan class PlanCache: """Cache for limited amount of plans.""" def __init__(self, maxplans=100): self.maxplans = maxplans self.plan_map = {} self.plan_list = DList() def get_plan(self, sql, types): """Prepare the plan and cache it.""" t = (sql, tuple(types)) if t in self.plan_map: pc = self.plan_map[t] # put to the end self.plan_list.remove(pc) self.plan_list.append(pc) return pc.plan # prepare new plan plan = plpy.prepare(sql, types) # add to cache pc = CachedPlan(t, plan) self.plan_list.append(pc) self.plan_map[t] = pc # remove plans if too much while len(self.plan_map) > self.maxplans: # this is ugly workaround for pylint drop = self.plan_list.pop() del self.plan_map[getattr(drop, 'key')] return plan class QueryBuilderCore: """Helper for query building. """ def __init__(self, sqlexpr, params): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. """ self._params = params self._arg_type_list = [] self._arg_value_list = [] self._sql_parts = [] self._arg_conf = QArgConf() self._nargs = 0 if sqlexpr: self.add(sqlexpr, required=True) def add(self, expr, sql_type="text", required=False): """Add SQL fragment to query. """ self._add_expr('', expr, self._params, sql_type, required) def get_sql(self, param_type=PARAM_INLINE): """Return generated SQL (thus far) as string. Possible values for param_type: - 0: Insert values quoted with quote_literal() - 1: Insert %()s in place of parameters. - 2: Insert $n in place of parameters. """ self._arg_conf.param_type = param_type tmp = [str(part) for part in self._sql_parts] return "".join(tmp) def _add_expr(self, pfx, expr, params, sql_type, required): parts = [] types = [] values = [] nargs = self._nargs if pfx: parts.append(pfx) pos = 0 while True: # find start of next argument a1 = expr.find('{', pos) if a1 < 0: parts.append(expr[pos:]) break # find end end of argument name a2 = expr.find('}', a1) if a2 < 0: raise Exception("missing argument terminator: " + expr) # add plain sql if a1 > pos: parts.append(expr[pos:a1]) pos = a2 + 1 # get arg name, check if exists k = expr[a1 + 1: a2] # split name from type tpos = k.rfind(':') if tpos > 0: kparam = k[:tpos] ktype = k[tpos + 1:] else: kparam = k ktype = sql_type # params==None means params are checked later if params is not None and kparam not in params: if required: raise Exception("required parameter missing: " + kparam) # optional fragment, param missing, skip it return # got arg nargs += 1 if params is not None: val = params[kparam] else: val = kparam values.append(val) types.append(ktype) arg = QArg(kparam, val, nargs, self._arg_conf) parts.append(arg) # add interesting parts to the main sql self._sql_parts.extend(parts) if types: self._arg_type_list.extend(types) if values: self._arg_value_list.extend(values) self._nargs = nargs class QueryBuilder(QueryBuilderCore): def execute(self, curs): """Client-side query execution on DB-API 2.0 cursor. Calls C{curs.execute()} with proper arguments. Returns result of curs.execute(), although that does not return anything interesting. Later curs.fetch* methods must be called to get result. """ q = self.get_sql(PARAM_DBAPI) args = self._params return curs.execute(q, args) class PLPyQueryBuilder(QueryBuilderCore): def __init__(self, sqlexpr, params, plan_cache=None, sqls=None): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. @param plan_cache: (PL/Python) A dict object where to store the plan cache, under the key C{"plan_cache"}. If not given, plan will not be cached and values will be inserted directly to query. Usually either C{GD} or C{SD} should be given here. @param sqls: list object where to append executed sqls (used for debugging) """ super(PLPyQueryBuilder, self).__init__(sqlexpr, params) self._sqls = sqls if plan_cache is not None: if 'plan_cache' not in plan_cache: plan_cache['plan_cache'] = PlanCache() self._plan_cache = plan_cache['plan_cache'] else: self._plan_cache = None def execute(self): """Server-side query execution via plpy. Query can be run either cached or uncached, depending on C{plan_cache} setting given to L{__init__}. Returns result of plpy.execute(). """ args = self._arg_value_list types = self._arg_type_list if self._sqls is not None: self._sqls.append({"sql": self.get_sql(PARAM_INLINE)}) if self._plan_cache is not None: sql = self.get_sql(PARAM_PLPY) plan = self._plan_cache.get_plan(sql, types) res = plpy.execute(plan, args) else: sql = self.get_sql(PARAM_INLINE) res = plpy.execute(sql) if res: res = [skytools.dbdict(r) for r in res] return res class PLPyQuery: """Static, cached PL/Python query that uses QueryBuilder formatting. See L{plpy_exec} for simple usage. """ def __init__(self, sql): qb = QueryBuilder(sql, None) p_sql = qb.get_sql(PARAM_PLPY) p_types = qb._arg_type_list self.plan = plpy.prepare(p_sql, p_types) self.arg_map = qb._arg_value_list self.sql = sql def execute(self, arg_dict, all_keys_required=True): try: if all_keys_required: arg_list = [arg_dict[k] for k in self.arg_map] else: arg_list = [arg_dict.get(k) for k in self.arg_map] return plpy.execute(self.plan, arg_list) except KeyError: need = set(self.arg_map) got = set(arg_dict.keys()) missing = list(need.difference(got)) plpy.error("Missing arguments: [%s] QUERY: %s" % ( ','.join(missing), repr(self.sql))) def __repr__(self): return 'PLPyQuery<%s>' % self.sql def plpy_exec(gd, sql, args, all_keys_required=True): """Cached plan execution for PL/Python. @param gd: dict to store cached plans under. If None, caching is disabled. @param sql: SQL statement to execute. @param args: dict of arguments to query. @param all_keys_required: if False, missing key is taken as NULL, instead of throwing error. """ if gd is None: return PLPyQueryBuilder(sql, args).execute() try: sq = gd['plq_cache'][sql] except KeyError: if 'plq_cache' not in gd: gd['plq_cache'] = {} sq = PLPyQuery(sql) gd['plq_cache'][sql] = sq return sq.execute(args, all_keys_required) # some helper functions for convenient sql execution def run_query(cur, sql, params=None, **kwargs): """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) rows = cur.fetchall() # convert result rows to dbdict if rows: rows = [skytools.dbdict(r) for r in rows] return rows def run_query_row(cur, sql, params=None, **kwargs): """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kwargs rows = run_query(cur, sql, params) if len(rows) == 0: return None return rows[0] def run_lookup(cur, sql, params=None, **kwargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) row = cur.fetchone() if row is None: return None return row[0] def run_exists(cur, sql, params=None, **kwargs): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs val = run_lookup(cur, sql, params) return val is not None # fake plpy for testing class fake_plpy: log = [] def prepare(self, sql, types): self.log.append("DBG: plpy.prepare(%s, %s)" % (repr(sql), repr(types))) return ('PLAN', sql, types) def execute(self, plan, args=()): self.log.append("DBG: plpy.execute(%s, %s)" % (repr(plan), repr(args))) def error(self, msg): self.log.append("DBG: plpy.error(%s)" % repr(msg)) # make plpy available if not plpy: plpy = fake_plpy() GD = {} python-skytools-3.6.1/skytools/quoting.py000066400000000000000000000131531373460333300206660ustar00rootroot00000000000000"""Various helpers for string quoting/unquoting. """ import json import re try: from skytools._cquoting import ( db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal, ) except ImportError: from skytools._pyquoting import ( db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal, ) __all__ = ( # _pyqoting / _cquoting "db_urldecode", "db_urlencode", "quote_bytea_raw", "quote_copy", "quote_literal", "unescape", "unquote_literal", # local "quote_bytea_literal", "quote_bytea_copy", "quote_statement", "quote_ident", "quote_fqident", "quote_json", "unescape_copy", "unquote_ident", "unquote_fqident", "json_encode", "json_decode", "make_pgarray", ) # # SQL quoting # def quote_bytea_literal(s): """Quote bytea for regular SQL.""" return quote_literal(quote_bytea_raw(s)) def quote_bytea_copy(s): """Quote bytea for COPY.""" return quote_copy(quote_bytea_raw(s)) def quote_statement(sql, dict_or_list): """Quote whole statement. Data values are taken from dict or list or tuple. """ if hasattr(dict_or_list, 'items'): qdict = {} for k, v in dict_or_list.items(): qdict[k] = quote_literal(v) return sql % qdict else: qvals = [quote_literal(v) for v in dict_or_list] return sql % tuple(qvals) # reserved keywords (RESERVED_KEYWORD + TYPE_FUNC_NAME_KEYWORD) _ident_kwmap = { "all": 1, "analyse": 1, "analyze": 1, "and": 1, "any": 1, "array": 1, "as": 1, "asc": 1, "asymmetric": 1, "authorization": 1, "between": 1, "binary": 1, "both": 1, "case": 1, "cast": 1, "check": 1, "collate": 1, "collation": 1, "column": 1, "concurrently": 1, "constraint": 1, "create": 1, "cross": 1, "current_catalog": 1, "current_date": 1, "current_role": 1, "current_schema": 1, "current_time": 1, "current_timestamp": 1, "current_user": 1, "default": 1, "deferrable": 1, "desc": 1, "distinct": 1, "do": 1, "else": 1, "end": 1, "errors": 1, "except": 1, "false": 1, "fetch": 1, "for": 1, "foreign": 1, "freeze": 1, "from": 1, "full": 1, "grant": 1, "group": 1, "having": 1, "ilike": 1, "in": 1, "initially": 1, "inner": 1, "intersect": 1, "into": 1, "is": 1, "isnull": 1, "join": 1, "lateral": 1, "leading": 1, "left": 1, "like": 1, "limit": 1, "localtime": 1, "localtimestamp": 1, "natural": 1, "new": 1, "not": 1, "notnull": 1, "null": 1, "off": 1, "offset": 1, "old": 1, "on": 1, "only": 1, "or": 1, "order": 1, "outer": 1, "over": 1, "overlaps": 1, "placing": 1, "primary": 1, "references": 1, "returning": 1, "right": 1, "select": 1, "session_user": 1, "similar": 1, "some": 1, "symmetric": 1, "table": 1, "tablesample": 1, "then": 1, "to": 1, "trailing": 1, "true": 1, "union": 1, "unique": 1, "user": 1, "using": 1, "variadic": 1, "verbose": 1, "when": 1, "where": 1, "window": 1, "with": 1, } _ident_bad = re.compile(r"[^a-z0-9_]|^[0-9]") def quote_ident(s): """Quote SQL identifier. If is checked against weird symbols and keywords. """ if _ident_bad.search(s) or s in _ident_kwmap: s = '"%s"' % s.replace('"', '""') elif not s: return '""' return s def quote_fqident(s): """Quote fully qualified SQL identifier. The '.' is taken as namespace separator and all parts are quoted separately """ tmp = s.split('.', 1) if len(tmp) == 1: return 'public.' + quote_ident(s) return '.'.join([quote_ident(name) for name in tmp]) # # quoting for JSON strings # _jsre = re.compile(r'[\x00-\x1F\\/"]') _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", "\\": "\\\\", '"': '\\"', "/": "\\/", # to avoid html attacks } def _json_quote_char(m): """Quote single char.""" c = m.group(0) try: return _jsmap[c] except KeyError: return r"\u%04x" % ord(c) def quote_json(s): """JSON style quoting.""" if s is None: return "null" return '"%s"' % _jsre.sub(_json_quote_char, s) def unescape_copy(val): r"""Removes C-style escapes, also converts "\N" to None. """ if val == r"\N": return None return unescape(val) def unquote_ident(val): """Unquotes possibly quoted SQL identifier. """ if len(val) > 1 and val[0] == '"' and val[-1] == '"': return val[1:-1].replace('""', '"') if val.find('"') > 0: raise Exception('unsupported syntax') return val.lower() def unquote_fqident(val): """Unquotes fully-qualified possibly quoted SQL identifier. """ tmp = val.split('.', 1) return '.'.join([unquote_ident(i) for i in tmp]) def json_encode(val=None, **kwargs): """Creates JSON string from Python object. """ return json.dumps(val or kwargs) def json_decode(s): """Parses JSON string into Python object. """ return json.loads(s) # # Create Postgres array # # any chars not in "good" set? main bad ones: [ ,{}\"] _pgarray_bad_rx = r"[^0-9a-z_.%&=()<>*/+-]" _pgarray_bad_rc = None def _quote_pgarray_elem(s): if s is None: return 'NULL' s = str(s) if _pgarray_bad_rc.search(s): s = s.replace('\\', '\\\\') return '"' + s.replace('"', r'\"') + '"' elif not s: return '""' return s def make_pgarray(lst): r"""Formats Python list as Postgres array. Reverse of parse_pgarray(). """ global _pgarray_bad_rc if _pgarray_bad_rc is None: _pgarray_bad_rc = re.compile(_pgarray_bad_rx) items = [_quote_pgarray_elem(v) for v in lst] return '{' + ','.join(items) + '}' python-skytools-3.6.1/skytools/scripting.py000066400000000000000000001135071373460333300212060ustar00rootroot00000000000000"""Useful functions and classes for database scripts. """ import argparse import errno import logging import logging.config import logging.handlers import optparse import os import select import signal import sys import time import skytools import skytools.skylog try: import skytools.installer_config default_skylog = skytools.installer_config.skylog except ImportError: default_skylog = 0 __all__ = ( 'BaseScript', 'UsageError', 'daemonize', 'DBScript', ) class UsageError(Exception): """User induced error.""" # # daemon mode # def daemonize(): """Turn the process into daemon. Goes background and disables all i/o. """ # launch new process, kill parent pid = os.fork() if pid != 0: os._exit(0) # start new session os.setsid() # stop i/o fd = os.open("/dev/null", os.O_RDWR) os.dup2(fd, 0) os.dup2(fd, 1) os.dup2(fd, 2) if fd > 2: os.close(fd) # # Pidfile locking+cleanup & daemonization combined # def run_single_process(runnable, daemon, pidfile): """Run runnable class, possibly daemonized, locked on pidfile.""" # check if another process is running if pidfile and os.path.isfile(pidfile): if skytools.signal_pidfile(pidfile, 0): print("Pidfile exists, another process running?") sys.exit(1) else: print("Ignoring stale pidfile") # daemonize if needed if daemon: daemonize() # clean only own pidfile own_pidfile = False try: if pidfile: data = str(os.getpid()) skytools.write_atomic(pidfile, data) own_pidfile = True runnable.run() finally: if own_pidfile: try: os.remove(pidfile) except BaseException: pass # # logging setup # _log_config_done = 0 _log_init_done = {} def _init_log(job_name, service_name, cf, log_level, is_daemon): """Logging setup happens here.""" global _log_config_done got_skylog = 0 use_skylog = cf.getint("use_skylog", default_skylog) # if non-daemon, avoid skylog if script is running on console. # set use_skylog=2 to disable. if not is_daemon and use_skylog == 1: # pylint gets spooked by it's own stdout wrapper and refuses to shut down # about it. 'noqa' tells prospector to ignore all warnings here. if sys.stdout.isatty(): # noqa use_skylog = 0 # load logging config if needed if use_skylog and not _log_config_done: # python logging.config braindamage: # cannot specify external classess without such hack logging.skylog = skytools.skylog skytools.skylog.set_service_name(service_name, job_name) # load general config flist = cf.getlist('skylog_locations', ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']) for fn in flist: fn = os.path.expanduser(fn) if os.path.isfile(fn): defs = {'job_name': job_name, 'service_name': service_name} logging.config.fileConfig(fn, defs, False) got_skylog = 1 break _log_config_done = 1 if not got_skylog: sys.stderr.write("skylog.ini not found!\n") sys.exit(1) # avoid duplicate logging init for job_name log = logging.getLogger(job_name) if job_name in _log_init_done: return log _log_init_done[job_name] = 1 # tune level on root logger root = logging.getLogger() root.setLevel(log_level) # compatibility: specify ini file in script config def_fmt = '%(asctime)s %(process)s %(levelname)s %(message)s' def_datefmt = '' # None logfile = cf.getfile("logfile", "") if logfile: fstr = cf.get('logfmt_file', def_fmt) fstr_date = cf.get('logdatefmt_file', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_file_verbose', fstr) fstr_date = cf.get('logdatefmt_file_verbose', fstr_date) fmt = logging.Formatter(fstr, fstr_date) size = cf.getint('log_size', 10 * 1024 * 1024) num = cf.getint('log_count', 3) file_hdlr = logging.handlers.RotatingFileHandler( logfile, 'a', size, num) file_hdlr.setFormatter(fmt) root.addHandler(file_hdlr) # if skylog.ini is disabled or not available, log at least to stderr if not got_skylog: fstr = cf.get('logfmt_console', def_fmt) fstr_date = cf.get('logdatefmt_console', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_console_verbose', fstr) fstr_date = cf.get('logdatefmt_console_verbose', fstr_date) stream_hdlr = logging.StreamHandler() fmt = logging.Formatter(fstr, fstr_date) stream_hdlr.setFormatter(fmt) root.addHandler(stream_hdlr) return log class BaseScript: """Base class for service scripts. Handles logging, daemonizing, config, errors. Config template:: ## Parameters for skytools.BaseScript ## # how many seconds to sleep between work loops # if missing or 0, then instead sleeping, the script will exit loop_delay = 1.0 # where to log logfile = ~/log/%(job_name)s.log # where to write pidfile pidfile = ~/pid/%(job_name)s.pid # per-process name to use in logging #job_name = %(config_name)s # whether centralized logging should be used # search-path [ ./skylog.ini, ~/.skylog.ini, /etc/skylog.ini ] # 0 - disabled # 1 - enabled, unless non-daemon on console (os.isatty()) # 2 - always enabled #use_skylog = 0 # where to find skylog.ini #skylog_locations = skylog.ini, ~/.skylog.ini, /etc/skylog.ini # how many seconds to sleep after catching a exception #exception_sleep = 20 """ service_name = None job_name = None cf = None cf_defaults = {} pidfile = None # >0 - sleep time if work() requests sleep # 0 - exit if work requests sleep # <0 - run work() once [same as looping=0] loop_delay = 1.0 # 0 - run work() once # 1 - run work() repeatedly looping = 1 # result from last work() call: # 1 - there is probably more work, don't sleep # 0 - no work, sleep before calling again # -1 - exception was thrown work_state = 1 # setup logger here, this allows override by subclass log = logging.getLogger('skytools.BaseScript') # start time started = 0 # set to True to use argparse ARGPARSE = False def __init__(self, service_name, args): """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset(), shutdown() and init_optparse(). NB: In case of daemon, __init__() and startup()/work()/shutdown() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.service_name = service_name self.go_daemon = 0 self.need_reload = 0 self.exception_count = 0 self.stat_dict = {} self.log_level = logging.INFO # parse command line self.options, self.args = self.parse_args(args) # check args if self.options.version: self.print_version() sys.exit(0) if self.options.daemon: self.go_daemon = 1 if self.options.quiet: self.log_level = logging.WARNING if self.options.verbose: if self.options.verbose > 1: self.log_level = skytools.skylog.TRACE else: self.log_level = logging.DEBUG self.cf_override = {} if self.options.set: for a in self.options.set: k, v = a.split('=', 1) self.cf_override[k.strip()] = v.strip() if self.options.ini: self.print_ini() sys.exit(0) # read config file self.reload() # init logging _init_log(self.job_name, self.service_name, self.cf, self.log_level, self.go_daemon) # send signal, if needed if self.options.cmd == "kill": self.send_signal(signal.SIGTERM) elif self.options.cmd == "stop": self.send_signal(signal.SIGINT) elif self.options.cmd == "reload": self.send_signal(signal.SIGHUP) def parse_args(self, args): if self.ARGPARSE: parser = self.init_argparse() options = parser.parse_args(args) args = getattr(options, "args", []) else: parser = self.init_optparse() options, args = parser.parse_args(args) return options, args def print_version(self): service = self.service_name ver = getattr(self, '__version__', None) if ver: service += ' version %s' % ver print('%s, Skytools version %s' % (service, getattr(skytools, '__version__'))) def print_ini(self): """Prints out ini file from doc string of the script of default for dbscript Used by --ini option on command line. """ # current service name print("[%s]\n" % self.service_name) # walk class hierarchy bases = [self.__class__] while len(bases) > 0: parents = [] for c in bases: for p in c.__bases__: if p not in parents: parents.append(p) doc = c.__doc__ if doc: self._print_ini_frag(doc) bases = parents def _print_ini_frag(self, doc): # use last '::' block as config template pos = doc and doc.rfind('::\n') or -1 if pos < 0: return doc = doc[pos + 2:].rstrip() doc = skytools.dedent(doc) # merge overrided options into output for ln in doc.splitlines(): vals = ln.split('=', 1) if len(vals) != 2: print(ln) continue k = vals[0].strip() v = vals[1].strip() if k and k[0] == '#': print(ln) k = k[1:] if k in self.cf_override: print('%s = %s' % (k, self.cf_override[k])) elif k in self.cf_override: if v: print('#' + ln) print('%s = %s' % (k, self.cf_override[k])) else: print(ln) print('') def load_config(self): """Loads and returns skytools.Config instance. By default it uses first command-line argument as config file name. Can be overridden. """ if len(self.args) < 1: print("need config file, use --help for help.") sys.exit(1) conf_file = self.args[0] return skytools.Config(self.service_name, conf_file, user_defs=self.cf_defaults, override=self.cf_override) def init_optparse(self, parser=None): """Initialize a OptionParser() instance that will be used to parse command line arguments. Note that it can be overridden both directions - either DBScript will initialize an instance and pass it to user code or user can initialize and then pass to DBScript.init_optparse(). @param parser: optional OptionParser() instance, where DBScript should attach its own arguments. @return: initialized OptionParser() instance. """ if parser: p = parser else: p = optparse.OptionParser() p.set_usage("%prog [options] INI") # generic options p.add_option("-q", "--quiet", action="store_true", help="log only errors and warnings") p.add_option("-v", "--verbose", action="count", help="log verbosely") p.add_option("-d", "--daemon", action="store_true", help="go background") p.add_option("-V", "--version", action="store_true", help="print version info and exit") p.add_option("", "--ini", action="store_true", help="display sample ini file") p.add_option("", "--set", action="append", help="override config setting (--set 'PARAM=VAL')") # control options g = optparse.OptionGroup(p, 'control running process') g.add_option("-r", "--reload", action="store_const", const="reload", dest="cmd", help="reload config (send SIGHUP)") g.add_option("-s", "--stop", action="store_const", const="stop", dest="cmd", help="stop program safely (send SIGINT)") g.add_option("-k", "--kill", action="store_const", const="kill", dest="cmd", help="kill program immediately (send SIGTERM)") p.add_option_group(g) return p def init_argparse(self, parser=None): """Initialize a ArgumentParser() instance that will be used to parse command line arguments. Note that it can be overridden both directions - either BaseScript will initialize an instance and pass it to user code or user can initialize and then pass to BaseScript.init_optparse(). @param parser: optional ArgumentParser() instance, where BaseScript should attach its own arguments. @return: initialized ArgumentParser() instance. """ if parser: p = parser else: p = argparse.ArgumentParser() # generic options p.add_argument("-q", "--quiet", action="store_true", help="log only errors and warnings") p.add_argument("-v", "--verbose", action="count", help="log verbosely") p.add_argument("-d", "--daemon", action="store_true", help="go background") p.add_argument("-V", "--version", action="store_true", help="print version info and exit") p.add_argument("--ini", action="store_true", help="display sample ini file") p.add_argument("--set", action="append", help="override config setting (--set 'PARAM=VAL')") p.add_argument("args", nargs="*") # control options g = p.add_argument_group('control running process') g.add_argument("-r", "--reload", action="store_const", const="reload", dest="cmd", help="reload config (send SIGHUP)") g.add_argument("-s", "--stop", action="store_const", const="stop", dest="cmd", help="stop program safely (send SIGINT)") g.add_argument("-k", "--kill", action="store_const", const="kill", dest="cmd", help="kill program immediately (send SIGTERM)") return p def send_signal(self, sig): if not self.pidfile: self.log.warning("No pidfile in config, nothing to do") elif os.path.isfile(self.pidfile): alive = skytools.signal_pidfile(self.pidfile, sig) if not alive: self.log.warning("pidfile exists, but process not running") else: self.log.warning("No pidfile, process not running") sys.exit(0) def set_single_loop(self, do_single_loop): """Changes whether the script will loop or not.""" if do_single_loop: self.looping = 0 else: self.looping = 1 def _boot_daemon(self): run_single_process(self, self.go_daemon, self.pidfile) def start(self): """This will launch main processing thread.""" if self.go_daemon: if not self.pidfile: self.log.error("Daemon needs pidfile") sys.exit(1) self.run_func_safely(self._boot_daemon) def stop(self): """Safely stops processing loop.""" self.looping = 0 def reload(self): "Reload config." # avoid double loading on startup if not self.cf: self.cf = self.load_config() else: self.cf.reload() self.log.info("Config reloaded") self.job_name = self.cf.get("job_name") self.pidfile = self.cf.getfile("pidfile", '') self.loop_delay = self.cf.getfloat("loop_delay", self.loop_delay) self.exception_sleep = self.cf.getfloat("exception_sleep", 20) self.exception_quiet = self.cf.getlist("exception_quiet", []) self.exception_grace = self.cf.getfloat("exception_grace", 5 * 60) self.exception_reset = self.cf.getfloat("exception_reset", 15 * 60) def hook_sighup(self, sig, frame): "Internal SIGHUP handler. Minimal code here." self.need_reload = 1 last_sigint = 0 def hook_sigint(self, sig, frame): "Internal SIGINT handler. Minimal code here." self.stop() t = time.time() if t - self.last_sigint < 1: self.log.warning("Double ^C, fast exit") sys.exit(1) self.last_sigint = t def stat_get(self, key): """Reads a stat value.""" try: value = self.stat_dict[key] except KeyError: value = None return value def stat_put(self, key, value): """Sets a stat value.""" self.stat_dict[key] = value def stat_increase(self, key, increase=1): """Increases a stat value.""" try: self.stat_dict[key] += increase except KeyError: self.stat_dict[key] = increase def send_stats(self): "Send statistics to log." res = [] for k, v in self.stat_dict.items(): res.append("%s: %s" % (k, v)) if len(res) == 0: return logmsg = "{%s}" % ", ".join(res) self.log.info(logmsg) self.stat_dict = {} def reset(self): "Something bad happened, reset all state." pass def run(self): "Thread main loop." # run startup, safely self.run_func_safely(self.startup) while True: # reload config, if needed if self.need_reload: self.reload() self.need_reload = 0 # do some work work = self.run_once() if not self.looping or self.loop_delay < 0: break # remember work state self.work_state = work # should sleep? if not work: if self.loop_delay > 0: self.sleep(self.loop_delay) if not self.looping: break else: break # run shutdown, safely? self.shutdown() def run_once(self): state = self.run_func_safely(self.work, True) # send stats that was added self.send_stats() return state last_func_fail = None def run_func_safely(self, func, prefer_looping=False): "Run users work function, safely." try: r = func() if self.last_func_fail and time.time() > self.last_func_fail + self.exception_reset: self.last_func_fail = None # set exception count to 0 after success self.exception_count = 0 return r except UsageError as d: self.log.error(str(d)) sys.exit(1) except MemoryError: try: # complex logging may not succeed self.log.exception("Job %s out of memory, exiting", self.job_name) except MemoryError: self.log.fatal("Out of memory") sys.exit(1) except SystemExit as d: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got SystemExit(%s), exiting", str(d)) self.reset() raise d except KeyboardInterrupt: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got KeyboardInterrupt, exiting") self.reset() sys.exit(1) except Exception as d: try: # this may fail too self.send_stats() except BaseException: pass if self.last_func_fail is None: self.last_func_fail = time.time() emsg = str(d).rstrip() self.reset() self.exception_hook(d, emsg) # reset and sleep self.reset() if prefer_looping and self.looping and self.loop_delay > 0: # increase exception count & sleep self.exception_count += 1 self.sleep_on_exception() return -1 sys.exit(1) def sleep(self, secs): """Make script sleep for some amount of time.""" try: time.sleep(secs) except IOError as ex: if ex.errno != errno.EINTR: raise def sleep_on_exception(self): """Make script sleep for some amount of time when an exception occurs. To implement more advance exception sleeping like exponential backoff you can override this method. Also note that you can use self.exception_count to track the number of consecutive exceptions. """ self.sleep(self.exception_sleep) def _is_quiet_exception(self, ex): return ((self.exception_quiet == ["ALL"] or ex.__class__.__name__ in self.exception_quiet) and self.last_func_fail and time.time() < self.last_func_fail + self.exception_grace) def exception_hook(self, det, emsg): """Called on after exception processing. Can do additional logging. @param det: exception details @param emsg: exception msg """ lm = "Job %s crashed: %s" % (self.job_name, emsg) if self._is_quiet_exception(det): self.log.warning(lm) else: self.log.exception(lm) def work(self): """Here should user's processing happen. Return value is taken as boolean - if true, the next loop starts immediately. If false, DBScript sleeps for a loop_delay. """ raise Exception("Nothing implemented?") def startup(self): """Will be called just before entering main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ self.started = time.time() # set signals if hasattr(signal, 'SIGHUP'): signal.signal(signal.SIGHUP, self.hook_sighup) if hasattr(signal, 'SIGINT'): signal.signal(signal.SIGINT, self.hook_sigint) def shutdown(self): """Will be called just after exiting main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ pass # define some aliases (short-cuts / backward compatibility cruft) stat_add = stat_put # Old, deprecated function. stat_inc = stat_increase ## ## DBScript ## #: how old connections need to be closed DEF_CONN_AGE = 20 * 60 # 20 min class DBScript(BaseScript): """Base class for database scripts. Handles database connection state. Config template:: ## Parameters for skytools.DBScript ## # default lifetime for database connections (in seconds) #connection_lifetime = 1200 """ def __init__(self, service_name, args): """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset() and init_optparse(). NB: in case of daemon, the __init__() and startup()/work() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.db_cache = {} self._db_defaults = {} self._listen_map = {} # dbname: channel_list super(DBScript, self).__init__(service_name, args) def connection_hook(self, dbname, conn): pass def set_database_defaults(self, dbname, **kwargs): self._db_defaults[dbname] = kwargs def add_connect_string_profile(self, connstr, profile): """Add extra profile info to connect string. """ if profile: extra = self.cf.get("%s_extra_connstr" % profile, '') if extra: connstr += ' ' + extra return connstr def get_database(self, dbname, autocommit=0, isolation_level=-1, cache=None, connstr=None, profile=None): """Load cached database connection. User must not store it permanently somewhere, as all connections will be invalidated on reset. """ max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE) if not cache: cache = dbname params = {} defs = self._db_defaults.get(cache, {}) params.update(defs) if isolation_level >= 0: params['isolation_level'] = isolation_level elif autocommit: params['isolation_level'] = 0 elif params.get('autocommit', 0): params['isolation_level'] = 0 elif 'isolation_level' not in params: params['isolation_level'] = skytools.I_READ_COMMITTED if 'max_age' not in params: params['max_age'] = max_age if cache in self.db_cache: dbc = self.db_cache[cache] if connstr is None: connstr = self.cf.get(dbname, '') if connstr: connstr = self.add_connect_string_profile(connstr, profile) dbc.check_connstr(connstr) else: if not connstr: connstr = self.cf.get(dbname) connstr = self.add_connect_string_profile(connstr, profile) # connstr might contain password, it is not a good idea to log it filtered_connstr = connstr pos = connstr.lower().find('password') if pos >= 0: filtered_connstr = connstr[:pos] + ' [...]' self.log.debug("Connect '%s' to '%s'", cache, filtered_connstr) dbc = DBCachedConn(cache, connstr, params['max_age'], setup_func=self.connection_hook) self.db_cache[cache] = dbc clist = [] if cache in self._listen_map: clist = self._listen_map[cache] return dbc.get_connection(params['isolation_level'], clist) def close_database(self, dbname): """Explicitly close a cached connection. Next call to get_database() will reconnect. """ if dbname in self.db_cache: dbc = self.db_cache[dbname] dbc.reset() del self.db_cache[dbname] def reset(self): "Something bad happened, reset all connections." for dbc in self.db_cache.values(): dbc.reset() self.db_cache = {} super(DBScript, self).reset() def run_once(self): state = super(DBScript, self).run_once() # reconnect if needed for dbc in self.db_cache.values(): dbc.refresh() return state def exception_hook(self, d, emsg): """Log database and query details from exception.""" curs = getattr(d, 'cursor', None) conn = getattr(curs, 'connection', None) if conn: # db connection sql = getattr(curs, 'query', None) or '?' if isinstance(sql, bytes): sql = sql.decode('utf8') if len(sql) > 200: # avoid logging londiste huge batched queries sql = sql[:60] + " ..." lm = "Job %s got error on connection: %s. Query: %s" % ( self.job_name, emsg, sql) if self._is_quiet_exception(d): self.log.warning(lm) else: self.log.exception(lm) else: super(DBScript, self).exception_hook(d, emsg) def sleep(self, secs): """Make script sleep for some amount of time.""" fdlist = [] for dbname in self._listen_map: if dbname not in self.db_cache: continue fd = self.db_cache[dbname].fileno() if fd is None: continue fdlist.append(fd) if not fdlist: return super(DBScript, self).sleep(secs) try: if hasattr(select, 'poll'): p = select.poll() for fd in fdlist: p.register(fd, select.POLLIN) p.poll(int(secs * 1000)) else: select.select(fdlist, [], [], secs) except select.error: self.log.info('wait canceled') return None def _exec_cmd(self, curs, sql, args, quiet=False, prefix=None): """Internal tool: Run SQL on cursor.""" if self.options.verbose: self.log.debug("exec_cmd: %s", skytools.quote_statement(sql, args)) _pfx = "" if prefix: _pfx = "[%s] " % prefix curs.execute(sql, args) ok = True rows = curs.fetchall() for row in rows: try: code = row['ret_code'] msg = row['ret_note'] except KeyError: self.log.error("Query does not conform to exec_cmd API:") self.log.error("SQL: %s", skytools.quote_statement(sql, args)) self.log.error("Row: %s", repr(row.copy())) sys.exit(1) level = code // 100 if level == 1: self.log.debug("%s%d %s", _pfx, code, msg) elif level == 2: if quiet: self.log.debug("%s%d %s", _pfx, code, msg) else: self.log.info("%s%s", _pfx, msg) elif level == 3: self.log.warning("%s%s", _pfx, msg) else: self.log.error("%s%s", _pfx, msg) self.log.debug("Query was: %s", skytools.quote_statement(sql, args)) ok = False return (ok, rows) def _exec_cmd_many(self, curs, sql, baseargs, extra_list, quiet=False, prefix=None): """Internal tool: Run SQL on cursor multiple times.""" ok = True rows = [] for a in extra_list: (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet, prefix) if not tmp_ok: ok = False rows += tmp_rows return (ok, rows) def exec_cmd(self, db_or_curs, q, args, commit=True, quiet=False, prefix=None): """Run SQL on db with code/value error handling.""" if hasattr(db_or_curs, 'cursor'): db = db_or_curs curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd(curs, q, args, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def exec_cmd_many(self, db_or_curs, sql, baseargs, extra_list, commit=True, quiet=False, prefix=None): """Run SQL on db multiple times.""" if hasattr(db_or_curs, 'cursor'): db = db_or_curs curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def execute_with_retry(self, dbname, stmt, args, exceptions=None): """ Execute SQL and retry if it fails. Return number of retries and current valid cursor, or raise an exception. """ sql_retry = self.cf.getbool("sql_retry", False) sql_retry_max_count = self.cf.getint("sql_retry_max_count", 10) sql_retry_max_time = self.cf.getint("sql_retry_max_time", 300) sql_retry_formula_a = self.cf.getint("sql_retry_formula_a", 1) sql_retry_formula_b = self.cf.getint("sql_retry_formula_b", 5) sql_retry_formula_cap = self.cf.getint("sql_retry_formula_cap", 60) elist = exceptions or tuple() stime = time.time() tried = 0 dbc = None while True: try: if dbc is None: if dbname not in self.db_cache: self.get_database(dbname, autocommit=1) dbc = self.db_cache[dbname] if dbc.isolation_level != skytools.I_AUTOCOMMIT: raise skytools.UsageError("execute_with_retry: autocommit required") else: dbc.reset() curs = dbc.get_connection(dbc.isolation_level).cursor() curs.execute(stmt, args) break except elist as e: if not sql_retry or tried >= sql_retry_max_count or time.time() - stime >= sql_retry_max_time: raise self.log.info("Job %s got error on connection %s: %s", self.job_name, dbname, e) except BaseException: raise # y = a + bx , apply cap y = sql_retry_formula_a + sql_retry_formula_b * tried if sql_retry_formula_cap is not None and y > sql_retry_formula_cap: y = sql_retry_formula_cap tried += 1 self.log.info("Retry #%i in %i seconds ...", tried, y) self.sleep(y) return tried, curs def listen(self, dbname, channel): """Make connection listen for specific event channel. Listening will be activated on next .get_database() call. Basically this means that DBScript.sleep() will poll for events on that db connection, so when event appears, script will be woken up. """ if dbname not in self._listen_map: self._listen_map[dbname] = [] clist = self._listen_map[dbname] if channel not in clist: clist.append(channel) def unlisten(self, dbname, channel='*'): """Stop connection for listening on specific event channel. Listening will stop on next .get_database() call. """ if dbname not in self._listen_map: return if channel == '*': del self._listen_map[dbname] return clist = self._listen_map[dbname] try: clist.remove(channel) except ValueError: pass class DBCachedConn: """Cache a db connection.""" def __init__(self, name, loc, max_age=DEF_CONN_AGE, verbose=False, setup_func=None, channels=()): self.name = name self.loc = loc self.conn = None self.conn_time = 0 self.max_age = max_age self.isolation_level = -1 self.verbose = verbose self.setup_func = setup_func self.listen_channel_list = [] def fileno(self): if not self.conn: return None return self.conn.cursor().fileno() def get_connection(self, isolation_level=-1, listen_channel_list=()): # default isolation_level is READ COMMITTED if isolation_level < 0: isolation_level = skytools.I_READ_COMMITTED # new conn? if not self.conn: self.isolation_level = isolation_level self.conn = skytools.connect_database(self.loc) self.conn.set_isolation_level(isolation_level) self.conn_time = time.time() if self.setup_func: self.setup_func(self.name, self.conn) else: if self.isolation_level != isolation_level: raise Exception("Conflict in isolation_level") self._sync_listen(listen_channel_list) # done return self.conn def _sync_listen(self, new_clist): if not new_clist and not self.listen_channel_list: return curs = self.conn.cursor() for ch in self.listen_channel_list: if ch not in new_clist: curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) for ch in new_clist: if ch not in self.listen_channel_list: curs.execute("LISTEN %s" % skytools.quote_ident(ch)) if self.isolation_level != skytools.I_AUTOCOMMIT: self.conn.commit() self.listen_channel_list = new_clist[:] def refresh(self): if not self.conn: return #for row in self.conn.notifies(): # if row[0].lower() == "reload": # self.reset() # return if not self.max_age: return if time.time() - self.conn_time >= self.max_age: self.reset() def reset(self): if not self.conn: return # drop reference conn = self.conn self.conn = None self.listen_channel_list = [] # close try: conn.close() except BaseException: pass def check_connstr(self, connstr): """Drop connection if connect string has changed. """ if self.loc != connstr: self.reset() python-skytools-3.6.1/skytools/skylog.py000066400000000000000000000252151373460333300205120ustar00rootroot00000000000000"""Our log handlers for Python's logging package. """ import logging import logging.handlers import os import socket import time from logging import LoggerAdapter import skytools import skytools.tnetstrings __all__ = ['getLogger'] # add TRACE level TRACE = 5 logging.TRACE = TRACE logging.addLevelName(TRACE, 'TRACE') # extra info to be added to each log record _service_name = 'unknown_svc' _job_name = 'unknown_job' _hostname = socket.gethostname() try: _hostaddr = socket.gethostbyname(_hostname) except BaseException: _hostaddr = "0.0.0.0" _log_extra = { 'job_name': _job_name, 'service_name': _service_name, 'hostname': _hostname, 'hostaddr': _hostaddr, } def set_service_name(service_name, job_name): """Set info about current script.""" global _service_name, _job_name _service_name = service_name _job_name = job_name _log_extra['job_name'] = _job_name _log_extra['service_name'] = _service_name # Make extra fields available to all log records _old_factory = logging.getLogRecordFactory() def _new_factory(*args, **kwargs): record = _old_factory(*args, **kwargs) record.__dict__.update(_log_extra) return record logging.setLogRecordFactory(_new_factory) # configurable file logger class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler): """Easier setup for RotatingFileHandler.""" def __init__(self, filename, maxBytes=10 * 1024 * 1024, backupCount=3): """Args same as for RotatingFileHandler, but in filename '~' is expanded.""" fn = os.path.expanduser(filename) super().__init__(fn, maxBytes=maxBytes, backupCount=backupCount) # send JSON message over UDP class UdpLogServerHandler(logging.handlers.DatagramHandler): """Sends log records over UDP to logserver in JSON format.""" # map logging levels to logserver levels _level_map = { logging.DEBUG: 'DEBUG', logging.INFO: 'INFO', logging.WARNING: 'WARN', logging.ERROR: 'ERROR', logging.CRITICAL: 'FATAL', } # JSON message template _log_template = '{\n\t'\ '"logger": "skytools.UdpLogServer",\n\t'\ '"timestamp": %.0f,\n\t'\ '"level": "%s",\n\t'\ '"thread": null,\n\t'\ '"message": %s,\n\t'\ '"properties": {"application":"%s", "apptype": "%s", "type": "sys", "hostname":"%s", "hostaddr": "%s"}\n'\ '}\n' # cut longer msgs MAXMSG = 1024 def makePickle(self, record): """Create message in JSON format.""" # get & cut msg msg = self.format(record) if len(msg) > self.MAXMSG: msg = msg[:self.MAXMSG] txt_level = self._level_map.get(record.levelno, "ERROR") hostname = _hostname hostaddr = _hostaddr jobname = _job_name svcname = _service_name pkt = self._log_template % ( time.time() * 1000, txt_level, skytools.quote_json(msg), jobname, svcname, hostname, hostaddr ) return pkt def send(self, s): """Disable socket caching.""" sock = self.makeSocket() if not isinstance(s, bytes): s = s.encode('utf8') sock.sendto(s, (self.host, self.port)) sock.close() # send TNetStrings message over UDP class UdpTNetStringsHandler(logging.handlers.DatagramHandler): """ Sends log records in TNetStrings format over UDP. """ # LogRecord fields to send send_fields = [ 'created', 'exc_text', 'levelname', 'levelno', 'message', 'msecs', 'name', 'hostaddr', 'hostname', 'job_name', 'service_name'] _udp_reset = 0 def makePickle(self, record): """ Create message in TNetStrings format. """ msg = {} self.format(record) # render 'message' attribute and others for k in self.send_fields: msg[k] = record.__dict__[k] tnetstr = skytools.tnetstrings.dumps(msg) return tnetstr def send(self, s): """ Cache socket for a moment, then recreate it. """ now = time.time() if now - 1 > self._udp_reset: if self.sock: self.sock.close() self.sock = self.makeSocket() self._udp_reset = now self.sock.sendto(s, (self.host, self.port)) class LogDBHandler(logging.handlers.SocketHandler): """Sends log records into PostgreSQL server. Additionally, does some statistics aggregating, to avoid overloading log server. It subclasses SocketHandler to get throtthling for failed connections. """ # map codes to string _level_map = { logging.DEBUG: 'DEBUG', logging.INFO: 'INFO', logging.WARNING: 'WARNING', logging.ERROR: 'ERROR', logging.CRITICAL: 'FATAL', } def __init__(self, connect_string): """ Initializes the handler with a specific connection string. """ super().__init__(None, None) self.closeOnError = 1 self.connect_string = connect_string self.stat_cache = {} self.stat_flush_period = 60 # send first stat line immediately self.last_stat_flush = 0 def createSocket(self): try: super().createSocket() except BaseException: self.sock = self.makeSocket() def makeSocket(self, timeout=1): """Create server connection. In this case its not socket but database connection.""" db = skytools.connect_database(self.connect_string) db.set_isolation_level(0) # autocommit return db def emit(self, record): """Process log record.""" # we do not want log debug messages if record.levelno < logging.INFO: return try: self.process_rec(record) except (SystemExit, KeyboardInterrupt): raise except BaseException: self.handleError(record) def process_rec(self, record): """Aggregate stats if needed, and send to logdb.""" # render msg msg = self.format(record) # dont want to send stats too ofter if record.levelno == logging.INFO and msg and msg[0] == "{": self.aggregate_stats(msg) if time.time() - self.last_stat_flush >= self.stat_flush_period: self.flush_stats(_job_name) return if record.levelno < logging.INFO: self.flush_stats(_job_name) # dont send more than one line ln = msg.find('\n') if ln > 0: msg = msg[:ln] txt_level = self._level_map.get(record.levelno, "ERROR") self.send_to_logdb(_job_name, txt_level, msg) def aggregate_stats(self, msg): """Sum stats together, to lessen load on logdb.""" msg = msg[1:-1] for rec in msg.split(", "): k, v = rec.split(": ") agg = self.stat_cache.get(k, 0) if v.find('.') >= 0: agg += float(v) else: agg += int(v) self.stat_cache[k] = agg def flush_stats(self, service): """Send acquired stats to logdb.""" res = [] for k, v in self.stat_cache.items(): res.append("%s: %s" % (k, str(v))) if len(res) > 0: logmsg = "{%s}" % ", ".join(res) self.send_to_logdb(service, "INFO", logmsg) self.stat_cache = {} self.last_stat_flush = time.time() def send_to_logdb(self, service, level, msg): """Actual sending is done here.""" if self.sock is None: self.createSocket() if self.sock: logcur = self.sock.cursor() query = "select * from log.add(%s, %s, %s)" logcur.execute(query, [level, service, msg]) # fix unicode bug in SysLogHandler class SysLogHandler(logging.handlers.SysLogHandler): """Fixes unicode bug in logging.handlers.SysLogHandler.""" # be compatible with both 2.6 and 2.7 socktype = socket.SOCK_DGRAM _udp_reset = 0 def _custom_format(self, record): msg = self.format(record) + '\000' # We need to convert record level to lowercase, maybe this will # change in the future. prio = '<%d>' % self.encodePriority(self.facility, self.mapPriority(record.levelname)) msg = prio + msg return msg def emit(self, record): """ Emit a record. The record is formatted, and then sent to the syslog server. If exception information is present, it is NOT sent to the server. """ msg = self._custom_format(record) # Message is a string. Convert to bytes as required by RFC 5424 if not isinstance(msg, bytes): msg = msg.encode('utf-8') ## this puts BOM in wrong place #if codecs: # msg = codecs.BOM_UTF8 + msg try: if self.unixsocket: try: self.socket.send(msg) except socket.error: self._connect_unixsocket(self.address) self.socket.send(msg) elif self.socktype == socket.SOCK_DGRAM: now = time.time() if now - 1 > self._udp_reset: self.socket.close() self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._udp_reset = now self.socket.sendto(msg, self.address) else: self.socket.sendall(msg) except (KeyboardInterrupt, SystemExit): raise except BaseException: self.handleError(record) class SysLogHostnameHandler(SysLogHandler): """Slightly modified standard SysLogHandler - sends also hostname and service type""" def _custom_format(self, record): msg = self.format(record) format_string = '<%d> %s %s %s\000' msg = format_string % (self.encodePriority(self.facility, self.mapPriority(record.levelname)), _hostname, _service_name, msg) return msg # add missing aliases (that are in Logger class) if not hasattr(LoggerAdapter, 'fatal'): LoggerAdapter.fatal = LoggerAdapter.critical class SkyLogger(LoggerAdapter): """Adds API to existing Logger. """ def trace(self, msg, *args, **kwargs): """Log message with severity TRACE.""" self.log(TRACE, msg, *args, **kwargs) def getLogger(name=None, **kwargs_extra): """Get logger with extra functionality. Adds additional log levels, and extra fields to log record. name - name for logging.getLogger() kwargs_extra - extra fields to add to log record """ log = logging.getLogger(name) return SkyLogger(log, kwargs_extra) python-skytools-3.6.1/skytools/sockutil.py000066400000000000000000000067041373460333300210410ustar00rootroot00000000000000"""Various low-level utility functions for sockets.""" import os import socket import sys try: import fcntl except ImportError: fcntl = None __all__ = ( 'set_tcp_keepalive', 'set_nonblocking', 'set_cloexec', ) def set_tcp_keepalive(fd, keepalive=True, tcp_keepidle=4 * 60, tcp_keepcnt=4, tcp_keepintvl=15): """Turn on TCP keepalive. The fd can be either numeric or socket object with 'fileno' method. OS defaults for SO_KEEPALIVE=1: - Linux: (7200, 9, 75) - can configure all. - MacOS: (7200, 8, 75) - can configure only tcp_keepidle. - Win32: (7200, 5|10, 1) - can configure tcp_keepidle and tcp_keepintvl. Our defaults: (240, 4, 15). """ # usable on this OS? if not hasattr(socket, 'SO_KEEPALIVE') or not hasattr(socket, 'fromfd'): return # need socket object if isinstance(fd, socket.SocketType): s = fd else: if hasattr(fd, 'fileno'): fd = fd.fileno() s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) # skip if unix socket if getattr(socket, 'AF_UNIX', None): if not isinstance(s.getsockname(), tuple): return # no keepalive? if not keepalive: s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0) return # basic keepalive s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # detect available options TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None) TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None) TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None) TCP_KEEPALIVE = getattr(socket, 'TCP_KEEPALIVE', None) SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None) if TCP_KEEPIDLE is None and TCP_KEEPALIVE is None and sys.platform == 'darwin': TCP_KEEPALIVE = 0x10 # configure if TCP_KEEPCNT is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPCNT, tcp_keepcnt) if TCP_KEEPINTVL is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPINTVL, tcp_keepintvl) if TCP_KEEPIDLE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPIDLE, tcp_keepidle) elif TCP_KEEPALIVE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, tcp_keepidle) elif SIO_KEEPALIVE_VALS is not None and fcntl: fcntl.ioctl(s.fileno(), SIO_KEEPALIVE_VALS, (1, tcp_keepidle * 1000, tcp_keepintvl * 1000)) def set_nonblocking(fd, onoff=True): """Toggle the O_NONBLOCK flag. If onoff==None then return current setting. Actual sockets from 'socket' module should use .setblocking() method, this is for situations where it is not available. Eg. pipes from 'subprocess' module. """ if fcntl is None: return onoff flags = fcntl.fcntl(fd, fcntl.F_GETFL) if onoff is None: return (flags & os.O_NONBLOCK) > 0 if onoff: flags |= os.O_NONBLOCK else: flags &= ~os.O_NONBLOCK fcntl.fcntl(fd, fcntl.F_SETFL, flags) return onoff def set_cloexec(fd, onoff=True): """Toggle the FD_CLOEXEC flag. If onoff==None then return current setting. Some libraries do it automatically (eg. libpq). Others do not (Python stdlib). """ if fcntl is None: return True flags = fcntl.fcntl(fd, fcntl.F_GETFD) if onoff is None: return (flags & fcntl.FD_CLOEXEC) > 0 if onoff: flags |= fcntl.FD_CLOEXEC else: flags &= ~fcntl.FD_CLOEXEC fcntl.fcntl(fd, fcntl.F_SETFD, flags) return onoff python-skytools-3.6.1/skytools/sqltools.py000066400000000000000000000426611373460333300210660ustar00rootroot00000000000000"""Database tools. """ import io import os import skytools __all__ = ( "fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys", "get_table_columns", "exists_schema", "exists_table", "exists_type", "exists_sequence", "exists_temp_table", "exists_view", "exists_function", "exists_language", "Snapshot", "magic_insert", "CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction", "DBLanguage", "db_install", "installer_find_file", "installer_apply_file", "dbdict", "mk_insert_sql", "mk_update_sql", "mk_delete_sql", ) class dbdict(dict): """Wrapper on actual dict that allows accessing dict keys as attributes. """ # obj.foo access def __getattr__(self, k): "Return attribute." try: return self[k] except KeyError: raise AttributeError(k) def __setattr__(self, k, v): "Set attribute." self[k] = v def __delattr__(self, k): "Remove attribute." del self[k] def merge(self, other): for key in other: if key not in self: self[key] = other[key] # # Fully qualified table name # def fq_name_parts(tbl): """Return fully qualified name parts. """ tmp = tbl.split('.', 1) if len(tmp) == 1: return ['public', tbl] return tmp def fq_name(tbl): """Return fully qualified name. """ return '.'.join(fq_name_parts(tbl)) # # info about table # def get_table_oid(curs, table_name): """Find Postgres OID for table.""" schema, name = fq_name_parts(table_name) q = """select c.oid from pg_namespace n, pg_class c where c.relnamespace = n.oid and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchall() if len(res) == 0: raise Exception('Table not found: ' + table_name) return res[0][0] def get_table_pkeys(curs, tbl): """Return list of pkey column names.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] def get_table_columns(curs, tbl): """Return list of column names for table.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_attribute k"\ " WHERE k.attrelid = %s"\ " AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] # # exist checks # def exists_schema(curs, schema): """Does schema exists?""" q = "select count(1) from pg_namespace where nspname = %s" curs.execute(q, [schema]) res = curs.fetchone() return res[0] def exists_table(curs, table_name): """Does table exists?""" schema, name = fq_name_parts(table_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'r' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_sequence(curs, seq_name): """Does sequence exists?""" schema, name = fq_name_parts(seq_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'S' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_view(curs, view_name): """Does view exists?""" schema, name = fq_name_parts(view_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'v' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_type(curs, type_name): """Does type exists?""" schema, name = fq_name_parts(type_name) q = """select count(1) from pg_namespace n, pg_type t where t.typnamespace = n.oid and n.nspname = %s and t.typname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return res[0] def exists_function(curs, function_name, nargs): """Does function exists?""" # this does not check arg types, so may match several functions schema, name = fq_name_parts(function_name) q = """select count(1) from pg_namespace n, pg_proc p where p.pronamespace = n.oid and p.pronargs = %s and n.nspname = %s and p.proname = %s""" curs.execute(q, [nargs, schema, name]) res = curs.fetchone() # if unqualified function, check builtin functions too if not res[0] and function_name.find('.') < 0: name = "pg_catalog." + function_name return exists_function(curs, name, nargs) return res[0] def exists_language(curs, lang_name): """Does PL exists?""" q = """select count(1) from pg_language where lanname = %s""" curs.execute(q, [lang_name]) res = curs.fetchone() return res[0] def exists_temp_table(curs, tbl): """Does temp table exists?""" # correct way, works only on 8.2 q = "select 1 from pg_class where relname = %s and relnamespace = pg_my_temp_schema()" curs.execute(q, [tbl]) tmp = curs.fetchall() return len(tmp) > 0 # # Support for PostgreSQL snapshot # class Snapshot: """Represents a PostgreSQL snapshot. """ def __init__(self, str_val): "Create snapshot from string." self.sn_str = str_val tmp = str_val.split(':') if len(tmp) != 3: raise ValueError('Unknown format for snapshot') self.xmin = int(tmp[0]) self.xmax = int(tmp[1]) self.txid_list = [] if tmp[2] != "": for s in tmp[2].split(','): self.txid_list.append(int(s)) def contains(self, txid): "Is txid visible in snapshot." txid = int(txid) if txid < self.xmin: return True if txid >= self.xmax: return False if txid in self.txid_list: return False return True # # Copy helpers # def _gen_dict_copy(tbl, row, fields, qfields): tmp = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_dict_insert(tbl, row, fields, qfields): tmp = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def _gen_list_copy(tbl, row, fields, qfields): tmp = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_list_insert(tbl, row, fields, qfields): tmp = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def magic_insert(curs, tablename, data, fields=None, use_insert=False, quoted_table=False): r"""Copy/insert a list of dict/list data to database. If curs is None, then the copy or insert statements are returned as string. For list of dict the field list is optional, as its possible to guess them from dict keys. """ if len(data) == 0: return None if fields is not None: fields = list(fields) # get rid of iterator # decide how to process if hasattr(data[0], 'keys'): if fields is None: fields = data[0].keys() if use_insert: row_func = _gen_dict_insert else: row_func = _gen_dict_copy else: if fields is None: raise Exception("Non-dict data needs field list") if use_insert: row_func = _gen_list_insert else: row_func = _gen_list_copy qfields = [skytools.quote_ident(f) for f in fields] if quoted_table: qtablename = tablename else: qtablename = skytools.quote_fqident(tablename) # init processing buf = io.StringIO() if curs is None and use_insert == 0: fmt = "COPY %s (%s) FROM STDIN;\n" buf.write(fmt % (qtablename, ",".join(qfields))) # process data for row in data: buf.write(row_func(qtablename, row, fields, qfields)) buf.write("\n") # if user needs only string, return it if curs is None: if use_insert == 0: buf.write("\\.\n") return buf.getvalue() # do the actual copy/inserts if use_insert: curs.execute(buf.getvalue()) else: buf.seek(0) hdr = "%s (%s)" % (qtablename, ",".join(qfields)) curs.copy_from(buf, hdr) return None # # Full COPY of table from one db to another # class CopyPipe(io.TextIOBase): """Splits one big COPY to chunks. """ def __init__(self, dstcurs, tablename=None, limit=512 * 1024, sql_from=None): super(CopyPipe, self).__init__() self.tablename = tablename self.sql_from = sql_from self.dstcurs = dstcurs self.buf = io.StringIO() self.limit = limit #hook for new data, hook func should return new data #def write_hook(obj, data): # return data self.write_hook = None #hook for flush, hook func result is discarded # def flush_hook(obj): # return None self.flush_hook = None self.total_rows = 0 self.total_bytes = 0 def write(self, data): """New row from psycopg """ if self.write_hook: data = self.write_hook(self, data) self.total_bytes += len(data) # it's chars now... self.total_rows += 1 self.buf.write(data) if self.buf.tell() >= self.limit: self.flush() def flush(self): """Send data out. """ if self.flush_hook: self.flush_hook(self) if self.buf.tell() <= 0: return self.buf.seek(0) if self.sql_from: self.dstcurs.copy_expert(self.sql_from, self.buf) else: self.dstcurs.copy_from(self.buf, self.tablename) self.buf.seek(0) self.buf.truncate() def full_copy(tablename, src_curs, dst_curs, column_list=(), condition=None, dst_tablename=None, dst_column_list=None, write_hook=None, flush_hook=None): """COPY table from one db to another.""" # default dst table and dst columns to source ones dst_tablename = dst_tablename or tablename dst_column_list = dst_column_list or column_list[:] if len(dst_column_list) != len(column_list): raise Exception('src and dst column lists must match in length') def build_qfields(cols): if cols: return ",".join([skytools.quote_ident(f) for f in cols]) else: return "*" def build_statement(table, cols): qtable = skytools.quote_fqident(table) if cols: qfields = build_qfields(cols) return "%s (%s)" % (qtable, qfields) else: return qtable dst = build_statement(dst_tablename, dst_column_list) if condition: src = "(SELECT %s FROM %s WHERE %s)" % (build_qfields(column_list), skytools.quote_fqident(tablename), condition) else: src = build_statement(tablename, column_list) sql_to = "COPY %s TO stdout" % src sql_from = "COPY %s FROM stdin" % dst buf = CopyPipe(dst_curs, sql_from=sql_from) buf.write_hook = write_hook buf.flush_hook = flush_hook src_curs.copy_expert(sql_to, buf) buf.flush() return (buf.total_bytes, buf.total_rows) # # SQL installer # class DBObject: """Base class for installable DB objects.""" name = None sql = None sql_file = None def __init__(self, name, sql=None, sql_file=None): """Generic dbobject init.""" self.name = name self.sql = sql self.sql_file = sql_file def create(self, curs, log=None): """Create a dbobject.""" if log: log.info('Installing %s' % self.name) if self.sql: sql = self.sql elif self.sql_file: fn = self.find_file() if log: log.info(" Reading from %s" % fn) with open(fn, "r") as f: sql = f.read() else: raise Exception('object not defined') for stmt in skytools.parse_statements(sql): #if log: log.debug(repr(stmt)) curs.execute(stmt) def find_file(self): """Find install script file.""" return installer_find_file(self.sql_file) class DBSchema(DBObject): """Handles db schema.""" def exists(self, curs): """Does schema exists.""" return exists_schema(curs, self.name) class DBTable(DBObject): """Handles db table.""" def exists(self, curs): """Does table exists.""" return exists_table(curs, self.name) class DBFunction(DBObject): """Handles db function.""" def __init__(self, name, nargs, sql=None, sql_file=None): """Function object - number of args is significant.""" super(DBFunction, self).__init__(name, sql, sql_file) self.nargs = nargs def exists(self, curs): """Does function exists.""" return exists_function(curs, self.name, self.nargs) class DBLanguage(DBObject): """Handles db language.""" def __init__(self, name): """PL object - creation happens with CREATE LANGUAGE.""" super(DBLanguage, self).__init__(name, sql="create language %s" % name) def exists(self, curs): """Does PL exists.""" return exists_language(curs, self.name) def db_install(curs, obj_list, log=None): """Installs list of objects into db.""" for obj in obj_list: if not obj.exists(curs): obj.create(curs, log) else: if log: log.info('%s is installed' % obj.name) def installer_find_file(filename): """Find SQL script from pre-defined paths.""" full_fn = None if filename[0] == "/": if os.path.isfile(filename): full_fn = filename else: from skytools.installer_config import sql_locations dir_list = sql_locations for fdir in dir_list: fn = os.path.join(fdir, filename) if os.path.isfile(fn): full_fn = fn break if not full_fn: raise Exception('File not found: ' + filename) return full_fn def installer_apply_file(db, filename, log): """Find SQL file and apply it to db, statement-by-statement.""" fn = installer_find_file(filename) with open(fn, "r") as f: sql = f.read() if log: log.info("applying %s" % fn) curs = db.cursor() for stmt in skytools.parse_statements(sql): #log.debug(repr(stmt)) curs.execute(stmt) # # Generate INSERT/UPDATE/DELETE statement # def mk_insert_sql(row, tbl, pkey_list=None, field_map=None): """Generate INSERT statement from dict data. """ col_list = [] val_list = [] if field_map: for src, dst in field_map.items(): col_list.append(skytools.quote_ident(dst)) val_list.append(skytools.quote_literal(row[src])) else: for c, v in row.items(): col_list.append(skytools.quote_ident(c)) val_list.append(skytools.quote_literal(v)) col_str = ", ".join(col_list) val_str = ", ".join(val_list) return "insert into %s (%s) values (%s);" % ( skytools.quote_fqident(tbl), col_str, val_str) def mk_update_sql(row, tbl, pkey_list, field_map=None): """Generate UPDATE statement from dict data. """ if len(pkey_list) < 1: raise Exception("update needs pkeys") set_list = [] whe_list = [] pkmap = {} for k in pkey_list: pkmap[k] = 1 new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) if field_map: for src, dst in field_map.items(): if src not in pkmap: col = skytools.quote_ident(dst) val = skytools.quote_literal(row[src]) set_list.append("%s = %s" % (col, val)) else: for col, val in row.items(): if col not in pkmap: col = skytools.quote_ident(col) val = skytools.quote_literal(val) set_list.append("%s = %s" % (col, val)) return "update only %s set %s where %s;" % (skytools.quote_fqident(tbl), ", ".join(set_list), " and ".join(whe_list)) def mk_delete_sql(row, tbl, pkey_list, field_map=None): """Generate DELETE statement from dict data. """ if len(pkey_list) < 1: raise Exception("delete needs pkeys") whe_list = [] for k in pkey_list: new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) whe_str = " and ".join(whe_list) return "delete from only %s where %s;" % (skytools.quote_fqident(tbl), whe_str) python-skytools-3.6.1/skytools/timeutil.py000066400000000000000000000065761373460333300210470ustar00rootroot00000000000000"""Fill gaps in Python time API-s. parse_iso_timestamp: Parse reasonable subset of ISO_8601 timestamp formats. [ http://en.wikipedia.org/wiki/ISO_8601 ] datetime_to_timestamp: Get POSIX timestamp from datetime() object. """ import re import time from datetime import datetime, timedelta, tzinfo __all__ = ( 'parse_iso_timestamp', 'FixedOffsetTimezone', 'datetime_to_timestamp', ) class FixedOffsetTimezone(tzinfo): """Fixed offset in minutes east from UTC.""" __slots__ = ('__offset', '__name') def __init__(self, offset): super(FixedOffsetTimezone, self).__init__() self.__offset = timedelta(minutes=offset) # numeric tz name h, m = divmod(abs(offset), 60) if offset < 0: h = -h if m: self.__name = "%+03d:%02d" % (h, m) else: self.__name = "%+03d" % h def utcoffset(self, dt): return self.__offset def tzname(self, dt): return self.__name def dst(self, dt): return ZERO ZERO = timedelta(0) # # Parse ISO_8601 timestamps. # """ TODO: - support more combinations from ISO 8601 (only reasonable ones) - cache TZ objects - make it faster? """ _iso_regex = r""" \s* (?P \d\d\d\d) [-] (?P \d\d) [-] (?P \d\d) [ T] (?P \d\d) [:] (?P \d\d) (?: [:] (?P \d\d ) (?: [.,] (?P \d+))? )? (?: \s* (?P [-+]) (?P \d\d) (?: [:]? (?P \d\d))? | (?P Z ) )? \s* $ """ _iso_rc = None def parse_iso_timestamp(s, default_tz=None): """Parse ISO timestamp to datetime object. YYYY-MM-DD[ T]HH:MM[:SS[.ss]][-+HH[:MM]] Assumes that second fractions are zero-trimmed from the end, so '.15' means 150000 microseconds. If the timezone offset is not present, use default_tz as tzinfo. By default its None, meaning the datetime object will be without tz. Only fixed offset timezones are supported. """ global _iso_rc if _iso_rc is None: _iso_rc = re.compile(_iso_regex, re.X) m = _iso_rc.match(s) if not m: raise ValueError('Date not in ISO format: %s' % repr(s)) tz = default_tz if m.group('tzsign'): tzofs = int(m.group('tzhr')) * 60 if m.group('tzmin'): tzofs += int(m.group('tzmin')) if m.group('tzsign') == '-': tzofs = -tzofs tz = FixedOffsetTimezone(tzofs) elif m.group('tzname'): tz = UTC return datetime( int(m.group('year')), int(m.group('month')), int(m.group('day')), int(m.group('hour')), int(m.group('min')), m.group('sec') and int(m.group('sec')) or 0, m.group('ss') and int(m.group('ss').ljust(6, '0')) or 0, tz ) # # POSIX timestamp from datetime() # UTC = FixedOffsetTimezone(0) TZ_EPOCH = datetime.fromtimestamp(0, UTC) UTC_NOTZ_EPOCH = datetime.utcfromtimestamp(0) def datetime_to_timestamp(dt, local_time=True): """Get posix timestamp from datetime() object. if dt is without timezone, then local_time specifies whether it's UTC or local time. Returns seconds since epoch as float. """ if dt.tzinfo: delta = dt - TZ_EPOCH return delta.total_seconds() elif local_time: s = time.mktime(dt.timetuple()) return s + (dt.microsecond / 1000000.0) else: delta = dt - UTC_NOTZ_EPOCH return delta.total_seconds() python-skytools-3.6.1/skytools/tnetstrings.py000066400000000000000000000066421373460333300215710ustar00rootroot00000000000000"""TNetStrings. """ import codecs __all__ = ['loads', 'dumps'] _memstr_types = (str, bytes, memoryview) _struct_types = (list, tuple, dict) _inttypes = (int,) _decode_utf8 = codecs.getdecoder('utf8') def _dumps(dst, val): if isinstance(val, _struct_types): tlenpos = len(dst) tlen = 0 dst.append(None) if isinstance(val, dict): for k in val: tlen += _dumps(dst, k) tlen += _dumps(dst, val[k]) dst.append(b'}') else: for v in val: tlen += _dumps(dst, v) dst.append(b']') dst[tlenpos] = b'%d:' % tlen return len(dst[tlenpos]) + tlen + 1 elif isinstance(val, _memstr_types): if isinstance(val, str): bval = val.encode('utf8') elif isinstance(val, bytes): bval = val else: bval = memoryview(val).tobytes() tval = b'%d:%s,' % (len(bval), bval) elif isinstance(val, bool): tval = val and b'4:true!' or b'5:false!' elif isinstance(val, _inttypes): bval = b'%d' % val tval = b'%d:%s#' % (len(bval), bval) elif isinstance(val, float): bval = b'%r' % val tval = b'%d:%s^' % (len(bval), bval) elif val is None: tval = b'0:~' else: raise TypeError("Object type not supported: %r" % val) dst.append(tval) return len(tval) def _loads(buf): pos = 0 maxlen = min(len(buf), 9) while buf[pos:pos + 1] != b':': pos += 1 if pos > maxlen: raise ValueError("Too large length") lenbytes = buf[: pos].tobytes() tlen = int(lenbytes) ofs = len(lenbytes) + 1 endofs = ofs + tlen val = buf[ofs: endofs] code = buf[endofs: endofs + 1] rest = buf[endofs + 1:] if len(val) + 1 != tlen + len(code): raise ValueError("failed to load value, invalid length") if code == b',': return _decode_utf8(val)[0], rest elif code == b'#': return int(val.tobytes(), 10), rest elif code == b'^': return float(val.tobytes()), rest elif code == b']': listobj = [] while val: elem, val = _loads(val) listobj.append(elem) return listobj, rest elif code == b'}': dictobj = {} while val: k, val = _loads(val) if not isinstance(k, str): raise ValueError("failed to load value, invalid key type") dictobj[k], val = _loads(val) return dictobj, rest elif code == b'!': if val == b'true': return True, rest if val == b'false': return False, rest raise ValueError("failed to load value, invalid boolean value") elif code == b'~': if val == b'': return None, rest raise ValueError("failed to load value, invalid null value") else: raise ValueError("failed to load value, invalid value code") # # Public API # def dumps(val): """Dump object tree as TNetString value. """ dst = [] _dumps(dst, val) return b''.join(dst) def loads(binval): """Parse TNetstring from byte string. """ if not isinstance(binval, (bytes, memoryview)): raise TypeError("Bytes or memoryview required") obj, rest = _loads(memoryview(binval)) if rest: raise ValueError("Not all data processed") return obj # old compat? parse = loads dump = dumps python-skytools-3.6.1/skytools/utf8.py000066400000000000000000000042161373460333300200660ustar00rootroot00000000000000r"""UTF-8 sanitizer. Python's UTF-8 parser is quite relaxed, this creates problems when talking with other software that uses stricter parsers. """ import codecs import re __all__ = ('safe_utf8_decode', 'sanitize_unicode') # by default, use same symbol as 'replace' REPLACEMENT_SYMBOL = chr(0xFFFD) # 65533 _urc = None def _fix_utf8(m): """Merge UTF16 surrogates, replace others""" u = m.group() if len(u) == 2: # merge into single symbol c1 = ord(u[0]) c2 = ord(u[1]) c = 0x10000 + ((c1 & 0x3FF) << 10) + (c2 & 0x3FF) return chr(c) else: # use replacement symbol return REPLACEMENT_SYMBOL def sanitize_unicode(u): """Fix invalid symbols in unicode string.""" global _urc if not isinstance(u, str): raise TypeError('Need unicode string') # regex for finding invalid chars, works on unicode string if not _urc: rx = u"[\uD800-\uDBFF] [\uDC00-\uDFFF]? | [\0\uDC00-\uDFFF]" _urc = re.compile(rx, re.X) # now find and fix UTF16 surrogates m = _urc.search(u) if m: u = _urc.sub(_fix_utf8, u) return u def safe_replace(exc): """Replace only one symbol at a time. Builtin .decode('xxx', 'replace') replaces several symbols together, which is unsafe. """ c2 = REPLACEMENT_SYMBOL # we could assume latin1 #if 0: # c1 = exc.object[exc.start] # c2 = chr(ord(c1)) return c2, exc.start + 1 # register, it will be globally available codecs.register_error("safe_replace", safe_replace) def safe_utf8_decode(s): """Decode UTF-8 safely. Acts like str.decode('utf8', 'replace') but also fixes UTF16 surrogates and NUL bytes, which Python's default decoder does not do. @param s: utf8-encoded byte string @return: tuple of (was_valid_utf8, unicode_string) """ # decode with error detection ok = True try: # expect no errors by default u = s.decode('utf8') except UnicodeDecodeError: u = s.decode('utf8', 'safe_replace') ok = False u2 = sanitize_unicode(u) if u is not u2: ok = False return (ok, u2) python-skytools-3.6.1/tests/000077500000000000000000000000001373460333300160765ustar00rootroot00000000000000python-skytools-3.6.1/tests/config.ini000066400000000000000000000007561373460333300200540ustar00rootroot00000000000000[base] foo = 1 bar = %(foo)s bool-true1 = 1 bool-true2 = true bool-false1 = 0 bool-false2 = false float-val = 2.0 list-val1 = list-val2 = a, 1, asd, ppp dict-val1 = dict-val2 = a : 1, b : 2, z file-val1 = - file-val2 = ~/foo bytes-val1 = 4 bytes-val2 = 2k wild-*-* = w2 wild-a-* = w.a wild-a-b = w.a.b vars1 = V2=%(vars2)s vars2 = V3=%(vars3)s vars3 = Q3 bad1 = B2=%(bad2)s bad2 = %(missing1)s %(missing2)s [DEFAULT] all = yes [other] test = try [testscript] opt = test db = python-skytools-3.6.1/tests/test_api.py000066400000000000000000000002211373460333300202530ustar00rootroot00000000000000 import skytools def test_version(): a = skytools.natsort_key(skytools.__version__) b = skytools.natsort_key('3.3') assert a >= b python-skytools-3.6.1/tests/test_config.py000066400000000000000000000126161373460333300207620ustar00rootroot00000000000000 import io import os.path import sys import pytest from skytools.config import ( Config, ConfigError, ExtendedCompatConfigParser, InterpolationError, NoOptionError, NoSectionError, ) TOP = os.path.dirname(__file__) CONFIG = os.path.join(TOP, 'config.ini') def test_config_str(): cf = Config('base', CONFIG) assert cf.get('foo') == '1' assert cf.get('missing', 'q') == 'q' with pytest.raises(NoOptionError): cf.get('missing') def test_config_int(): cf = Config('base', CONFIG) assert cf.getint('foo') == 1 assert cf.getint('missing', 2) == 2 with pytest.raises(NoOptionError): cf.getint('missing') def test_config_float(): cf = Config('base', CONFIG) assert cf.getfloat('float-val') == 2.0 assert cf.getfloat('missing', 3.0) == 3.0 with pytest.raises(NoOptionError): cf.getfloat('missing') def test_config_bool(): cf = Config('base', CONFIG) assert cf.getboolean('bool-true1') == True assert cf.getboolean('bool-true2') == True assert cf.getboolean('missing', True) == True with pytest.raises(NoOptionError): cf.getboolean('missing') assert cf.getboolean('bool-false1') == False assert cf.getboolean('bool-false2') == False assert cf.getboolean('missing', False) == False with pytest.raises(NoOptionError): cf.getbool('missing') def test_config_list(): cf = Config('base', CONFIG) assert cf.getlist('list-val1') == [] assert cf.getlist('list-val2'), ['a', '1', 'asd' == 'ppp'] assert cf.getlist('missing', [1]) == [1] with pytest.raises(NoOptionError): cf.getlist('missing') def test_config_dict(): cf = Config('base', CONFIG) assert cf.getdict('dict-val1') == {} assert cf.getdict('dict-val2') == {'a': '1', 'b': '2', 'z': 'z'} assert cf.getdict('missing', {'a': 1}) == {'a': 1} with pytest.raises(NoOptionError): cf.getdict('missing') def test_config_file(): cf = Config('base', CONFIG) assert cf.getfile('file-val1') == '-' if sys.platform != 'win32': assert cf.getfile('file-val2')[0] == '/' assert cf.getfile('missing', 'qwe') == 'qwe' with pytest.raises(NoOptionError): cf.getfile('missing') def test_config_bytes(): cf = Config('base', CONFIG) assert cf.getbytes('bytes-val1') == 4 assert cf.getbytes('bytes-val2') == 2048 assert cf.getbytes('missing', '3k') == 3072 with pytest.raises(NoOptionError): cf.getbytes('missing') def test_config_wildcard(): cf = Config('base', CONFIG) assert cf.get_wildcard('wild-*-*', ['a', 'b']) == 'w.a.b' assert cf.get_wildcard('wild-*-*', ['a', 'x']) == 'w.a' assert cf.get_wildcard('wild-*-*', ['q', 'b']) == 'w2' assert cf.get_wildcard('missing-*-*', ['1', '2'], 'def') == 'def' with pytest.raises(NoOptionError): cf.get_wildcard('missing-*-*', ['1', '2']) def test_config_default(): cf = Config('base', CONFIG) assert cf.get('all') == 'yes' def test_config_other(): cf = Config('base', CONFIG) assert sorted(cf.sections()), ['base' == 'other'] assert cf.has_section('base') == True assert cf.has_section('other') == True assert cf.has_section('missing') == False assert cf.has_section('DEFAULT') == False assert cf.has_option('missing') == False assert cf.has_option('all') == True assert cf.has_option('foo') == True cf2 = cf.clone('other') opts = list(sorted(cf2.options())) assert opts == [ 'all', 'config_dir', 'config_file', 'host_name', 'job_name', 'service_name', 'test' ] assert len(cf2.items()) == len(cf2.options()) def test_loading(): with pytest.raises(NoSectionError): Config('random', CONFIG) with pytest.raises(NoSectionError): Config('random', CONFIG) with pytest.raises(ConfigError): Config('random', 'random.ini') def test_nofile(): cf = Config('base', None, user_defs={'a': '1'}) assert cf.sections() == ['base'] assert cf.get('a') == '1' cf = Config('base', None, user_defs={'a': '1'}, ignore_defs=True) assert cf.get('a', '2') == '2' def test_override(): cf = Config('base', CONFIG, override={'foo': 'overrided'}) assert cf.get('foo') == 'overrided' def test_vars(): cf = Config('base', CONFIG) assert cf.get('vars1') == 'V2=V3=Q3' with pytest.raises(InterpolationError): cf.get('bad1') def test_extended_compat(): config = u'[foo]\nkey = ${sub} $${nosub}\nsub = 2\n[bar]\nkey = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') assert cf.get('bar', 'key') == '2 ${nosub}' config = u'[foo]\nloop1= ${loop1}\nloop2 = ${loop3}\nloop3 = ${loop2}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') with pytest.raises(InterpolationError): cf.get('foo', 'loop1') with pytest.raises(InterpolationError): cf.get('foo', 'loop2') config = u'[foo]\nkey = %(sub)s ${sub}\nsub = 2\n[bar]\nkey = %(foo:key)s\nkey2 = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') assert cf.get('bar', 'key2') == '2 2' with pytest.raises(NoOptionError): cf.get('bar', 'key') config = u'[foo]\nkey = ${bad:xxx}\n[bad]\nsub = 1\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') with pytest.raises(NoOptionError): cf.get('foo', 'key') python-skytools-3.6.1/tests/test_dbservice.py000066400000000000000000000010511373460333300214520ustar00rootroot00000000000000 from skytools.dbservice import transform_fields def test_transform_fields(): rows = [] rows.append({'time': '22:00', 'metric': 'count', 'value': 100}) rows.append({'time': '22:00', 'metric': 'dur', 'value': 7}) rows.append({'time': '23:00', 'metric': 'count', 'value': 200}) rows.append({'time': '23:00', 'metric': 'dur', 'value': 5}) res = list(transform_fields(rows, ['time'], 'metric', 'value')) assert res[0] == {'count': 100, 'dur': 7, 'time': '22:00'} assert res[1] == {'count': 200, 'dur': 5, 'time': '23:00'} python-skytools-3.6.1/tests/test_fileutil.py000066400000000000000000000005701373460333300213260ustar00rootroot00000000000000 import os import tempfile from skytools.fileutil import write_atomic def test_write_atomic(): pidfn = tempfile.mktemp('.pid') write_atomic(pidfn, "1") write_atomic(pidfn, "2") os.remove(pidfn) def test_write_atomic_bak(): pidfn = tempfile.mktemp('.pid') write_atomic(pidfn, "1", '.bak') write_atomic(pidfn, "2", '.bak') os.remove(pidfn) python-skytools-3.6.1/tests/test_gzlog.py000066400000000000000000000011651373460333300206340ustar00rootroot00000000000000 import gzip import os import tempfile from skytools.gzlog import gzip_append def test_gzlog(): fd, tmpname = tempfile.mkstemp(suffix='.gz') os.close(fd) try: blk = b'1234567890' * 100 write_total = 0 for i in range(5): gzip_append(tmpname, blk) write_total += len(blk) read_total = 0 with gzip.open(tmpname) as rfd: while True: blk = rfd.read(512) if not blk: break read_total += len(blk) finally: os.remove(tmpname) assert read_total == write_total python-skytools-3.6.1/tests/test_hashtext.py000066400000000000000000000046211373460333300213420ustar00rootroot00000000000000from skytools.hashtext import ( hashtext_new, hashtext_new_py, hashtext_old, hashtext_old_py, ) def test_hashtext_new_const(): c0 = [hashtext_new_py(b'x' * (0 * 5 + j)) for j in range(5)] c1 = [hashtext_new_py(b'x' * (1 * 5 + j)) for j in range(5)] c2 = [hashtext_new_py(b'x' * (2 * 5 + j)) for j in range(5)] assert c0 == [-1477818771, 1074944137, -1086392228, -1992236649, -1379736791] assert c1 == [-370454118, 1489915569, -66683019, -2126973000, 1651296771] assert c2 == [755764456, -1494243903, 631527812, 28686851, -9498641] def test_hashtext_old_const(): c0 = [hashtext_old_py(b'x' * (0 * 5 + j)) for j in range(5)] c1 = [hashtext_old_py(b'x' * (1 * 5 + j)) for j in range(5)] c2 = [hashtext_old_py(b'x' * (2 * 5 + j)) for j in range(5)] assert c0 == [-863449762, 37835117, 294739542, -320432768, 1007638138] assert c1 == [1422906842, -261065348, 59863994, -162804943, 1736144510] assert c2 == [-682756517, 317827663, -495599455, -1411793989, 1739997714] def test_hashtext_new_impl(): data = b'HypficUjFitraxlumCitcemkiOkIkthi' p = [hashtext_new_py(data[:l]) for l in range(len(data) + 1)] c = [hashtext_new(data[:l]) for l in range(len(data) + 1)] assert p == c, '%s <> %s' % (p, c) def test_hashtext_old_impl(): data = b'HypficUjFitraxlumCitcemkiOkIkthi' p = [hashtext_old_py(data[:l]) for l in range(len(data) + 1)] c = [hashtext_old(data[:l]) for l in range(len(data) + 1)] assert p == c, '%s <> %s' % (p, c) def test_hashtext_new_input_types(): data = b'HypficUjFitraxlumCitcemkiOkIkthi' exp = hashtext_new(data) assert hashtext_new(data.decode("utf8")) == exp #assert hashtext_new(memoryview(data)) == exp #assert hashtext_new(bytearray(data)) == exp assert hashtext_new_py(data) == exp assert hashtext_new_py(data.decode("utf8")) == exp #assert hashtext_new_py(memoryview(data)) == exp #assert hashtext_new_py(bytearray(data)) == exp def test_hashtext_old_input_types(): data = b'HypficUjFitraxlumCitcemkiOkIkthi' exp = hashtext_old(data) assert hashtext_old(data.decode("utf8")) == exp #assert hashtext_old(memoryview(data)) == exp #assert hashtext_old(bytearray(data)) == exp assert hashtext_old_py(data) == exp assert hashtext_old_py(data.decode("utf8")) == exp #assert hashtext_old_py(memoryview(data)) == exp #assert hashtext_old_py(bytearray(data)) == exp python-skytools-3.6.1/tests/test_kwcheck.py000066400000000000000000000040761373460333300211350ustar00rootroot00000000000000"""Check if SQL keywords are up-to-date. """ import os.path import re import skytools.quoting versions = [ "94", "95", "96", "9.4", "9.5", "9.6", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19" ] locations = [ "/usr/include/postgresql/{VER}/server/parser/kwlist.h", "~/src/pgsql/pg{VER}/src/include/parser/kwlist.h", "~/src/pgsql/postgresql/src/include/parser/kwlist.h", ] def _load_kwlist(fn, full_map, cur_map): fn = os.path.expanduser(fn) if not os.path.isfile(fn): return with open(fn, 'rt') as f: data = f.read() rc = re.compile(r'PG_KEYWORD[(]"(.*)" , \s* \w+ , \s* (\w+) [)]', re.X) for kw, cat in rc.findall(data): full_map[kw] = cat if cat == 'UNRESERVED_KEYWORD': continue if cat == 'COL_NAME_KEYWORD': continue cur_map[kw] = cat def test_kwcheck(): """Compare keyword list in quoting.py to the one in postgres sources """ kwset = set(skytools.quoting._ident_kwmap) full_map = {} # all types from kwlist.h cur_map = {} # only kwlist.h new_list = [] # missing from kwset obsolete_list = [] # in kwset, but not in cur_map done = set() for loc in locations: for ver in versions: fn = loc.format(VER=ver) if fn not in done: _load_kwlist(fn, full_map, cur_map) done.add(fn) if not full_map: return for kw in sorted(cur_map): if kw not in kwset: new_list.append((kw, cur_map[kw])) kwset.add(kw) for k in sorted(kwset): if k not in full_map: # especially obsolete obsolete_list.append((k, '!FULL')) elif k not in cur_map: # slightly obsolete obsolete_list.append((k, '!CUR')) assert new_list == [] # here we need to keep older keywords around longer #assert obsolete_list == [] # [('between', '!CUR'), ('errors', '!FULL'), ('new', '!CUR'), # ('off', '!CUR'), ('old', '!CUR'), ('over', '!CUR')] python-skytools-3.6.1/tests/test_natsort.py000066400000000000000000000030171373460333300212020ustar00rootroot00000000000000 from skytools.natsort import ( natsort, natsort_icase, natsort_key, natsorted, natsorted_icase, ) def test_natsorted(): res = natsorted(['1', 'ver-1.11', '', 'ver-1.0']) assert res == ['', '1', 'ver-1.0', 'ver-1.11'] def test_natsort(): res = ['a1', '2a', '.1'] natsort(res) assert res == ['.1', '2a', 'a1'] def test_natsorted_icase(): res = natsorted_icase(['Ver-1.1', 'vEr-1.11', '', 'veR-1.0']) assert res == ['', 'veR-1.0', 'Ver-1.1', 'vEr-1.11'] def test_natsort_icase(): res = ['Ver-1.1', 'vEr-1.11', '', 'veR-1.0'] natsort_icase(res) assert res == ['', 'veR-1.0', 'Ver-1.1', 'vEr-1.11'] def _natcmp(a, b): k1 = natsort_key(a) k2 = natsort_key(b) if k1 < k2: return 'ok' return f"fail: a='{a}' > b='{b}'" def test_natsort_order(): assert _natcmp('1', '2') == 'ok' assert _natcmp('2', '11') == 'ok' assert _natcmp('.', '1') == 'ok' assert _natcmp('1', 'a') == 'ok' assert _natcmp('a~1', 'ab') == 'ok' assert _natcmp('a~1', 'a') == 'ok' assert _natcmp('a~1', 'a1') == 'ok' assert _natcmp('00', '0') == 'ok' assert _natcmp('001', '0') == 'ok' assert _natcmp('0', '01') == 'ok' assert _natcmp('011', '02') == 'ok' assert _natcmp('00~1', '0~1') == 'ok' assert _natcmp('~~~', '~~') == 'ok' assert _natcmp('1~beta0', '1') == 'ok' assert _natcmp('1', '1.0') == 'ok' assert _natcmp('~', '') == 'ok' assert _natcmp('', '0') == 'ok' assert _natcmp('', '1') == 'ok' assert _natcmp('', 'a') == 'ok' python-skytools-3.6.1/tests/test_parsing.py000066400000000000000000000162021373460333300211530ustar00rootroot00000000000000 import pytest from skytools.parsing import ( dedent, hsize_to_bytes, merge_connect_string, parse_acl, parse_connect_string, parse_logtriga_sql, parse_pgarray, parse_sqltriga_sql, parse_statements, parse_tabbed_table, sql_tokenizer, ) def test_parse_pgarray(): assert parse_pgarray('{}') == [] assert parse_pgarray('{a,b,null,"null"}') == ['a', 'b', None, 'null'] assert parse_pgarray(r'{"a,a","b\"b","c\\c"}') == ['a,a', 'b"b', 'c\\c'] assert parse_pgarray("[0,3]={1,2,3}") == ['1', '2', '3'] assert parse_pgarray(None) is None with pytest.raises(ValueError): parse_pgarray('}{') with pytest.raises(ValueError): parse_pgarray('[1]=}') with pytest.raises(ValueError): parse_pgarray('{"..." , }') with pytest.raises(ValueError): parse_pgarray('{"..." ; }') with pytest.raises(ValueError): parse_pgarray('{"}') with pytest.raises(ValueError): parse_pgarray('{"..."}zzz') with pytest.raises(ValueError): parse_pgarray('{"..."}z') def test_parse_sqltriga_sql(): # Insert event row = parse_logtriga_sql('I', '(id, data) values (1, null)') assert row == {'data': None, 'id': '1'} row = parse_sqltriga_sql('I', '(id, data) values (1, null)', pklist=["id"]) assert row == {'data': None, 'id': '1'} # Update event row = parse_logtriga_sql('U', "data='foo' where id = 1") assert row == {'data': 'foo', 'id': '1'} # Delete event row = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'") assert row == {'id': '1', 'id2': "str'val"} # Insert event, splitkeys keys, row = parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True) assert keys == {} assert row == {'data': None, 'id': '1'} keys, row = parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True) assert keys == {} assert row == {'data': None, 'id': '1'} # Update event, splitkeys keys, row = parse_logtriga_sql('U', "data='foo' where id = 1", splitkeys=True) assert keys == {'id': '1'} assert row == {'data': 'foo'} keys, row = parse_logtriga_sql('U', "data='foo',type=3 where id = 1", splitkeys=True) assert keys == {'id': '1'} assert row == {'data': 'foo', 'type': '3'} # Delete event, splitkeys keys, row = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'", splitkeys=True) assert keys == {'id': '1', 'id2': "str'val"} # generic with pytest.raises(ValueError): parse_logtriga_sql('J', "(id, data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id) values (1, null)") # insert errors with pytest.raises(ValueError): parse_logtriga_sql('I', "insert (id, data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id; data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) select (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values of (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values (1; null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values (1, null) ;") with pytest.raises(ValueError, match="EOF"): parse_logtriga_sql('I', "(id, data) values (1, null) , insert") # update errors with pytest.raises(ValueError): parse_logtriga_sql('U', "(id,data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('U', "id,data") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo';type=3 where id = 1") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo' where id>1") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo' where id=1 or true") # delete errors with pytest.raises(ValueError): parse_logtriga_sql('D', "foo,1") with pytest.raises(ValueError): parse_logtriga_sql('D', "foo = 1 ,") def test_parse_tabbed_table(): assert parse_tabbed_table('col1\tcol2\nval1\tval2\n') == [ {'col1': 'val1', 'col2': 'val2'} ] # skip rows with different size assert parse_tabbed_table('col1\tcol2\nval1\tval2\ntmp\n') == [ {'col1': 'val1', 'col2': 'val2'} ] def test_sql_tokenizer(): res = sql_tokenizer("select * from a.b", ignore_whitespace=True) assert list(res) == [ ('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b') ] res = sql_tokenizer("\"c olumn\",'str''val'") assert list(res) == [ ('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'") ] res = sql_tokenizer('a.b a."b "" c" a.1', fqident=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1') ] res = sql_tokenizer(r"set 'a''\' + E'\''", standard_quoting=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'set'), ('str', "'a''\\'"), ('sym', '+'), ('str', "E'\\''") ] res = sql_tokenizer('a.b a."b "" c" a.1', fqident=True, standard_quoting=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1') ] res = sql_tokenizer('a.b\nc;', show_location=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a', 1), ('sym', '.', 2), ('ident', 'b', 3), ('ident', 'c', 5), ('sym', ';', 6) ] def test_parse_statements(): res = parse_statements("begin; select 1; select 'foo'; end;") assert list(res) == ['begin;', 'select 1;', "select 'foo';", 'end;'] res = parse_statements("select (select 2+(select 3;);) ; select 4;") assert list(res) == ['select (select 2+(select 3;);) ;', 'select 4;'] with pytest.raises(ValueError): list(parse_statements('select ());')) with pytest.raises(ValueError): list(parse_statements('copy from stdin;')) def test_parse_acl(): assert parse_acl('user=rwx/owner') == ('user', 'rwx', 'owner') assert parse_acl('" ""user"=rwx/" ""owner"') == (' "user', 'rwx', ' "owner') assert parse_acl('user=rwx') == ('user', 'rwx', None) assert parse_acl('=/f') == (None, '', 'f') # is this ok? assert parse_acl('?') is None def test_dedent(): assert dedent(' Line1:\n Line 2\n') == 'Line1:\n Line 2\n' res = dedent(' \nLine1:\n Line 2\n Line 3\n Line 4') assert res == 'Line1:\nLine 2\n Line 3\n Line 4\n' def test_hsize_to_bytes(): assert hsize_to_bytes('10G') == 10737418240 assert hsize_to_bytes('12k') == 12288 with pytest.raises(ValueError): hsize_to_bytes("x") def test_parse_connect_string(): assert parse_connect_string("host=foo") == [('host', 'foo')] res = parse_connect_string(r" host = foo password = ' f\\\o\'o ' ") assert res == [('host', 'foo'), ('password', "' f\\o'o '")] with pytest.raises(ValueError): parse_connect_string(r" host = ") def test_merge_connect_string(): res = merge_connect_string([('host', 'ip'), ('pass', ''), ('x', ' ')]) assert res == "host=ip pass='' x=' '" python-skytools-3.6.1/tests/test_querybuilder.py000066400000000000000000000057501373460333300222320ustar00rootroot00000000000000 from skytools.querybuilder import ( PARAM_DBAPI, PARAM_INLINE, PARAM_PLPY, DList, PlanCache, QueryBuilder, plpy, plpy_exec, ) def test_dlist(): root = DList() assert root.empty() == True elem1 = DList() elem2 = DList() elem3 = DList() root.append(elem1) root.append(elem2) root.append(elem3) assert root.empty() == False assert elem1.empty() == False root.remove(elem2) root.remove(elem3) root.remove(elem1) assert root.empty() == True assert elem1.next is None assert elem2.next is None assert elem3.next is None assert elem1.prev is None assert elem2.prev is None assert elem3.prev is None def test_cached_plan(): cache = PlanCache(3) p1 = cache.get_plan('sql1', ['text']) assert p1 is cache.get_plan('sql1', ['text']) p2 = cache.get_plan('sql1', ['int']) assert p2 is cache.get_plan('sql1', ['int']) assert p1 is not p2 p3 = cache.get_plan('sql3', ['text']) assert p3 is cache.get_plan('sql3', ['text']) p4 = cache.get_plan('sql4', ['text']) assert p4 is cache.get_plan('sql4', ['text']) p1x = cache.get_plan('sql1', ['text']) assert p1 is not p1x def test_querybuilder_core(): args = {'success': 't', 'total': 45, 'ccy': 'EEK', 'id': 556} q = QueryBuilder("update orders set total = {total} where id = {id}", args) q.add(" and optional = {non_exist}") q.add(" and final = {success}") exp = "update orders set total = '45' where id = '556' and final = 't'" assert q.get_sql(PARAM_INLINE) == exp exp = "update orders set total = %s where id = %s and final = %s" assert q.get_sql(PARAM_DBAPI) == exp exp = "update orders set total = $1 where id = $2 and final = $3" assert q.get_sql(PARAM_PLPY) == exp def test_plpy_exec(): GD = {} plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) assert plpy.log == [ "DBG: plpy.prepare('select $1, $2, $3', ['text', 'int4', 'text'])", "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['1', '2', '1'])", ] plpy.log.clear() plpy_exec(None, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) assert plpy.log == [ """DBG: plpy.execute("select '1', '2', '1'", ())""" ] plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3', 'arg2': '4'}) assert plpy.log == [ "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', '4', '3'])" ] plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}) assert plpy.log == [ """DBG: plpy.error("Missing arguments: [arg2] QUERY: 'select {arg1}, {arg2:int4}, {arg1}'")""" ] plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}, False) assert plpy.log == [ "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', None, '3'])" ] python-skytools-3.6.1/tests/test_quoting.py000066400000000000000000000175571373460333300212140ustar00rootroot00000000000000"""Extra tests for quoting module. """ from decimal import Decimal import psycopg2.extras import pytest import skytools._cquoting import skytools._pyquoting import skytools.psycopgwrapper from skytools.quoting import ( json_decode, json_encode, make_pgarray, quote_fqident, unescape_copy, unquote_fqident, ) class fake_cursor: """create a DictCursor row""" index = {'id': 0, 'data': 1} description = ['x', 'x'] dbrow = psycopg2.extras.DictRow(fake_cursor()) dbrow[0] = '123' dbrow[1] = 'value' def try_func(qfunc, data_list): for val, exp in data_list: got = qfunc(val) assert got == exp def try_catch(qfunc, data_list, exc): for d in data_list: with pytest.raises(exc): qfunc(d) def test_quote_literal(): sql_literal = [ [None, "null"], ["", "''"], ["a'b", "'a''b'"], [r"a\'b", r"E'a\\''b'"], [1, "'1'"], [True, "'True'"], [Decimal(1), "'1'"], [u'qwe', "'qwe'"] ] try_func(skytools._cquoting.quote_literal, sql_literal) try_func(skytools._pyquoting.quote_literal, sql_literal) try_func(skytools.quote_literal, sql_literal) qliterals_common = [ (r"""null""", None), (r"""NULL""", None), (r"""123""", "123"), (r"""''""", r""""""), (r"""'a''b''c'""", r"""a'b'c"""), (r"""'foo'""", r"""foo"""), (r"""E'foo'""", r"""foo"""), (r"""E'a\n\t\a\b\0\z\'b'""", "a\n\t\x07\x08\x00z'b"), (r"""$$$$""", r""), (r"""$$qw$e$z$$""", r"qw$e$z"), (r"""$qq$$aa$$$'"\\$qq$""", '$aa$$$\'"\\\\'), (u"'qwe'", 'qwe'), ] bad_dol_literals = [ ('$$', '$$'), #('$$q', '$$q'), ('$$q$', '$$q$'), ('$q$q$', '$q$q$'), ('$q$q$x$', '$q$q$x$'), ] def test_unquote_literal(): qliterals_nonstd = qliterals_common + [ (r"""'a\\b\\c'""", r"""a\b\c"""), (r"""e'a\\b\\c'""", r"""a\b\c"""), ] try_func(skytools._cquoting.unquote_literal, qliterals_nonstd) try_func(skytools._pyquoting.unquote_literal, qliterals_nonstd) try_func(skytools.unquote_literal, qliterals_nonstd) for v1, v2 in bad_dol_literals: with pytest.raises(ValueError): skytools._pyquoting.unquote_literal(v1) with pytest.raises(ValueError): skytools._cquoting.unquote_literal(v1) with pytest.raises(ValueError): skytools.unquote_literal(v1) def test_unquote_literal_std(): qliterals_std = qliterals_common + [ (r"''", r""), (r"'foo'", r"foo"), (r"E'foo'", r"foo"), (r"'\\''z'", r"\\'z"), ] for val, exp in qliterals_std: assert skytools._cquoting.unquote_literal(val, True) == exp assert skytools._pyquoting.unquote_literal(val, True) == exp assert skytools.unquote_literal(val, True) == exp def test_quote_copy(): sql_copy = [ [None, "\\N"], ["", ""], ["a'\tb", "a'\\tb"], [r"a\'b", r"a\\'b"], [1, "1"], [True, "True"], [u"qwe", "qwe"], [Decimal(1), "1"], ] try_func(skytools._cquoting.quote_copy, sql_copy) try_func(skytools._pyquoting.quote_copy, sql_copy) try_func(skytools.quote_copy, sql_copy) def test_quote_bytea_raw(): sql_bytea_raw = [ [None, None], [b"", ""], [b"a'\tb", "a'\\011b"], [b"a\\'b", r"a\\'b"], [b"\t\344", r"\011\344"], ] try_func(skytools._cquoting.quote_bytea_raw, sql_bytea_raw) try_func(skytools._pyquoting.quote_bytea_raw, sql_bytea_raw) try_func(skytools.quote_bytea_raw, sql_bytea_raw) def test_quote_bytea_raw_fail(): with pytest.raises(TypeError): skytools._pyquoting.quote_bytea_raw(u'qwe') #assert_raises(TypeError, skytools._cquoting.quote_bytea_raw, u'qwe') #assert_raises(TypeError, skytools.quote_bytea_raw, 'qwe') def test_quote_ident(): sql_ident = [ ['', '""'], ["a'\t\\\"b", '"a\'\t\\""b"'], ['abc_19', 'abc_19'], ['from', '"from"'], ['0foo', '"0foo"'], ['mixCase', '"mixCase"'], [u'utf', 'utf'], ] try_func(skytools.quote_ident, sql_ident) def test_fqident(): assert quote_fqident('tbl') == 'public.tbl' assert quote_fqident('Baz.Foo.Bar') == '"Baz"."Foo.Bar"' def _sort_urlenc(func): def wrapper(data): res = func(data) return '&'.join(sorted(res.split('&'))) return wrapper def test_db_urlencode(): t_urlenc = [ [{}, ""], [{'a': 1}, "a=1"], [{'a': None}, "a"], [{'qwe': 1, u'zz': u"qwe"}, 'qwe=1&zz=qwe'], [{'qwe': 1, u'zz': u"qwe"}, 'qwe=1&zz=qwe'], [{'a': '\000%&'}, "a=%00%25%26"], [dbrow, 'data=value&id=123'], [{'a': Decimal("1")}, "a=1"], ] try_func(_sort_urlenc(skytools._cquoting.db_urlencode), t_urlenc) try_func(_sort_urlenc(skytools._pyquoting.db_urlencode), t_urlenc) try_func(_sort_urlenc(skytools.db_urlencode), t_urlenc) def test_db_urldecode(): t_urldec = [ ["", {}], ["a=b&c", {'a': 'b', 'c': None}], ["&&b=f&&", {'b': 'f'}], [u"abc=qwe", {'abc': 'qwe'}], ["b=", {'b': ''}], ["b=%00%45", {'b': '\x00E'}], ] try_func(skytools._cquoting.db_urldecode, t_urldec) try_func(skytools._pyquoting.db_urldecode, t_urldec) try_func(skytools.db_urldecode, t_urldec) def test_unescape(): t_unesc = [ ["", ""], ["\\N", "N"], ["abc", "abc"], [u"abc", "abc"], [r"\0\000\001\01\1", "\0\000\001\001\001"], [r"a\001b\tc\r\n", "a\001b\tc\r\n"], ] try_func(skytools._cquoting.unescape, t_unesc) try_func(skytools._pyquoting.unescape, t_unesc) try_func(skytools.unescape, t_unesc) def test_quote_bytea_literal(): bytea_raw = [ [None, "null"], [b"", "''"], [b"a'\tb", "E'a''\\\\011b'"], [b"a\\'b", r"E'a\\\\''b'"], [b"\t\344", r"E'\\011\\344'"], ] try_func(skytools.quote_bytea_literal, bytea_raw) def test_quote_bytea_copy(): bytea_raw = [ [None, "\\N"], [b"", ""], [b"a'\tb", "a'\\\\011b"], [b"a\\'b", r"a\\\\'b"], [b"\t\344", r"\\011\\344"], ] try_func(skytools.quote_bytea_copy, bytea_raw) def test_quote_statement(): sql = "set a=%s, b=%s, c=%s" args = [None, u"qwe'qwe", 6.6] assert skytools.quote_statement(sql, args) == "set a=null, b='qwe''qwe', c='6.6'" sql = "set a=%(a)s, b=%(b)s, c=%(c)s" args = dict(a=None, b="qwe'qwe", c=6.6) assert skytools.quote_statement(sql, args) == "set a=null, b='qwe''qwe', c='6.6'" def test_quote_json(): json_string_vals = [ [None, "null"], ['', '""'], [u'xx', '"xx"'], ['qwe"qwe\t', '"qwe\\"qwe\\t"'], ['\x01', '"\\u0001"'], ] try_func(skytools.quote_json, json_string_vals) def test_unquote_ident(): idents = [ ['qwe', 'qwe'], [u'qwe', 'qwe'], ['"qwe"', 'qwe'], ['"q""w\\\\e"', 'q"w\\\\e'], ['Foo', 'foo'], ['"Wei "" rd"', 'Wei " rd'], ] try_func(skytools.unquote_ident, idents) def test_unquote_ident_fail(): with pytest.raises(Exception): skytools.unquote_ident('asd"asd') def test_unescape_copy(): assert unescape_copy(r'baz\tfo\'o') == "baz\tfo'o" assert unescape_copy(r'\N') is None def test_unquote_fqident(): assert unquote_fqident('Foo') == 'foo' assert unquote_fqident('"Foo"."Bar "" z"') == 'Foo.Bar " z' def test_json_encode(): assert json_encode({'a': 1}) == '{"a": 1}' assert json_encode('a') == '"a"' assert json_encode(['a']) == '["a"]' assert json_encode(a=1) == '{"a": 1}' def test_json_decode(): assert json_decode('[1]') == [1] def test_make_pgarray(): assert make_pgarray([]) == '{}' assert make_pgarray(['foo_3', 1, '', None]) == '{foo_3,1,"",NULL}' res = make_pgarray([None, ',', '\\', "'", '"', "{", "}", '_']) exp = '{NULL,",","\\\\","\'","\\"","{","}",_}' assert res == exp python-skytools-3.6.1/tests/test_scripting.py000066400000000000000000000075401373460333300215170ustar00rootroot00000000000000 import os import signal import sys import time import pytest import skytools from skytools.scripting import run_single_process WIN32 = sys.platform == "win32" CONF = os.path.join(os.path.dirname(__file__), "config.ini") TEST_DB = os.environ.get("TEST_DB") def checklog(log, word): with open(log, 'r') as f: return word in f.read() class Runner: def __init__(self, logfile, word, sleep=0): self.logfile = logfile self.word = word self.sleep = sleep def run(self): with open(self.logfile, "a") as f: f.write(self.word + "\n") time.sleep(self.sleep) @pytest.mark.skipif(WIN32, reason="cannot daemonize on win32") def test_bg_process(tmp_path): pidfile = str(tmp_path / "proc.pid") logfile = str(tmp_path / "proc.log") run_single_process(Runner(logfile, "STEP1"), False, pidfile) while skytools.signal_pidfile(pidfile, 0): time.sleep(1) assert checklog(logfile, "STEP1") # daemonize from other process pid = os.fork() if pid == 0: run_single_process(Runner(logfile, "STEP2", 10), True, pidfile) time.sleep(2) with pytest.raises(SystemExit): run_single_process(Runner(logfile, "STEP3"), False, pidfile) skytools.signal_pidfile(pidfile, signal.SIGTERM) while skytools.signal_pidfile(pidfile, 0): time.sleep(1) assert checklog(logfile, "STEP2") assert not checklog(logfile, "STEP3") class OptScript(skytools.BaseScript): ARGPARSE = False looping = 0 def send_signal(self, code): print("signal: %s" % code) sys.exit(0) def work(self): print("opt=%s" % self.cf.get("opt")) class ArgScript(OptScript): ARGPARSE = True def test_optparse_script(capsys): with pytest.raises(SystemExit): OptScript("testscript", ["-h"]) res = capsys.readouterr() assert "display" in res.out def test_argparse_script(capsys): with pytest.raises(SystemExit): ArgScript("testscript", ["-h"]) res = capsys.readouterr() assert "display" in res.out @pytest.mark.skipif(WIN32, reason="use signals on win32") def test_optparse_signals(capsys): with pytest.raises(SystemExit): OptScript("testscript", ["-s", CONF]) res = capsys.readouterr() assert "SIGINT" in res.out with pytest.raises(SystemExit): OptScript("testscript", ["-r", CONF]) res = capsys.readouterr() assert "SIGHUP" in res.out with pytest.raises(SystemExit): OptScript("testscript", ["-k", CONF]) res = capsys.readouterr() assert "SIGTERM" in res.out @pytest.mark.skipif(WIN32, reason="need to use signals") def test_argparse_signals(capsys): with pytest.raises(SystemExit): ArgScript("testscript", ["-s", CONF]) res = capsys.readouterr() assert "SIGINT" in res.out with pytest.raises(SystemExit): ArgScript("testscript", ["-r", CONF]) res = capsys.readouterr() assert "SIGHUP" in res.out with pytest.raises(SystemExit): ArgScript("testscript", ["-k", CONF]) res = capsys.readouterr() assert "SIGTERM" in res.out def test_optparse_confopt(capsys): s = ArgScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "opt=test" in res.out def test_argparse_confopt(capsys): s = ArgScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "opt=test" in res.out class DBScript(skytools.DBScript): ARGPARSE = True looping = 0 def work(self): db = self.get_database("db", connstr=TEST_DB) curs = db.cursor() curs.execute("select 1") curs.fetchall() self.close_database("db") print("OK") @pytest.mark.skipif(not TEST_DB, reason="need database config") def test_get_database(capsys): s = DBScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "OK" in res.out python-skytools-3.6.1/tests/test_skylog.py000066400000000000000000000005551373460333300210240ustar00rootroot00000000000000 import logging import skytools from skytools import skylog def test_trace_setup(): assert skylog.TRACE < logging.DEBUG assert skylog.TRACE == logging.TRACE assert logging.getLevelName(skylog.TRACE) == "TRACE" def test_skylog(): log = skytools.getLogger("test.skylog") log.trace("tracemsg") assert not log.isEnabledFor(logging.TRACE) python-skytools-3.6.1/tests/test_sockutil.py000066400000000000000000000016631373460333300213520ustar00rootroot00000000000000import os import socket import sys import pytest from skytools.sockutil import set_cloexec, set_nonblocking, set_tcp_keepalive def test_set_tcp_keepalive(): with socket.socket() as s: set_tcp_keepalive(s) @pytest.mark.skipif( sys.platform == 'win32', reason="set_nonblocking on fd does not work on win32" ) def test_set_nonblocking(): with socket.socket() as s: assert set_nonblocking(s, None) == False assert set_nonblocking(s, 1) == 1 assert set_nonblocking(s, None) == True def test_set_cloexec_file(): with open(os.devnull, 'rb') as f: assert set_cloexec(f, None) in (True, False) assert set_cloexec(f, True) == True assert set_cloexec(f, None) == True def test_set_cloexec_socket(): with socket.socket() as s: assert set_cloexec(s, None) in (True, False) assert set_cloexec(s, True) == True assert set_cloexec(s, None) == True python-skytools-3.6.1/tests/test_sqltools.py000066400000000000000000000105521373460333300213720ustar00rootroot00000000000000 import pytest from skytools.sqltools import ( Snapshot, dbdict, fq_name, fq_name_parts, magic_insert, mk_delete_sql, mk_insert_sql, mk_update_sql, ) def test_dbdict(): row = dbdict(a=1, b=2) assert (row.a, row.b, row['a'], row['b']) == (1, 2, 1, 2) row.c = 3 assert row['c'] == 3 del row.c with pytest.raises(AttributeError): row.c with pytest.raises(KeyError): row['c'] row.merge({'q': 4}) assert row.q == 4 def test_fq_name_parts(): assert fq_name_parts('tbl') == ['public', 'tbl'] assert fq_name_parts('foo.tbl') == ['foo', 'tbl'] assert fq_name_parts('foo.tbl.baz') == ['foo', 'tbl.baz'] def test_fq_name(): assert fq_name('tbl') == 'public.tbl' assert fq_name('foo.tbl') == 'foo.tbl' assert fq_name('foo.tbl.baz') == 'foo.tbl.baz' def test_snapshot(): sn = Snapshot('11:20:11,12,15') assert sn.contains(9) assert not sn.contains(11) assert sn.contains(17) assert not sn.contains(20) with pytest.raises(ValueError): Snapshot(':') def test_magic_insert(): res = magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2']) exp = 'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2'], use_insert=True) exp = "insert into public.tbl (col1,col2) values ('1','1');\ninsert into public.tbl (col1,col2) values ('2','2');\n" assert res == exp assert magic_insert(None, 'tbl', [], ['col1', 'col2']) is None res = magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']], ['col1', 'col2'], quoted_table=True) exp = 'COPY tbl."1" (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' assert res == exp with pytest.raises(Exception): magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']]) res = magic_insert(None, 'a.tbl', [{'a': 1}, {'a': 2}]) exp = 'COPY a.tbl (a) FROM STDIN;\n1\n2\n\\.\n' assert res == exp res = magic_insert(None, 'a.tbl', [{'a': 1}, {'a': 2}], use_insert=True) exp = "insert into a.tbl (a) values ('1');\ninsert into a.tbl (a) values ('2');\n" assert res == exp # More fields than data res = magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3']) exp = 'COPY public.tbl (col1,col2,col3) FROM STDIN;\n1\ta\t\\N\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3'], use_insert=True) exp = "insert into public.tbl (col1,col2,col3) values ('1','a',null);\n" assert res == exp res = magic_insert(None, 'tbl', [{'a': 1}, {'b': 2}], ['a', 'b'], use_insert=False) exp = 'COPY public.tbl (a,b) FROM STDIN;\n1\t\\N\n\\N\t2\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [{'a': 1}, {'b': 2}], ['a', 'b'], use_insert=True) exp = "insert into public.tbl (a,b) values ('1',null);\ninsert into public.tbl (a,b) values (null,'2');\n" assert res == exp def test_mk_insert_sql(): row = {'id': 1, 'data': None} res = mk_insert_sql(row, 'tbl') exp = "insert into public.tbl (id, data) values ('1', null);" assert res == exp fmap = {'id': 'id_', 'data': 'data_'} res = mk_insert_sql(row, 'tbl', ['x'], fmap) exp = "insert into public.tbl (id_, data_) values ('1', null);" assert res == exp def test_mk_update_sql(): res = mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2']) exp = 'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';' assert res == exp res = mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2'], {'id': '_id', 'id2': '_id2', 'data': '_data'}) exp = 'update only public."Table" set _data = E\'str\\\\\' where _id = \'0\' and _id2 = \'2\';' assert res == exp with pytest.raises(Exception): mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', []) def test_mk_delete_sql(): res = mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', ['a', 'b']) exp = "delete from only public.tablename where a = '1' and b = '2';" assert res == exp res = mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', ['a', 'b'], {'a': 'aa', 'b': 'bb'}) exp = "delete from only public.tablename where aa = '1' and bb = '2';" assert res == exp with pytest.raises(Exception): mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', []) python-skytools-3.6.1/tests/test_timeutil.py000066400000000000000000000034531373460333300213500ustar00rootroot00000000000000 from datetime import datetime import pytest from skytools.timeutil import UTC, datetime_to_timestamp, parse_iso_timestamp def test_parse_iso_timestamp(): res = str(parse_iso_timestamp('2005-06-01 15:00')) assert res == '2005-06-01 15:00:00' res = str(parse_iso_timestamp(' 2005-06-01T15:00 +02 ')) assert res == '2005-06-01 15:00:00+02:00' res = str(parse_iso_timestamp('2005-06-01 15:00:33+02:00')) assert res == '2005-06-01 15:00:33+02:00' d = parse_iso_timestamp('2005-06-01 15:00:59.33 +02') assert d.strftime("%z %Z") == '+0200 +02' assert str(parse_iso_timestamp(str(d))) == '2005-06-01 15:00:59.330000+02:00' res = parse_iso_timestamp('2005-06-01 15:00-0530').strftime('%Y-%m-%d %H:%M %z %Z') assert res == '2005-06-01 15:00 -0530 -05:30' res = parse_iso_timestamp('2014-10-27T11:59:13Z').strftime('%Y-%m-%d %H:%M:%S %z %Z') assert res == '2014-10-27 11:59:13 +0000 +00' with pytest.raises(ValueError): parse_iso_timestamp('2014.10.27') def test_datetime_to_timestamp(): res = datetime_to_timestamp(parse_iso_timestamp("2005-06-01 15:00:59.5 +02")) assert res == 1117630859.5 res = datetime_to_timestamp(datetime.fromtimestamp(1117630859.5, UTC)) assert res == 1117630859.5 res = datetime_to_timestamp(datetime.fromtimestamp(1117630859.5)) assert res == 1117630859.5 now = datetime.utcnow() now2 = datetime.utcfromtimestamp(datetime_to_timestamp(now, False)) assert abs(now2.microsecond - now.microsecond) < 100 now2 = now2.replace(microsecond=now.microsecond) assert now == now2 now = datetime.now() now2 = datetime.fromtimestamp(datetime_to_timestamp(now)) assert abs(now2.microsecond - now.microsecond) < 100 now2 = now2.replace(microsecond=now.microsecond) assert now == now2 python-skytools-3.6.1/tests/test_tnetstrings.py000066400000000000000000000035271373460333300221020ustar00rootroot00000000000000 import pytest from skytools.tnetstrings import dumps, loads def ustr(v): return repr(v).replace("u'", "'") def nstr(b): return b.decode('utf8') def test_dumps_simple_values(): vals = (None, False, True, 222, 333.0, b"foo", u"bar") tnvals = [nstr(dumps(v)) for v in vals] assert tnvals == [ '0:~', '5:false!', '4:true!', '3:222#', '5:333.0^', '3:foo,', '3:bar,' ] def test_dumps_complex_values(): vals2 = [[], (), {}, [1, 2], {'a': 'b'}] tnvals2 = [nstr(dumps(v)) for v in vals2] assert tnvals2 == ['0:]', '0:]', '0:}', '8:1:1#1:2#]', '8:1:a,1:b,}'] def test_loads_simple(): vals = (None, False, True, 222, 333.0, b"foo", u"bar") res = ustr([loads(dumps(v)) for v in vals]) assert res == "[None, False, True, 222, 333.0, 'foo', 'bar']" def test_loads_complex(): vals2 = [[], (), {}, [1, 2], {'a': 'b'}] res = ustr([loads(dumps(v)) for v in vals2]) assert res == "[[], [], {}, [1, 2], {'a': 'b'}]" def test_dumps_mview(): res = nstr(dumps([memoryview(b'zzz'), b'qqq'])) assert res == '12:3:zzz,3:qqq,]' def test_loads_errors(): with pytest.raises(ValueError): loads(b'4:qwez!') with pytest.raises(ValueError): loads(b'4:') with pytest.raises(ValueError): loads(b'4:qwez') with pytest.raises(ValueError): loads(b'4') with pytest.raises(ValueError): loads(b'') with pytest.raises(ValueError): loads(b'999999999999999999:z,') with pytest.raises(TypeError): loads(u'qweqwe') with pytest.raises(ValueError): loads(b'4:true!0:~') with pytest.raises(ValueError): loads(b'1:X~') with pytest.raises(ValueError): loads(b'8:1:1#1:2#}') with pytest.raises(ValueError): loads(b'1:Xz') def test_dumps_errors(): with pytest.raises(TypeError): dumps(open) python-skytools-3.6.1/tests/test_utf8.py000066400000000000000000000014051373460333300203750ustar00rootroot00000000000000 import pytest from skytools.utf8 import safe_utf8_decode, sanitize_unicode def test_safe_decode(): assert safe_utf8_decode(b"foobar") == (True, "foobar") assert safe_utf8_decode(b'X\0Z') == (False, "X\uFFFDZ") assert safe_utf8_decode(b"OK") == (True, "OK") assert safe_utf8_decode(b'X\xF1Y') == (False, "X\uFFFDY") assert sanitize_unicode(u'\uD801\uDC01') == "\U00010401" with pytest.raises(TypeError): sanitize_unicode(b'qwe') ## these give different results in py27 and py35 # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80Y\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 'Y', 65533, 65533, 65533, 'Z']) # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 65533, 65533, 65533, 'Z']) python-skytools-3.6.1/tox.ini000066400000000000000000000025261373460333300162540ustar00rootroot00000000000000 [tox] envlist = lint,py3 [package] name = skytools deps = psycopg2-binary==2.8.5 test_deps = coverage==5.2 pytest==5.4.3 pytest-cov==2.10.0 lint_deps = pylint==2.5.3 flake8==3.8.3 doc_deps = sphinx==3.1.2 docutils==0.16 [testenv] #changedir = {envsitepackagesdir} changedir = {toxinidir} deps = {[package]deps} {[package]test_deps} passenv = TEST_DB commands = pytest \ --cov \ --cov-report=term \ --cov-report=xml:{toxinidir}/cover/coverage.xml \ --cov-report=html:{toxinidir}/cover/{envname} \ --cov-config={toxinidir}/.coveragerc \ --rootdir={toxinidir} \ {posargs} [testenv:lint] changedir = {toxinidir} basepython = python3 deps = {[package]deps} {[package]lint_deps} setenv = PYLINTRC={toxinidir}/.pylintrc commands = flake8 tests pylint {[package]name} [testenv:pytype] changedir = {toxinidir} basepython = python3 deps = {[package]deps} pytype==2020.7.30 commands = pytype {[package]name} [testenv:mypy] changedir = {toxinidir} basepython = python3 deps = {[package]deps} mypy==0.782 commands = mypy {[package]name} [testenv:docs] basepython = python3 deps = {[package]deps} {[package]doc_deps} changedir = doc commands = sphinx-build -q -W -b html -d {envtmpdir}/doctrees . ../tmp/dochtml