pax_global_header00006660000000000000000000000064144726556670014540gustar00rootroot0000000000000052 comment=ea0db107d516475b0b05c1cc0d614db47725ca66 python-skytools-3.9.2/000077500000000000000000000000001447265566700147615ustar00rootroot00000000000000python-skytools-3.9.2/.coveragerc000066400000000000000000000003641447265566700171050ustar00rootroot00000000000000[report] exclude_lines = pragma: no cover if __name__ except ImportError: raise NotImplementedError omit = .tox/* tests/* */tests/* [paths] source_pkgs = skytools/ */skytools/ [run] source_pkgs = skytools python-skytools-3.9.2/.github/000077500000000000000000000000001447265566700163215ustar00rootroot00000000000000python-skytools-3.9.2/.github/workflows/000077500000000000000000000000001447265566700203565ustar00rootroot00000000000000python-skytools-3.9.2/.github/workflows/ci.yml000066400000000000000000000155521447265566700215040ustar00rootroot00000000000000# # https://docs.github.com/en/actions/reference # https://github.com/actions # https://cibuildwheel.readthedocs.io/en/stable/options/ # # uses: https://github.com/actions/checkout @v3 # uses: https://github.com/actions/setup-python @v4 # uses: https://github.com/actions/download-artifact @v3 # uses: https://github.com/actions/upload-artifact @v3 # uses: https://github.com/pypa/cibuildwheel @v2.15 name: CI on: pull_request: {} push: {} jobs: check: name: "Check" runs-on: ubuntu-latest strategy: matrix: test: - {PY: "3.11", TOXENV: "lint"} steps: - name: "Checkout" uses: actions/checkout@v3 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v4 with: python-version: ${{matrix.test.PY}} - run: python3 -m pip install -r etc/requirements.build.txt --disable-pip-version-check - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} run: python3 -m tox -r no_database: name: "${{matrix.test.osname}} + Python ${{matrix.test.PY}} ${{matrix.test.arch}}" runs-on: ${{matrix.test.os}} strategy: matrix: test: - {os: "ubuntu-latest", osname: "Linux", PY: "3.7", TOXENV: "py37", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "3.8", TOXENV: "py38", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "3.9", TOXENV: "py39", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "3.10", TOXENV: "py310", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "3.11", TOXENV: "py311", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "3.12", TOXENV: "py312", arch: "x64"} - {os: "macos-latest", osname: "MacOS", PY: "3.10", TOXENV: "py310", arch: "x64"} - {os: "macos-latest", osname: "MacOS", PY: "3.11", TOXENV: "py311", arch: "x64"} #- {os: "macos-latest", osname: "MacOS", PY: "3.12", TOXENV: "py312", arch: "x64"} - {os: "windows-latest", osname: "Windows", PY: "3.7", TOXENV: "py37", arch: "x86"} - {os: "windows-latest", osname: "Windows", PY: "3.8", TOXENV: "py38", arch: "x64"} - {os: "windows-latest", osname: "Windows", PY: "3.10", TOXENV: "py310", arch: "x86"} - {os: "windows-latest", osname: "Windows", PY: "3.11", TOXENV: "py311", arch: "x64"} #- {os: "windows-latest", osname: "Windows", PY: "3.12", TOXENV: "py312", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "pypy3.8", TOXENV: "pypy38", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "pypy3.9", TOXENV: "pypy39", arch: "x64"} - {os: "ubuntu-latest", osname: "Linux", PY: "pypy3.10", TOXENV: "pypy310", arch: "x64"} steps: - name: "Checkout" uses: actions/checkout@v3 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v4 with: python-version: ${{matrix.test.PY}} architecture: ${{matrix.test.arch}} allow-prereleases: true - run: python3 -m pip install -r etc/requirements.build.txt --disable-pip-version-check - name: "Build" run: python setup.py build - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} run: python -m tox -r -- --color=yes database: name: "Python ${{matrix.test.PY}} + PostgreSQL ${{matrix.test.PG}}" runs-on: ubuntu-latest strategy: matrix: test: - {PY: "3.7", PG: "11", TOXENV: "py37"} - {PY: "3.8", PG: "12", TOXENV: "py38"} - {PY: "3.9", PG: "13", TOXENV: "py39"} - {PY: "3.10", PG: "14", TOXENV: "py310"} - {PY: "3.11", PG: "15", TOXENV: "py311"} #- {PY: "pypy3.9", PG: "15", TOXENV: "pypy39"} #- {PY: "pypy3.10", PG: "15", TOXENV: "pypy310"} steps: - name: "Checkout" uses: actions/checkout@v3 - name: "Setup Python ${{matrix.test.PY}}" uses: actions/setup-python@v4 with: python-version: ${{matrix.test.PY}} - run: python3 -m pip install -r etc/requirements.build.txt --disable-pip-version-check - name: "InstallDB" run: | echo "::group::apt-get-update" sudo -nH apt-get -q update sudo -nH apt-get -q install curl ca-certificates gnupg curl https://www.postgresql.org/media/keys/ACCC4CF8.asc \ | gpg --dearmor \ | sudo -nH tee /etc/apt/trusted.gpg.d/apt.postgresql.org.gpg echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main ${{matrix.test.PG}}" \ | sudo -nH tee /etc/apt/sources.list.d/pgdg.list sudo -nH apt-get -q update echo "::endgroup::" echo "::group::apt-get-install" # disable new cluster creation sudo -nH mkdir -p /etc/postgresql-common/createcluster.d echo "create_main_cluster = false" | sudo -nH tee /etc/postgresql-common/createcluster.d/no-main.conf sudo -nH apt-get -qyu install postgresql-${{matrix.test.PG}} echo "::endgroup::" # tune environment echo "/usr/lib/postgresql/${{matrix.test.PG}}/bin" >> $GITHUB_PATH echo "PGHOST=/tmp" >> $GITHUB_ENV - name: "StartDB" run: | rm -rf data log mkdir -p log LANG=C initdb data sed -ri -e "s,^[# ]*(unix_socket_directories).*,\\1='/tmp'," data/postgresql.conf pg_ctl -D data -l log/pg.log start || { cat log/pg.log ; exit 1; } sleep 1 createdb testdb - name: "Test" env: TOXENV: ${{matrix.test.TOXENV}} TEST_DB: dbname=testdb host=/tmp run: | python3 -m tox -r -- --color=yes - name: "StopDB" run: | pg_ctl -D data stop rm -rf data log /tmp/.s.PGSQL* cibuildwheel: name: "Wheels: ${{matrix.sys.name}} [${{matrix.sys.archs}}]" runs-on: ${{matrix.sys.os}} strategy: matrix: sys: - {os: "ubuntu-latest", name: "Linux", archs: "auto", qemu: false} - {os: "ubuntu-latest", name: "Linux", archs: "aarch64", qemu: true} - {os: "macos-latest", name: "MacOS", archs: "x86_64 arm64 universal2", qemu: false} - {os: "windows-latest", name: "Windows", archs: "auto", qemu: false} steps: - uses: actions/checkout@v3 - name: "Set up QEMU" if: ${{matrix.sys.qemu}} uses: docker/setup-qemu-action@v2 with: platforms: all - uses: pypa/cibuildwheel@v2.15 env: CIBW_ARCHS: "${{matrix.sys.archs}}" # cp38: cp37-macos does not support universal2/arm64 CIBW_BUILD: "cp38-* pp*-manylinux_x86_64" CIBW_SKIP: "pp37-*" - name: "Check" shell: bash run: | ls -l wheelhouse - uses: actions/upload-artifact@v3 with: {name: "dist", path: "wheelhouse"} python-skytools-3.9.2/.github/workflows/release.yml000066400000000000000000000100611447265566700225170ustar00rootroot00000000000000# # This runs when version tag is pushed # name: REL on: push: tags: ["v[0-9]*"] jobs: sdist: name: "Build source package" runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: {python-version: "3.11"} - run: python3 -m pip install -r etc/requirements.build.txt --disable-pip-version-check - run: python3 setup.py sdist - uses: actions/upload-artifact@v3 with: {name: "dist", path: "dist"} cibuildwheel: name: "Wheels: ${{matrix.sys.name}} [${{matrix.sys.archs}}]" runs-on: ${{matrix.sys.os}} strategy: matrix: sys: - {os: "ubuntu-latest", name: "Linux", archs: "auto", qemu: false} - {os: "ubuntu-latest", name: "Linux", archs: "aarch64", qemu: true} - {os: "macos-latest", name: "MacOS", archs: "x86_64 arm64 universal2", qemu: false} - {os: "windows-latest", name: "Windows", archs: "auto", qemu: false} steps: - uses: actions/checkout@v3 - name: "Set up QEMU" if: ${{matrix.sys.qemu}} uses: docker/setup-qemu-action@v2 with: platforms: all - uses: pypa/cibuildwheel@v2.15 env: CIBW_ARCHS: "${{matrix.sys.archs}}" # cp38: cp37-macos does not support universal2/arm64 CIBW_BUILD: "cp38-* pp*-manylinux_x86_64" CIBW_SKIP: "pp37-*" - name: "Check" shell: bash run: | ls -l wheelhouse - uses: actions/upload-artifact@v3 with: {name: "dist", path: "wheelhouse"} publish: name: "Publish" runs-on: ubuntu-latest needs: [sdist, cibuildwheel] steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v4 with: {python-version: "3.11"} - run: python3 -m pip install -r etc/requirements.build.txt --disable-pip-version-check - name: "Get files" uses: actions/download-artifact@v3 with: {name: "dist", path: "dist"} - name: "Install pandoc" run: | sudo -nH apt-get -u -y install pandoc pandoc --version - name: "Prepare" run: | PACKAGE=$(python3 setup.py --name) VERSION=$(python3 setup.py --version) TGZ="${PACKAGE}-${VERSION}.tar.gz" # default - gh:release, pypi # PRERELEASE - gh:prerelease, pypi # DRAFT - gh:draft,prerelease, testpypi PRERELEASE="false"; DRAFT="false" case "${VERSION}" in *[ab]*|*rc*) PRERELEASE="true";; *dev*) PRERELEASE="true"; DRAFT="true";; esac test "${{github.ref}}" = "refs/tags/v${VERSION}" || { echo "ERR: tag mismatch"; exit 1; } test -f "dist/${TGZ}" || { echo "ERR: sdist failed"; exit 1; } echo "PACKAGE=${PACKAGE}" >> $GITHUB_ENV echo "VERSION=${VERSION}" >> $GITHUB_ENV echo "TGZ=${TGZ}" >> $GITHUB_ENV echo "PRERELEASE=${PRERELEASE}" >> $GITHUB_ENV echo "DRAFT=${DRAFT}" >> $GITHUB_ENV mkdir -p tmp make -s shownote > tmp/note.md cat tmp/note.md ls -l dist - name: "Create Github release" env: GH_TOKEN: ${{secrets.GITHUB_TOKEN}} run: | title="${PACKAGE} v${VERSION}" ghf="--notes-file=./tmp/note.md" if test "${DRAFT}" = "true"; then ghf="${ghf} --draft"; fi if test "${PRERELEASE}" = "true"; then ghf="${ghf} --prerelease"; fi gh release create "v${VERSION}" "dist/${TGZ}" --title="${title}" ${ghf} - name: "Upload to PYPI" id: pypi_upload env: PYPI_TOKEN: ${{secrets.PYPI_TOKEN}} PYPI_TEST_TOKEN: ${{secrets.PYPI_TEST_TOKEN}} run: | ls -l dist if test "${DRAFT}" = "false"; then python -m twine upload -u __token__ -p ${PYPI_TOKEN} \ --repository pypi --disable-progress-bar dist/* else python -m twine upload -u __token__ -p ${PYPI_TEST_TOKEN} \ --repository testpypi --disable-progress-bar dist/* fi python-skytools-3.9.2/.gitignore000066400000000000000000000003571447265566700167560ustar00rootroot00000000000000__pycache__ *.pyc *.swp *.o *.so *.egg-info *.debhelper *.log *.substvars *-stamp debian/files debian/python-skytools debian/python3-skytools .pytype .pytest_cache .mypy_cache .tox .coverage .pybuild cover *.xml MANIFEST build tmp dist python-skytools-3.9.2/AUTHORS000066400000000000000000000013451447265566700160340ustar00rootroot00000000000000 Maintainers ----------- Marko Kreen Petr Jelinek Sasha Aliashkevich Contributors ------------ Aleksei Plotnikov André Malo Andrew Dunstan Artyom Nosov Asko Oja Asko Tiidumaa Cédric Villemain Charles Duffy Devrim Gündüz Dimitri Fontaine Dmitriy V'jukov Doug Gorley Eero Oja Egon Valdmees Emiel van de Laar Erik Jones Glenn Davy Götz Lange Hannu Krosing Hans-Juergen Schoenig Jason Buberel Juta Vaks Kaarel Kitsemets Kristo Kaiv Luc Van Hoeylandt Lukáš Lalinský Marcin Stępnicki Mark Kirkwood Martin Otto Martin Pihlak Nico Mandery Petr Jelinek Pierre-Emmanuel André Priit Kustala Sasha Aliashkevich Sébastien Lardière Sergey Burladyan Sergey Konoplev Shoaib Mir Steve Singer Tarvi Pillessaar Tony Arkles Zoltán Böszörményi python-skytools-3.9.2/COPYRIGHT000066400000000000000000000013451447265566700162570ustar00rootroot00000000000000 Copyright (c) 2007-2020 Skytools Authors Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted, provided that the above copyright notice and this permission notice appear in all copies. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. python-skytools-3.9.2/MANIFEST.in000066400000000000000000000002461447265566700165210ustar00rootroot00000000000000include skytools/py.typed include modules/*.[ch] include tests/*.py tests/*.ini include tox.ini .coveragerc .pylintrc include MANIFEST.in include README.rst NEWS.rst python-skytools-3.9.2/Makefile000066400000000000000000000015641447265566700164270ustar00rootroot00000000000000 VERSION = $(shell python3 setup.py --version) RXVERSION = $(shell python3 setup.py --version | sed 's/\./[.]/g') TAG = v$(VERSION) NEWS = NEWS.rst all: tox -e lint tox -e py3 clean: rm -rf build *.egg-info */__pycache__ tests/*.pyc rm -rf .pybuild MANIFEST sdist: python3 setup.py sdist lint: tox -e lint xlint: tox -e xlint xclean: clean rm -rf .tox dist checkver: @echo "Checking version" @grep -q '^Skytools $(RXVERSION)\b' $(NEWS) \ || { echo "Version '$(VERSION)' not in $(NEWS)"; exit 1; } @echo "Checking git repo" @git diff --stat --exit-code || { echo "ERROR: Unclean repo"; exit 1; } release: checkver git tag $(TAG) git push github $(TAG):$(TAG) unrelease: git push github :$(TAG) git tag -d $(TAG) shownote: gawk -v VER="$(VERSION)" -f etc/note.awk $(NEWS) \ | pandoc -f rst -t gfm --wrap=none showuses: grep uses: .github/workflows/*.yml python-skytools-3.9.2/NEWS.rst000066400000000000000000000106201447265566700162660ustar00rootroot00000000000000 NEWS ==== Skytools 3.9.2 (2023-08-27) --------------------------- Cleanups: * basetypes: add 'statusmessage' property * ci: cleanups * ci: split qemu to separate runner Skytools 3.9.1 (2023-08-25) --------------------------- Fixes: * sqltools: mark ``exists_*`` functions as returning bool * basetypes: sync DictRow with Mapping * basetypes: describe additional psycopg2 api Cleanups: * build: create ``abi3`` wheels * ci: drop unmaintained ``create-release``, ``upload-release-asset`` actions * ci: build aarch64 wheel * ci: test on pypy Skytools 3.9 (2023-08-23) ------------------------- Feature removal: * Drop support for Python 3.6 and earlier. Fixes: * dbstruct: fix PUBLIC grant handling. Cleanups: * Apply mypy 'strict' typing to most modules and tests. * Use ``pyproject.toml`` for project setup. Warning: next release will drop some ancient and rarely used code: * skytools.plpy_applyrow * skytools.dbservice * skytools.skylog.LogDBHandler Skytools 3.8.2 (2023-05-19) --------------------------- Fixes: * scripting: restore tracking of failed work() state Skytools 3.8.1 (2022-11-21) --------------------------- Fixes: * full_copy: use ``ONLY`` when using filter query * test_scripting: support Python 3.11 Skytools 3.8 (2022-07-11) ------------------------- Cleanups: * Lots of typing improvements * Refresh CI setup * Work around PyPy3.9 bug Skytools 3.7.3 (2021-08-03) --------------------------- Fixes: * Allow binary I/O in copy_expert signature. Skytools 3.7.2 (2021-07-06) --------------------------- Fixes: * Avoid psycopg copy_from, not usable in v2.9 Skytools 3.7.1 (2021-06-08) --------------------------- Fixes: * quoting: drop obsolete keywords from quote_ident * quoting: add COL_NAME_KEYWORDs into quote_ident list * querybuilder: use dbdict more consistently Cleanups: * basetypes: tune Protocol classes * tests: avoid 'pointless-statement' * sqltools: annotate dbdict * checker: use 'with' with files * modules: add .pyi annotations Skytools 3.7 (2021-05-17) ------------------------- Features: * config: config_format=2 switches to extended format. * querybuilder: alt SQL for missing value. * querybuilder: handle more value types in inline queries. * querybuilder/plpy: always use prepared plan. Prevously when GD/SD was not given, it switched to inline params, but that was problem because inline value quoting may be different that PL/Python's. Now it always uses plpy.prepare. Cleanups: * querybuilder: switch to functools.lru_cache, instead local LRU. * querybuilder: use regex for parsing, gives cleaner code. * querybuilder: improve error handling * natsort: switch to string key, instead of tuple. * style: Add type annotations to most modules. * style: use new-style super() everywhere. * ci: drop win32 repack, abi3 is now supported on win32 * ci: drop ubuntu 16.04, to be obsoleted. * ci: build wheels using manylinux2014 images. Skytools 3.6.1 (2020-09-29) --------------------------- Fixes: * scripting: Do not set .my_name on connection, does not work on plain Psycopg connection. * cquoting: Work around pypy3 PyBytes_Check bug. * modules: Use multiphase init. Skytools 3.6 (2020-08-11) ------------------------- Feature removal: * Remove ancient compat code from psycopgwrapper: - dict* and iter* methods - getattr access to fields. - Keepalive tuning from connect_database(). That is built-in to libpq since 9.0. - Require psycpopg 2.5+ Cleanups: * Switch C modules to use stable ABI only (abi3). * Remove Debian packaging. * Upgrade apipkg to 1.5. * Remove Py2 compat. Skytools 3.5 (2020-07-18) ------------------------- Fixes: * dbservice: py3 fix for row.values() * skylog: Use logging.setLogRecordFactory for adding extra fields * fileutil,sockutil: fixes for win32. * natsort: py3 fix, improve rules. Cleanups: * Set up Github Actions for CI and release. * Use "with" for opening files. * Drop py2 syntax. * Code reformat. * Convert nose+doctests to pytest. Skytools 3.4 (2019-11-14) ------------------------- * Support Postgres 10 sequences * Make full_copy text-based * Allow None fields in magic_insert * Fix iterator use in magic insert * Fix Python3 bugs * Switch off Python2 tests, to avoid wasting time. Skytools 3.3 (2017-09-21) ------------------------- * Separate 'skytools' module out from big package * Python 3 support Skytools 3.2 and older ---------------------- See old changes here: https://github.com/pgq/skytools-legacy/blob/master/NEWS python-skytools-3.9.2/README.rst000066400000000000000000000012471447265566700164540ustar00rootroot00000000000000 Skytools - Utilities for writing Python scripts =============================================== This is the low-level utility module split out from old Skytools meta-package. It contains various utilities for writing database scripts. Database specific utilites are mainly meant for PostgreSQL. Features -------- * Support for background scripts - Daemonizing - logging - config parsing * Database tools - Tuned connection - DB structure examining - SQL parsing - COPY I/O * Time utilities - ISO timestamp parsing - datetime to timestamp * Text utilities - Natural sort - Fast urlencode I/O TODO ---- * Move from optparse to argparse * Doc cleanup python-skytools-3.9.2/etc/000077500000000000000000000000001447265566700155345ustar00rootroot00000000000000python-skytools-3.9.2/etc/note.awk000066400000000000000000000004601447265566700172050ustar00rootroot00000000000000# extract version notes for version VER /^[-_0-9a-zA-Z]+ v?[0-9]/ { if ($2 == VER) { good = 1 next } else { good = 0 } } /^(===|---)/ { next } { if (good) { # also remove sphinx syntax print gensub(/:(\w+):`~?([^`]+)`/, "``\\2``", "g") } } python-skytools-3.9.2/etc/requirements.build.txt000066400000000000000000000000631447265566700221150ustar00rootroot00000000000000setuptools>=67 wheel>=0.41 twine==4.0.2 tox==4.8.0 python-skytools-3.9.2/modules/000077500000000000000000000000001447265566700164315ustar00rootroot00000000000000python-skytools-3.9.2/modules/cquoting.c000066400000000000000000000400421447265566700204260ustar00rootroot00000000000000/* * Fast quoting functions for Python. */ #define PY_SSIZE_T_CLEAN #include #include #ifdef _MSC_VER #define inline __inline #define strcasecmp stricmp #endif /* inheritance check is broken in pypy3 */ #ifdef PYPY_VERSION_NUM #undef PyBytes_Check #define PyBytes_Check PyBytes_CheckExact #undef PyDict_Check #define PyDict_Check PyDict_CheckExact #endif #include "get_buffer.h" /* * Common buffer management. */ struct Buf { unsigned char *ptr; Py_ssize_t pos; Py_ssize_t alloc; }; static unsigned char *buf_init(struct Buf *buf, Py_ssize_t init_size) { if (init_size < 256) init_size = 256; buf->ptr = PyMem_Malloc(init_size); if (buf->ptr) { buf->pos = 0; buf->alloc = init_size; } return buf->ptr; } /* return new pos */ static unsigned char *buf_enlarge(struct Buf *buf, Py_ssize_t need_room) { Py_ssize_t alloc = buf->alloc; Py_ssize_t need_size = buf->pos + need_room; unsigned char *ptr; /* no alloc needed */ if (need_size < alloc) return buf->ptr + buf->pos; if (alloc <= need_size / 2) alloc = need_size; else alloc = alloc * 2; ptr = PyMem_Realloc(buf->ptr, alloc); if (!ptr) return NULL; buf->ptr = ptr; buf->alloc = alloc; return buf->ptr + buf->pos; } static void buf_free(struct Buf *buf) { PyMem_Free(buf->ptr); buf->ptr = NULL; buf->pos = buf->alloc = 0; } static inline unsigned char *buf_get_target_for(struct Buf *buf, Py_ssize_t len) { if (buf->pos + len <= buf->alloc) return buf->ptr + buf->pos; else return buf_enlarge(buf, len); } static inline void buf_set_target(struct Buf *buf, unsigned char *newpos) { assert(buf->ptr + buf->pos <= newpos); assert(buf->ptr + buf->alloc >= newpos); buf->pos = newpos - buf->ptr; } static inline int buf_put(struct Buf *buf, unsigned char c) { if (buf->pos < buf->alloc) { buf->ptr[buf->pos++] = c; return 1; } else if (buf_enlarge(buf, 1)) { buf->ptr[buf->pos++] = c; return 1; } return 0; } static PyObject *buf_pystr(struct Buf *buf, Py_ssize_t start_pos, unsigned char *newpos) { PyObject *res; if (newpos) buf_set_target(buf, newpos); res = PyUnicode_FromStringAndSize((char *)buf->ptr + start_pos, buf->pos - start_pos); buf_free(buf); return res; } /* * Common argument parsing. */ typedef PyObject *(*quote_fn)(unsigned char *src, Py_ssize_t src_len); static PyObject *common_quote(PyObject *args, quote_fn qfunc) { unsigned char *src = NULL; Py_ssize_t src_len = 0; PyObject *arg, *res, *strtmp = NULL; if (!PyArg_ParseTuple(args, "O", &arg)) return NULL; if (arg != Py_None) { src_len = get_buffer(arg, &src, &strtmp); if (src_len < 0) return NULL; } res = qfunc(src, src_len); Py_CLEAR(strtmp); return res; } /* * Simple quoting functions. */ static const char doc_quote_literal[] = "Quote a literal value for SQL.\n" "\n" "If string contains '\\', it is quoted and result is prefixed with E.\n" "Input value of None results in string \"null\" without quotes.\n" "\n" "C implementation.\n"; static PyObject *quote_literal_body(unsigned char *src, Py_ssize_t src_len) { struct Buf buf; unsigned char *esc, *dst, *src_end = src + src_len; unsigned int start_ofs = 1; if (src == NULL) return PyUnicode_FromString("null"); esc = dst = buf_init(&buf, src_len * 2 + 2 + 1); if (!dst) return NULL; *dst++ = ' '; *dst++ = '\''; while (src < src_end) { if (*src == '\\') { *dst++ = '\\'; start_ofs = 0; } else if (*src == '\'') { *dst++ = '\''; } *dst++ = *src++; } *dst++ = '\''; if (start_ofs == 0) *esc = 'E'; return buf_pystr(&buf, start_ofs, dst); } static PyObject *quote_literal(PyObject *self, PyObject *args) { return common_quote(args, quote_literal_body); } /* COPY field */ static const char doc_quote_copy[] = "Quoting for COPY data. None is converted to \\N.\n\n" "C implementation."; static PyObject *quote_copy_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) return PyUnicode_FromString("\\N"); dst = buf_init(&buf, src_len * 2); if (!dst) return NULL; while (src < src_end) { switch (*src) { case '\t': *dst++ = '\\'; *dst++ = 't'; src++; break; case '\n': *dst++ = '\\'; *dst++ = 'n'; src++; break; case '\r': *dst++ = '\\'; *dst++ = 'r'; src++; break; case '\\': *dst++ = '\\'; *dst++ = '\\'; src++; break; default: *dst++ = *src++; break; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_copy(PyObject *self, PyObject *args) { return common_quote(args, quote_copy_body); } /* raw bytea for byteain() */ static const char doc_quote_bytea_raw[] = "Quoting for bytea parser. Returns None as None.\n" "\n" "C implementation."; static PyObject *quote_bytea_raw_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { Py_INCREF(Py_None); return Py_None; } dst = buf_init(&buf, src_len * 4); if (!dst) return NULL; while (src < src_end) { if (*src < 0x20 || *src >= 0x7F) { *dst++ = '\\'; *dst++ = '0' + (*src >> 6); *dst++ = '0' + ((*src >> 3) & 7); *dst++ = '0' + (*src & 7); src++; } else { if (*src == '\\') *dst++ = '\\'; *dst++ = *src++; } } return buf_pystr(&buf, 0, dst); } static PyObject *quote_bytea_raw(PyObject *self, PyObject *args) { return common_quote(args, quote_bytea_raw_body); } /* SQL unquote */ static const char doc_unquote_literal[] = "Unquote SQL value.\n\n" "E'..' -> extended quoting.\n" "'..' -> standard or extended quoting\n" "null -> None\n" "other -> returned as-is\n\n" "C implementation.\n"; static PyObject *do_sql_ext(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src == '\'') { src++; if (src < src_end && *src == '\'') { *dst++ = *src++; continue; } goto failed; } if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken exteded SQL string"); return NULL; } static PyObject *do_sql_std(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\'') { *dst++ = *src++; continue; } src++; if (src >= src_end || *src != '\'') goto failed; *dst++ = *src++; } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken standard SQL string"); return NULL; } static PyObject *do_dolq(unsigned char *src, Py_ssize_t src_len) { /* src_len >= 2, '$' in start and end */ unsigned char *src_end = src + src_len; unsigned char *p1 = src + 1, *p2 = src_end - 2; while (p1 < src_end && *p1 != '$') p1++; while (p2 > src && *p2 != '$') p2--; if (p2 <= p1) goto failed; p1++; /* position after '$' */ if ((p1 - src) != (src_end - p2)) goto failed; if (memcmp(src, p2, p1 - src) != 0) goto failed; return PyUnicode_FromStringAndSize((char *)p1, p2 - p1); failed: PyErr_Format(PyExc_ValueError, "Broken dollar-quoted string"); return NULL; } static PyObject *unquote_literal(PyObject *self, PyObject *args) { unsigned char *src = NULL; Py_ssize_t src_len = 0; int stdstr = 0; PyObject *value = NULL; PyObject *tmp = NULL; PyObject *res = NULL; if (!PyArg_ParseTuple(args, "O|i", &value, &stdstr)) return NULL; src_len = get_buffer(value, &src, &tmp); if (src_len < 0) return NULL; if (src_len == 4 && strcasecmp((char *)src, "null") == 0) { Py_INCREF(Py_None); res = Py_None; } else if (src_len >= 2 && src[0] == '$' && src[src_len - 1] == '$') { res = do_dolq(src, src_len); } else if (src_len < 2 || src[src_len - 1] != '\'') { /* seems invalid, return as-is */ Py_INCREF(value); res = value; } else if (src[0] == '\'') { src++; src_len -= 2; res = stdstr ? do_sql_std(src, src_len) : do_sql_ext(src, src_len); } else if (src_len > 2 && (src[0] | 0x20) == 'e' && src[1] == '\'') { src += 2; src_len -= 3; res = do_sql_ext(src, src_len); } if (tmp) Py_CLEAR(tmp); return res; } /* C unescape */ static const char doc_unescape[] = "Unescape C-style escaped string.\n\n" "C implementation."; static PyObject *unescape_body(unsigned char *src, Py_ssize_t src_len) { unsigned char *dst, *src_end = src + src_len; struct Buf buf; if (src == NULL) { PyErr_Format(PyExc_TypeError, "None not allowed"); return NULL; } dst = buf_init(&buf, src_len); if (!dst) return NULL; while (src < src_end) { if (*src != '\\') { *dst++ = *src++; continue; } if (++src >= src_end) goto failed; switch (*src) { case 't': *dst++ = '\t'; src++; break; case 'n': *dst++ = '\n'; src++; break; case 'r': *dst++ = '\r'; src++; break; case 'a': *dst++ = '\a'; src++; break; case 'b': *dst++ = '\b'; src++; break; default: if (*src >= '0' && *src <= '7') { unsigned char c = *src++ - '0'; if (src < src_end && *src >= '0' && *src <= '7') { c = (c << 3) | ((*src++) - '0'); if (src < src_end && *src >= '0' && *src <= '7') c = (c << 3) | ((*src++) - '0'); } *dst++ = c; } else { *dst++ = *src++; } } } return buf_pystr(&buf, 0, dst); failed: PyErr_Format(PyExc_ValueError, "Broken string - \\ at the end"); return NULL; } static PyObject *unescape(PyObject *self, PyObject *args) { return common_quote(args, unescape_body); } /* * urlencode of dict */ static bool urlenc(struct Buf *buf, PyObject *obj) { Py_ssize_t len; unsigned char *src, *dst; PyObject *strtmp = NULL; static const unsigned char hextbl[] = "0123456789abcdef"; bool ok = false; len = get_buffer(obj, &src, &strtmp); if (len < 0) goto failed; dst = buf_get_target_for(buf, len * 3); if (!dst) goto failed; while (len--) { if ((*src >= 'a' && *src <= 'z') || (*src >= 'A' && *src <= 'Z') || (*src >= '0' && *src <= '9') || (*src == '.' || *src == '_' || *src == '-')) { *dst++ = *src++; } else if (*src == ' ') { *dst++ = '+'; src++; } else { *dst++ = '%'; *dst++ = hextbl[*src >> 4]; *dst++ = hextbl[*src & 0xF]; src++; } } buf_set_target(buf, dst); ok = true; failed: Py_CLEAR(strtmp); return ok; } /* urlencode key+val pair. val can be None */ static bool urlenc_keyval(struct Buf *buf, PyObject *key, PyObject *value, bool needAmp) { if (needAmp && !buf_put(buf, '&')) return false; if (!urlenc(buf, key)) return false; if (value != Py_None) { if (!buf_put(buf, '=')) return false; if (!urlenc(buf, value)) return false; } return true; } /* encode native dict using PyDict_Next */ static PyObject *encode_dict(PyObject *data) { PyObject *key, *value; Py_ssize_t pos = 0; bool needAmp = false; struct Buf buf; if (!buf_init(&buf, 1024)) return NULL; while (PyDict_Next(data, &pos, &key, &value)) { if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; } return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); return NULL; } /* encode custom object using .iteritems() */ static PyObject *encode_dictlike(PyObject *data) { PyObject *key = NULL, *value = NULL, *tup, *iter; struct Buf buf; bool needAmp = false; if (!buf_init(&buf, 1024)) return NULL; iter = PyObject_CallMethod(data, "items", NULL); if (iter == NULL) { buf_free(&buf); return NULL; } while ((tup = PyIter_Next(iter))) { key = PySequence_GetItem(tup, 0); value = key ? PySequence_GetItem(tup, 1) : NULL; Py_CLEAR(tup); if (!key || !value) goto failed; if (!urlenc_keyval(&buf, key, value, needAmp)) goto failed; needAmp = true; Py_CLEAR(key); Py_CLEAR(value); } /* allow error from iterator */ if (PyErr_Occurred()) goto failed; Py_CLEAR(iter); return buf_pystr(&buf, 0, NULL); failed: buf_free(&buf); Py_CLEAR(iter); Py_CLEAR(key); Py_CLEAR(value); return NULL; } static const char doc_db_urlencode[] = "Urlencode for database records.\n" "If a value is None the key is output without '='.\n" "\n" "C implementation."; static PyObject *db_urlencode(PyObject *self, PyObject *args) { PyObject *data; if (!PyArg_ParseTuple(args, "O", &data)) return NULL; if (PyDict_Check(data)) { return encode_dict(data); } else { return encode_dictlike(data); } } /* * urldecode to dict */ static inline int gethex(unsigned char c) { if (c >= '0' && c <= '9') return c - '0'; c |= 0x20; if (c >= 'a' && c <= 'f') return c - 'a' + 10; return -1; } static PyObject *get_elem(unsigned char *buf, unsigned char **src_p, unsigned char *src_end) { int c1, c2; unsigned char *src = *src_p; unsigned char *dst = buf; while (src < src_end) { switch (*src) { case '%': if (++src + 2 > src_end) goto hex_incomplete; if ((c1 = gethex(*src++)) < 0) goto hex_invalid; if ((c2 = gethex(*src++)) < 0) goto hex_invalid; *dst++ = (c1 << 4) | c2; break; case '+': *dst++ = ' '; src++; break; case '&': case '=': goto gotit; default: *dst++ = *src++; } } gotit: *src_p = src; return PyUnicode_FromStringAndSize((char *)buf, dst - buf); hex_incomplete: PyErr_Format(PyExc_ValueError, "Incomplete hex code"); return NULL; hex_invalid: PyErr_Format(PyExc_ValueError, "Invalid hex code"); return NULL; } static const char doc_db_urldecode[] = "Urldecode from string to dict.\n" "NULL are detected by missing '='.\n" "Duplicate keys are ignored - only latest is kept.\n" "\n" "C implementation."; static PyObject *db_urldecode(PyObject *self, PyObject *args) { unsigned char *src, *src_end; Py_ssize_t src_len; PyObject *dict = NULL, *key = NULL, *value = NULL; struct Buf buf; if (!PyArg_ParseTuple(args, "s#", &src, &src_len)) return NULL; if (!buf_init(&buf, src_len)) return NULL; dict = PyDict_New(); if (!dict) { buf_free(&buf); return NULL; } src_end = src + src_len; while (src < src_end) { if (*src == '&') { src++; continue; } key = get_elem(buf.ptr, &src, src_end); if (!key) goto failed; if (src < src_end && *src == '=') { src++; value = get_elem(buf.ptr, &src, src_end); if (value == NULL) goto failed; } else { Py_INCREF(Py_None); value = Py_None; } /* lessen memory usage by intering */ PyUnicode_InternInPlace(&key); if (PyDict_SetItem(dict, key, value) < 0) goto failed; Py_CLEAR(key); Py_CLEAR(value); } buf_free(&buf); return dict; failed: buf_free(&buf); Py_CLEAR(key); Py_CLEAR(value); Py_CLEAR(dict); return NULL; } /* * Module initialization */ static PyMethodDef methods[] = { { "quote_literal", quote_literal, METH_VARARGS, doc_quote_literal }, { "quote_copy", quote_copy, METH_VARARGS, doc_quote_copy }, { "quote_bytea_raw", quote_bytea_raw, METH_VARARGS, doc_quote_bytea_raw }, { "unescape", unescape, METH_VARARGS, doc_unescape }, { "db_urlencode", db_urlencode, METH_VARARGS, doc_db_urlencode }, { "db_urldecode", db_urldecode, METH_VARARGS, doc_db_urldecode }, { "unquote_literal", unquote_literal, METH_VARARGS, doc_unquote_literal }, { NULL } }; static PyModuleDef_Slot slots[] = {{0, NULL}}; static struct PyModuleDef module = { PyModuleDef_HEAD_INIT, .m_name = "_cquoting", .m_doc = "fast quoting for skytools", .m_size = 0, .m_methods = methods, .m_slots = slots }; PyMODINIT_FUNC PyInit__cquoting(void) { return PyModuleDef_Init(&module); } python-skytools-3.9.2/modules/get_buffer.h000066400000000000000000000022451447265566700207150ustar00rootroot00000000000000 /* work around pypy3.9 v7.3.9 bug with Py_LIMITED_API */ #ifdef PYPY_VERSION #ifndef Py_None #define Py_None (&_Py_NoneStruct) #endif #endif /* * Get string data from Python object. */ static Py_ssize_t get_buffer(PyObject *obj, unsigned char **buf_p, PyObject **tmp_obj_p) { PyObject *str = NULL; Py_ssize_t res; /* check for None */ if (obj == Py_None) { PyErr_Format(PyExc_TypeError, "None is not allowed"); return -1; } /* quick path for bytes */ if (PyBytes_Check(obj)) { if (PyBytes_AsStringAndSize(obj, (char**)buf_p, &res) < 0) return -1; return res; } /* convert to bytes */ if (PyUnicode_Check(obj)) { /* no direct string access in abi3 */ *tmp_obj_p = PyUnicode_AsUTF8String(obj); } else if (PyMemoryView_Check(obj) || PyByteArray_Check(obj)) { /* no direct buffer access in abi3 */ *tmp_obj_p = PyBytes_FromObject(obj); } else { /* Not a string-like object, run str() or it. */ str = PyObject_Str(obj); if (str == NULL) return -1; *tmp_obj_p = PyUnicode_AsUTF8String(str); Py_CLEAR(str); } if (*tmp_obj_p == NULL) return -1; if (PyBytes_AsStringAndSize(*tmp_obj_p, (char**)buf_p, &res) < 0) return -1; return res; } python-skytools-3.9.2/modules/hashtext.c000066400000000000000000000200061447265566700204230ustar00rootroot00000000000000/* * Postgres hashes for Python. */ #define PY_SSIZE_T_CLEAN #include #include #include typedef uint32_t (*hash_fn_t)(const void *src, Py_ssize_t src_len); typedef uint8_t uint8; typedef uint16_t uint16; typedef uint32_t uint32; #define rot(x, k) (((x)<<(k)) | ((x)>>(32-(k)))) /* * Old Postgres hashtext() */ #define mix_old(a,b,c) \ { \ a -= b; a -= c; a ^= ((c)>>13); \ b -= c; b -= a; b ^= ((a)<<8); \ c -= a; c -= b; c ^= ((b)>>13); \ a -= b; a -= c; a ^= ((c)>>12); \ b -= c; b -= a; b ^= ((a)<<16); \ c -= a; c -= b; c ^= ((b)>>5); \ a -= b; a -= c; a ^= ((c)>>3); \ b -= c; b -= a; b ^= ((a)<<10); \ c -= a; c -= b; c ^= ((b)>>15); \ } static uint32_t hash_old_hashtext(const void *_k, Py_ssize_t keylen) { const unsigned char *k = _k; uint32 a, b, c; Py_ssize_t len; /* Set up the internal state */ len = keylen; a = b = 0x9e3779b9; /* the golden ratio; an arbitrary value */ c = 3923095; /* initialize with an arbitrary value */ /* handle most of the key */ while (len >= 12) { a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); mix_old(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ c += (uint32)keylen; switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the first byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } mix_old(a, b, c); /* report the result */ return c; } /* * New Postgres hashtext() */ #define UINT32_ALIGN_MASK 3 #define mix_new(a,b,c) \ { \ a -= c; a ^= rot(c, 4); c += b; \ b -= a; b ^= rot(a, 6); a += c; \ c -= b; c ^= rot(b, 8); b += a; \ a -= c; a ^= rot(c,16); c += b; \ b -= a; b ^= rot(a,19); a += c; \ c -= b; c ^= rot(b, 4); b += a; \ } #define final_new(a,b,c) \ { \ c ^= b; c -= rot(b,14); \ a ^= c; a -= rot(c,11); \ b ^= a; b -= rot(a,25); \ c ^= b; c -= rot(b,16); \ a ^= c; a -= rot(c, 4); \ b ^= a; b -= rot(a,14); \ c ^= b; c -= rot(b,24); \ } static uint32_t hash_new_hashtext(const void *_k, Py_ssize_t keylen) { const unsigned char *k = _k; uint32_t a, b, c; Py_ssize_t len = keylen; /* Set up the internal state */ a = b = c = 0x9e3779b9 + (uint32)len + 3923095; /* If the source pointer is word-aligned, we use word-wide fetches */ if (((uintptr_t) k & UINT32_ALIGN_MASK) == 0) { /* Code path for aligned source data */ const uint32_t *ka = (const uint32_t *) k; /* handle most of the key */ while (len >= 12) { a += ka[0]; b += ka[1]; c += ka[2]; mix_new(a, b, c); ka += 3; len -= 12; } /* handle the last 11 bytes */ k = (const unsigned char *) ka; #ifdef WORDS_BIGENDIAN switch (len) { case 11: c += ((uint32) k[10] << 8); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 8); /* fall through */ case 6: b += ((uint32) k[5] << 16); /* fall through */ case 5: b += ((uint32) k[4] << 24); /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 8); /* fall through */ case 2: a += ((uint32) k[1] << 16); /* fall through */ case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) { case 11: c += ((uint32) k[10] << 24); /* fall through */ case 10: c += ((uint32) k[9] << 16); /* fall through */ case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ /* fall through */ case 8: b += ka[1]; a += ka[0]; break; case 7: b += ((uint32) k[6] << 16); /* fall through */ case 6: b += ((uint32) k[5] << 8); /* fall through */ case 5: b += k[4]; /* fall through */ case 4: a += ka[0]; break; case 3: a += ((uint32) k[2] << 16); /* fall through */ case 2: a += ((uint32) k[1] << 8); /* fall through */ case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } else { /* Code path for non-aligned source data */ /* handle most of the key */ while (len >= 12) { #ifdef WORDS_BIGENDIAN a += (k[3] + ((uint32) k[2] << 8) + ((uint32) k[1] << 16) + ((uint32) k[0] << 24)); b += (k[7] + ((uint32) k[6] << 8) + ((uint32) k[5] << 16) + ((uint32) k[4] << 24)); c += (k[11] + ((uint32) k[10] << 8) + ((uint32) k[9] << 16) + ((uint32) k[8] << 24)); #else /* !WORDS_BIGENDIAN */ a += (k[0] + ((uint32) k[1] << 8) + ((uint32) k[2] << 16) + ((uint32) k[3] << 24)); b += (k[4] + ((uint32) k[5] << 8) + ((uint32) k[6] << 16) + ((uint32) k[7] << 24)); c += (k[8] + ((uint32) k[9] << 8) + ((uint32) k[10] << 16) + ((uint32) k[11] << 24)); #endif /* WORDS_BIGENDIAN */ mix_new(a, b, c); k += 12; len -= 12; } /* handle the last 11 bytes */ #ifdef WORDS_BIGENDIAN switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 8); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 24); /* the lowest byte of c is reserved for the length */ case 8: b += k[7]; case 7: b += ((uint32) k[6] << 8); case 6: b += ((uint32) k[5] << 16); case 5: b += ((uint32) k[4] << 24); case 4: a += k[3]; case 3: a += ((uint32) k[2] << 8); case 2: a += ((uint32) k[1] << 16); case 1: a += ((uint32) k[0] << 24); /* case 0: nothing left to add */ } #else /* !WORDS_BIGENDIAN */ switch (len) /* all the case statements fall through */ { case 11: c += ((uint32) k[10] << 24); case 10: c += ((uint32) k[9] << 16); case 9: c += ((uint32) k[8] << 8); /* the lowest byte of c is reserved for the length */ case 8: b += ((uint32) k[7] << 24); case 7: b += ((uint32) k[6] << 16); case 6: b += ((uint32) k[5] << 8); case 5: b += k[4]; case 4: a += ((uint32) k[3] << 24); case 3: a += ((uint32) k[2] << 16); case 2: a += ((uint32) k[1] << 8); case 1: a += k[0]; /* case 0: nothing left to add */ } #endif /* WORDS_BIGENDIAN */ } final_new(a, b, c); /* report the result */ return c; } /* * Common argument parsing. */ static PyObject *run_hash(PyObject *args, hash_fn_t real_hash) { unsigned char *src = NULL; Py_ssize_t src_len = 0; int32_t hash; if (!PyArg_ParseTuple(args, "s#", &src, &src_len)) return NULL; hash = real_hash(src, src_len); return PyLong_FromLong(hash); } /* * Python wrappers around actual hash functions. */ static PyObject *hashtext_old(PyObject *self, PyObject *args) { return run_hash(args, hash_old_hashtext); } static PyObject *hashtext_new(PyObject *self, PyObject *args) { return run_hash(args, hash_new_hashtext); } /* * Module initialization */ static PyMethodDef methods[] = { { "hashtext_old", hashtext_old, METH_VARARGS, "Old Postgres hashtext().\n" }, { "hashtext_new", hashtext_new, METH_VARARGS, "New Postgres hashtext().\n" }, { NULL } }; static PyModuleDef_Slot slots[] = {{0, NULL}}; static struct PyModuleDef module = { PyModuleDef_HEAD_INIT, .m_name = "_chashtext", .m_doc = "String hash functions", .m_size = 0, .m_methods = methods, .m_slots = slots }; PyMODINIT_FUNC PyInit__chashtext(void) { return PyModuleDef_Init(&module); } python-skytools-3.9.2/pyproject.toml000066400000000000000000000617751447265566700177150ustar00rootroot00000000000000[project] name = "skytools" description = "Utilities for database scripts" readme = "README.rst" keywords = ["database"] dynamic = ["version"] requires-python = ">= 3.7" maintainers = [{name = "Marko Kreen", email = "markokr@gmail.com"}] classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Console", "Intended Audience :: Developers", "License :: OSI Approved :: ISC License (ISCL)", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", "Programming Language :: Python :: 3", "Topic :: Database", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Utilities", ] [project.optional-dependencies] test = ["pytest", "pytest-cov", "coverage[toml]", "psycopg2-binary"] doc = ["sphinx"] [project.urls] homepage = "https://github.com/pgq/python-skytools" #documentation = "https://readthedocs.org" repository = "https://github.com/pgq/python-skytools" changelog = "https://github.com/pgq/python-skytools/blob/master/NEWS.rst" [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [tool.setuptools] packages = ["skytools"] package-data = {"skytools" = ["py.typed", "_chashtext.pyi", "_cquoting.pyi"]} zip-safe = false [tool.setuptools.dynamic.version] attr = "skytools.installer_config.package_version" # # testing # [tool.pytest] testpaths = ["tests"] [tool.coverage.paths] source = ["skytools", "**/site-packages/skytools"] [tool.coverage.report] exclude_lines = [ "pragma: no cover", "def __repr__", "if self.debug:", "if settings.DEBUG", "raise AssertionError", "raise NotImplementedError", "if 0:", "if __name__ == .__main__.:", ] # # formatting # [tool.isort] atomic = true line_length = 100 multi_line_output = 5 known_first_party = ["skytools"] known_third_party = ["pytest", "yaml"] include_trailing_comma = true balanced_wrapping = true [tool.autopep8] exclude = ".tox, git, tmp, build, cover, dist" ignore = ["E301", "E265", "W391"] max-line-length = 110 in-place = true recursive = true aggressive = 2 [tool.doc8] extensions = "rst" # # linters # [tool.mypy] python_version = "3.10" strict = true disallow_any_unimported = true disallow_any_expr = false disallow_any_decorated = false disallow_any_explicit = false disallow_any_generics = false warn_return_any = false warn_unreachable = false #warn_unused_ignores = false #exclude = ".*/test_misc.py" #[[tool.mypy.overrides]] #module = ["test_misc"] #disable_error_code = ["misc"] [[tool.mypy.overrides]] module = [ "skytools.apipkg", "skytools.dbservice", "skytools.plpy_applyrow" ] strict = false disallow_untyped_defs = false disallow_untyped_calls = false disallow_incomplete_defs = false [[tool.mypy.overrides]] module = ["plpy"] ignore_missing_imports = true [tool.ruff] line-length = 120 select = ["E", "F", "Q", "W", "UP", "YTT", "ANN"] ignore = [ "ANN101", # Missing type annotation for `self` in method "ANN102", # Missing type annotation for `cls` in classmethod "ANN401", # Dynamically typed expressions (typing.Any) are disallowed "UP006", # Use `dict` instead of `Dict` "UP007", # Use `X | Y` for type annotations "UP031", # Use format specifiers instead of percent format "UP032", # Use f-string instead of `format` call "UP035", # typing.List` is deprecated "UP037", # Remove quotes from type annotation "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` ] [tool.ruff.flake8-quotes] docstring-quotes = "double" # # reference links # # https://packaging.python.org/en/latest/specifications/declaring-project-metadata/ # https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html # [tool.pylint.main] # Analyse import fallback blocks. This can be used to support both Python 2 and 3 # compatible code, which means that the block might have code that exists only in # one or another interpreter, leading to false positives when analysed. # analyse-fallback-blocks = # Clear in-memory caches upon conclusion of linting. Useful if running pylint in # a server-like mode. # clear-cache-post-run = # Always return a 0 (non-error) status code, even if lint errors are found. This # is primarily useful in continuous integration scripts. # exit-zero = # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. # extension-pkg-allow-list = # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. (This is an alternative name to extension-pkg-allow-list # for backward compatibility.) # extension-pkg-whitelist = # Return non-zero exit code if any of these messages/categories are detected, # even if score is above --fail-under value. Syntax same as enable. Messages # specified are enabled, while categories only check already-enabled messages. # fail-on = # Specify a score threshold under which the program will exit with error. fail-under = 10 # Interpret the stdin as a python script, whose filename needs to be passed as # the module_or_package argument. # from-stdin = # Files or directories to be skipped. They should be base names, not paths. ignore = ["CVS", "tmp", "dist"] # Add files or directories matching the regular expressions patterns to the # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, it # can't be used as an escape character. # ignore-paths = # Files or directories matching the regular expression patterns are skipped. The # regex matches against base names, not paths. The default value ignores Emacs # file locks # ignore-patterns = # List of module names for which member attributes should not be checked (useful # for modules/projects where namespaces are manipulated during runtime and thus # existing member attributes cannot be deduced by static analysis). It supports # qualified module names, as well as Unix pattern matching. # ignored-modules = # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). # init-hook = # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to # avoid hangs. jobs = 1 # Control the amount of potential inferred values when inferring a single object. # This can help the performance when dealing with large functions or complex, # nested conditions. limit-inference-results = 100 # List of plugins (as comma separated values of python module names) to load, # usually to register additional checkers. # load-plugins = # Pickle collected data for later comparisons. persistent = true # Minimum Python version to use for version dependent checks. Will default to the # version used to run pylint. py-version = "3.10" # Discover python modules and packages in the file system subtree. # recursive = # Add paths to the list of the source roots. Supports globbing patterns. The # source root is an absolute path or a path relative to the current working # directory used to determine a package namespace for modules located under the # source root. # source-roots = # When enabled, pylint would attempt to guess common misconfiguration and emit # user-friendly hints instead of false-positive error messages. suggestion-mode = true # Allow loading of arbitrary C extensions. Extensions are imported into the # active Python interpreter and may run arbitrary code. # unsafe-load-any-extension = [tool.pylint.basic] # Naming style matching correct argument names. argument-naming-style = "snake_case" # Regular expression matching correct argument names. Overrides argument-naming- # style. If left empty, argument names will be checked with the set naming style. # argument-rgx = # Naming style matching correct attribute names. attr-naming-style = "snake_case" # Regular expression matching correct attribute names. Overrides attr-naming- # style. If left empty, attribute names will be checked with the set naming # style. # attr-rgx = # Bad variable names which should always be refused, separated by a comma. bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"] # Bad variable names regexes, separated by a comma. If names match any regex, # they will always be refused # bad-names-rgxs = # Naming style matching correct class attribute names. class-attribute-naming-style = "any" # Regular expression matching correct class attribute names. Overrides class- # attribute-naming-style. If left empty, class attribute names will be checked # with the set naming style. # class-attribute-rgx = # Naming style matching correct class constant names. class-const-naming-style = "UPPER_CASE" # Regular expression matching correct class constant names. Overrides class- # const-naming-style. If left empty, class constant names will be checked with # the set naming style. # class-const-rgx = # Naming style matching correct class names. class-naming-style = "PascalCase" # Regular expression matching correct class names. Overrides class-naming-style. # If left empty, class names will be checked with the set naming style. # class-rgx = # Naming style matching correct constant names. const-naming-style = "UPPER_CASE" # Regular expression matching correct constant names. Overrides const-naming- # style. If left empty, constant names will be checked with the set naming style. # const-rgx = # Minimum line length for functions/classes that require docstrings, shorter ones # are exempt. docstring-min-length = -1 # Naming style matching correct function names. function-naming-style = "snake_case" # Regular expression matching correct function names. Overrides function-naming- # style. If left empty, function names will be checked with the set naming style. # function-rgx = # Good variable names which should always be accepted, separated by a comma. good-names = ["i", "j", "k", "ex", "Run", "_"] # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted # good-names-rgxs = # Include a hint for the correct naming format with invalid-name. # include-naming-hint = # Naming style matching correct inline iteration names. inlinevar-naming-style = "any" # Regular expression matching correct inline iteration names. Overrides # inlinevar-naming-style. If left empty, inline iteration names will be checked # with the set naming style. # inlinevar-rgx = # Naming style matching correct method names. method-naming-style = "snake_case" # Regular expression matching correct method names. Overrides method-naming- # style. If left empty, method names will be checked with the set naming style. # method-rgx = # Naming style matching correct module names. module-naming-style = "snake_case" # Regular expression matching correct module names. Overrides module-naming- # style. If left empty, module names will be checked with the set naming style. # module-rgx = # Colon-delimited sets of names that determine each other's naming style when the # name regexes allow several styles. # name-group = # Regular expression which should only match function or class names that do not # require a docstring. no-docstring-rgx = "^_" # List of decorators that produce properties, such as abc.abstractproperty. Add # to this list to register other decorators that produce valid properties. These # decorators are taken in consideration only for invalid-name. property-classes = ["abc.abstractproperty"] # Regular expression matching correct type alias names. If left empty, type alias # names will be checked with the set naming style. # typealias-rgx = # Regular expression matching correct type variable names. If left empty, type # variable names will be checked with the set naming style. # typevar-rgx = # Naming style matching correct variable names. variable-naming-style = "snake_case" # Regular expression matching correct variable names. Overrides variable-naming- # style. If left empty, variable names will be checked with the set naming style. # variable-rgx = [tool.pylint.classes] # Warn about protected attribute access inside special methods # check-protected-access-in-special-methods = # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods = ["__init__", "__new__", "setUp"] # List of member names, which should be excluded from the protected access # warning. exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make"] # List of valid names for the first argument in a class method. valid-classmethod-first-arg = ["cls"] # List of valid names for the first argument in a metaclass class method. valid-metaclass-classmethod-first-arg = ["cls"] [tool.pylint.design] # List of regular expressions of class ancestor names to ignore when counting # public methods (see R0903) # exclude-too-few-public-methods = # List of qualified class names to ignore when counting class parents (see R0901) # ignored-parents = # Maximum number of arguments for function / method. max-args = 15 # Maximum number of attributes for a class (see R0902). max-attributes = 37 # Maximum number of boolean expressions in an if statement (see R0916). max-bool-expr = 5 # Maximum number of branch for function / method body. max-branches = 50 # Maximum number of locals for function / method body. max-locals = 45 # Maximum number of parents for a class (see R0901). max-parents = 7 # Maximum number of public methods for a class (see R0904). max-public-methods = 420 # Maximum number of return / yield for function / method body. max-returns = 16 # Maximum number of statements in function / method body. max-statements = 150 # Minimum number of public methods for a class (see R0903). min-public-methods = 0 [tool.pylint.exceptions] # Exceptions that will emit a warning when caught. overgeneral-exceptions = ["builtins.BaseException", "builtins.Exception"] [tool.pylint.format] # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. expected-line-ending-format = "LF" # Regexp for a line that is allowed to be longer than the limit. ignore-long-lines = "^\\s*(# )??$" # Number of spaces of indent required inside a hanging or continued line. indent-after-paren = 4 # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 # tab). indent-string = " " # Maximum number of characters on a single line. max-line-length = 190 # Maximum number of lines in a module. max-module-lines = 10000 # Allow the body of a class to be on the same line as the declaration if body # contains single statement. # single-line-class-stmt = # Allow the body of an if to be on the same line as the test if there is no else. # single-line-if-stmt = [tool.pylint.imports] # List of modules that can be imported at any level, not just the top level one. # allow-any-import-level = # Allow explicit reexports by alias from a package __init__. # allow-reexport-from-package = # Allow wildcard imports from modules that define __all__. # allow-wildcard-with-all = # Deprecated modules which should not be used, separated by a comma. deprecated-modules = ["optparse", "tkinter.tix"] # Output a graph (.gv or any supported image format) of external dependencies to # the given file (report RP0402 must not be disabled). # ext-import-graph = # Output a graph (.gv or any supported image format) of all (i.e. internal and # external) dependencies to the given file (report RP0402 must not be disabled). # import-graph = # Output a graph (.gv or any supported image format) of internal dependencies to # the given file (report RP0402 must not be disabled). # int-import-graph = # Force import order to recognize a module as part of the standard compatibility # libraries. # known-standard-library = # Force import order to recognize a module as part of a third party library. known-third-party = ["enchant"] # Couples of modules and preferred modules, separated by a comma. # preferred-modules = [tool.pylint.logging] # The type of string formatting that logging methods do. `old` means using % # formatting, `new` is for `{}` formatting. logging-format-style = "old" # Logging modules to check that the string format arguments are in logging # function parameter format. logging-modules = ["logging"] [tool.pylint."messages control"] # Only show warnings with the listed confidence levels. Leave empty to show all. # Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED. confidence = ["HIGH", "CONTROL_FLOW", "INFERENCE", "INFERENCE_FAILURE", "UNDEFINED"] # Disable the message, report, category or checker with the given id(s). You can # either give multiple identifiers separated by comma (,) or put this option # multiple times (only on the command line, not in the configuration file where # it should appear only once). You can also use "--disable=all" to disable # everything first and then re-enable specific checks. For example, if you want # to run only the similarities checker, you can use "--disable=all # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". disable = [ "raw-checker-failed", "bad-inline-option", "locally-disabled", "file-ignored", "suppressed-message", "useless-suppression", "deprecated-pragma", "use-symbolic-message-instead", "bare-except", "broad-exception-caught", "useless-return", "consider-using-in", "consider-using-ternary", "fixme", "global-statement", "invalid-name", "missing-module-docstring", "missing-class-docstring", "missing-function-docstring", "no-else-raise", "no-else-return", "trailing-newlines", "unused-argument", "unused-variable", "using-constant-test", "useless-object-inheritance", "duplicate-code", "singleton-comparison", "consider-using-f-string", "broad-exception-raised", "arguments-differ", "multiple-statements", "use-implicit-booleaness-not-len", "chained-comparison", "unnecessary-pass", "cyclic-import", "too-many-ancestors", "import-outside-toplevel", "protected-access", "try-except-raise", "deprecated-module", "no-else-break", "no-else-continue", # junk "trailing-newlines", "consider-using-f-string", # expected "cyclic-import", # issues "broad-exception-caught", "no-else-return", ] # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where it # should appear only once). See also the "--disable" option for examples. enable = ["c-extension-no-member"] [tool.pylint.method_args] # List of qualified names (i.e., library.method) which require a timeout # parameter e.g. 'requests.api.get,requests.api.post' timeout-methods = [ "requests.api.delete", "requests.api.get", "requests.api.head", "requests.api.options", "requests.api.patch", "requests.api.post", "requests.api.put", "requests.api.request" ] [tool.pylint.miscellaneous] # List of note tags to take in consideration, separated by a comma. notes = ["FIXME", "XXX", "TODO"] # Regular expression of note tags to take in consideration. # notes-rgx = [tool.pylint.refactoring] # Maximum number of nested blocks for function / method body max-nested-blocks = 10 # Complete name of functions that never returns. When checking for inconsistent- # return-statements if a never returning function is called then it will be # considered as an explicit return statement and no message will be printed. never-returning-functions = ["sys.exit"] [tool.pylint.reports] # Python expression which should return a score less than or equal to 10. You # have access to the variables 'fatal', 'error', 'warning', 'refactor', # 'convention', and 'info' which contain the number of messages in each category, # as well as 'statement' which is the total number of statements analyzed. This # score is used by the global evaluation report (RP0004). evaluation = "10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)" # Template used to display messages. This is a python new-style format string # used to format the message information. See doc for all details. # msg-template = # Set the output format. Available formats are text, parseable, colorized, json # and msvs (visual studio). You can also give a reporter class, e.g. # mypackage.mymodule.MyReporterClass. # output-format = # Tells whether to display a full report or only the messages. # reports = # Activate the evaluation score. # score = [tool.pylint.similarities] # Comments are removed from the similarity computation ignore-comments = true # Docstrings are removed from the similarity computation ignore-docstrings = true # Imports are removed from the similarity computation # ignore-imports = # Signatures are removed from the similarity computation ignore-signatures = true # Minimum lines number of a similarity. min-similarity-lines = 4 [tool.pylint.spelling] # Limits count of emitted suggestions for spelling mistakes. max-spelling-suggestions = 4 # Spelling dictionary name. No available dictionaries : You need to install both # the python package and the system dependency for enchant to work.. # spelling-dict = # List of comma separated words that should be considered directives if they # appear at the beginning of a comment and should not be checked. spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:" # List of comma separated words that should not be checked. spelling-ignore-words = "usr,bin,env" # A path to a file that contains the private dictionary; one word per line. spelling-private-dict-file = ".local.dict" # Tells whether to store unknown words to the private dictionary (see the # --spelling-private-dict-file option) instead of raising a message. # spelling-store-unknown-words = [tool.pylint.typecheck] # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. contextmanager-decorators = ["contextlib.contextmanager"] # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. # generated-members = # Tells whether missing members accessed in mixin class should be ignored. A # class is considered mixin if its name matches the mixin-class-rgx option. # Tells whether to warn about missing members when the owner of the attribute is # inferred to be None. ignore-none = true # This flag controls whether pylint should warn about no-member and similar # checks whenever an opaque object is returned when inferring. The inference can # return multiple potential results while evaluating a Python object, but some # branches might not be evaluated, which results in partial inference. In that # case, it might be useful to still emit no-member and other checks for the rest # of the inferred objects. ignore-on-opaque-inference = true # List of symbolic message names to ignore for Mixin members. ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"] # List of class names for which member attributes should not be checked (useful # for classes with dynamically set attributes). This supports the use of # qualified names. ignored-classes = ["optparse.Values", "thread._local", "_thread._local"] # Show a hint with possible names when a member name was not found. The aspect of # finding the hint is based on edit distance. missing-member-hint = true # The minimum edit distance a name should have in order to be considered a # similar match for a missing member name. missing-member-hint-distance = 1 # The total number of similar names that should be taken in consideration when # showing a hint for a missing member. missing-member-max-choices = 1 # Regex pattern to define which classes are considered mixins. mixin-class-rgx = ".*[Mm]ixin" # List of decorators that change the signature of a decorated function. # signature-mutators = [tool.pylint.variables] # List of additional names supposed to be defined in builtins. Remember that you # should avoid defining new builtins when possible. # additional-builtins = # Tells whether unused global variables should be treated as a violation. allow-global-unused-variables = true # List of names allowed to shadow builtins # allowed-redefined-builtins = # List of strings which can identify a callback function by name. A callback name # must start or end with one of those strings. callbacks = ["cb_", "_cb"] # A regular expression matching the name of dummy variables (i.e. expected to not # be used). dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" # Argument names that match this expression will be ignored. ignored-argument-names = "_.*|^ignored_|^unused_" # Tells whether we should check for unused import in __init__ files. # init-import = # List of qualified module names which can have objects that can redefine # builtins. redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"] python-skytools-3.9.2/setup.py000066400000000000000000000015631447265566700165000ustar00rootroot00000000000000"""Setup for skytools module. """ from typing import Tuple from setuptools import Extension, setup try: from wheel.bdist_wheel import bdist_wheel class bdist_wheel_abi3(bdist_wheel): def get_tag(self) -> Tuple[str, str, str]: python, abi, plat = super().get_tag() if python.startswith("cp"): return CP_VER, "abi3", plat return python, abi, plat cmdclass = {"bdist_wheel": bdist_wheel_abi3} except ImportError: cmdclass = {} CP_VER = "cp37" API_VER = ('Py_LIMITED_API', '0x03070000') setup( cmdclass = cmdclass, ext_modules = [ Extension("skytools._cquoting", ["modules/cquoting.c"], define_macros=[API_VER], py_limited_api=True), Extension("skytools._chashtext", ["modules/hashtext.c"], define_macros=[API_VER], py_limited_api=True), ] ) python-skytools-3.9.2/skytools/000077500000000000000000000000001447265566700166505ustar00rootroot00000000000000python-skytools-3.9.2/skytools/__init__.py000066400000000000000000000175311447265566700207700ustar00rootroot00000000000000 """Tools for Python database scripts.""" # pylint:disable=redefined-builtin,unused-wildcard-import,wildcard-import try: from skytools import apipkg as _apipkg except ImportError: # make pylint think everything is imported immediately from skytools.adminscript import * from skytools.config import * from skytools.dbservice import * from skytools.dbstruct import * from skytools.fileutil import * from skytools.gzlog import * from skytools.hashtext import * from skytools.natsort import * from skytools.parsing import * from skytools.psycopgwrapper import * from skytools.querybuilder import * from skytools.quoting import * from skytools.scripting import * from skytools.skylog import * from skytools.sockutil import * from skytools.sqltools import * from skytools.timeutil import * from skytools.utf8 import * _symbols = { # skytools.adminscript 'AdminScript': 'skytools.adminscript:AdminScript', # skytools.config 'Config': 'skytools.config:Config', # skytools.dbservice 'DBService': 'skytools.dbservice:DBService', 'ServiceContext': 'skytools.dbservice:ServiceContext', 'TableAPI': 'skytools.dbservice:TableAPI', 'get_record': 'skytools.dbservice:get_record', 'get_record_list': 'skytools.dbservice:get_record_list', 'make_record': 'skytools.dbservice:make_record', 'make_record_array': 'skytools.dbservice:make_record_array', # skytools.dbstruct 'SeqStruct': 'skytools.dbstruct:SeqStruct', 'TableStruct': 'skytools.dbstruct:TableStruct', 'T_ALL': 'skytools.dbstruct:T_ALL', 'T_CONSTRAINT': 'skytools.dbstruct:T_CONSTRAINT', 'T_DEFAULT': 'skytools.dbstruct:T_DEFAULT', 'T_GRANT': 'skytools.dbstruct:T_GRANT', 'T_INDEX': 'skytools.dbstruct:T_INDEX', 'T_OWNER': 'skytools.dbstruct:T_OWNER', 'T_PARENT': 'skytools.dbstruct:T_PARENT', 'T_PKEY': 'skytools.dbstruct:T_PKEY', 'T_RULE': 'skytools.dbstruct:T_RULE', 'T_SEQUENCE': 'skytools.dbstruct:T_SEQUENCE', 'T_TABLE': 'skytools.dbstruct:T_TABLE', 'T_TRIGGER': 'skytools.dbstruct:T_TRIGGER', # skytools.fileutil 'signal_pidfile': 'skytools.fileutil:signal_pidfile', 'write_atomic': 'skytools.fileutil:write_atomic', # skytools.gzlog 'gzip_append': 'skytools.gzlog:gzip_append', # skytools.hashtext 'hashtext_old': 'skytools.hashtext:hashtext_old', 'hashtext_new': 'skytools.hashtext:hashtext_new', # skytools.natsort 'natsort': 'skytools.natsort:natsort', 'natsort_icase': 'skytools.natsort:natsort_icase', 'natsorted': 'skytools.natsort:natsorted', 'natsorted_icase': 'skytools.natsort:natsorted_icase', 'natsort_key': 'skytools.natsort:natsort_key', 'natsort_key_icase': 'skytools.natsort:natsort_key_icase', # skytools.parsing 'dedent': 'skytools.parsing:dedent', 'hsize_to_bytes': 'skytools.parsing:hsize_to_bytes', 'merge_connect_string': 'skytools.parsing:merge_connect_string', 'parse_acl': 'skytools.parsing:parse_acl', 'parse_connect_string': 'skytools.parsing:parse_connect_string', 'parse_logtriga_sql': 'skytools.parsing:parse_logtriga_sql', 'parse_pgarray': 'skytools.parsing:parse_pgarray', 'parse_sqltriga_sql': 'skytools.parsing:parse_sqltriga_sql', 'parse_statements': 'skytools.parsing:parse_statements', 'parse_tabbed_table': 'skytools.parsing:parse_tabbed_table', 'sql_tokenizer': 'skytools.parsing:sql_tokenizer', # skytools.psycopgwrapper 'connect_database': 'skytools.psycopgwrapper:connect_database', 'DBError': 'skytools.psycopgwrapper:DBError', 'I_AUTOCOMMIT': 'skytools.psycopgwrapper:I_AUTOCOMMIT', 'I_READ_COMMITTED': 'skytools.psycopgwrapper:I_READ_COMMITTED', 'I_REPEATABLE_READ': 'skytools.psycopgwrapper:I_REPEATABLE_READ', 'I_SERIALIZABLE': 'skytools.psycopgwrapper:I_SERIALIZABLE', # skytools.querybuilder 'PLPyQuery': 'skytools.querybuilder:PLPyQuery', 'PLPyQueryBuilder': 'skytools.querybuilder:PLPyQueryBuilder', 'QueryBuilder': 'skytools.querybuilder:QueryBuilder', 'plpy_exec': 'skytools.querybuilder:plpy_exec', 'run_exists': 'skytools.querybuilder:run_exists', 'run_lookup': 'skytools.querybuilder:run_lookup', 'run_query': 'skytools.querybuilder:run_query', 'run_query_row': 'skytools.querybuilder:run_query_row', # skytools.quoting 'db_urldecode': 'skytools.quoting:db_urldecode', 'db_urlencode': 'skytools.quoting:db_urlencode', 'json_decode': 'skytools.quoting:json_decode', 'json_encode': 'skytools.quoting:json_encode', 'make_pgarray': 'skytools.quoting:make_pgarray', 'quote_bytea_copy': 'skytools.quoting:quote_bytea_copy', 'quote_bytea_literal': 'skytools.quoting:quote_bytea_literal', 'quote_bytea_raw': 'skytools.quoting:quote_bytea_raw', 'quote_copy': 'skytools.quoting:quote_copy', 'quote_fqident': 'skytools.quoting:quote_fqident', 'quote_ident': 'skytools.quoting:quote_ident', 'quote_json': 'skytools.quoting:quote_json', 'quote_literal': 'skytools.quoting:quote_literal', 'quote_statement': 'skytools.quoting:quote_statement', 'unescape': 'skytools.quoting:unescape', 'unescape_copy': 'skytools.quoting:unescape_copy', 'unquote_fqident': 'skytools.quoting:unquote_fqident', 'unquote_ident': 'skytools.quoting:unquote_ident', 'unquote_literal': 'skytools.quoting:unquote_literal', # skytools.scripting 'BaseScript': 'skytools.scripting:BaseScript', 'daemonize': 'skytools.scripting:daemonize', 'DBScript': 'skytools.scripting:DBScript', 'UsageError': 'skytools.scripting:UsageError', # skytools.skylog 'getLogger': 'skytools.skylog:getLogger', # skytools.sockutil 'set_cloexec': 'skytools.sockutil:set_cloexec', 'set_nonblocking': 'skytools.sockutil:set_nonblocking', 'set_tcp_keepalive': 'skytools.sockutil:set_tcp_keepalive', # skytools.sqltools 'dbdict': 'skytools.sqltools:dbdict', 'CopyPipe': 'skytools.sqltools:CopyPipe', 'DBFunction': 'skytools.sqltools:DBFunction', 'DBLanguage': 'skytools.sqltools:DBLanguage', 'DBObject': 'skytools.sqltools:DBObject', 'DBSchema': 'skytools.sqltools:DBSchema', 'DBTable': 'skytools.sqltools:DBTable', 'Snapshot': 'skytools.sqltools:Snapshot', 'db_install': 'skytools.sqltools:db_install', 'exists_function': 'skytools.sqltools:exists_function', 'exists_language': 'skytools.sqltools:exists_language', 'exists_schema': 'skytools.sqltools:exists_schema', 'exists_sequence': 'skytools.sqltools:exists_sequence', 'exists_table': 'skytools.sqltools:exists_table', 'exists_temp_table': 'skytools.sqltools:exists_temp_table', 'exists_type': 'skytools.sqltools:exists_type', 'exists_view': 'skytools.sqltools:exists_view', 'fq_name': 'skytools.sqltools:fq_name', 'fq_name_parts': 'skytools.sqltools:fq_name_parts', 'full_copy': 'skytools.sqltools:full_copy', 'get_table_columns': 'skytools.sqltools:get_table_columns', 'get_table_oid': 'skytools.sqltools:get_table_oid', 'get_table_pkeys': 'skytools.sqltools:get_table_pkeys', 'installer_apply_file': 'skytools.sqltools:installer_apply_file', 'installer_find_file': 'skytools.sqltools:installer_find_file', 'magic_insert': 'skytools.sqltools:magic_insert', 'mk_delete_sql': 'skytools.sqltools:mk_delete_sql', 'mk_insert_sql': 'skytools.sqltools:mk_insert_sql', 'mk_update_sql': 'skytools.sqltools:mk_update_sql', # skytools.timeutil 'FixedOffsetTimezone': 'skytools.timeutil:FixedOffsetTimezone', 'datetime_to_timestamp': 'skytools.timeutil:datetime_to_timestamp', 'parse_iso_timestamp': 'skytools.timeutil:parse_iso_timestamp', # skytools.utf8 'safe_utf8_decode': 'skytools.utf8:safe_utf8_decode', } __all__ = tuple(_symbols) _symbols['__version__'] = 'skytools.installer_config:package_version' # lazy-import exported vars _apipkg.initpkg(__name__, _symbols, {'apipkg': _apipkg}) # type: ignore python-skytools-3.9.2/skytools/_chashtext.pyi000066400000000000000000000002021447265566700215170ustar00rootroot00000000000000 from typing import Union def hashtext_old(v: Union[bytes, str]) -> int: ... def hashtext_new(v: Union[bytes, str]) -> int: ... python-skytools-3.9.2/skytools/_cquoting.pyi000066400000000000000000000006441447265566700213670ustar00rootroot00000000000000 from typing import Any, Optional, Mapping, Dict def quote_literal(value: Any) -> str: ... def quote_copy(value: Any) -> str: ... def quote_bytea_raw(s: Optional[bytes]) -> Optional[str]: ... def db_urlencode(dict_val: Mapping[str, Any]) -> str: ... def db_urldecode(qs: str) -> Dict[str, Optional[str]]: ... def unescape(val: str) -> str: ... def unquote_literal(val: str, stdstr: bool = False) -> Optional[str]: ... python-skytools-3.9.2/skytools/_pyquoting.py000066400000000000000000000116001447265566700214160ustar00rootroot00000000000000"""Various helpers for string quoting/unquoting. Here is pure Python that should match C code in _cquoting. """ import re from typing import Any, Dict, Mapping, Match, Optional from urllib.parse import quote_plus, unquote_plus # noqa __all__ = ( "quote_literal", "quote_copy", "quote_bytea_raw", "db_urlencode", "db_urldecode", "unescape", "unquote_literal", ) # # SQL quoting # def quote_literal(value: Any) -> str: r"""Quote a literal value for SQL. If string contains '\\', extended E'' quoting is used, otherwise standard quoting. Input value of None results in string "null" without quotes. Python implementation. """ if value is None: return "null" s = str(value).replace("'", "''") s2 = s.replace("\\", "\\\\") if len(s) != len(s2): return "E'" + s2 + "'" return "'" + s2 + "'" def quote_copy(value: Any) -> str: """Quoting for copy command. None is converted to \\N. Python implementation. """ if value is None: return "\\N" s = str(value) s = s.replace("\\", "\\\\") s = s.replace("\t", "\\t") s = s.replace("\n", "\\n") s = s.replace("\r", "\\r") return s _bytea_map: Optional[Dict[int, str]] = None def quote_bytea_raw(s: Optional[bytes]) -> Optional[str]: """Quoting for bytea parser. Returns None as None. Python implementation. """ global _bytea_map if s is None: return None if not isinstance(s, bytes): raise TypeError("Expect bytes") if _bytea_map is None: _bytea_map = {} for i in range(256): c = i if i < 0x20 or i >= 0x7F: _bytea_map[c] = "\\%03o" % i elif i == ord("\\"): _bytea_map[c] = "\\\\" else: _bytea_map[c] = '%c' % i return "".join([_bytea_map[b] for b in s]) # # Database specific urlencode and urldecode. # def db_urlencode(dict_val: Mapping[str, Any]) -> str: """Database specific urlencode. Encode None as key without '='. That means that in "foo&bar=", foo is NULL and bar is empty string. Python implementation. """ elem_list = [] for k, v in dict_val.items(): if v is None: elem = quote_plus(str(k)) else: elem = quote_plus(str(k)) + '=' + quote_plus(str(v)) elem_list.append(elem) return '&'.join(elem_list) def db_urldecode(qs: str) -> Dict[str, Optional[str]]: """Database specific urldecode. Decode key without '=' as None. This also does not support one key several times. Python implementation. """ res: Dict[str, Optional[str]] = {} for elem in qs.split('&'): if not elem: continue pair = elem.split('=', 1) name = unquote_plus(pair[0]) if len(pair) == 1: res[name] = None else: res[name] = unquote_plus(pair[1]) return res # # Remove C-like backslash escapes # _esc_re = r"\\([0-7]{1,3}|.)" _esc_rc = re.compile(_esc_re) _esc_map = { 't': '\t', 'n': '\n', 'r': '\r', 'a': '\a', 'b': '\b', "'": "'", '"': '"', '\\': '\\', } def _sub_unescape_c(m: Match[str]) -> str: """unescape single escape seq.""" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v else: return chr(int(v, 8)) def unescape(val: str) -> str: """Removes C-style escapes from string. Python implementation. """ return _esc_rc.sub(_sub_unescape_c, val) _esql_re = r"''|\\([0-7]{1,3}|.)" _esql_rc = re.compile(_esql_re) def _sub_unescape_sqlext(m: Match[str]) -> str: """Unescape extended-quoted string.""" if m.group() == "''": return "'" v = m.group(1) if (len(v) == 1) and (v < '0' or v > '7'): try: return _esc_map[v] except KeyError: return v return chr(int(v, 8)) def unquote_literal(val: str, stdstr: bool = False) -> Optional[str]: """Unquotes SQL string. E'..' -> extended quoting. '..' -> standard or extended quoting null -> None other -> returned as-is """ if val[0] == "'" and val[-1] == "'": if stdstr: return val[1:-1].replace("''", "'") else: return _esql_rc.sub(_sub_unescape_sqlext, val[1:-1]) elif len(val) > 2 and val[0] in ('E', 'e') and val[1] == "'" and val[-1] == "'": return _esql_rc.sub(_sub_unescape_sqlext, val[2:-1]) elif len(val) >= 2 and val[0] == '$' and val[-1] == '$': p1 = val.find('$', 1) p2 = val.rfind('$', 1, -1) if p1 > 0 and p2 > p1: t1 = val[:p1 + 1] t2 = val[p2:] if t1 == t2: return val[len(t1):-len(t1)] raise ValueError("Bad dollar-quoted string") elif val.lower() == "null": return None return val python-skytools-3.9.2/skytools/adminscript.py000066400000000000000000000106731447265566700215460ustar00rootroot00000000000000"""Admin scripting. """ import inspect import sys from typing import Sequence, Optional, Any, Mapping, Callable import skytools from .basetypes import Connection, DictRow, ExecuteParams __all__ = ['AdminScript'] class AdminScript(skytools.DBScript): """Contains common admin script tools. Second argument (first is .ini file) is taken as command name. If class method 'cmd_' + arg exists, it is called, otherwise error is given. """ commands_without_pidfile: Sequence[str] = () def __init__(self, service_name: str, args: Sequence[str]) -> None: """AdminScript init.""" super().__init__(service_name, args) if len(self.args) < 2: self.log.error("need command") sys.exit(1) cmd = self.args[1] if cmd in self.commands_without_pidfile: self.pidfile = None if self.pidfile: self.pidfile = self.pidfile + ".admin" def work(self) -> Optional[int]: """Non-looping work function, calls command function.""" self.set_single_loop(1) cmd = self.args[1] cmdargs = self.args[2:] # find function fname = "cmd_" + cmd.replace('-', '_') if not hasattr(self, fname): self.log.error('bad subcommand, see --help for usage') sys.exit(1) fn = getattr(self, fname) # check if correct number of arguments ( args, varargs, ___varkw, ___defaults, ___kwonlyargs, __kwonlydefaults, ___annotations, ) = inspect.getfullargspec(fn) n_args = len(args) - 1 # drop 'self' if varargs is None and n_args != len(cmdargs): helpstr = "" if n_args: helpstr = ": " + " ".join(args[1:]) self.log.error("command '%s' got %d args, but expects %d%s", cmd, len(cmdargs), n_args, helpstr) sys.exit(1) # run command fn(*cmdargs) return None def fetch_list(self, db: Connection, sql: str, args: ExecuteParams, keycol: Optional[str] = None) -> Sequence[Any]: """Fetch a resultset from db, optionally turning it into value list.""" curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if not keycol: res = rows else: res = [r[keycol] for r in rows] return res def display_table(self, db: Connection, desc: str, sql: str, args: ExecuteParams = (), fields: Sequence[str] = (), fieldfmt: Optional[Mapping[str, Callable[[Any], str]]]=None) -> int: """Display multirow query as a table.""" self.log.debug("display_table: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) rows = curs.fetchall() db.commit() if len(rows) == 0: return 0 if not fieldfmt: fieldfmt = {} if not fields: fields = [f[0] for f in curs.description] widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = len(str(row[k])) if row[k] else 0 widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) if desc: print(desc) print(fmt % tuple(fields)) print(fmt % tuple('-' * (w - 2) for w in widths)) #print(fmt % tuple(['-'*15] * len(fields))) for row in rows: vals = [] for field in fields: val = row[field] if field in fieldfmt: val = fieldfmt[field](val) vals.append(val) print(fmt % tuple(vals)) print('\n') return 1 def exec_stmt(self, db: Connection, sql: str, args: ExecuteParams) -> None: """Run regular non-query SQL on db.""" self.log.debug("exec_stmt: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) db.commit() def exec_query(self, db: Connection, sql: str, args: ExecuteParams) -> Sequence[DictRow]: """Run regular query SQL on db.""" self.log.debug("exec_query: %s", skytools.quote_statement(sql, args)) curs = db.cursor() curs.execute(sql, args) res = curs.fetchall() db.commit() return res python-skytools-3.9.2/skytools/apipkg.py000066400000000000000000000152171447265566700205030ustar00rootroot00000000000000""" apipkg: control the exported namespace of a Python package. see https://pypi.python.org/pypi/apipkg (c) holger krekel, 2009 - MIT license """ #pylint: skip-file import os import sys from typing import List from types import ModuleType __version__ = "1.5" def _py_abspath(path): """ special version of abspath that will leave paths from jython jars alone """ if path.startswith('__pyclasspath__'): return path else: return os.path.abspath(path) def distribution_version(name): """try to get the version of the named distribution, returs None on failure""" from pkg_resources import DistributionNotFound, get_distribution try: dist = get_distribution(name) except DistributionNotFound: pass else: return dist.version def initpkg(pkgname, exportdefs, attr=None, eager=False): """ initialize given package from the export definitions. """ attr = attr or {} oldmod = sys.modules.get(pkgname) assert oldmod d = {} f = getattr(oldmod, '__file__', None) if f: f = _py_abspath(f) d['__file__'] = f if hasattr(oldmod, '__version__'): d['__version__'] = oldmod.__version__ if hasattr(oldmod, '__loader__'): d['__loader__'] = oldmod.__loader__ if hasattr(oldmod, '__path__'): d['__path__'] = [_py_abspath(p) for p in oldmod.__path__] if hasattr(oldmod, '__package__'): d['__package__'] = oldmod.__package__ if '__doc__' not in exportdefs and getattr(oldmod, '__doc__', None): d['__doc__'] = oldmod.__doc__ d.update(attr) if hasattr(oldmod, "__dict__"): oldmod.__dict__.update(d) mod = ApiModule(pkgname, exportdefs, implprefix=pkgname, attr=d) sys.modules[pkgname] = mod # eagerload in bypthon to avoid their monkeypatching breaking packages if 'bpython' in sys.modules or eager: for module in list(sys.modules.values()): if isinstance(module, ApiModule): module.__dict__ def importobj(modpath, attrname): """imports a module, then resolves the attrname on it""" module = __import__(modpath, None, None, ['__doc__']) if not attrname: return module retval = module names = attrname.split(".") for x in names: retval = getattr(retval, x) return retval class ApiModule(ModuleType): """the magical lazy-loading module standing""" def __docget(self): try: return self.__doc except AttributeError: if '__doc__' in self.__map__: return self.__makeattr('__doc__') def __docset(self, value): self.__doc = value __doc__ = property(__docget, __docset) # type: ignore def __init__(self, name, importspec, implprefix=None, attr=None): super().__init__(name) self.__all__ = [x for x in importspec if x != '__onfirstaccess__'] self.__map__ = {} self.__implprefix__ = implprefix or name if attr: for name, val in attr.items(): # print "setting", self.__name__, name, val setattr(self, name, val) for name, importspec in importspec.items(): if isinstance(importspec, dict): subname = '%s.%s' % (self.__name__, name) apimod = ApiModule(subname, importspec, implprefix) sys.modules[subname] = apimod setattr(self, name, apimod) else: parts = importspec.split(':') modpath = parts.pop(0) attrname = parts and parts[0] or "" if modpath[0] == '.': modpath = implprefix + modpath if not attrname: subname = '%s.%s' % (self.__name__, name) apimod = AliasModule(subname, modpath) sys.modules[subname] = apimod if '.' not in name: setattr(self, name, apimod) else: self.__map__[name] = (modpath, attrname) def __repr__(self): repr_list = [] if hasattr(self, '__version__'): repr_list.append("version=" + repr(self.__version__)) if hasattr(self, '__file__'): repr_list.append('from ' + repr(self.__file__)) if repr_list: return '' % (self.__name__, " ".join(repr_list)) return '' % (self.__name__,) def __makeattr(self, name): """lazily compute value for name or raise AttributeError if unknown.""" # print "makeattr", self.__name__, name target = None if '__onfirstaccess__' in self.__map__: target = self.__map__.pop('__onfirstaccess__') importobj(*target)() try: modpath, attrname = self.__map__[name] except KeyError: if target is not None and name != '__onfirstaccess__': # retry, onfirstaccess might have set attrs return getattr(self, name) raise AttributeError(name) else: result = importobj(modpath, attrname) setattr(self, name, result) try: del self.__map__[name] except KeyError: pass # in a recursive-import situation a double-del can happen return result __getattr__ = __makeattr @property def __dict__(self): # force all the content of the module # to be loaded when __dict__ is read dictdescr = ModuleType.__dict__['__dict__'] # type: ignore dict = dictdescr.__get__(self) if dict is not None: hasattr(self, 'some') for name in self.__all__: try: self.__makeattr(name) except AttributeError: pass return dict def AliasModule(modname, modpath, attrname=None): mod: List[ModuleType] = [] def getmod(): if not mod: x = importobj(modpath, None) if attrname is not None: x = getattr(x, attrname) mod.append(x) return mod[0] class AliasModule(ModuleType): def __repr__(self): x = modpath if attrname: x += "." + attrname return '' % (modname, x) def __getattribute__(self, name): try: return getattr(getmod(), name) except ImportError: return None def __setattr__(self, name, value): setattr(getmod(), name, value) def __delattr__(self, name): delattr(getmod(), name) return AliasModule(str(modname)) python-skytools-3.9.2/skytools/basetypes.py000066400000000000000000000074621447265566700212320ustar00rootroot00000000000000"""Database tools. """ import abc import io import typing import types from typing import ( IO, Any, Mapping, Optional, Sequence, Tuple, Type, Union, KeysView, ValuesView, ItemsView, Iterator, ) try: from typing import Protocol except ImportError: Protocol = object # type: ignore __all__ = ( "ExecuteParams", "DictRow", "Cursor", "Connection", "Runnable", "HasFileno", "FileDescriptor", "FileDescriptorLike", "Buffer", ) ExecuteParams = Union[Sequence[Any], Mapping[str, Any]] class DictRow(Protocol): """Allow both key and index-based access. Both Psycopg2 DictRow and PL/Python rows support this. """ def keys(self) -> KeysView[str]: raise NotImplementedError def values(self) -> ValuesView[Any]: raise NotImplementedError def items(self) -> ItemsView[str, Any]: raise NotImplementedError def get(self, key: str, default: Any = None) -> Any: raise NotImplementedError def __getitem__(self, key: Union[str, int]) -> Any: raise NotImplementedError def __iter__(self) -> Iterator[str]: raise NotImplementedError def __len__(self) -> int: raise NotImplementedError def __contains__(self, key: object) -> bool: raise NotImplementedError class Cursor(Protocol): @property def rowcount(self) -> int: raise NotImplementedError @property def statusmessage(self) -> Optional[str]: raise NotImplementedError def execute(self, sql: str, params: Optional[ExecuteParams] = None) -> None: raise NotImplementedError def fetchall(self) -> Sequence[DictRow]: raise NotImplementedError def fetchone(self) -> DictRow: raise NotImplementedError def __enter__(self) -> "Cursor": raise NotImplementedError def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[types.TracebackType]) -> None: raise NotImplementedError def copy_expert( self, sql: str, f: Union[IO[str], IO[bytes], io.TextIOBase, io.RawIOBase], size: int = 8192 ) -> None: raise NotImplementedError def fileno(self) -> int: raise NotImplementedError @property def description(self) -> Sequence[Tuple[str, int, int, int, Optional[int], Optional[int], None]]: raise NotImplementedError @property def connection(self) -> "Connection": raise NotImplementedError class Connection(Protocol): def cursor(self) -> Cursor: raise NotImplementedError def rollback(self) -> None: raise NotImplementedError def commit(self) -> None: raise NotImplementedError def close(self) -> None: raise NotImplementedError @property def isolation_level(self) -> int: raise NotImplementedError def set_isolation_level(self, level: int) -> None: raise NotImplementedError @property def encoding(self) -> str: raise NotImplementedError def set_client_encoding(self, encoding: str) -> None: raise NotImplementedError @property def server_version(self) -> int: raise NotImplementedError def __enter__(self) -> "Connection": raise NotImplementedError def __exit__(self, typ: Optional[Type[BaseException]], exc: Optional[BaseException], tb: Optional[types.TracebackType]) -> None: raise NotImplementedError class Runnable(Protocol): def run(self) -> None: raise NotImplementedError class HasFileno(Protocol): def fileno(self) -> int: raise NotImplementedError FileDescriptor = int FileDescriptorLike = Union[int, HasFileno] try: from typing_extensions import Buffer except ImportError: if typing.TYPE_CHECKING: from _typeshed import Buffer # type: ignore else: try: from collections.abc import Buffer # type: ignore except ImportError: class Buffer(abc.ABC): pass Buffer.register(memoryview) Buffer.register(bytearray) Buffer.register(bytes) python-skytools-3.9.2/skytools/checker.py000066400000000000000000000534541447265566700206410ustar00rootroot00000000000000"""Catch moment when tables are in sync on master and slave. """ import logging import os import subprocess import sys import time from typing import IO, List, Optional, Sequence, Tuple, Dict, cast, Mapping, Any import skytools from .basetypes import Connection, Cursor DictRow = Dict[str, str] class TableRepair: """Checks that tables in two databases are in sync.""" table_name: str fq_table_name: str log: logging.Logger pkey_list: List[str] common_fields: List[str] apply_fixes: bool apply_cursor: Optional[Cursor] cnt_insert: int cnt_update: int cnt_delete: int total_src: int total_dst: int def __init__(self, table_name: str, log: logging.Logger) -> None: self.table_name = table_name self.fq_table_name = skytools.quote_fqident(table_name) self.log = log self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None self.reset() def reset(self) -> None: self.cnt_insert = 0 self.cnt_update = 0 self.cnt_delete = 0 self.total_src = 0 self.total_dst = 0 self.pkey_list = [] self.common_fields = [] self.apply_fixes = False self.apply_cursor = None def do_repair(self, src_db: Connection, dst_db: Connection, where: str, pfx: str = 'repair', apply_fixes: bool = False) -> None: """Actual comparison.""" self.reset() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.apply_fixes = apply_fixes if apply_fixes: self.apply_cursor = dst_curs self.log.info('Checking %s', self.table_name) copy_tbl = self.gen_copy_tbl(src_curs, dst_curs, where) dump_src = "%s.%s.src" % (pfx, self.table_name) dump_dst = "%s.%s.dst" % (pfx, self.table_name) fix = "%s.%s.fix" % (pfx, self.table_name) self.log.info("Dumping src table: %s", self.table_name) self.dump_table(copy_tbl, src_curs, dump_src) src_db.commit() self.log.info("Dumping dst table: %s", self.table_name) self.dump_table(copy_tbl, dst_curs, dump_dst) dst_db.commit() self.log.info("Sorting src table: %s", self.table_name) self.do_sort(dump_src, dump_src + '.sorted') self.log.info("Sorting dst table: %s", self.table_name) self.do_sort(dump_dst, dump_dst + '.sorted') self.dump_compare(dump_src + ".sorted", dump_dst + ".sorted", fix) os.unlink(dump_src) os.unlink(dump_dst) os.unlink(dump_src + ".sorted") os.unlink(dump_dst + ".sorted") if apply_fixes: dst_db.commit() def do_sort(self, src: str, dst: str) -> None: with subprocess.Popen(["sort", "--version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p: s_ver = p.communicate()[0] xenv = os.environ.copy() xenv['LANG'] = 'C' xenv['LC_ALL'] = 'C' cmdline = ['sort', '-T', '.'] if s_ver.find(b"coreutils") > 0: cmdline.append('-S') cmdline.append('30%') cmdline.append('-o') cmdline.append(dst) cmdline.append(src) with subprocess.Popen(cmdline, env=xenv) as p: if p.wait() != 0: raise Exception('sort failed') def gen_copy_tbl(self, src_curs: Cursor, dst_curs: Cursor, where: str) -> str: """Create COPY expession from common fields.""" self.pkey_list = skytools.get_table_pkeys(src_curs, self.table_name) dst_pkey = skytools.get_table_pkeys(dst_curs, self.table_name) if dst_pkey != self.pkey_list: self.log.error('pkeys do not match') sys.exit(1) src_cols = skytools.get_table_columns(src_curs, self.table_name) dst_cols = skytools.get_table_columns(dst_curs, self.table_name) field_list = [] for f in self.pkey_list: field_list.append(f) for f in src_cols: if f in self.pkey_list: continue if f in dst_cols: field_list.append(f) self.common_fields = field_list fqlist = [skytools.quote_ident(col) for col in field_list] tbl_expr = "select %s from %s" % (",".join(fqlist), self.fq_table_name) if where: tbl_expr += ' where ' + where tbl_expr = "COPY (%s) TO STDOUT" % tbl_expr self.log.debug("using copy expr: %s", tbl_expr) return tbl_expr def dump_table(self, copy_cmd: str, curs: Cursor, fn: str) -> None: """Dump table to disk.""" with open(fn, "w", 64 * 1024, encoding="utf8") as f: curs.copy_expert(copy_cmd, f) self.log.info('%s: Got %d bytes', self.table_name, f.tell()) def get_row(self, ln: str) -> Optional[DictRow]: """Parse a row into dict.""" if not ln: return None t = ln[:-1].split('\t') row: DictRow = {} for i, n in enumerate(self.common_fields): row[n] = t[i] return row def dump_compare(self, src_fn: str, dst_fn: str, fix: str) -> None: """Dump + compare single table.""" self.log.info("Comparing dumps: %s", self.table_name) with open(src_fn, "r", 64 * 1024, encoding="utf8") as f1: with open(dst_fn, "r", 64 * 1024, encoding="utf8") as f2: self.dump_compare_streams(f1, f2, fix) def dump_compare_streams(self, f1: IO[str], f2: IO[str], fix: str) -> None: src_ln = f1.readline() dst_ln = f2.readline() if src_ln: self.total_src += 1 if dst_ln: self.total_dst += 1 if os.path.isfile(fix): os.unlink(fix) while src_ln or dst_ln: keep_src = keep_dst = 0 if src_ln != dst_ln: src_row = self.get_row(src_ln) dst_row = self.get_row(dst_ln) diff = self.cmp_keys(src_row, dst_row) if diff > 0 and dst_row: # src > dst self.got_missed_delete(dst_row, fix) keep_src = 1 elif diff < 0 and src_row: # src < dst self.got_missed_insert(src_row, fix) keep_dst = 1 elif src_row and dst_row: if self.cmp_data(src_row, dst_row) != 0: self.got_missed_update(src_row, dst_row, fix) if not keep_src: src_ln = f1.readline() if src_ln: self.total_src += 1 if not keep_dst: dst_ln = f2.readline() if dst_ln: self.total_dst += 1 self.log.info("finished %s: src: %d rows, dst: %d rows," " missed: %d inserts, %d updates, %d deletes", self.table_name, self.total_src, self.total_dst, self.cnt_insert, self.cnt_update, self.cnt_delete) f1.close() f2.close() def got_missed_insert(self, src_row: DictRow, fn: str) -> None: """Create sql for missed insert.""" self.cnt_insert += 1 fld_list = self.common_fields fq_list = [] val_list = [] for f in fld_list: fq_list.append(skytools.quote_ident(f)) v = skytools.unescape_copy(src_row[f]) val_list.append(skytools.quote_literal(v)) q = "insert into %s (%s) values (%s);" % ( self.fq_table_name, ", ".join(fq_list), ", ".join(val_list)) self.show_fix(q, 'insert', fn) def got_missed_update(self, src_row: DictRow, dst_row: DictRow, fn: str) -> None: """Create sql for missed update.""" self.cnt_update += 1 fld_list = self.common_fields set_list: List[str] = [] whe_list: List[str] = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(src_row[f])) for f in fld_list: v1 = src_row[f] v2 = dst_row[f] if self.cmp_value(v1, v2) == 0: continue self.addeq(set_list, skytools.quote_ident(f), skytools.unescape_copy(v1)) self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(v2)) q = "update only %s set %s where %s;" % ( self.fq_table_name, ", ".join(set_list), " and ".join(whe_list)) self.show_fix(q, 'update', fn) def got_missed_delete(self, dst_row: DictRow, fn: str) -> None: """Create sql for missed delete.""" self.cnt_delete += 1 whe_list: List[str] = [] for f in self.pkey_list: self.addcmp(whe_list, skytools.quote_ident(f), skytools.unescape_copy(dst_row[f])) q = "delete from only %s where %s;" % (self.fq_table_name, " and ".join(whe_list)) self.show_fix(q, 'delete', fn) def show_fix(self, q: str, desc: str, fn: str) -> None: """Print/write/apply repair sql.""" self.log.debug("missed %s: %s", desc, q) with open(fn, "a", encoding="utf8") as f: f.write("%s\n" % q) if self.apply_fixes and self.apply_cursor: self.apply_cursor.execute(q) def addeq(self, dst_list: List[str], f: str, v: Optional[str]) -> None: """Add quoted SET.""" vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def addcmp(self, dst_list: List[str], f: str, v: Optional[str]) -> None: """Add quoted comparison.""" if v is None: s = "%s is null" % f else: vq = skytools.quote_literal(v) s = "%s = %s" % (f, vq) dst_list.append(s) def cmp_data(self, src_row: DictRow, dst_row: DictRow) -> int: """Compare data field-by-field.""" for k in self.common_fields: v1 = src_row[k] v2 = dst_row[k] if self.cmp_value(v1, v2) != 0: return -1 return 0 def cmp_value(self, v1: str, v2: str) -> int: """Compare single field, tolerates tz vs notz dates.""" if v1 == v2: return 0 # try to work around tz vs. notz z1 = len(v1) z2 = len(v2) if z1 == z2 + 3 and z2 >= 19 and v1[z2] == '+': v1 = v1[:-3] if v1 == v2: return 0 elif z1 + 3 == z2 and z1 >= 19 and v2[z1] == '+': v2 = v2[:-3] if v1 == v2: return 0 return -1 def cmp_keys(self, src_row: Optional[DictRow], dst_row: Optional[DictRow]) -> int: """Compare primary keys of the rows. Returns 1 if src > dst, -1 if src < dst and 0 if src == dst""" # None means table is done. tag it larger than any existing row. if src_row is None: if dst_row is None: return 0 return 1 elif dst_row is None: return -1 for k in self.pkey_list: v1 = src_row[k] v2 = dst_row[k] if v1 < v2: return -1 elif v1 > v2: return 1 return 0 class Syncer(skytools.DBScript): """Checks that tables in two databases are in sync.""" lock_timeout: float = 10 ticker_lag_limit: int = 20 consumer_lag_limit: int = 20 def sync_table(self, cstr1: str, cstr2: str, queue_name: str, consumer_name: str, table_name: str) -> Tuple[Connection, Connection]: """Syncer main function. Returns (src_db, dst_db) that are in transaction where table should be in sync. """ setup_db = self.get_database('setup_db', connstr=cstr1, autocommit=1) lock_db = self.get_database('lock_db', connstr=cstr1) src_db = self.get_database('src_db', connstr=cstr1, isolation_level=skytools.I_REPEATABLE_READ) dst_db = self.get_database('dst_db', connstr=cstr2, isolation_level=skytools.I_REPEATABLE_READ) lock_curs = lock_db.cursor() setup_curs = setup_db.cursor() src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.check_consumer(setup_curs, queue_name, consumer_name) # lock table in separate connection self.log.info('Locking %s', table_name) self.set_lock_timeout(lock_curs) lock_time = time.time() lock_curs.execute("LOCK TABLE %s IN SHARE MODE" % skytools.quote_fqident(table_name)) # now wait until consumer has updated target table until locking self.log.info('Syncing %s', table_name) # consumer must get further than this tick self.force_tick(setup_curs, queue_name) # try to force second tick also self.force_tick(setup_curs, queue_name) # take server time setup_curs.execute("select to_char(now(), 'YYYY-MM-DD HH24:MI:SS.MS')") tpos = setup_curs.fetchone()[0] # now wait while True: time.sleep(0.5) q = "select now() - lag > timestamp %s, now(), lag from pgq.get_consumer_info(%s, %s)" setup_curs.execute(q, [tpos, queue_name, consumer_name]) res = setup_curs.fetchall() if len(res) == 0: raise Exception('No such consumer: %s/%s' % (queue_name, consumer_name)) row = res[0] self.log.debug("tpos=%s now=%s lag=%s ok=%s", tpos, row[1], row[2], row[0]) if row[0]: break # limit lock time if time.time() > lock_time + self.lock_timeout: self.log.error('Consumer lagging too much, exiting') lock_db.rollback() sys.exit(1) # take snapshot on provider side src_db.commit() src_curs.execute("SELECT 1") # take snapshot on subscriber side dst_db.commit() dst_curs.execute("SELECT 1") # release lock lock_db.commit() self.close_database('setup_db') self.close_database('lock_db') return (src_db, dst_db) def set_lock_timeout(self, curs: Cursor) -> None: ms = int(1000 * self.lock_timeout) if ms > 0: q = "SET LOCAL statement_timeout = %d" % ms self.log.debug(q) curs.execute(q) def check_consumer(self, curs: Cursor, queue_name: str, consumer_name: str) -> None: """ Before locking anything check if consumer is working ok. """ self.log.info("Queue: %s Consumer: %s", queue_name, consumer_name) curs.execute('select current_database()') self.log.info('Actual db: %s', curs.fetchone()[0]) # get ticker lag q = "select extract(epoch from ticker_lag) from pgq.get_queue_info(%s);" curs.execute(q, [queue_name]) ticker_lag = curs.fetchone()[0] self.log.info("Ticker lag: %s", ticker_lag) # get consumer lag q = "select extract(epoch from lag) from pgq.get_consumer_info(%s, %s);" curs.execute(q, [queue_name, consumer_name]) res = curs.fetchall() if len(res) == 0: self.log.error('check_consumer: No such consumer: %s/%s', queue_name, consumer_name) sys.exit(1) consumer_lag = res[0][0] # check that lag is acceptable self.log.info("Consumer lag: %s", consumer_lag) if consumer_lag > ticker_lag + 10: self.log.error('Consumer lagging too much, cannot proceed') sys.exit(1) def force_tick(self, curs: Cursor, queue_name: str) -> None: """ Force tick into source queue so that consumer can move on faster """ q = "select pgq.force_tick(%s)" curs.execute(q, [queue_name]) res = curs.fetchone() cur_pos = res[0] start = time.time() while True: time.sleep(0.5) curs.execute(q, [queue_name]) res = curs.fetchone() if res[0] != cur_pos: # new pos return res[0] # dont loop more than 10 secs dur = time.time() - start if dur > 10 and not self.options.force: raise Exception("Ticker seems dead") class Checker(Syncer): """Checks that tables in two databases are in sync. Config options:: ## data_checker ## confdb = dbname=confdb host=confdb.service extra_connstr = user=marko # one of: compare, repair, repair-apply, compare-repair-apply check_type = compare # random params used in queries cluster_name = instance_name = proxy_host = proxy_db = # list of tables to be compared table_list = foo, bar, baz where_expr = (hashtext(key_user_name) & %%(max_slot)s) in (%%(slots)s) # gets no args source_query = select h.hostname, d.db_name from dba.cluster c join dba.cluster_host ch on (ch.key_cluster = c.id_cluster) join conf.host h on (h.id_host = ch.key_host) join dba.database d on (d.key_host = ch.key_host) where c.db_name = '%(cluster_name)s' and c.instance_name = '%(instance_name)s' and d.mk_db_type = 'partition' and d.mk_db_status = 'active' order by d.db_name, h.hostname target_query = select db_name, hostname, slots, max_slot from dba.get_cross_targets(%%(hostname)s, %%(db_name)s, '%(proxy_host)s', '%(proxy_db)s') consumer_query = select q.queue_name, c.consumer_name from conf.host h join dba.database d on (d.key_host = h.id_host) join dba.pgq_queue q on (q.key_database = d.id_database) join dba.pgq_consumer c on (c.key_queue = q.id_queue) where h.hostname = %%(hostname)s and d.db_name = %%(db_name)s and q.queue_name like 'xm%%%%' """ def __init__(self, args: Sequence[str]) -> None: """Checker init.""" super().__init__('data_checker', args) self.set_single_loop(1) self.log.info('Checker starting %s', str(args)) self.lock_timeout = self.cf.getfloat('lock_timeout', 10) self.table_list = self.cf.getlist('table_list') def work(self) -> Optional[int]: """Syncer main function.""" source_query = self.cf.get('source_query') target_query = self.cf.get('target_query') consumer_query = self.cf.get('consumer_query') where_expr = self.cf.get('where_expr') extra_connstr = self.cf.get('extra_connstr') check = self.cf.get('check_type', 'compare') confdb = self.get_database('confdb', autocommit=1) curs = confdb.cursor() curs.execute(source_query) for src_row in curs.fetchall(): s_host = src_row['hostname'] s_db = src_row['db_name'] cast_row = cast(Mapping[str, Any], src_row) curs.execute(consumer_query, cast_row) r = curs.fetchone() consumer_name = r['consumer_name'] queue_name = r['queue_name'] curs.execute(target_query, cast_row) for dst_row in curs.fetchall(): d_db = dst_row['db_name'] d_host = dst_row['hostname'] cstr1 = "dbname=%s host=%s %s" % (s_db, s_host, extra_connstr) cstr2 = "dbname=%s host=%s %s" % (d_db, d_host, extra_connstr) where = where_expr % dst_row self.log.info('Source: db=%s host=%s queue=%s consumer=%s', s_db, s_host, queue_name, consumer_name) self.log.info('Target: db=%s host=%s where=%s', d_db, d_host, where) for tbl in self.table_list: src_db, dst_db = self.sync_table(cstr1, cstr2, queue_name, consumer_name, tbl) if check == 'compare': self.do_compare(tbl, src_db, dst_db, where) elif check == 'repair': tr = TableRepair(tbl, self.log) tr.do_repair(src_db, dst_db, where, 'fix.' + tbl, False) elif check == 'repair-apply': tr = TableRepair(tbl, self.log) tr.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) elif check == 'compare-repair-apply': ok = self.do_compare(tbl, src_db, dst_db, where) if not ok: tr = TableRepair(tbl, self.log) tr.do_repair(src_db, dst_db, where, 'fix.' + tbl, True) else: raise Exception('unknown check type') self.reset() return None def do_compare(self, tbl: str, src_db: Connection, dst_db: Connection, where: str) -> bool: """Actual comparison.""" src_curs = src_db.cursor() dst_curs = dst_db.cursor() self.log.info('Counting %s', tbl) q = "select count(1) as cnt, sum(hashtext(t.*::text)) as chksum from only _TABLE_ t where %s;" % where q = self.cf.get('compare_sql', q) q = q.replace('_TABLE_', skytools.quote_fqident(tbl)) f = "%(cnt)d rows, checksum=%(chksum)s" f = self.cf.get('compare_fmt', f) self.log.debug("srcdb: %s", q) src_curs.execute(q) src_row = src_curs.fetchone() src_str = f % src_row self.log.info("srcdb: %s", src_str) self.log.debug("dstdb: %s", q) dst_curs.execute(q) dst_row = dst_curs.fetchone() dst_str = f % dst_row self.log.info("dstdb: %s", dst_str) src_db.commit() dst_db.commit() if src_str != dst_str: self.log.warning("%s: Results do not match!", tbl) return False else: self.log.info("%s: OK!", tbl) return True if __name__ == '__main__': script = Checker(sys.argv[1:]) script.start() python-skytools-3.9.2/skytools/config.py000066400000000000000000000306731447265566700205000ustar00rootroot00000000000000"""Nicer config class. """ import os import os.path import re import socket from configparser import MAX_INTERPOLATION_DEPTH, ConfigParser from configparser import Error as ConfigError from configparser import ( ExtendedInterpolation, Interpolation, InterpolationDepthError, InterpolationError, NoOptionError, NoSectionError, RawConfigParser, ) from typing import Dict, List, Mapping, Optional, Sequence, Tuple, MutableMapping, Set import skytools __all__ = ( 'Config', 'NoOptionError', 'ConfigError', 'ConfigParser', 'ExtendedConfigParser', 'ExtendedCompatConfigParser', 'InterpolationError', 'NoOptionError', 'NoSectionError', ) def read_versioned_config(filenames: Sequence[str], main_section: str) -> ConfigParser: """Pick syntax based on "config_format" value. """ rcf = RawConfigParser() rcf.read(filenames) # avoid has_option here, so value can live in DEFAULT section ver = rcf.get(main_section, "config_format", fallback="1") if ver == "1": cf = ConfigParser() elif ver == "2": cf = ExtendedConfigParser() else: raise ConfigError('Unsupported config format %r in %r' % (ver, filenames)) cf.read(filenames) return cf class Config: """Bit improved ConfigParser. Additional features: - Remembers section. - Accepts defaults in get() functions. - List value support. """ main_section: str # main section filename: Optional[str] # file name that was loaded override: Mapping[str, str] # override values in config file defs: Mapping[str, str] # defaults visible in all sections cf: ConfigParser # actual ConfigParser instance def __init__(self, main_section: str, filename: Optional[str], sane_config: Optional[bytes] = None, # unused user_defs: Optional[Mapping[str, str]] = None, override: Optional[Mapping[str, str]] = None, ignore_defs: bool = False) -> None: """Initialize Config and read from file. """ # use config file name as default job_name if filename: job_name = os.path.splitext(os.path.basename(filename))[0] else: job_name = main_section # initialize defaults, make them usable in config file if ignore_defs: self.defs = {} else: self.defs = { 'job_name': job_name, 'service_name': main_section, 'host_name': socket.gethostname(), } if filename: self.defs['config_dir'] = os.path.dirname(filename) self.defs['config_file'] = filename if user_defs: self.defs.update(user_defs) self.main_section = main_section self.filename = filename self.override = override or {} if filename is None: self.cf = ConfigParser() self.cf.add_section(main_section) elif not os.path.isfile(filename): raise ConfigError('Config file not found: ' + filename) else: self.cf = read_versioned_config([filename], main_section) self.reload() def reload(self) -> None: """Re-reads config file.""" if self.filename: self.cf.read(self.filename) if not self.cf.has_section(self.main_section): raise NoSectionError(self.main_section) # apply default if key not set for k, v in self.defs.items(): if not self.cf.has_option(self.main_section, k): self.cf.set(self.main_section, k, v) # apply overrides if self.override: for k, v in self.override.items(): self.cf.set(self.main_section, k, v) def get(self, key: str, default: Optional[str] = None) -> str: """Reads string value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return str(self.cf.get(self.main_section, key)) def getint(self, key: str, default: Optional[int] = None) -> int: """Reads int value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getint(self.main_section, key) def getboolean(self, key: str, default: Optional[bool] = None) -> bool: """Reads boolean value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getboolean(self.main_section, key) def getfloat(self, key: str, default: Optional[float] = None) -> float: """Reads float value, if not set then default.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default return self.cf.getfloat(self.main_section, key) def getlist(self, key: str, default: Optional[List[str]] = None) -> List[str]: """Reads comma-separated list from key.""" if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res: List[str] = [] if not s: return res for v in s.split(","): res.append(v.strip()) return res def getdict(self, key: str, default: Optional[Dict[str, str]] = None) -> Dict[str, str]: """Reads key-value dict from parameter. Key and value are separated with ':'. If missing, key itself is taken as value. """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) return default s = self.get(key).strip() res: Dict[str, str] = {} if not s: return res for kv in s.split(","): tmp = kv.split(':', 1) if len(tmp) > 1: k = tmp[0].strip() v = tmp[1].strip() else: k = kv.strip() v = k res[k] = v return res def getfile(self, key: str, default: Optional[str] = None) -> str: """Reads filename from config. In addition to reading string value, expands ~ to user directory. """ fn = self.get(key, default) if fn == "" or fn == "-": return fn # simulate that the cwd is script location #path = os.path.dirname(sys.argv[0]) # seems bad idea, cwd should be cwd fn = os.path.expanduser(fn) return fn def getbytes(self, key: str, default: Optional[str] = None) -> int: """Reads a size value in human format, if not set then default. Examples: 1, 2 B, 3K, 4 MB """ if not self.cf.has_option(self.main_section, key): if default is None: raise NoOptionError(key, self.main_section) s = default else: s = self.cf.get(self.main_section, key) return skytools.hsize_to_bytes(s) def get_wildcard(self, key: str, values: Sequence[str] = (), default: Optional[str] = None) -> str: """Reads a wildcard property from conf and returns its string value, if not set then default.""" orig_key = key keys = [key] for wild in values: key = key.replace('*', wild, 1) keys.append(key) keys.reverse() for k in keys: if self.cf.has_option(self.main_section, k): return self.cf.get(self.main_section, k) if default is None: raise NoOptionError(orig_key, self.main_section) return default def sections(self) -> Sequence[str]: """Returns list of sections in config file, excluding DEFAULT.""" return self.cf.sections() def has_section(self, section: str) -> bool: """Checks if section is present in config file, excluding DEFAULT.""" return self.cf.has_section(section) def clone(self, main_section: str) -> "Config": """Return new Config() instance with new main section on same config file.""" return Config(main_section, self.filename) def options(self) -> Sequence[str]: """Return list of options in main section.""" return self.cf.options(self.main_section) def has_option(self, opt: str) -> bool: """Checks if option exists in main section.""" return self.cf.has_option(self.main_section, opt) def items(self) -> Sequence[Tuple[str, str]]: """Returns list of (name, value) for each option in main section.""" return self.cf.items(self.main_section) # define some aliases (short-cuts / backward compatibility cruft) getbool = getboolean ParserSection = Mapping[str, str] ParserState = MutableMapping[str, ParserSection] #ParserState = ConfigParser ParserLoop = Set[Tuple[str, str]] class ExtendedInterpolationCompat(Interpolation): _EXT_VAR_RX = r'\$\$|\$\{[^(){}]+\}' _OLD_VAR_RX = r'%%|%\([^(){}]+\)s' _var_rc = re.compile('(%s|%s)' % (_EXT_VAR_RX, _OLD_VAR_RX)) _bad_rc = re.compile('[%$]') def before_get(self, parser: ParserState, section: str, option: str, value: str, defaults: ParserSection) -> str: dst: List[str] = [] self._interpolate_ext(dst, parser, section, option, value, defaults, set()) return ''.join(dst) def before_set(self, parser: ParserState, section: str, option: str, value: str) -> str: sub = self._var_rc.sub('', value) if self._bad_rc.search(sub): raise ValueError("invalid interpolation syntax in %r" % value) return value def _interpolate_ext(self, dst: List[str], parser: ParserState, section: str, option: str, rawval: str, defaults: ParserSection, loop_detect: ParserLoop) -> None: if not rawval: return if len(loop_detect) > MAX_INTERPOLATION_DEPTH: raise InterpolationDepthError(option, section, rawval) xloop = (section, option) if xloop in loop_detect: raise InterpolationError(section, option, 'Loop detected: %r in %r' % (xloop, loop_detect)) loop_detect.add(xloop) parts = self._var_rc.split(rawval) for i, frag in enumerate(parts): fullkey = None use_vars: Optional[ParserSection] = defaults if i % 2 == 0: dst.append(frag) continue if frag in ('$$', '%%'): dst.append(frag[0]) continue if frag.startswith('${') and frag.endswith('}'): fullkey = frag[2:-1] # use section access only for new-style keys if ':' in fullkey: ksect, key = fullkey.split(':', 1) use_vars = None else: ksect, key = section, fullkey elif frag.startswith('%(') and frag.endswith(')s'): fullkey = frag[2:-2] ksect, key = section, fullkey else: raise InterpolationError(section, option, 'Internal parse error: %r' % frag) if isinstance(parser, RawConfigParser): key = parser.optionxform(key) newpart = parser.get(ksect, key, raw=True, vars=use_vars) if newpart is None: raise InterpolationError(ksect, key, 'Key referenced is None') self._interpolate_ext(dst, parser, ksect, key, newpart, defaults, loop_detect) loop_detect.remove(xloop) class ExtendedConfigParser(ConfigParser): """ConfigParser that uses Python3-style extended interpolation by default. Syntax: ${var} and ${section:var} """ _DEFAULT_INTERPOLATION: Interpolation = ExtendedInterpolation() class ExtendedCompatConfigParser(ExtendedConfigParser): r"""Support both extended "${}" syntax from python3 and old "%()s" too. New ${} syntax allows ${key} to refer key in same section, and ${sect:key} to refer key in other sections. """ _DEFAULT_INTERPOLATION: Interpolation = ExtendedInterpolationCompat() python-skytools-3.9.2/skytools/dbservice.py000066400000000000000000000623301447265566700211740ustar00rootroot00000000000000""" Class used to handle multiset receiving and returning PL/Python procedures """ import logging from typing import List, Optional, Sequence, Any, Dict, Union, Tuple import skytools from skytools import dbdict try: import plpy except ImportError: pass __all__ = ( 'DBService', 'ServiceContext', 'get_record', 'get_record_list', 'make_record', 'make_record_array', 'TableAPI', #'log_result', 'transform_fields' ) def transform_fields(rows: Sequence[Dict[str, Any]], key_fields: Sequence[str], name_field: str, data_field: str) -> List[Dict[str, Any]]: """Convert multiple-rows per key input array to one-row, multiple-column output array. The input arrays must be sorted by the key fields. """ cur_key: List[str] = [] cur_row: Dict[str, Any] = {} res = [] for r in rows: k = [r[f] for f in key_fields] if k != cur_key: cur_key = k cur_row = {} for f in key_fields: cur_row[f] = r[f] res.append(cur_row) cur_row[r[name_field]] = r[data_field] return res # render_table def render_table(rows: Sequence[Dict[str, Any]], fields: Sequence[str]) -> List[str]: """ Render result rows as a table. Returns array of lines. """ widths = [15] * len(fields) for row in rows: for i, k in enumerate(fields): rlen = len(str(row.get(k))) widths[i] = widths[i] > rlen and widths[i] or rlen widths = [w + 2 for w in widths] fmt = '%%-%ds' * (len(widths) - 1) + '%%s' fmt = fmt % tuple(widths[:-1]) lines = [] lines.append(fmt % tuple(fields)) lines.append(fmt % tuple(['-' * 15] * len(fields))) for row in rows: lines.append(fmt % tuple(str(row.get(k)) for k in fields)) return lines # data conversion to and from url def get_record(arg: Optional[str]) -> dbdict: """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ if not arg: return dbdict() # allow array of single record if arg[0] in ('{', '['): lst = skytools.parse_pgarray(arg) if not lst or len(lst) != 1: raise ValueError('get_record() expects exactly 1 row, got %d' % len(lst or [])) arg = lst[0] if not arg: return dbdict() # parse record return dbdict(skytools.db_urldecode(arg)) def get_record_list(array: Optional[Union[str, List[Optional[str]]]]) -> List[dbdict]: """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ if array is None: return [] if not isinstance(array, list): array = skytools.parse_pgarray(array) or [] return [get_record(el) for el in array if el is not None] def get_record_lists(tbl: Sequence[Dict[str, Any]], field: str) -> dbdict: """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ records = dbdict() for rec in tbl: master_id = str(rec[field]) records.setdefault(master_id, []).append(rec) return records def _make_record_convert(row: Dict[str, Any]) -> str: """Converts complex values.""" d = row.copy() for k, v in d.items(): if isinstance(v, list): d[k] = skytools.make_pgarray(v) return skytools.db_urlencode(d) def make_record(row: Dict[str, Any]) -> str: """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ for v in row.values(): if isinstance(v, list): return _make_record_convert(row) return skytools.db_urlencode(row) def make_record_array(rowlist: Sequence[Dict[str, Any]]): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return '{' + ','.join([make_record(row) for row in rowlist]) + '}' def get_result_items(rec_list: Sequence[Dict[str, Any]], name: str) -> Optional[List[dbdict]]: """ Get return values from result """ for r in rec_list: if r['res_code'] == name: return get_record_list(r['res_rows']) return None def log_result(log: logging.Logger, rec_list: Sequence[Dict[str, Any]]) -> None: """ Sends dbservice execution logs to logfile """ msglist = get_result_items(rec_list, "_status") if msglist is None: if rec_list: log.warning('Unhandled output result: _status res_code not present.') else: for msg in msglist: log.debug(msg['_message']) class DBService: """ Wrap parameterized query handling and multiset stored procedure writing """ ROW = "_row" # name of the fake field where internal record id is stored FIELD = "_field" # parameter name for the field in record that is related to current message PARAM = "_param" # name of the parameter to which message relates SKIP = "skip" # used when record is needed for it's data but is not been updated INSERT = "insert" UPDATE = "update" DELETE = "delete" INFO = "info" # just informative message for the user NOTICE = "notice" # more than info less than warning WARNING = "warning" # warning message, something is out of ordinary ERROR = "error" # error found but execution continues until check then error is raised FATAL = "fatal" # execution is terminated at once and all found errors returned rows_found: int = 0 _context: str _is_test: bool _retval: List[Tuple[str, List[dbdict]]] global_dict: Optional[Dict[str, Any]] sqls: Optional[List[Dict[str, str]]] can_save: bool messages: List[Dict[str, Any]] def __init__(self, context: str, global_dict: Optional[Dict[str, Any]] = None) -> None: """ This object must be initiated in the beginning of each db service """ rec = skytools.db_urldecode(context) self._context = context # used to run dbservice in retval self.global_dict = global_dict # used for cacheing query plans self._retval = [] # used to collect return resultsets self._is_test = 'is_test' in rec # used to convert output into human readable form self.sqls = None # if sqls stays None then no recording of sqls is done if "show_sql" in rec: # api must add exected sql to resultset self.sqls = [] # sql's executed by dbservice, used for dubugging self.can_save = True # used to keep value most severe error found so far self.messages = [] # used to hold list of messages to be returned to the user # error and message handling def tell_user(self, severity: str, code: str, message: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> None: """ Adds another message to the set of messages to be sent back to user If error message then can_save is set false If fatal message then error or found errors are raised at once """ params = params or kvargs #plpy.notice("%s %s: %s %s" % (severity, code, message, str(params))) params["_severity"] = severity params["_code"] = code params["_message"] = message self.messages.append(params) if severity == self.ERROR: self.can_save = False if severity == self.FATAL: self.can_save = False self.raise_if_errors() def raise_if_errors(self) -> None: """ To be used in places where before continuing must be chcked if errors have been found Raises found errors packing them into error message as urlencoded string """ if not self.can_save: msgs = "Dbservice error(s): " + make_record_array(self.messages) plpy.error(msgs) # run sql meant mostly for select but not limited to def create_query(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> skytools.PLPyQueryBuilder: """ Returns initialized querybuilder object for building complex dynamic queries """ params = params or kvargs return skytools.PLPyQueryBuilder(sql, params, self.global_dict, self.sqls) def run_query(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> List[dbdict]: """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kvargs rows = skytools.plpy_exec(self.global_dict, sql, params) # convert result rows to dbdict if rows: rows = [dbdict(r) for r in rows] self.rows_found = len(rows) else: self.rows_found = 0 rows = [] return rows def run_query_row(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> Optional[dbdict]: """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None return rows[0] def run_exists(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> int: """ Helper function to find out that record in given table exists using values in dict as criteria. Takes away all the hassle of preparing statements and processing returned result giving out just one boolean """ params = params or kvargs self.run_query(sql, params) return self.rows_found def run_lookup(self, sql: str, params: Optional[Dict[str, Any]] = None, **kvargs: Any): """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. Uses plan cache if used inside db service """ params = params or kvargs rows = self.run_query(sql, params) if len(rows) == 0: return None row = rows[0] return list(row.values())[0] # resultset handling def return_next(self, rows: List[dbdict], res_name: str, severity: Optional[str] = None) -> List[dbdict]: """ Adds given set of rows to resultset """ self._retval.append((res_name, rows)) if severity is not None and len(rows) == 0: self.tell_user(severity, "dbsXXXX", "No matching records found") return rows def return_next_sql(self, sql: str, params: Optional[Dict[str, Any]], res_name: str, severity: Optional[str] = None) -> List[dbdict]: """ Exectes query and adds recors resultset """ rows = self.run_query(sql, params) return self.return_next(rows, res_name, severity) def retval(self, service_name: Optional[str] = None, params: Optional[Dict[str, Any]] = None, **kvargs: Any) -> List[Tuple[str, str, str]]: """ Return collected resultsets and append to the end messages to the users Method is called usually as last statement in dbservice to return the results Also converts results into desired format """ params = params or kvargs self.raise_if_errors() if len(self.messages): self.return_next(self.messages, "_status") # type: ignore if self.sqls is not None and len(self.sqls): self.return_next(self.sqls, "_sql") # type: ignore results: List[Tuple[str, str, str]] = [] for r in self._retval: res_name = r[0] rows = r[1] res_count = str(len(rows)) if self._is_test and len(rows) > 0: results.append((res_name, res_count, res_name)) n = 1 for trow in render_table(rows, list(rows[0].keys())): results.append((res_name, str(n), trow)) n += 1 else: res_rows = make_record_array(rows) results.append((res_name, res_count, res_rows)) if service_name: sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) par = dbdict(i_context=self._context, i_params=make_record(params)) res = self.run_query(sql, par) for row in res: results.append((row.res_code, row.res_text, row.res_rows)) return results # miscellaneous def check_required(self, record_name, record, severity, *fields): """ Checks if all required fields are present in record Used to validate incoming data Returns list of field names that are missing or empty """ missing = [] params = {self.PARAM: record_name} if self.ROW in record: params[self.ROW] = record[self.ROW] for field in fields: params[self.FIELD] = field if field in record: if record[field] is None or (isinstance(record[field], str) and len(record[field]) == 0): self.tell_user(severity, "dbsXXXX", "Required value missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) else: self.tell_user(severity, "dbsXXXX", "Required field missing: {%s}.{%s}" % ( self.PARAM, self.FIELD), **params) missing.append(field) return missing # TableAPI class TableAPI: """ Class for managing one record updates using primary key """ _table = None # schema name and table name _where = None # where condition used for update and delete _id = None # name of the primary key filed _id_type = None # column type of primary key _op = None # operation currently carried out _ctx = None # context object for username and version _logging = True # should tapi log data changed _row = None # row identifer from calling program def __init__(self, ctx, table, create_log=True, id_type='int8'): """ Table name is used to construct insert update and delete statements Table must have primary key field whose name is in format id_ Tablename should be in format schema.tablename """ self._ctx = ctx self._table = skytools.quote_fqident(table) self._id = "id_" + skytools.fq_name_parts(table)[1] self._id_type = id_type self._where = '%s = {%s:%s}' % (skytools.quote_ident(self._id), self._id, self._id_type) self._logging = create_log def _log(self, result, original=None): """ Log changei into table log.changelog """ if not self._logging: return assert self._ctx changes = [] for key in result.keys(): if self._op == 'update': if key in original: if str(original[key]) != str(result[key]): changes.append(key + ": " + str(original[key]) + " -> " + str(result[key])) else: changes.append(key + ": " + str(result[key])) self._ctx.log(self._table, result[self._id], self._op, "\n".join(changes)) def _version_check(self, original, version): assert self._ctx if version is None: self._ctx.tell_user( self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been deleted by other user " "while you were editing. Check version ({ver}) in changelog for details.", table=self._table, field=self._id, id=original[self._id], ver=original.version, _row=self._row ) if version is not None and original.version is not None: if int(version) != int(original.version) and self._ctx: self._ctx.tell_user( self._ctx.INFO, "dbsXXXX", "Record ({table}.{field}={id}) has been changed by other user while you were editing. " "Version in db: ({db_ver}) and version sent by caller ({caller_ver}). " "See changelog for details.", table=self._table, field=self._id, id=original[self._id], db_ver=original.version, caller_ver=version, _row=self._row ) def _insert(self, data): assert self._ctx fields = [] values = [] for key in data.keys(): if data[key] is not None: # ignore empty fields.append(skytools.quote_ident(key)) values.append("{" + key + "}") sql = "insert into %s (%s) values (%s) returning *;" % ( self._table, ",".join(fields), ",".join(values) ) result = self._ctx.run_query_row(sql, data) self._log(result) return result def _update(self, data, version): assert self._ctx sql = "select * from %s where %s" % (self._table, self._where) original = self._ctx.run_query_row(sql, data) self._version_check(original, version) pairs = [] for key in data.keys(): if data[key] is None: pairs.append(key + " = NULL") else: pairs.append(key + " = {" + key + "}") sql = "update %s set %s where %s returning *;" % (self._table, ", ".join(pairs), self._where) assert self._ctx result = self._ctx.run_query_row(sql, data) self._log(result, original) return result def _delete(self, data, version): assert self._ctx sql = "delete from %s where %s returning *;" % (self._table, self._where) result = self._ctx.run_query_row(sql, data) self._version_check(result, version) self._log(result) return result def do(self, data): """ Do dml according to special field _op that must be given together wit data """ assert self._ctx result = data # so it is initialized for skip self._op = data.pop(self._ctx.OP) # determines operation done self._row = data.pop(self._ctx.ROW, None) # internal record id used for error reporting if self._row is None: # if no _row variable was provided self._row = data.get(self._id, None) # use id instead if self._id in data and data[self._id]: # if _id field is given if int(data[self._id]) < 0: # and it is fake key generated by ui data.pop(self._id) # remove fake key so real one can be assigned version = data.get('version', None) # version sent from caller data['version'] = self._ctx.version # current transaction id is stored in each record if self._op == self._ctx.INSERT: result = self._insert(data) elif self._op == self._ctx.UPDATE: result = self._update(data, version) elif self._op == self._ctx.DELETE: result = self._delete(data, version) elif self._op == self._ctx.SKIP: pass elif self._ctx: self._ctx.tell_user(self._ctx.ERROR, "dbsXXXX", "Unahndled _op='{op}' value in TableAPI (table={table}, id={id})", op=self._op, table=self._table, id=data[self._id]) if self._ctx: result[self._ctx.OP] = self._op result[self._ctx.ROW] = self._row return result # ServiceContext class ServiceContext(DBService): OP = "_op" # name of the fake field where record modificaton operation is stored def __init__(self, context: str, global_dict: Optional[Dict[str, Any]] = None) -> None: """ This object must be initiated in the beginning of each db service """ super().__init__(context, global_dict) rec = skytools.db_urldecode(context) if "username" not in rec: plpy.error("Username must be provided in db service context parameter") self.username = rec['username'] # used for logging purposes res = plpy.execute("select txid_current() as txid;") row = res[0] self.version = row["txid"] self.rows_found = 0 # Flag set by run query to inicate number of rows got # logging def log(self, _object_type, _key_object, _change_op, _payload): """ Log stuff into the changelog whatever seems relevant to be logged """ self.run_query( "select log.log_change( {version}, {username}, {object_type}, {key_object}, {change_op}, {payload} );", version=self.version, username=self.username, object_type=_object_type, key_object=_key_object, change_op=_change_op, payload=_payload) # data conversion to and from url def get_record(self, arg): """ Parse data for one urlencoded record. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record(arg) def get_record_list(self, array): """ Parse array of urlencoded records. Useful for turning incoming serialized data into structure usable for manipulation. """ return get_record_list(array) def get_list_groups(self, tbl, field): """ Create dictionary of lists from given list using field as grouping criteria Used for master detail operatons to group detail records according to master id """ return get_record_lists(tbl, field) def make_record(self, row): """ Takes record as dict and returns it as urlencoded string. Used to send data out of db service layer.or to fake incoming calls """ return make_record(row) def make_record_array(self, rowlist): """ Takes list of records got from plpy execute and turns it into postgers aray string. Used to send data out of db service layer. """ return make_record_array(rowlist) # tapi based dml functions def _changelog(self, fields): log = True if fields: if '_log' in fields: if not fields.pop('_log'): log = False if '_log_id' in fields: fields.pop('_log_id') if '_log_field' in fields: fields.pop('_log_field') return log def tapi_do(self, tablename, row, **fields): """ Convenience function for just doing the change without creating tapi object first Fields object may contain aditional overriding values that are applied before do """ tapi = TableAPI(self, tablename, self._changelog(fields)) row = row or dbdict() if fields: row.update(fields) return tapi.do(row) def tapi_do_set(self, tablename, rows, **fields): """ Does changes to list of detail rows Used for normal foreign keys in master detail relationships Dows first deletes then updates and then inserts to avoid uniqueness problems """ tapi = TableAPI(self, tablename, self._changelog(fields)) results, updates, inserts = [], [], [] for row in rows: if fields: row.update(fields) if row[self.OP] == self.DELETE: results.append(tapi.do(row)) elif row[self.OP] == self.UPDATE: updates.append(row) else: inserts.append(row) for row in updates: results.append(tapi.do(row)) for row in inserts: results.append(tapi.do(row)) return results # resultset handling def retval_dbservice(self, service_name: str, ctx: str, **params: Any) -> List[Tuple[str, str, str]]: """ Runs service with standard interface. Convenient to use for calling select services from other services For example to return data after doing save """ self.raise_if_errors() service_sql = "select * from %s( {i_context}, {i_params} );" % skytools.quote_fqident(service_name) service_params = {"i_context": ctx, "i_params": self.make_record(params)} results = self.run_query(service_sql, service_params) retval = self.retval() for r in results: retval.append((r.res_code, r.res_text, r.res_rows)) return retval # miscellaneous def field_copy(self, rec, *keys): """ Used to copy subset of fields from one record into another example: dbs.copy(record, hosting) "start_date", "key_colo", "key_rack") """ retval = dbdict() for key in keys: if key in rec: retval[key] = rec[key] return retval def field_set(self, **fields): """ Fills dict with given values and returns resulting dict If dict was not provied with call it is created """ return fields python-skytools-3.9.2/skytools/dbstruct.py000066400000000000000000000644411447265566700210650ustar00rootroot00000000000000"""Find table structure and allow CREATE/DROP elements from it. """ # pylint:disable=arguments-renamed import re from typing import List, Optional, Type, Tuple, TypeVar, Any from logging import Logger import skytools from skytools import quote_fqident, quote_ident from skytools.basetypes import Cursor, DictRow __all__ = ( 'TableStruct', 'SeqStruct', 'T_TABLE', 'T_CONSTRAINT', 'T_INDEX', 'T_TRIGGER', 'T_RULE', 'T_GRANT', 'T_OWNER', 'T_PKEY', 'T_ALL', 'T_SEQUENCE', 'T_PARENT', 'T_DEFAULT', ) T_TABLE = 1 << 0 T_CONSTRAINT = 1 << 1 T_INDEX = 1 << 2 T_TRIGGER = 1 << 3 T_RULE = 1 << 4 T_GRANT = 1 << 5 T_OWNER = 1 << 6 T_SEQUENCE = 1 << 7 T_PARENT = 1 << 8 T_DEFAULT = 1 << 9 T_PKEY = 1 << 20 # special, one of constraints T_ALL = (T_TABLE | T_CONSTRAINT | T_INDEX | T_SEQUENCE | T_TRIGGER | T_RULE | T_GRANT | T_OWNER | T_DEFAULT) T = TypeVar("T", bound="TElem") # # Utility functions # def find_new_name(curs: Optional[Cursor], name: str) -> str: """Create new object name for case the old exists. Needed when creating a new table besides old one. """ if curs is None: raise ValueError('Cannot new name without db cursor') # cut off previous numbers m = re.search('_[0-9]+$', name) if m: name = name[:m.start()] # now loop for i in range(1, 1000): tname = "%s_%d" % (name, i) q = "select count(1) from pg_class where relname = %s" curs.execute(q, [tname]) if curs.fetchone()[0] == 0: return tname # failed raise Exception('find_new_name failed') def rx_replace(rx: str, sql: str, new_part: str) -> str: """Find a regex match and replace that part with new_part.""" m = re.search(rx, sql, re.I) if not m: raise Exception('rx_replace failed: rx=%r sql=%r new=%r' % (rx, sql, new_part)) p1 = sql[:m.start()] p2 = sql[m.end():] return p1 + new_part + p2 # # Schema objects # class TElem: """Keeps info about one metadata object.""" SQL = "" type = 0 name: str def __init__(self, name: str, row: DictRow) -> None: raise NotImplementedError def get_create_sql(self, curs: Optional[Cursor], new_name: Optional[str] = None) -> str: """Return SQL statement for creating or empty string if not supported.""" return '' def get_drop_sql(self, curs: Optional[Cursor]) -> str: """Return SQL statement for dropping or empty string if not supported.""" return '' @classmethod def get_load_sql(cls, pgver: int) -> str: """Return SQL statement for finding objects.""" return cls.SQL class TConstraint(TElem): """Info about constraint.""" type = T_CONSTRAINT SQL = """ SELECT c.conname as name, pg_get_constraintdef(c.oid) as def, c.contype, i.indisclustered as is_clustered FROM pg_constraint c LEFT JOIN pg_index i ON c.conrelid = i.indrelid AND c.conname = (SELECT r.relname FROM pg_class r WHERE r.oid = i.indexrelid) WHERE c.conrelid = %(oid)s AND c.contype != 'f' """ table_name: str name: str defn: str contype: str is_clustered: bool def __init__(self, table_name: str, row: DictRow) -> None: """Init constraint.""" self.table_name = table_name self.name = row['name'] self.defn = row['def'] self.contype = row['contype'] self.is_clustered = row['is_clustered'] # tag pkeys if self.contype == 'p': self.type += T_PKEY def get_create_sql(self, curs: Optional[Cursor], new_table_name: Optional[str] = None) -> str: """Generate creation SQL.""" # no ONLY here as table with childs (only case that matters) # cannot have contraints that childs do not have fmt = "ALTER TABLE %s ADD CONSTRAINT %s\n %s;" if new_table_name: name = self.name if self.contype in ('p', 'u'): name = find_new_name(curs, self.name) qtbl = quote_fqident(new_table_name) qname = quote_ident(name) else: qtbl = quote_fqident(self.table_name) qname = quote_ident(self.name) sql = fmt % (qtbl, qname, self.defn) if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % (qtbl, qname) return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: """Generate removal sql.""" fmt = "ALTER TABLE ONLY %s\n DROP CONSTRAINT %s;" sql = fmt % (quote_fqident(self.table_name), quote_ident(self.name)) return sql class TIndex(TElem): """Info about index.""" type = T_INDEX SQL = """ SELECT n.nspname || '.' || c.relname as name, pg_get_indexdef(i.indexrelid) as defn, c.relname as local_name, i.indisclustered as is_clustered FROM pg_index i, pg_class c, pg_namespace n WHERE c.oid = i.indexrelid AND i.indrelid = %(oid)s AND n.oid = c.relnamespace AND NOT EXISTS (select objid from pg_depend where classid = %(pg_class_oid)s and objid = c.oid and deptype = 'i') """ table_name: str name: str defn: str is_clustered: bool local_name: str def __init__(self, table_name: str, row: DictRow) -> None: self.name = row['name'] self.defn = row['defn'].replace(' USING ', '\n USING ', 1) + ';' self.is_clustered = row['is_clustered'] self.table_name = table_name self.local_name = row['local_name'] def get_create_sql(self, curs: Optional[Cursor], new_table_name: Optional[str] = None) -> str: """Generate creation SQL.""" if new_table_name: # fixme: seems broken iname = find_new_name(curs, self.name) tname = new_table_name pnew = "INDEX %s ON %s " % (quote_ident(iname), quote_fqident(tname)) rx = r"\bINDEX[ ][a-z0-9._]+[ ]ON[ ][a-z0-9._]+[ ]" sql = rx_replace(rx, self.defn, pnew) else: sql = self.defn iname = self.local_name tname = self.table_name if self.is_clustered: sql += ' ALTER TABLE ONLY %s\n CLUSTER ON %s;' % ( quote_fqident(tname), quote_ident(iname)) return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: return 'DROP INDEX %s;' % quote_fqident(self.name) class TRule(TElem): """Info about rule.""" type = T_RULE SQL = """SELECT rw.*, pg_get_ruledef(rw.oid) as def FROM pg_rewrite rw WHERE rw.ev_class = %(oid)s AND rw.rulename <> '_RETURN'::name """ table_name: str name: str defn: str enabled: str def __init__(self, table_name: str, row: DictRow, new_name: Optional[str] = None) -> None: self.table_name = table_name self.name = row['rulename'] self.defn = row['def'] self.enabled = row.get('ev_enabled', 'O') def get_create_sql(self, curs: Optional[Cursor], new_table_name: Optional[str] = None) -> str: """Generate creation SQL.""" if not new_table_name: sql = self.defn table = self.table_name else: idrx = r'''([a-z0-9._]+|"([^"]+|"")+")+''' # fixme: broken / quoting rx = r"\bTO[ ]" + idrx rc = re.compile(rx, re.X) m = rc.search(self.defn) if not m: raise Exception('Cannot find table name in rule') old_tbl = m.group(1) new_tbl = quote_fqident(new_table_name) sql = self.defn.replace(old_tbl, new_tbl) table = new_table_name if self.enabled != 'O': # O - rule fires in origin and local modes # D - rule is disabled # R - rule fires in replica mode # A - rule fires always action = {'R': 'ENABLE REPLICA', 'A': 'ENABLE ALWAYS', 'D': 'DISABLE'}[self.enabled] sql += ('\nALTER TABLE %s %s RULE %s;' % (table, action, self.name)) return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: return 'DROP RULE %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) class TTrigger(TElem): """Info about trigger.""" type = T_TRIGGER table_name: str name: str defn: str def __init__(self, table_name: str, row: DictRow) -> None: self.table_name = table_name self.name = row['name'] self.defn = row['def'] + ';' self.defn = self.defn.replace('FOR EACH', '\n FOR EACH', 1) def get_create_sql(self, curs: Optional[Cursor], new_table_name: Optional[str] = None) -> str: """Generate creation SQL.""" if not new_table_name: return self.defn # fixme: broken / quoting rx = r"\bON[ ][a-z0-9._]+[ ]" pnew = "ON %s " % new_table_name return rx_replace(rx, self.defn, pnew) def get_drop_sql(self, curs: Optional[Cursor]) -> str: return 'DROP TRIGGER %s ON %s' % (quote_ident(self.name), quote_fqident(self.table_name)) @classmethod def get_load_sql(cls, pg_vers: int) -> str: """Return SQL statement for finding objects.""" sql = "SELECT tgname as name, pg_get_triggerdef(oid) as def "\ " FROM pg_trigger "\ " WHERE tgrelid = %(oid)s AND " if pg_vers >= 90000: sql += "NOT tgisinternal" else: sql += "NOT tgisconstraint" return sql class TParent(TElem): """Info about trigger.""" type = T_PARENT SQL = """ SELECT n.nspname||'.'||c.relname AS name FROM pg_inherits i JOIN pg_class c ON i.inhparent = c.oid JOIN pg_namespace n ON c.relnamespace = n.oid WHERE i.inhrelid = %(oid)s """ name: str parent_name: str def __init__(self, table_name: str, row: DictRow) -> None: self.name = table_name self.parent_name = row['name'] def get_create_sql(self, curs: Optional[Cursor], new_table_name: Optional[str] = None) -> str: return 'ALTER TABLE ONLY %s\n INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) def get_drop_sql(self, curs: Optional[Cursor]) -> str: return 'ALTER TABLE ONLY %s\n NO INHERIT %s' % (quote_fqident(self.name), quote_fqident(self.parent_name)) class TOwner(TElem): """Info about table owner.""" type = T_OWNER SQL = """ SELECT pg_get_userbyid(relowner) as owner FROM pg_class WHERE oid = %(oid)s """ table_name: str name: str owner: str def __init__(self, table_name: str, row: DictRow, new_name: Optional[str] = None) -> None: self.table_name = table_name self.name = 'Owner' self.owner = row['owner'] def get_create_sql(self, curs: Optional[Cursor], new_name: Optional[str] = None) -> str: """Generate creation SQL.""" if not new_name: new_name = self.table_name return 'ALTER TABLE %s\n OWNER TO %s;' % (quote_fqident(new_name), quote_ident(self.owner)) class TGrant(TElem): """Info about permissions.""" type = T_GRANT SQL = "SELECT relacl FROM pg_class where oid = %(oid)s" # Sync with: src/include/utils/acl.h acl_map = { 'a': 'INSERT', 'r': 'SELECT', 'w': 'UPDATE', 'd': 'DELETE', 'D': 'TRUNCATE', 'x': 'REFERENCES', 't': 'TRIGGER', 'X': 'EXECUTE', 'U': 'USAGE', 'C': 'CREATE', 'T': 'TEMPORARY', 'c': 'CONNECT', # old 'R': 'RULE', } name: str acl_list: List[Tuple[Optional[str], str, Optional[str]]] def acl_to_grants(self, acl: str) -> Tuple[str, str]: if acl == "arwdRxt": # ALL for tables return "ALL", "" i = 0 lst1 = [] lst2 = [] while i < len(acl): a = self.acl_map[acl[i]] if i + 1 < len(acl) and acl[i + 1] == '*': lst2.append(a) i += 2 else: lst1.append(a) i += 1 return ", ".join(lst1), ", ".join(lst2) def parse_relacl(self, relacl: Optional[str]) -> List[Tuple[Optional[str], str, Optional[str]]]: """Parse ACL to tuple of (user, acl, who)""" if relacl is None: return [] tup_list = [] parsed_list = skytools.parse_pgarray(relacl) or [] for sacl in parsed_list: if sacl: acl = skytools.parse_acl(sacl) if acl: tup_list.append(acl) return tup_list def __init__(self, table_name: str, row: DictRow, new_name: Optional[str] = None) -> None: self.name = table_name self.acl_list = self.parse_relacl(row['relacl']) def get_create_sql(self, curs: Optional[Cursor], new_name: Optional[str] = None) -> str: """Generate creation SQL.""" if not new_name: new_name = self.name qtarget = quote_fqident(new_name) sql_list = [] for role, acl, ___who in self.acl_list: qrole = quote_ident(role) if role else "public" astr1, astr2 = self.acl_to_grants(acl) if astr1: sql = "GRANT %s ON %s\n TO %s;" % (astr1, qtarget, qrole) sql_list.append(sql) if astr2: sql = "GRANT %s ON %s\n TO %s WITH GRANT OPTION;" % (astr2, qtarget, qrole) sql_list.append(sql) return "\n".join(sql_list) def get_drop_sql(self, curs: Optional[Cursor]) -> str: sql_list = [] for user, ___acl, ___who in self.acl_list: sql = "REVOKE ALL FROM %s ON %s;" % ( quote_ident(user) if user else "public", quote_fqident(self.name) ) sql_list.append(sql) return "\n".join(sql_list) class TColumnDefault(TElem): """Info about table column default value.""" type = T_DEFAULT SQL = """ select a.attname as name, pg_get_expr(d.adbin, d.adrelid) as expr from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.atthasdef and a.attnum > 0 order by a.attnum; """ table_name: str name: str expr: str def __init__(self, table_name: str, row: DictRow) -> None: self.table_name = table_name self.name = row['name'] self.expr = row['expr'] def get_create_sql(self, curs: Optional[Cursor], new_name: Optional[str] = None) -> str: """Generate creation SQL.""" tbl = new_name or self.table_name sql = "ALTER TABLE ONLY %s ALTER COLUMN %s\n SET DEFAULT %s;" % ( quote_fqident(tbl), quote_ident(self.name), self.expr) return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: return "ALTER TABLE %s ALTER COLUMN %s\n DROP DEFAULT;" % ( quote_fqident(self.table_name), quote_ident(self.name)) class TColumn(TElem): """Info about table column.""" SQL = """ select a.attname as name, quote_ident(a.attname) as qname, format_type(a.atttypid, a.atttypmod) as dtype, a.attnotnull, (select max(char_length(aa.attname)) from pg_attribute aa where aa.attrelid = %(oid)s) as maxcol, pg_get_serial_sequence(%(fq2name)s, a.attname) as seqname from pg_attribute a left join pg_attrdef d on (d.adrelid = a.attrelid and d.adnum = a.attnum) where a.attrelid = %(oid)s and not a.attisdropped and a.attnum > 0 order by a.attnum; """ name: str column_def: str seqname: Optional[str] def __init__(self, table_name: str, row: DictRow) -> None: self.name = row['name'] fname = row['qname'].ljust(row['maxcol'] + 3) self.column_def = fname + ' ' + row['dtype'] if row['attnotnull']: self.column_def += ' not null' self.seqname = None if row['seqname']: self.seqname = skytools.unquote_fqident(row['seqname']) class TGPDistKey(TElem): """Info about GreenPlum table distribution keys""" SQL = """ select a.attname as name from pg_attribute a, gp_distribution_policy p where p.localoid = %(oid)s and a.attrelid = %(oid)s and a.attnum = any(p.attrnums) order by a.attnum; """ name: str def __init__(self, table_name: str, row: DictRow) -> None: self.name = row['name'] class TTable(TElem): """Info about table only (columns).""" type = T_TABLE name: str col_list: List[TColumn] dist_key_list: Optional[List[TGPDistKey]] def __init__(self, table_name: str, col_list: List[TColumn], dist_key_list: Optional[List[TGPDistKey]] = None) -> None: self.name = table_name self.col_list = col_list self.dist_key_list = dist_key_list def get_create_sql(self, curs: Optional[Cursor], new_name: Optional[str] = None) -> str: """Generate creation SQL.""" if not new_name: new_name = self.name sql = "CREATE TABLE %s (" % quote_fqident(new_name) sep = "\n " for c in self.col_list: sql += sep + c.column_def sep = ",\n " sql += "\n)" if self.dist_key_list is not None: if self.dist_key_list != []: sql += "\ndistributed by(%s)" % ','.join(c.name for c in self.dist_key_list) else: sql += '\ndistributed randomly' sql += ";" return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: return "DROP TABLE %s;" % quote_fqident(self.name) class TSeq(TElem): """Info about sequence.""" type = T_SEQUENCE SQL_PG10 = """ SELECT %(fq2name)s::name AS sequence_name, s.last_value, p.seqstart AS start_value, p.seqincrement AS increment_by, p.seqmax AS max_value, p.seqmin AS min_value, p.seqcache AS cache_value, s.log_cnt, s.is_called, p.seqcycle AS is_cycled, %(owner)s as owner FROM pg_catalog.pg_sequence p, %(fqname)s s WHERE p.seqrelid = %(fq2name)s::regclass::oid """ SQL_PG9 = """ SELECT %(fq2name)s AS sequence_name, last_value, start_value, increment_by, max_value, min_value, cache_value, log_cnt, is_called, is_cycled, %(owner)s AS "owner" FROM %(fqname)s """ name: str defn: str owner: str @classmethod def get_load_sql(cls, pg_vers: int) -> str: """Return SQL statement for finding objects.""" if pg_vers < 100000: return cls.SQL_PG9 return cls.SQL_PG10 def __init__(self, seq_name: str, row: DictRow) -> None: self.name = seq_name defn = '' self.owner = row['owner'] if row.get('increment_by', 1) != 1: defn += ' INCREMENT BY %d' % row['increment_by'] if row.get('min_value', 1) != 1: defn += ' MINVALUE %d' % row['min_value'] if row.get('max_value', 9223372036854775807) != 9223372036854775807: defn += ' MAXVALUE %d' % row['max_value'] last_value = row['last_value'] if row['is_called']: last_value += row.get('increment_by', 1) if last_value >= row.get('max_value', 9223372036854775807): raise Exception('duh, seq passed max_value') if last_value != 1: defn += ' START %d' % last_value if row.get('cache_value', 1) != 1: defn += ' CACHE %d' % row['cache_value'] if row.get('is_cycled'): defn += ' CYCLE ' if self.owner: defn += ' OWNED BY %s' % self.owner self.defn = defn def get_create_sql(self, curs: Optional[Cursor], new_seq_name: Optional[str] = None) -> str: """Generate creation SQL.""" # we are in table def, forget full def if self.owner: sql = "ALTER SEQUENCE %s\n OWNED BY %s;" % ( quote_fqident(self.name), self.owner) return sql name = self.name if new_seq_name: name = new_seq_name sql = 'CREATE SEQUENCE %s %s;' % (quote_fqident(name), self.defn) return sql def get_drop_sql(self, curs: Optional[Cursor]) -> str: if self.owner: return '' return 'DROP SEQUENCE %s;' % quote_fqident(self.name) # # Main table object, loads all the others # class BaseStruct: """Collects and manages all info about a higher-level db object. Allow to issue CREATE/DROP statements about any group of elements. """ object_list: List[TElem] = [] def __init__(self, curs: Optional[Cursor], name: str) -> None: """Initializes class by loading info about table_name from database.""" self.name = name self.fqname = quote_fqident(name) def _load_elem(self, curs: Cursor, name: str, args: Any, eclass: Type[T]) -> List[T]: """Fetch element(s) from db.""" elem_list = [] #print "Loading %s, name=%s, args=%s" % (repr(eclass), repr(name), repr(args)) sql = eclass.get_load_sql(curs.connection.server_version) curs.execute(sql % args) for row in curs.fetchall(): elem_list.append(eclass(name, row)) return elem_list def create(self, curs: Cursor, objs: int, new_table_name: Optional[str] = None, log: Optional[Logger] = None) -> None: """Issues CREATE statements for requested set of objects. If new_table_name is giver, creates table under that name and also tries to rename all indexes/constraints that conflict with existing table. """ for o in self.object_list: if o.type & objs: sql = o.get_create_sql(curs, new_table_name) if not sql: continue if log: log.info('Creating %s' % o.name) log.debug(sql) curs.execute(sql) def drop(self, curs: Cursor, objs: int, log: Optional[Logger] = None) -> None: """Issues DROP statements for requested set of objects.""" # make sure the creating & dropping happen in reverse order olist = self.object_list[:] olist.reverse() for o in olist: if o.type & objs: sql = o.get_drop_sql(curs) if not sql: continue if log: log.info('Dropping %s' % o.name) log.debug(sql) curs.execute(sql) def get_create_sql(self, objs: int) -> str: res = [] for o in self.object_list: if o.type & objs: sql = o.get_create_sql(None, None) if sql: res.append(sql) return "\n".join(res) class TableStruct(BaseStruct): """Collects and manages all info about table. Allow to issue CREATE/DROP statements about any group of elements. """ table_name: str col_list: List[TColumn] dist_key_list: Optional[List[TGPDistKey]] object_list: List[TElem] seq_list: List[TSeq] def __init__(self, curs: Cursor, table_name: str) -> None: """Initializes class by loading info about table_name from database.""" super().__init__(curs, table_name) self.table_name = table_name # fill args schema, name = skytools.fq_name_parts(table_name) args = { 'schema': schema, 'table': name, 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'oid': skytools.get_table_oid(curs, table_name), 'pg_class_oid': skytools.get_table_oid(curs, 'pg_catalog.pg_class'), } # load table struct self.col_list = self._load_elem(curs, self.name, args, TColumn) # if db is GP then read also table distribution keys if skytools.exists_table(curs, "pg_catalog.gp_distribution_policy"): self.dist_key_list = self._load_elem(curs, self.name, args, TGPDistKey) else: self.dist_key_list = None self.object_list = [TTable(table_name, self.col_list, self.dist_key_list)] self.seq_list = [] # load seqs for col in self.col_list: if col.seqname: fqname = quote_fqident(col.seqname) owner = self.fqname + '.' + quote_ident(col.name) seq_args = { 'fqname': fqname, 'fq2name': skytools.quote_literal(fqname), 'owner': skytools.quote_literal(owner), } self.seq_list += self._load_elem(curs, col.seqname, seq_args, TSeq) self.object_list += self.seq_list # load additional objects to_load = [TColumnDefault, TConstraint, TIndex, TTrigger, TRule, TGrant, TOwner, TParent] for eclass in to_load: self.object_list += self._load_elem(curs, self.name, args, eclass) def get_column_list(self) -> List[str]: """Returns list of column names the table has.""" res = [] for c in self.col_list: res.append(c.name) return res class SeqStruct(BaseStruct): """Collects and manages all info about sequence. Allow to issue CREATE/DROP statements about any group of elements. """ def __init__(self, curs: Cursor, seq_name: str) -> None: """Initializes class by loading info about table_name from database.""" super().__init__(curs, seq_name) # fill args args = { 'fqname': self.fqname, 'fq2name': skytools.quote_literal(self.fqname), 'owner': 'null', } # load table struct self.object_list = self._load_elem(curs, seq_name, args, TSeq) def manual_check() -> None: from skytools import connect_database db = connect_database("dbname=fooz") curs = db.cursor() s = TableStruct(curs, "public.data1") s.drop(curs, T_ALL) s.create(curs, T_ALL) s.create(curs, T_ALL, "data1_new") s.create(curs, T_PKEY) if __name__ == '__main__': manual_check() python-skytools-3.9.2/skytools/fileutil.py000066400000000000000000000107101447265566700210360ustar00rootroot00000000000000"""File utilities """ # pylint:disable=unspecified-encoding import errno import os import sys from typing import Optional, Union __all__ = ['write_atomic', 'signal_pidfile'] def write_atomic_unix(fn: str, data: Union[bytes, str], bakext: Optional[str] = None, mode: str = 'b') -> None: """Write file with rename.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") if mode == "b" and not isinstance(data, bytes): data = data.encode('utf8') mode = 'w' + mode # write new data to tmp file fn2 = fn + '.new' if "b" in mode: with open(fn2, mode) as f: f.write(data) else: with open(fn2, mode, encoding="utf8") as f: f.write(data) # link old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.unlink(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.link(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise # win32 does not like replace if sys.platform == 'win32': try: os.remove(fn) except BaseException: pass # atomically replace file os.rename(fn2, fn) def signal_pidfile(pidfile: str, sig: int) -> bool: """Send a signal to process whose ID is located in pidfile. Read only first line of pidfile to support multiline pidfiles like postmaster.pid. Returns True is successful, False if pidfile does not exist or process itself is dead. Any other errors will passed as exceptions.""" ln = '' try: with open(pidfile, 'r', encoding="utf8") as f: ln = f.readline().strip() pid = int(ln) if sig == 0 and sys.platform == 'win32': return win32_detect_pid(pid) os.kill(pid, sig) return True except (IOError, OSError) as ex: if ex.errno not in (errno.ESRCH, errno.ENOENT): raise except ValueError: # this leaves slight race when someone is just creating the file, # but more common case is old empty file. if not ln: return False raise ValueError('Corrupt pidfile: %s' % pidfile) from None return False def win32_detect_pid(pid: int) -> bool: """Process detection for win32.""" # avoid pywin32 dependecy, use ctypes instead import ctypes # win32 constants PROCESS_QUERY_INFORMATION = 1024 STILL_ACTIVE = 259 ERROR_INVALID_PARAMETER = 87 ERROR_ACCESS_DENIED = 5 # Load kernel32.dll k = ctypes.windll.kernel32 # type: ignore OpenProcess = k.OpenProcess OpenProcess.restype = ctypes.c_void_p # query pid exit code h = OpenProcess(PROCESS_QUERY_INFORMATION, 0, pid) if h is None: err = k.GetLastError() if err == ERROR_INVALID_PARAMETER: return False if err == ERROR_ACCESS_DENIED: return True raise OSError(errno.EFAULT, "Unknown win32error: " + str(err)) code = ctypes.c_int() k.GetExitCodeProcess(h, ctypes.byref(code)) k.CloseHandle(h) return code.value == STILL_ACTIVE def win32_write_atomic(fn: str, data: Union[bytes, str], bakext: Optional[str] = None, mode: str = 'b') -> None: """Write file with rename for win32.""" if mode not in ['', 'b', 't']: raise ValueError("unsupported fopen mode") if mode == "b" and not isinstance(data, bytes): data = data.encode('utf8') mode = "w" + mode # write new data to tmp file fn2 = fn + '.new' if "b" in mode: with open(fn2, mode) as f: f.write(data) else: with open(fn2, mode, encoding="utf8") as f: f.write(data) # move old data to bak file if bakext: if bakext.find('/') >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: os.remove(fnb) except OSError as e: if e.errno != errno.ENOENT: raise try: os.rename(fn, fnb) except OSError as e: if e.errno != errno.ENOENT: raise else: try: os.remove(fn) except BaseException: pass # replace file os.rename(fn2, fn) if sys.platform == 'win32': write_atomic = win32_write_atomic else: write_atomic = write_atomic_unix python-skytools-3.9.2/skytools/gzlog.py000066400000000000000000000014011447265566700203400ustar00rootroot00000000000000"""Atomic append of gzipped data. The point is - if several gzip streams are concatenated, they are read back as one whole stream. """ import gzip import io __all__ = ('gzip_append',) def gzip_append(filename: str, data: bytes, level: int = 6) -> None: """Append a block of data to file with safety checks.""" # compress data buf = io.BytesIO() with gzip.GzipFile(fileobj=buf, compresslevel=level, mode="w") as g: g.write(data) zdata = buf.getvalue() # append, safely with open(filename, "ab+", 0) as f: f.seek(0, 2) pos = f.tell() try: f.write(zdata) except Exception as ex: # rollback on error f.seek(pos, 0) f.truncate() raise ex python-skytools-3.9.2/skytools/hashtext.py000066400000000000000000000070621447265566700210570ustar00rootroot00000000000000""" Implementation of Postgres hashing function. hashtext_old() - used up to PostgreSQL 8.3 hashtext_new() - used since PostgreSQL 8.4 """ import struct import sys from typing import Tuple, Union try: from skytools._chashtext import hashtext_new, hashtext_old except ImportError: def hashtext_old(v: Union[bytes, str]) -> int: return hashtext_old_py(v) def hashtext_new(v: Union[bytes, str]) -> int: return hashtext_new_py(v) __all__ = ("hashtext_old", "hashtext_new") # pad for last partial block PADDING = b'\0' * 12 def uint32(x: int) -> int: """python does not have 32 bit integer so we need this hack to produce uint32 after bit operations""" return x & 0xffffffff # # Old Postgres hashtext() - lookup2 with custom initval # FMT_OLD = struct.Struct(" Tuple[int, int, int]: c = uint32(c) a = uint32((a - b - c) ^ (c >> 13)) b = uint32((b - c - a) ^ (a << 8)) c = uint32((c - a - b) ^ (b >> 13)) a = uint32((a - b - c) ^ (c >> 12)) b = uint32((b - c - a) ^ (a << 16)) c = uint32((c - a - b) ^ (b >> 5)) a = uint32((a - b - c) ^ (c >> 3)) b = uint32((b - c - a) ^ (a << 10)) c = uint32((c - a - b) ^ (b >> 15)) return a, b, c def hashtext_old_py(k: Union[bytes, str]) -> int: """Old Postgres hashtext()""" if isinstance(k, str): k = k.encode() remain = len(k) pos = 0 a = b = 0x9e3779b9 c = 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_OLD.unpack_from(k, pos) a, b, c = mix_old(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_OLD.unpack_from(k[pos:] + PADDING, 0) # the lowest byte of c is reserved for the length c2 = (c2 << 8) + len(k) a, b, c = mix_old(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) # # New Postgres hashtext() - hacked lookup3: # - custom initval # - calls mix() when len=12 # - shifted c in last block on little-endian # FMT_NEW = struct.Struct("=LLL") def rol32(x: int, k: int) -> int: return ((x) << (k)) | (uint32(x) >> (32 - (k))) def mix_new(a: int, b: int, c: int) -> Tuple[int, int, int]: a = (a - c) ^ rol32(c, 4) c += b b = (b - a) ^ rol32(a, 6) a += c c = (c - b) ^ rol32(b, 8) b += a a = (a - c) ^ rol32(c, 16) c += b b = (b - a) ^ rol32(a, 19) a += c c = (c - b) ^ rol32(b, 4) b += a return uint32(a), uint32(b), uint32(c) def final_new(a: int, b: int, c: int) -> Tuple[int, int, int]: c = (c ^ b) - rol32(b, 14) a = (a ^ c) - rol32(c, 11) b = (b ^ a) - rol32(a, 25) c = (c ^ b) - rol32(b, 16) a = (a ^ c) - rol32(c, 4) b = (b ^ a) - rol32(a, 14) c = (c ^ b) - rol32(b, 24) return uint32(a), uint32(b), uint32(c) def hashtext_new_py(k: Union[bytes, str]) -> int: """New Postgres hashtext()""" if isinstance(k, str): k = k.encode() remain = len(k) pos = 0 a = b = c = 0x9e3779b9 + len(k) + 3923095 # handle most of the key while remain >= 12: a2, b2, c2 = FMT_NEW.unpack_from(k, pos) a, b, c = mix_new(a + a2, b + b2, c + c2) pos += 12 remain -= 12 # handle the last 11 bytes a2, b2, c2 = FMT_NEW.unpack_from(k[pos:] + PADDING, 0) if sys.byteorder == 'little': c2 = c2 << 8 a, b, c = final_new(a + a2, b + b2, c + c2) # convert to signed int if c & 0x80000000: c = -0x100000000 + c return int(c) python-skytools-3.9.2/skytools/installer_config.py000066400000000000000000000003141447265566700225420ustar00rootroot00000000000000 """SQL script locations.""" __all__ = ['sql_locations'] sql_locations = [ "/usr/share/skytools3", ] # PEP 440 version: [N!]N(.N)*[{a|b|rc}N][.postN][.devN] package_version = "3.9.2" skylog = 0 python-skytools-3.9.2/skytools/natsort.py000066400000000000000000000053061447265566700207200ustar00rootroot00000000000000"""Natural sort. Compares numeric parts numerically. Rules: * String consists of numeric and non-numeric parts. * Parts are compared pairwise, if both are numeric then numerically, otherwise textually. (Only first fragment can have different type) * If strings are different, they should compare differnet in natsort. Extra rules for version numbers: * In textual comparision, '~' is less than anything else, including end-of-string. * In numeric comparision, numbers that start with '0' are compared as decimal fractions, thus number that starts with more zeroes is smaller. """ import re from typing import List, Sequence __all__ = ( 'natsort_key', 'natsort', 'natsorted', 'natsort_key_icase', 'natsort_icase', 'natsorted_icase' ) _rc = re.compile(r'\d+|\D+', re.A) def natsort_key(s: str) -> str: """Returns string that sorts according to natsort rules. """ # generates four types of fragments: # 1) strings < "0", stay as-is # 2) numbers starting with 0, fragment starts with "A".."Z" # 3) numbers starting with 1..9, fragment starts with "a".."z" # 4) strings > "9", fragment starts with "|" if "~" in s: s = s.replace("~", "\0") key: List[str] = [] key_append = key.append for frag in _rc.findall(s): if frag < "0": key_append(frag) key_append("\1") elif frag < "1": nzeros = len(frag) - len(frag.lstrip('0')) mag = str(nzeros) mag = str(10**len(mag) - nzeros) key_append(chr(0x5B - len(mag))) # Z, Y, X, ... key_append(mag) key_append(frag) elif frag < ":": mag = str(len(frag)) key_append(chr(0x60 + len(mag))) # a, b, c, ... key_append(mag) key_append(frag) else: key_append("|") key_append(frag) key_append("\1") if not (key and key[-1] == "\1"): key_append("\1") return "".join(key) def natsort(lst: List[str]) -> None: """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key) def natsorted(lst: Sequence[str]) -> List[str]: """Return copy of list, sorted in natural order, case-sensitive. """ return sorted(lst, key=natsort_key) # case-insensitive api def natsort_key_icase(s: str) -> str: """Split string to numeric and non-numeric fragments.""" return natsort_key(s.lower()) def natsort_icase(lst: List[str]) -> None: """Natural in-place sort, case-sensitive.""" lst.sort(key=natsort_key_icase) def natsorted_icase(lst: Sequence[str]) -> List[str]: """Return copy of list, sorted in natural order, case-sensitive. """ return sorted(lst, key=natsort_key_icase) python-skytools-3.9.2/skytools/parsing.py000066400000000000000000000351771447265566700207020ustar00rootroot00000000000000"""Various parsers for Postgres-specific data formats. """ import re from typing import Iterator, List, Optional, Sequence, Tuple, Dict, Union import skytools __all__ = ( "parse_pgarray", "parse_logtriga_sql", "parse_tabbed_table", "parse_statements", 'sql_tokenizer', 'parse_sqltriga_sql', "parse_acl", "dedent", "hsize_to_bytes", "parse_connect_string", "merge_connect_string", ) _rc_listelem = re.compile(r'( [^,"}]+ | ["] ( [^"\\]+ | [\\]. )* ["] )', re.X) def parse_pgarray(array: Optional[str]) -> Optional[List[Optional[str]]]: r"""Parse Postgres array and return list of items inside it. """ if array is None: return None if not array or array[0] not in ("{", "[") or array[-1] != '}': raise ValueError("bad array format: must be surrounded with {}") res = [] pos = 1 # skip optional dimensions descriptor "[a,b]={...}" if array[0] == "[": pos = array.find('{') + 1 if pos < 1: raise ValueError("bad array format 2: must be surrounded with {}") while True: m = _rc_listelem.search(array, pos) if not m: break pos2 = m.end() item = array[pos:pos2] if len(item) == 4 and item.upper() == "NULL": val = None else: if item and item[0] == '"': item = item[1:-1] val = skytools.unescape(item) res.append(val) pos = pos2 + 1 if array[pos2] == "}": break elif array[pos2] != ",": raise ValueError("bad array format: expected ,} got " + repr(array[pos2])) if pos < len(array) - 1: raise ValueError("bad array format: failed to parse completely (pos=%d len=%d)" % (pos, len(array))) return res # # parse logtriga partial sql # class _logtriga_parser: """Parses logtriga/sqltriga partial SQL to values.""" pklist: List[str] def tokenizer(self, sql: str) -> Iterator[str]: """Token generator.""" for tup in sql_tokenizer(sql, ignore_whitespace=True): yield tup[1] def parse_insert(self, tk: Iterator[str], fields: List[str], values: List[str], key_fields: List[str], key_values: List[str]) -> None: """Handler for inserts.""" # (col1, col2) values ('data', null) if next(tk) != "(": raise ValueError("syntax error") while True: fields.append(next(tk)) t = next(tk) if t == ")": break elif t != ",": raise ValueError("syntax error") if next(tk).lower() != "values": raise ValueError("syntax error, expected VALUES") if next(tk) != "(": raise ValueError("syntax error, expected (") while True: values.append(next(tk)) t = next(tk) if t == ")": break if t == ",": continue raise ValueError("expected , or ) got " + t) t = next(tk) raise ValueError("expected EOF, got " + repr(t)) def parse_update(self, tk: Iterator[str], fields: List[str], values: List[str], key_fields: List[str], key_values: List[str]) -> None: """Handler for updates.""" # col1 = 'data1', col2 = null where pk1 = 'pk1' and pk2 = 'pk2' while True: fields.append(next(tk)) if next(tk) != "=": raise ValueError("syntax error") values.append(next(tk)) t = next(tk) if t == ",": continue elif t.lower() == "where": break else: raise ValueError("syntax error, expected WHERE or , got " + repr(t)) while True: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise ValueError("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise ValueError("syntax error, expected AND got " + repr(t)) def parse_delete(self, tk: Iterator[str], fields: List[str], values: List[str], key_fields: List[str], key_values: List[str]) -> None: """Handler for deletes.""" # pk1 = 'pk1' and pk2 = 'pk2' while True: fld = next(tk) key_fields.append(fld) self.pklist.append(fld) if next(tk) != "=": raise ValueError("syntax error") key_values.append(next(tk)) t = next(tk) if t.lower() != "and": raise ValueError("syntax error, expected AND, got " + repr(t)) def _create_dbdict(self, fields: List[str], values: List[str]) -> skytools.dbdict: fields2 = [skytools.unquote_ident(f) for f in fields] values2 = [skytools.unquote_literal(v) for v in values] return skytools.dbdict(zip(fields2, values2)) def parse_sql(self, op: str, sql: str, pklist: Optional[Sequence[str]] = None, splitkeys: bool = False ) -> Union[skytools.dbdict, Tuple[skytools.dbdict, skytools.dbdict]]: """Main entry point.""" if pklist is None: self.pklist = [] else: self.pklist = list(pklist) tk = self.tokenizer(sql) fields: List[str] = [] values: List[str] = [] key_fields: List[str] = [] key_values: List[str] = [] try: if op == "I": self.parse_insert(tk, fields, values, key_fields, key_values) elif op == "U": self.parse_update(tk, fields, values, key_fields, key_values) elif op == "D": self.parse_delete(tk, fields, values, key_fields, key_values) raise ValueError("syntax error") except StopIteration: pass # last sanity check if (len(fields) + len(key_fields) == 0 or len(fields) != len(values) or len(key_fields) != len(key_values)): raise ValueError("syntax error, fields do not match values") if splitkeys: return (self._create_dbdict(key_fields, key_values), self._create_dbdict(fields, values)) return self._create_dbdict(fields + key_fields, values + key_values) def parse_logtriga_sql( op: str, sql: str, splitkeys: bool = False ) -> Union[skytools.dbdict, Tuple[skytools.dbdict, skytools.dbdict]]: return parse_sqltriga_sql(op, sql, splitkeys=splitkeys) def parse_sqltriga_sql( op: str, sql: str, pklist: Optional[Sequence[str]] = None, splitkeys: bool = False ) -> Union[skytools.dbdict, Tuple[skytools.dbdict, skytools.dbdict]]: """Parse partial SQL used by pgq.sqltriga() back to data values. Parser has following limitations: - Expects standard_quoted_strings = off - Does not support dollar quoting. - Does not support complex expressions anywhere. (hashtext(col1) = hashtext(val1)) - WHERE expression must not contain IS (NOT) NULL - Does not support updating pk value, unless you use the splitkeys parameter. Returns dict of col->data pairs. """ return _logtriga_parser().parse_sql(op, sql, pklist, splitkeys=splitkeys) def parse_tabbed_table(txt: str) -> List[Dict[str, str]]: r"""Parse a tab-separated table into list of dicts. Expect first row to be column names. Very primitive. """ txt = txt.replace("\r\n", "\n") fields = None data: List[Dict[str, str]] = [] for ln in txt.split("\n"): if not ln: continue if not fields: fields = ln.split("\t") continue cols = ln.split("\t") if len(cols) != len(fields): continue row = dict(zip(fields, cols)) data.append(row) return data _extstr = r""" ['] (?: [^'\\]+ | \\. | [']['] )* ['] """ _stdstr = r""" ['] (?: [^']+ | [']['] )* ['] """ _name = r""" (?: [a-z_][a-z0-9_$]* | " (?: [^"]+ | "" )* " ) """ _ident = r""" (?P %s ) """ % _name _fqident = r""" (?P %s (?: \. %s )* ) """ % (_name, _name) _base_sql = r""" (?P (?P [$] (?: [_a-z][_a-z0-9]*)? [$] ) .*? (?P=dname) ) | (?P [0-9][0-9.e]* ) | (?P [$] [0-9]+ ) | (?P [%][(] [a-z_][a-z0-9_]* [)] [s] ) | (?P [{] [^{}]+ [}] ) | (?P (?: \s+ | [/][*] .*? [*][/] | [-][-][^\n]* )+ ) | (?P (?: [-+*~!@#^&|?/%<>=]+ | [,()\[\].:;] ) ) | (?P . )""" _base_sql_fq = r"%s | %s" % (_fqident, _base_sql) _base_sql = r"%s | %s" % (_ident, _base_sql) _std_sql = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql) _std_sql_fq = r"""(?: (?P [E] %s | %s ) | %s )""" % (_extstr, _stdstr, _base_sql_fq) _ext_sql = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql) _ext_sql_fq = r"""(?: (?P [E]? %s ) | %s )""" % (_extstr, _base_sql_fq) _std_sql_rc = _ext_sql_rc = None _std_sql_fq_rc = _ext_sql_fq_rc = None def sql_tokenizer( sql: str, standard_quoting: bool = False, ignore_whitespace: bool = False, fqident: bool = False, show_location: bool = False ) -> Iterator[Union[Tuple[str, str], Tuple[str, str, int]]]: r"""Parser SQL to tokens. Iterator, returns (toktype, tokstr) tuples. """ global _std_sql_rc, _ext_sql_rc, _std_sql_fq_rc, _ext_sql_fq_rc if not _std_sql_rc: _std_sql_rc = re.compile(_std_sql, re.X | re.I | re.S) _ext_sql_rc = re.compile(_ext_sql, re.X | re.I | re.S) _std_sql_fq_rc = re.compile(_std_sql_fq, re.X | re.I | re.S) _ext_sql_fq_rc = re.compile(_ext_sql_fq, re.X | re.I | re.S) if standard_quoting: if fqident: rc = _std_sql_fq_rc else: rc = _std_sql_rc else: if fqident: rc = _ext_sql_fq_rc else: rc = _ext_sql_rc pos = 0 while True: m = rc.match(sql, pos) if not m: break pos = m.end() typ = m.lastgroup or '?' if ignore_whitespace and typ == "ws": continue tk = m.group() if show_location: yield (typ, tk, pos) else: yield (typ, tk) _copy_from_stdin_re = r"copy.*from\s+stdin" _copy_from_stdin_rc = None def parse_statements(sql: str, standard_quoting: bool = False) -> Iterator[str]: """Parse multi-statement string into separate statements. Returns list of statements. """ global _copy_from_stdin_rc if not _copy_from_stdin_rc: _copy_from_stdin_rc = re.compile(_copy_from_stdin_re, re.X | re.I) tokens: List[str] = [] pcount = 0 # '(' level for tmp in sql_tokenizer(sql, standard_quoting=standard_quoting): typ, t = tmp[0], tmp[1] # skip whitespace and comments before statement if len(tokens) == 0 and typ == "ws": continue # keep the rest tokens.append(t) if t == "(": pcount += 1 elif t == ")": pcount -= 1 elif t == ";" and pcount == 0: sql = "".join(tokens) if _copy_from_stdin_rc.match(sql): raise ValueError("copy from stdin not supported") yield "".join(tokens) tokens = [] if len(tokens) > 0: yield "".join(tokens) if pcount != 0: raise ValueError("syntax error - unbalanced parenthesis") _acl_name = r'(?: [0-9a-z_]+ | " (?: [^"]+ | "" )* " )' _acl_re = r''' \s* (?: group \s+ | user \s+ )? (?P %s )? (?P = [a-z*]* )? (?P / %s )? \s* $ ''' % (_acl_name, _acl_name) _acl_rc = None def parse_acl(acl: str) -> Optional[Tuple[Optional[str], str, Optional[str]]]: """Parse ACL entry. """ global _acl_rc if not _acl_rc: _acl_rc = re.compile(_acl_re, re.I | re.X) m = _acl_rc.match(acl) if not m: return None target = m.group('tgt') perm = m.group('perm') owner = m.group('owner') if target: target = skytools.unquote_ident(target) if perm: perm = perm[1:] if owner: owner = skytools.unquote_ident(owner[1:]) return (target, perm, owner) def dedent(doc: str) -> str: r"""Relaxed dedent. - takes whitespace to be removed from first indented line. - allows empty or non-indented lines at the start - allows first line to be unindented - skips empty lines at the start - ignores indent of empty lines - if line does not match common indent, is stays unchanged """ pfx: Optional[str] = None res: List[str] = [] for ln in doc.splitlines(): ln = ln.rstrip() if not pfx and len(res) < 2: if not ln: continue wslen = len(ln) - len(ln.lstrip()) pfx = ln[: wslen] if pfx: if ln.startswith(pfx): ln = ln[len(pfx):] res.append(ln) res.append('') return '\n'.join(res) def hsize_to_bytes(input_str: str) -> int: """ Convert sizes from human format to bytes (string to integer) """ m = re.match(r"^([0-9]+) *([KMGTPEZY]?)B?$", input_str.strip(), re.IGNORECASE) if not m: raise ValueError("cannot parse: %s" % input_str) units = ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y'] unit = m.group(2) or '' nbytes = int(m.group(1)) * 1024 ** units.index(unit.upper()) assert isinstance(nbytes, int) return nbytes # # Connect string parsing # _cstr_rx = r""" \s* (\w+) \s* = \s* ( ' ( \\.| [^'\\] )* ' | \S+ ) \s* """ _cstr_unesc_rx = r"\\(.)" _cstr_badval_rx = r"[\s'\\]" _cstr_rc = None _cstr_unesc_rc = None _cstr_badval_rc = None def parse_connect_string(cstr: str) -> List[Tuple[str, str]]: r"""Parse Postgres connect string. """ global _cstr_rc, _cstr_unesc_rc if not _cstr_rc: _cstr_rc = re.compile(_cstr_rx, re.X) _cstr_unesc_rc = re.compile(_cstr_unesc_rx) pos = 0 res = [] while pos < len(cstr): m = _cstr_rc.match(cstr, pos) if not m: raise ValueError('Invalid connect string') pos = m.end() k = m.group(1) v = m.group(2) if v[0] == "'": v = _cstr_unesc_rc.sub(r"\1", v) res.append((k, v)) return res def merge_connect_string(cstr_arg_list: Sequence[Tuple[str, str]]) -> str: """Put fragments back together. """ global _cstr_badval_rc if not _cstr_badval_rc: _cstr_badval_rc = re.compile(_cstr_badval_rx) buf = [] for k, v in cstr_arg_list: if not v or _cstr_badval_rc.search(v): v = v.replace('\\', r'\\') v = v.replace("'", r"\'") v = "'" + v + "'" buf.append("%s=%s" % (k, v)) return ' '.join(buf) python-skytools-3.9.2/skytools/plpy_applyrow.py000066400000000000000000000151331447265566700221460ustar00rootroot00000000000000"""PLPY helper module for applying row events from pgq.logutriga(). """ from typing import Sequence, Optional import skytools try: import plpy except ImportError: pass ## TODO: automatic fkey detection # find FK columns FK_SQL = """ SELECT (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = conrelid AND attnum = conkey[i])) FROM generate_series(1, array_upper(conkey, 1)) i) AS kcols, (SELECT array_agg( (SELECT attname::text FROM pg_attribute WHERE attrelid = confrelid AND attnum = confkey[i])) FROM generate_series(1, array_upper(confkey, 1)) i) AS fcols, confrelid::regclass::text AS ftable FROM pg_constraint WHERE conrelid = {tbl}::regclass AND contype='f' """ class DataError(Exception): "Invalid data" def colfilter_full(rnew, rold): return rnew def colfilter_changed(rnew, rold): res = {} for k, _ in rnew: if rnew[k] != rold[k]: res[k] = rnew[k] return res def canapply_dummy(rnew, rold): return True def canapply_tstamp_helper(rnew, rold, tscol): tnew = rnew[tscol] told = rold[tscol] if not tnew[0].isdigit(): raise DataError('invalid timestamp') if not told[0].isdigit(): raise DataError('invalid timestamp') return tnew > told def applyrow(tblname, ev_type, new_row, backup_row=None, alt_pkey_cols=None, fkey_cols=None, fkey_ref_table=None, fkey_ref_cols=None, fn_canapply=canapply_dummy, fn_colfilter=colfilter_full): """Core logic. Actual decisions will be done in callback functions. - [IUD]: If row referenced by fkey does not exist, event is not applied - If pkey does not exist but alt_pkey does, row is not applied. @param tblname: table name, schema-qualified @param ev_type: [IUD]:pkey1,pkey2 @param alt_pkey_cols: list of alternatice columns to consuder @param fkey_cols: columns in this table that refer to other table @param fkey_ref_table: other table referenced here @param fkey_ref_cols: column in other table that must match @param fn_canapply: callback function, gets new and old row, returns whether the row should be applied @param fn_colfilter: callback function, gets new and old row, returns dict of final columns to be applied """ gd = None # parse ev_type tmp = ev_type.split(':', 1) if len(tmp) != 2 or tmp[0] not in ('I', 'U', 'D'): raise DataError('Unsupported ev_type: ' + repr(ev_type)) if not tmp[1]: raise DataError('No pkey in event') cmd = tmp[0] pkey_cols = tmp[1].split(',') qtblname = skytools.quote_fqident(tblname) # parse ev_data fields = skytools.db_urldecode(new_row) if ev_type.find('}') >= 0: raise DataError('Really suspicious activity') if ",".join(fields.keys()).find('}') >= 0: raise DataError('Really suspicious activity 2') # generate pkey expressions tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in pkey_cols] pkey_expr = " and ".join(tmp) alt_pkey_expr = None if alt_pkey_cols: tmp = ["%s = {%s}" % (skytools.quote_ident(k), k) for k in alt_pkey_cols] alt_pkey_expr = " and ".join(tmp) log = "data ok" # # Row data seems fine, now apply it # res: Optional[Sequence[skytools.dbdict]] oldrow: Optional[skytools.dbdict] if fkey_ref_table: tmp = [] for k, rk in zip(fkey_cols, fkey_ref_cols): tmp.append("%s = {%s}" % (skytools.quote_ident(rk), k)) fkey_expr = " and ".join(tmp) q = "select 1 from only %s where %s" % ( skytools.quote_fqident(fkey_ref_table), fkey_expr) res = skytools.plpy_exec(gd, q, fields) if not res: return "IGN: parent row does not exist" log += ", fkey ok" # fetch old row if alt_pkey_expr: q = "select * from only %s where %s for update" % (qtblname, alt_pkey_expr) res = skytools.plpy_exec(gd, q, fields) if res: oldrow = res[0] # if altpk matches, but pk not, then delete need_del = 0 for k in pkey_cols: # fixme: proper type cmp? if fields[k] != str(oldrow[k]): need_del = 1 break if need_del: log += ", altpk del" q = "delete from only %s where %s" % (qtblname, alt_pkey_expr) skytools.plpy_exec(gd, q, fields) res = None else: log += ", altpk ok" else: # no altpk q = "select * from only %s where %s for update" % (qtblname, pkey_expr) res = skytools.plpy_exec(None, q, fields) # got old row, with same pk and altpk if res: oldrow = res[0] log += ", old row" ok = fn_canapply(fields, oldrow) if ok: log += ", new row better" if not ok: # ignore the update return "IGN:" + log + ", current row more up-to-date" else: log += ", no old row" oldrow = None if res: if cmd == 'I': cmd = 'U' else: if cmd == 'U': cmd = 'I' # allow column changes if oldrow: fields2 = fn_colfilter(fields, oldrow) for k in pkey_cols: if k not in fields2: fields2[k] = fields[k] fields = fields2 # apply change if cmd == 'I': q = skytools.mk_insert_sql(fields, tblname, pkey_cols) elif cmd == 'U': q = skytools.mk_update_sql(fields, tblname, pkey_cols) elif cmd == 'D': q = skytools.mk_delete_sql(fields, tblname, pkey_cols) else: plpy.error('Huh') plpy.execute(q) return log def ts_conflict_handler(gd, args): """Conflict handling based on timestamp column.""" conf = skytools.db_urldecode(args[0]) timefield = conf['timefield'] ev_type = args[1] ev_data = args[2] ev_extra1 = args[3] ev_extra2 = args[4] #ev_extra3 = args[5] #ev_extra4 = args[6] altpk = None cf_altpk = conf.get('altpk') if cf_altpk: altpk = cf_altpk.split(',') def ts_canapply(rnew, rold): return canapply_tstamp_helper(rnew, rold, timefield) return applyrow( ev_extra1, ev_type, ev_data, backup_row=ev_extra2, alt_pkey_cols=altpk, fkey_ref_table=conf.get('fkey_ref_table'), fkey_ref_cols=conf.get('fkey_ref_cols'), fkey_cols=conf.get('fkey_cols'), fn_canapply=ts_canapply ) python-skytools-3.9.2/skytools/psycopgwrapper.py000066400000000000000000000035151447265566700223130ustar00rootroot00000000000000"""Wrapper around psycopg2. Database connection provides regular DB-API 2.0 interface. Connection object methods:: .cursor() .commit() .rollback() .close() Cursor methods:: .execute(query[, args]) .fetchone() .fetchall() Sample usage:: db = self.get_database('somedb') curs = db.cursor() # query arguments as array q = "select * from table where id = %s and name = %s" curs.execute(q, [1, 'somename']) # query arguments as dict q = "select id, name from table where id = %(id)s and name = %(name)s" curs.execute(q, {'id': 1, 'name': 'somename'}) # loop over resultset for row in curs.fetchall(): # columns can be asked by index: id = row[0] name = row[1] # and by name: id = row['id'] name = row['name'] # now commit the transaction db.commit() """ from typing import cast import psycopg2 import psycopg2.extensions import psycopg2.extras from psycopg2 import Error as DBError from .basetypes import Connection __all__ = ( 'connect_database', 'DBError', 'I_AUTOCOMMIT', 'I_READ_COMMITTED', 'I_REPEATABLE_READ', 'I_SERIALIZABLE', ) #: Isolation level for db.set_isolation_level() I_AUTOCOMMIT = psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT #: Isolation level for db.set_isolation_level() I_READ_COMMITTED = psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED #: Isolation level for db.set_isolation_level() I_REPEATABLE_READ = psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ #: Isolation level for db.set_isolation_level() I_SERIALIZABLE = psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE def connect_database(connstr: str) -> Connection: """Create a db connection with DictCursor. """ db = psycopg2.connect(connstr, cursor_factory=psycopg2.extras.DictCursor) return cast(Connection, db) python-skytools-3.9.2/skytools/py.typed000066400000000000000000000000001447265566700203350ustar00rootroot00000000000000python-skytools-3.9.2/skytools/querybuilder.py000066400000000000000000000313321447265566700217400ustar00rootroot00000000000000"""Helper classes for complex query generation. Main target is code execution under PL/Python. Query parameters are referenced as C{{key}} or C{{key:type}}. Type will be given to C{plpy.prepare}. If C{type} is missing, C{text} is assumed. See L{plpy_exec} for examples. """ import json import re from functools import lru_cache from typing import Any, Dict, List, Mapping, Optional, Sequence, Union, Tuple, cast import skytools from .basetypes import Cursor try: import plpy except ImportError: plpy = None __all__ = ( 'QueryBuilder', 'PLPyQueryBuilder', 'PLPyQuery', 'plpy_exec', "run_query", "run_query_row", "run_lookup", "run_exists", ) PARAM_INLINE = 0 # quote_literal() PARAM_DBAPI = 1 # %()s PARAM_PLPY = 2 # $n _RC_PARAM = re.compile(r""" \{ ( [^|{}:]* ) (?: : ( [^|{}:]+ ) )? (?: \| ( [^|{}:]+ ) )? ( \} )? """, re.X) def _inline_to_text(val: Any) -> Optional[str]: """Approx emulate PL/Python and Psycopg2 internal conversions for common types. """ if val is None or isinstance(val, str): return val if isinstance(val, dict): return json.dumps(val) if isinstance(val, (tuple, list)): return skytools.make_pgarray(val) if isinstance(val, bytes): return "\\x" + val.hex() return str(val) class QArgConf: """Per-query arg-type config object.""" param_type = PARAM_INLINE class QArg: """Place-holder for a query parameter.""" def __init__(self, name: str, value: Any, pos: int, conf: QArgConf): self.name = name self.value = value self.pos = pos self.conf = conf def __str__(self) -> str: if self.conf.param_type == PARAM_INLINE: return skytools.quote_literal(_inline_to_text(self.value)) elif self.conf.param_type == PARAM_DBAPI: return "%s" elif self.conf.param_type == PARAM_PLPY: return "$%d" % self.pos else: raise Exception("bad QArgConf.param_type") class PlanCache: """Cache for limited amount of plans.""" def __init__(self, maxplans: int = 100) -> None: self.maxplans = maxplans @lru_cache(maxplans) def _cached_prepare(key: Tuple[str, Tuple[str, ...]]) -> Any: sql, types = key return plpy.prepare(sql, types) self._cached_prepare = _cached_prepare def get_plan(self, sql: str, types: Sequence[str]) -> Any: """Prepare the plan and cache it.""" key = (sql, tuple(types)) return self._cached_prepare(key) class QueryBuilderCore: """Helper for query building. """ _params: Optional[Mapping[str, Any]] _arg_type_list: List[str] _arg_value_list: List[Any] _sql_parts: List[Union[str, QArg]] _arg_conf: QArgConf _nargs: int def __init__(self, sqlexpr: str, params: Optional[Mapping[str, Any]]): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. """ self._params = params self._arg_type_list = [] self._arg_value_list = [] self._sql_parts = [] self._arg_conf = QArgConf() self._nargs = 0 if sqlexpr: self.add(sqlexpr, required=True) def add(self, expr: str, sql_type: str = "text", required: bool = False) -> None: """Add SQL fragment to query. """ self._add_expr('', expr, self._params, sql_type, required) def get_sql(self, param_type: int = PARAM_INLINE) -> str: """Return generated SQL (thus far) as string. Possible values for param_type: - 0: Insert values quoted with quote_literal() - 1: Insert %()s in place of parameters. - 2: Insert $n in place of parameters. """ self._arg_conf.param_type = param_type tmp = [str(part) for part in self._sql_parts] return "".join(tmp) def _add_expr(self, pfx: str, expr: str, params: Optional[Mapping[str, Any]], sql_type: str, required: bool) -> None: parts: List[Union[str, QArg]] = [] types: List[str] = [] values: List[Any] = [] nargs = self._nargs if pfx: parts.append(pfx) pos = 0 while True: # find next argument m = _RC_PARAM.search(expr, pos) if not m: parts.append(expr[pos:]) break # add plain sql parts.append(expr[pos:m.start()]) pos = m.end() # get arg name and type kparam, ktype, alt_frag, tag = m.groups() if not kparam or not tag: raise ValueError("invalid tag syntax: <%s>" % m.group(0)) if not ktype: ktype = sql_type # params==None means params are checked later if params is None: if alt_frag is not None: raise ValueError("alt_frag not supported with params=None") elif kparam not in params: if alt_frag is not None: parts.append(alt_frag) continue elif required: raise Exception("required parameter missing: " + kparam) # optional fragment, param missing, skip it return # got arg nargs += 1 if params is not None: val = params[kparam] else: val = kparam values.append(val) types.append(ktype) arg = QArg(kparam, val, nargs, self._arg_conf) parts.append(arg) # add interesting parts to the main sql self._sql_parts.extend(parts) if types: self._arg_type_list.extend(types) if values: self._arg_value_list.extend(values) self._nargs = nargs class QueryBuilder(QueryBuilderCore): def execute(self, curs: Cursor) -> None: """Client-side query execution on DB-API 2.0 cursor. Calls C{curs.execute()} with proper arguments. Does not return anything, curs.fetch* methods must be called to get result. """ q = self.get_sql(PARAM_DBAPI) args = self._params curs.execute(q, args) class PLPyQueryBuilder(QueryBuilderCore): def __init__(self, sqlexpr: str, params: Optional[Mapping[str, Any]], plan_cache: Optional[Dict[str, Any]] = None, sqls: Optional[List[Dict[str, str]]] = None): """Init the object. @param sqlexpr: Partial sql fragment. @param params: Dict of parameter values. @param plan_cache: (PL/Python) A dict object where to store the plan cache, under the key C{"plan_cache"}. If not given, plan will not be cached and values will be inserted directly to query. Usually either C{GD} or C{SD} should be given here. @param sqls: list object where to append executed sqls (used for debugging) """ super().__init__(sqlexpr, params) self._sqls = sqls if plan_cache is not None: if 'plan_cache' not in plan_cache: plan_cache['plan_cache'] = PlanCache() self._plan_cache = plan_cache['plan_cache'] else: self._plan_cache = None def execute(self) -> List[skytools.dbdict]: """Server-side query execution via plpy. Query can be run either cached or uncached, depending on C{plan_cache} setting given to L{__init__}. Returns result of plpy.execute(). """ args = self._arg_value_list types = self._arg_type_list if self._sqls is not None: self._sqls.append({"sql": self.get_sql(PARAM_INLINE)}) sql = self.get_sql(PARAM_PLPY) if self._plan_cache is not None: plan = self._plan_cache.get_plan(sql, types) else: plan = plpy.prepare(sql, types) res = plpy.execute(plan, args) if res: return [skytools.dbdict(r) for r in res] else: return [] class PLPyQuery: """Static, cached PL/Python query that uses QueryBuilder formatting. See L{plpy_exec} for simple usage. """ def __init__(self, sql: str) -> None: qb = QueryBuilder(sql, None) p_sql = qb.get_sql(PARAM_PLPY) p_types = qb._arg_type_list self.plan = plpy.prepare(p_sql, p_types) self.arg_map = qb._arg_value_list self.sql = sql def execute(self, arg_dict: Optional[Mapping[str, Any]], all_keys_required: bool = True) -> List[skytools.dbdict]: if arg_dict is None: arg_dict = {} try: if all_keys_required: arg_list = [arg_dict[k] for k in self.arg_map] else: arg_list = [arg_dict.get(k) for k in self.arg_map] res = plpy.execute(self.plan, arg_list) if res: return [skytools.dbdict(row) for row in res] return [] except KeyError: need = set(self.arg_map) got = set(arg_dict.keys()) missing = list(need.difference(got)) plpy.error("Missing arguments: [%s] QUERY: %s" % ( ','.join(missing), repr(self.sql))) raise ValueError("unreachable") from None def __repr__(self) -> str: return 'PLPyQuery<%s>' % self.sql def plpy_exec(gd: Optional[Dict[str, Any]], sql: str, args: Optional[Mapping[str, Any]], all_keys_required: bool = True) -> List[skytools.dbdict]: """Cached plan execution for PL/Python. @param gd: dict to store cached plans under. If None, caching is disabled. @param sql: SQL statement to execute. @param args: dict of arguments to query. @param all_keys_required: if False, missing key is taken as NULL, instead of throwing error. """ if gd is None: return PLPyQueryBuilder(sql, args).execute() if 'plq_cache' not in gd: gd['plq_cache'] = {} cache = cast(Dict[str, PLPyQuery], gd['plq_cache']) try: sq = cache[sql] except KeyError: sq = PLPyQuery(sql) cache[sql] = sq return sq.execute(args, all_keys_required) # some helper functions for convenient sql execution def run_query(cur: Cursor, sql: str, params: Optional[Mapping[str, Any]] = None, **kwargs: Any ) -> List[skytools.dbdict]: """ Helper function if everything you need is just paramertisized execute Sets rows_found that is coneninet to use when you don't need result just want to know how many rows were affected """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) rows = cur.fetchall() # convert result rows to dbdict if rows: return [skytools.dbdict(r) for r in rows] return [] def run_query_row(cur: Cursor, sql: str, params: Optional[Mapping[str, Any]] = None, **kwargs: Any ) -> Optional[skytools.dbdict]: """ Helper function if everything you need is just paramertisized execute to fetch one row only. If not found none is returned """ params = params or kwargs rows = run_query(cur, sql, params) if len(rows) == 0: return None return rows[0] def run_lookup(cur: Cursor, sql: str, params: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> Any: """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs sql = QueryBuilder(sql, params).get_sql(0) cur.execute(sql) row = cur.fetchone() if row is None: return None return row[0] def run_exists(cur: Cursor, sql: str, params: Optional[Mapping[str, Any]] = None, **kwargs: Any) -> bool: """ Helper function to fetch one value Takes away all the hassle of preparing statements and processing returned result giving out just one value. """ params = params or kwargs val = run_lookup(cur, sql, params) return val is not None # fake plpy for testing class fake_plpy: log: List[str] = [] def prepare(self, sql: str, types: Sequence[str]) -> Tuple[str, str, Sequence[str]]: self.log.append("DBG: plpy.prepare(%s, %s)" % (repr(sql), repr(types))) return ('PLAN', sql, types) def execute(self, plan: Any, args: Any = ()) -> List[skytools.dbdict]: self.log.append("DBG: plpy.execute(%s, %s)" % (repr(plan), repr(args))) return [] def error(self, msg: str) -> None: self.log.append("DBG: plpy.error(%s)" % repr(msg)) raise Exception("plpy.error") # make plpy available if not plpy: plpy = fake_plpy() GD: Dict[str, Any] = {} python-skytools-3.9.2/skytools/quoting.py000066400000000000000000000132741447265566700207170ustar00rootroot00000000000000"""Various helpers for string quoting/unquoting. """ import json import re from typing import Any, Dict, Mapping, Match, Optional, Sequence, Union try: from skytools._cquoting import ( db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal, ) except ImportError: from skytools._pyquoting import ( db_urldecode, db_urlencode, quote_bytea_raw, quote_copy, quote_literal, unescape, unquote_literal, ) __all__ = ( # _pyqoting / _cquoting "db_urldecode", "db_urlencode", "quote_bytea_raw", "quote_copy", "quote_literal", "unescape", "unquote_literal", # local "quote_bytea_literal", "quote_bytea_copy", "quote_statement", "quote_ident", "quote_fqident", "quote_json", "unescape_copy", "unquote_ident", "unquote_fqident", "json_encode", "json_decode", "make_pgarray", ) # # SQL quoting # def quote_bytea_literal(s: Optional[bytes]) -> str: """Quote bytea for regular SQL.""" return quote_literal(quote_bytea_raw(s)) def quote_bytea_copy(s: Optional[bytes]) -> str: """Quote bytea for COPY.""" return quote_copy(quote_bytea_raw(s)) def quote_statement(sql: str, dict_or_list: Union[Mapping[str, Any], Sequence[Any]]) -> str: """Quote whole statement. Data values are taken from dict or list or tuple. """ if hasattr(dict_or_list, 'items'): qdict: Dict[str, str] = {} for k, v in dict_or_list.items(): qdict[k] = quote_literal(v) return sql % qdict else: qvals = [quote_literal(v) for v in dict_or_list] return sql % tuple(qvals) # reserved keywords (RESERVED_KEYWORD + TYPE_FUNC_NAME_KEYWORD + COL_NAME_KEYWORD) # same list as postgres quote_ident() _ident_kwmap = frozenset(""" all analyse analyze and any array as asc asymmetric authorization between bigint binary bit boolean both case cast char character check coalesce collate collation column concurrently constraint create cross current_catalog current_date current_role current_schema current_time current_timestamp current_user dec decimal default deferrable desc distinct do else end except exists extract false fetch float for foreign freeze from full grant greatest group grouping having ilike in initially inner inout int integer intersect interval into is isnull join lateral leading least left like limit localtime localtimestamp national natural nchar none normalize not notnull null nullif numeric offset on only or order out outer overlaps overlay placing position precision primary real references returning right row select session_user setof similar smallint some substring symmetric table tablesample then time timestamp to trailing treat trim true union unique user using values varchar variadic verbose when where window with xmlattributes xmlconcat xmlelement xmlexists xmlforest xmlnamespaces xmlparse xmlpi xmlroot xmlserialize xmltable """.split()) _ident_bad = re.compile(r"[^a-z0-9_]|^[0-9]") def quote_ident(s: str) -> str: """Quote SQL identifier. If is checked against weird symbols and keywords. """ if _ident_bad.search(s) or s in _ident_kwmap: s = '"%s"' % s.replace('"', '""') elif not s: return '""' return s def quote_fqident(s: str) -> str: """Quote fully qualified SQL identifier. The '.' is taken as namespace separator and all parts are quoted separately """ tmp = s.split('.', 1) if len(tmp) == 1: return 'public.' + quote_ident(s) return '.'.join([quote_ident(name) for name in tmp]) # # quoting for JSON strings # _jsre = re.compile(r'[\x00-\x1F\\/"]') _jsmap = { "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", "\\": "\\\\", '"': '\\"', "/": "\\/", # to avoid html attacks } def _json_quote_char(m: Match[str]) -> str: """Quote single char.""" c = m.group(0) try: return _jsmap[c] except KeyError: return r"\u%04x" % ord(c) def quote_json(s: Optional[str]) -> str: """JSON style quoting.""" if s is None: return "null" return '"%s"' % _jsre.sub(_json_quote_char, s) def unescape_copy(val: str) -> Optional[str]: r"""Removes C-style escapes, also converts "\N" to None. """ if val == r"\N": return None return unescape(val) def unquote_ident(val: str) -> str: """Unquotes possibly quoted SQL identifier. """ if len(val) > 1 and val[0] == '"' and val[-1] == '"': return val[1:-1].replace('""', '"') if val.find('"') > 0: raise Exception('unsupported syntax') return val.lower() def unquote_fqident(val: str) -> str: """Unquotes fully-qualified possibly quoted SQL identifier. """ tmp = val.split('.', 1) return '.'.join([unquote_ident(i) for i in tmp]) def json_encode(val: Any = None, **kwargs: Any) -> str: """Creates JSON string from Python object. """ return json.dumps(val or kwargs) def json_decode(s: str) -> Any: """Parses JSON string into Python object. """ return json.loads(s) # # Create Postgres array # # any chars not in "good" set? main bad ones: [ ,{}\"] _pgarray_bad_rx = r"[^0-9a-z_.%&=()<>*/+-]" _pgarray_bad_rc = re.compile(_pgarray_bad_rx) def _quote_pgarray_elem(value: Any) -> str: if value is None: return 'NULL' s = str(value) if _pgarray_bad_rc.search(s): s = s.replace('\\', '\\\\') return '"' + s.replace('"', r'\"') + '"' elif not s: return '""' return s def make_pgarray(lst: Sequence[Any]) -> str: r"""Formats Python list as Postgres array. Reverse of parse_pgarray(). """ items = [_quote_pgarray_elem(v) for v in lst] return '{' + ','.join(items) + '}' python-skytools-3.9.2/skytools/scripting.py000066400000000000000000001212061447265566700212260ustar00rootroot00000000000000"""Useful functions and classes for database scripts. """ import argparse import errno import logging import logging.config import logging.handlers import optparse import os import select import signal import sys import time from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, Callable, Type, cast import skytools import skytools.skylog from .basetypes import Connection, Runnable, Cursor, DictRow, ExecuteParams try: import skytools.installer_config default_skylog = skytools.installer_config.skylog except ImportError: default_skylog = 0 __all__ = ( 'BaseScript', 'UsageError', 'daemonize', 'DBScript', ) class UsageError(Exception): """User induced error.""" # # daemon mode # def daemonize() -> None: """Turn the process into daemon. Goes background and disables all i/o. """ # launch new process, kill parent pid = os.fork() if pid != 0: os._exit(0) # start new session os.setsid() # stop i/o fd = os.open("/dev/null", os.O_RDWR) os.dup2(fd, 0) os.dup2(fd, 1) os.dup2(fd, 2) if fd > 2: os.close(fd) # # Pidfile locking+cleanup & daemonization combined # def run_single_process(runnable: Runnable, daemon: bool, pidfile: Optional[str]) -> None: """Run runnable class, possibly daemonized, locked on pidfile.""" # check if another process is running if pidfile and os.path.isfile(pidfile): if skytools.signal_pidfile(pidfile, 0): print("Pidfile exists, another process running?") sys.exit(1) else: print("Ignoring stale pidfile") # daemonize if needed if daemon: daemonize() # clean only own pidfile own_pidfile = False try: if pidfile: data = str(os.getpid()) skytools.write_atomic(pidfile, data) own_pidfile = True runnable.run() finally: if own_pidfile and pidfile: try: os.remove(pidfile) except BaseException: pass # # logging setup # _log_config_done: int = 0 _log_init_done: Dict[str, int] = {} def _init_log(job_name: str, service_name: str, cf: skytools.Config, log_level: int, is_daemon: bool) -> logging.Logger: """Logging setup happens here.""" global _log_config_done got_skylog = 0 use_skylog = cf.getint("use_skylog", default_skylog) # if non-daemon, avoid skylog if script is running on console. # set use_skylog=2 to disable. if not is_daemon and use_skylog == 1: # pylint gets spooked by it's own stdout wrapper and refuses to shut down # about it. 'noqa' tells prospector to ignore all warnings here. if sys.stdout.isatty(): # noqa use_skylog = 0 # load logging config if needed if use_skylog and not _log_config_done: # python logging.config braindamage: # cannot specify external classess without such hack logging.skylog = skytools.skylog # type: ignore skytools.skylog.set_service_name(service_name, job_name) # load general config flist = cf.getlist('skylog_locations', ['skylog.ini', '~/.skylog.ini', '/etc/skylog.ini']) for fn in flist: fn = os.path.expanduser(fn) if os.path.isfile(fn): defs = {'job_name': job_name, 'service_name': service_name} logging.config.fileConfig(fn, defs, False) got_skylog = 1 break _log_config_done = 1 if not got_skylog: sys.stderr.write("skylog.ini not found!\n") sys.exit(1) # avoid duplicate logging init for job_name log = logging.getLogger(job_name) if job_name in _log_init_done: return log _log_init_done[job_name] = 1 # tune level on root logger root = logging.getLogger() root.setLevel(log_level) # compatibility: specify ini file in script config def_fmt = '%(asctime)s %(process)s %(levelname)s %(message)s' def_datefmt = '' # None logfile = cf.getfile("logfile", "") if logfile: fstr = cf.get('logfmt_file', def_fmt) fstr_date = cf.get('logdatefmt_file', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_file_verbose', fstr) fstr_date = cf.get('logdatefmt_file_verbose', fstr_date) fmt = logging.Formatter(fstr, fstr_date) size = cf.getint('log_size', 10 * 1024 * 1024) num = cf.getint('log_count', 3) file_hdlr = logging.handlers.RotatingFileHandler( logfile, 'a', size, num) file_hdlr.setFormatter(fmt) root.addHandler(file_hdlr) # if skylog.ini is disabled or not available, log at least to stderr if not got_skylog: fstr = cf.get('logfmt_console', def_fmt) fstr_date = cf.get('logdatefmt_console', def_datefmt) if log_level < logging.INFO: fstr = cf.get('logfmt_console_verbose', fstr) fstr_date = cf.get('logdatefmt_console_verbose', fstr_date) stream_hdlr = logging.StreamHandler() fmt = logging.Formatter(fstr, fstr_date) stream_hdlr.setFormatter(fmt) root.addHandler(stream_hdlr) return log class BaseScript: """Base class for service scripts. Handles logging, daemonizing, config, errors. Config template:: ## Parameters for skytools.BaseScript ## # how many seconds to sleep between work loops # if missing or 0, then instead sleeping, the script will exit loop_delay = 1.0 # where to log logfile = ~/log/%(job_name)s.log # where to write pidfile pidfile = ~/pid/%(job_name)s.pid # per-process name to use in logging #job_name = %(config_name)s # whether centralized logging should be used # search-path [ ./skylog.ini, ~/.skylog.ini, /etc/skylog.ini ] # 0 - disabled # 1 - enabled, unless non-daemon on console (os.isatty()) # 2 - always enabled #use_skylog = 0 # where to find skylog.ini #skylog_locations = skylog.ini, ~/.skylog.ini, /etc/skylog.ini # how many seconds to sleep after catching a exception #exception_sleep = 20 """ service_name: str job_name: str cf: "skytools.Config" go_daemon: bool need_reload: bool cf_defaults: Dict[str, str] = {} pidfile: Optional[str] = None # >0 - sleep time if work() requests sleep # 0 - exit if work requests sleep # <0 - run work() once [same as looping=0] loop_delay: float = 1.0 # 0 - run work() once # 1 - run work() repeatedly looping: int = 1 # result from last work() call: # 1 - there is probably more work, don't sleep # 0 - no work, sleep before calling again # -1 - exception was thrown work_state: int = 1 # setup logger here, this allows override by subclass log = logging.getLogger('skytools.BaseScript') # start time started: float = 0 # set to True to use argparse ARGPARSE: bool = False def __init__(self, service_name: str, args: Sequence[str]) -> None: """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset(), shutdown() and init_optparse(). NB: In case of daemon, __init__() and startup()/work()/shutdown() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.service_name = service_name self.go_daemon = False self.need_reload = False self.exception_count = 0 self.stat_dict: Dict[str, float] = {} self.log_level = logging.INFO # parse command line self.options, self.args = self.parse_args(args) # check args if self.options.version: self.print_version() sys.exit(0) if self.options.daemon: self.go_daemon = True if self.options.quiet: self.log_level = logging.WARNING if self.options.verbose: if self.options.verbose > 1: self.log_level = skytools.skylog.TRACE else: self.log_level = logging.DEBUG self.cf_override = {} if self.options.set: for a in self.options.set: k, v = a.split('=', 1) self.cf_override[k.strip()] = v.strip() if self.options.ini: self.print_ini() sys.exit(0) # read config file self.reload() # init logging _init_log(self.job_name, self.service_name, self.cf, self.log_level, self.go_daemon) # send signal, if needed if self.options.cmd == "kill": self.send_signal(signal.SIGTERM) elif self.options.cmd == "stop": self.send_signal(signal.SIGINT) elif self.options.cmd == "reload": self.send_signal(signal.SIGHUP) def parse_args(self, args: Sequence[str]) -> Tuple[Any, Sequence[str]]: if self.ARGPARSE: arg_parser = self.init_argparse() options = arg_parser.parse_args(args) args = getattr(options, "args", []) return options, args opt_parser = self.init_optparse() options2, args2 = opt_parser.parse_args(args) return options2, args2 def print_version(self) -> None: service = self.service_name ver = getattr(self, '__version__', None) if ver: service += ' version %s' % ver print('%s, Skytools version %s' % (service, getattr(skytools, '__version__'))) def print_ini(self) -> None: """Prints out ini file from doc string of the script of default for dbscript Used by --ini option on command line. """ # current service name print("[%s]\n" % self.service_name) # walk class hierarchy bases = [self.__class__] while len(bases) > 0: parents = [] for c in bases: for p in c.__bases__: if p not in parents: parents.append(p) doc = c.__doc__ if doc: self._print_ini_frag(doc) bases = parents def _print_ini_frag(self, doc: str) -> None: # use last '::' block as config template pos = doc and doc.rfind('::\n') or -1 if pos < 0: return doc = doc[pos + 2:].rstrip() doc = skytools.dedent(doc) # merge overrided options into output for ln in doc.splitlines(): vals = ln.split('=', 1) if len(vals) != 2: print(ln) continue k = vals[0].strip() v = vals[1].strip() if k and k[0] == '#': print(ln) k = k[1:] if k in self.cf_override: print('%s = %s' % (k, self.cf_override[k])) elif k in self.cf_override: if v: print('#' + ln) print('%s = %s' % (k, self.cf_override[k])) else: print(ln) print('') def load_config(self) -> skytools.Config: """Loads and returns skytools.Config instance. By default it uses first command-line argument as config file name. Can be overridden. """ if len(self.args) < 1: print("need config file, use --help for help.") sys.exit(1) conf_file = self.args[0] return skytools.Config(self.service_name, conf_file, user_defs=self.cf_defaults, override=self.cf_override) def init_optparse(self, parser: Optional[optparse.OptionParser] = None) -> optparse.OptionParser: """Initialize a OptionParser() instance that will be used to parse command line arguments. Note that it can be overridden both directions - either DBScript will initialize an instance and pass it to user code or user can initialize and then pass to DBScript.init_optparse(). @param parser: optional OptionParser() instance, where DBScript should attach its own arguments. @return: initialized OptionParser() instance. """ if parser: p = parser else: p = optparse.OptionParser() p.set_usage("%prog [options] INI") # generic options p.add_option("-q", "--quiet", action="store_true", help="log only errors and warnings") p.add_option("-v", "--verbose", action="count", help="log verbosely") p.add_option("-d", "--daemon", action="store_true", help="go background") p.add_option("-V", "--version", action="store_true", help="print version info and exit") p.add_option("", "--ini", action="store_true", help="display sample ini file") p.add_option("", "--set", action="append", help="override config setting (--set 'PARAM=VAL')") # control options g = optparse.OptionGroup(p, 'control running process') g.add_option("-r", "--reload", action="store_const", const="reload", dest="cmd", help="reload config (send SIGHUP)") g.add_option("-s", "--stop", action="store_const", const="stop", dest="cmd", help="stop program safely (send SIGINT)") g.add_option("-k", "--kill", action="store_const", const="kill", dest="cmd", help="kill program immediately (send SIGTERM)") p.add_option_group(g) return p def init_argparse(self, parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: """Initialize a ArgumentParser() instance that will be used to parse command line arguments. Note that it can be overridden both directions - either BaseScript will initialize an instance and pass it to user code or user can initialize and then pass to BaseScript.init_optparse(). @param parser: optional ArgumentParser() instance, where BaseScript should attach its own arguments. @return: initialized ArgumentParser() instance. """ if parser: p = parser else: p = argparse.ArgumentParser() # generic options p.add_argument("-q", "--quiet", action="store_true", help="log only errors and warnings") p.add_argument("-v", "--verbose", action="count", help="log verbosely") p.add_argument("-d", "--daemon", action="store_true", help="go background") p.add_argument("-V", "--version", action="store_true", help="print version info and exit") p.add_argument("--ini", action="store_true", help="display sample ini file") p.add_argument("--set", action="append", help="override config setting (--set 'PARAM=VAL')") p.add_argument("args", nargs="*") # control options g = p.add_argument_group('control running process') g.add_argument("-r", "--reload", action="store_const", const="reload", dest="cmd", help="reload config (send SIGHUP)") g.add_argument("-s", "--stop", action="store_const", const="stop", dest="cmd", help="stop program safely (send SIGINT)") g.add_argument("-k", "--kill", action="store_const", const="kill", dest="cmd", help="kill program immediately (send SIGTERM)") return p def send_signal(self, sig: int) -> None: if not self.pidfile: self.log.warning("No pidfile in config, nothing to do") elif os.path.isfile(self.pidfile): alive = skytools.signal_pidfile(self.pidfile, sig) if not alive: self.log.warning("pidfile exists, but process not running") else: self.log.warning("No pidfile, process not running") sys.exit(0) def set_single_loop(self, do_single_loop: int) -> None: """Changes whether the script will loop or not.""" if do_single_loop: self.looping = 0 else: self.looping = 1 def _boot_daemon(self) -> None: run_single_process(self, self.go_daemon, self.pidfile) def start(self) -> None: """This will launch main processing thread.""" if self.go_daemon: if not self.pidfile: self.log.error("Daemon needs pidfile") sys.exit(1) self.run_func_safely(self._boot_daemon) def stop(self) -> None: """Safely stops processing loop.""" self.looping = 0 def reload(self) -> None: "Reload config." # avoid double loading on startup if not getattr(self, "cf", None): self.cf = self.load_config() else: self.cf.reload() self.log.info("Config reloaded") self.job_name = self.cf.get("job_name") self.pidfile = self.cf.getfile("pidfile", '') self.loop_delay = self.cf.getfloat("loop_delay", self.loop_delay) self.exception_sleep = self.cf.getfloat("exception_sleep", 20) self.exception_quiet = self.cf.getlist("exception_quiet", []) self.exception_grace = self.cf.getfloat("exception_grace", 5 * 60) self.exception_reset = self.cf.getfloat("exception_reset", 15 * 60) def hook_sighup(self, sig: int, frame: Any) -> None: "Internal SIGHUP handler. Minimal code here." self.need_reload = True last_sigint: float = 0 def hook_sigint(self, sig: int, frame: Any) -> None: "Internal SIGINT handler. Minimal code here." self.stop() t = time.time() if t - self.last_sigint < 1: self.log.warning("Double ^C, fast exit") sys.exit(1) self.last_sigint = t def stat_get(self, key: str) -> Optional[float]: """Reads a stat value.""" try: return self.stat_dict[key] except KeyError: return None def stat_put(self, key: str, value: float) -> None: """Sets a stat value.""" self.stat_dict[key] = value def stat_increase(self, key: str, increase: float = 1) -> None: """Increases a stat value.""" try: self.stat_dict[key] += increase except KeyError: self.stat_dict[key] = increase def send_stats(self) -> None: "Send statistics to log." res = [] for k, v in self.stat_dict.items(): res.append("%s: %s" % (k, v)) if len(res) == 0: return logmsg = "{%s}" % ", ".join(res) self.log.info(logmsg) self.stat_dict = {} def reset(self) -> None: "Something bad happened, reset all state." pass def run(self) -> None: "Thread main loop." # run startup, safely self.run_func_safely(self.startup) while True: # reload config, if needed if self.need_reload: self.reload() self.need_reload = False # do some work work = self.run_once() if not self.looping or self.loop_delay < 0: break # remember work state if work: self.work_state = 1 if work > 0 else -1 else: self.work_state = 0 # should sleep? if not work: if self.loop_delay > 0: self.sleep(self.loop_delay) if not self.looping: break else: break # run shutdown, safely? self.shutdown() def run_once(self) -> int: state = self.run_func_safely(self.work, True) # send stats that was added self.send_stats() return state last_func_fail: Optional[float] = None def run_func_safely(self, func: Callable[[], Optional[int]], prefer_looping: bool = False) -> int: "Run users work function, safely." try: r = func() if self.last_func_fail and time.time() > self.last_func_fail + self.exception_reset: self.last_func_fail = None # set exception count to 0 after success self.exception_count = 0 return 1 if r else 0 except UsageError as d: self.log.error(str(d)) sys.exit(1) except MemoryError: try: # complex logging may not succeed self.log.exception("Job %s out of memory, exiting", self.job_name) except MemoryError: self.log.fatal("Out of memory") sys.exit(1) except SystemExit as d: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got SystemExit(%s), exiting", str(d)) self.reset() raise d except KeyboardInterrupt: self.send_stats() if prefer_looping and self.looping and self.loop_delay > 0: self.log.info("got KeyboardInterrupt, exiting") self.reset() sys.exit(1) except Exception as d: try: # this may fail too self.send_stats() except BaseException: pass if self.last_func_fail is None: self.last_func_fail = time.time() emsg = str(d).rstrip() self.reset() self.exception_hook(d, emsg) # reset and sleep self.reset() if prefer_looping and self.looping and self.loop_delay > 0: # increase exception count & sleep self.exception_count += 1 self.sleep_on_exception() return -1 sys.exit(1) def sleep(self, secs: float) -> None: """Make script sleep for some amount of time.""" try: time.sleep(secs) except IOError as ex: if ex.errno != errno.EINTR: raise def sleep_on_exception(self) -> None: """Make script sleep for some amount of time when an exception occurs. To implement more advance exception sleeping like exponential backoff you can override this method. Also note that you can use self.exception_count to track the number of consecutive exceptions. """ self.sleep(self.exception_sleep) def _is_quiet_exception(self, ex: Exception) -> bool: if "ALL" in self.exception_quiet or ex.__class__.__name__ in self.exception_quiet: if self.last_func_fail and time.time() < self.last_func_fail + self.exception_grace: return True return False def exception_hook(self, det: Exception, emsg: str) -> None: """Called on after exception processing. Can do additional logging. @param det: exception details @param emsg: exception msg """ lm = "Job %s crashed: %s" % (self.job_name, emsg) if self._is_quiet_exception(det): self.log.warning(lm) else: self.log.exception(lm) def work(self) -> Optional[int]: """Here should user's processing happen. Return value is taken as boolean - if true, the next loop starts immediately. If false, DBScript sleeps for a loop_delay. """ raise Exception("Nothing implemented?") def startup(self) -> None: """Will be called just before entering main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ self.started = time.time() # set signals if hasattr(signal, 'SIGHUP'): signal.signal(signal.SIGHUP, self.hook_sighup) if hasattr(signal, 'SIGINT'): signal.signal(signal.SIGINT, self.hook_sigint) def shutdown(self) -> None: """Will be called just after exiting main loop. In case of daemon, if will be called in same process as work(), unlike __init__(). """ pass # define some aliases (short-cuts / backward compatibility cruft) stat_add = stat_put # Old, deprecated function. stat_inc = stat_increase ## ## DBScript ## #: how old connections need to be closed DEF_CONN_AGE = 20 * 60 # 20 min class DBScript(BaseScript): """Base class for database scripts. Handles database connection state. Config template:: ## Parameters for skytools.DBScript ## # default lifetime for database connections (in seconds) #connection_lifetime = 1200 """ db_cache: Dict[str, "DBCachedConn"] _db_defaults: Dict[str, Mapping[str, int]] _listen_map: Dict[str, List[str]] def __init__(self, service_name: str, args: Sequence[str]) -> None: """Script setup. User class should override work() and optionally __init__(), startup(), reload(), reset() and init_optparse(). NB: in case of daemon, the __init__() and startup()/work() will be run in different processes. So nothing fancy should be done in __init__(). @param service_name: unique name for script. It will be also default job_name, if not specified in config. @param args: cmdline args (sys.argv[1:]), but can be overridden """ self.db_cache: Dict[str, DBCachedConn] = {} self._db_defaults = {} self._listen_map = {} # dbname: channel_list super().__init__(service_name, args) def connection_hook(self, dbname: str, conn: Connection) -> None: pass def set_database_defaults(self, dbname: str, **kwargs: int) -> None: self._db_defaults[dbname] = kwargs def add_connect_string_profile(self, connstr: str, profile: Optional[str]) -> str: """Add extra profile info to connect string. """ if profile: extra = self.cf.get("%s_extra_connstr" % profile, '') if extra: connstr += ' ' + extra return connstr def get_database(self, dbname: str, autocommit: int = 0, isolation_level: int = -1, cache: Optional[str] = None, connstr: Optional[str] = None, profile: Optional[str] = None) -> Connection: """Load cached database connection. User must not store it permanently somewhere, as all connections will be invalidated on reset. """ max_age = self.cf.getint('connection_lifetime', DEF_CONN_AGE) if not cache: cache = dbname params: Dict[str, int] = {} defs = self._db_defaults.get(cache, {}) for k in defs: params[k] = defs[k] if isolation_level >= 0: params['isolation_level'] = isolation_level elif autocommit: params['isolation_level'] = 0 elif params.get('autocommit', 0): params['isolation_level'] = 0 elif 'isolation_level' not in params: params['isolation_level'] = skytools.I_READ_COMMITTED if 'max_age' not in params: params['max_age'] = max_age if cache in self.db_cache: dbc = self.db_cache[cache] if connstr is None: connstr = self.cf.get(dbname, '') if connstr: connstr = self.add_connect_string_profile(connstr, profile) dbc.check_connstr(connstr) else: if not connstr: connstr = self.cf.get(dbname) connstr = self.add_connect_string_profile(connstr, profile) # connstr might contain password, it is not a good idea to log it filtered_connstr = connstr pos = connstr.lower().find('password') if pos >= 0: filtered_connstr = connstr[:pos] + ' [...]' self.log.debug("Connect '%s' to '%s'", cache, filtered_connstr) dbc = DBCachedConn(cache, connstr, params['max_age'], setup_func=self.connection_hook) self.db_cache[cache] = dbc clist = [] if cache in self._listen_map: clist = self._listen_map[cache] return dbc.get_connection(params['isolation_level'], clist) def close_database(self, dbname: str) -> None: """Explicitly close a cached connection. Next call to get_database() will reconnect. """ if dbname in self.db_cache: dbc = self.db_cache[dbname] dbc.reset() del self.db_cache[dbname] def reset(self) -> None: "Something bad happened, reset all connections." for dbc in self.db_cache.values(): dbc.reset() self.db_cache = {} super().reset() def run_once(self) -> int: state = super().run_once() # reconnect if needed for dbc in self.db_cache.values(): dbc.refresh() return state def exception_hook(self, det: Exception, emsg: str) -> None: """Log database and query details from exception.""" curs = getattr(det, 'cursor', None) conn = getattr(curs, 'connection', None) if conn: # db connection sql = getattr(curs, 'query', None) or '?' if isinstance(sql, bytes): sql = sql.decode('utf8') if len(sql) > 200: # avoid logging londiste huge batched queries sql = sql[:60] + " ..." lm = "Job %s got error on connection: %s. Query: %s" % ( self.job_name, emsg, sql) if self._is_quiet_exception(det): self.log.warning(lm) else: self.log.exception(lm) else: super().exception_hook(det, emsg) def sleep(self, secs: float) -> None: """Make script sleep for some amount of time.""" fdlist = [] for dbname in self._listen_map: if dbname not in self.db_cache: continue fd = self.db_cache[dbname].fileno() if fd is None: continue fdlist.append(fd) if not fdlist: return super().sleep(secs) try: if hasattr(select, 'poll'): p = select.poll() for fd in fdlist: p.register(fd, select.POLLIN) p.poll(int(secs * 1000)) else: select.select(fdlist, [], [], secs) except select.error: self.log.info('wait canceled') return None def _exec_cmd(self, curs: Cursor, sql: str, args: ExecuteParams, quiet: bool = False, prefix: Optional[str] = None) -> Tuple[bool, Sequence[DictRow]]: """Internal tool: Run SQL on cursor.""" if self.options.verbose: self.log.debug("exec_cmd: %s", skytools.quote_statement(sql, args)) _pfx = "" if prefix: _pfx = "[%s] " % prefix curs.execute(sql, args) ok = True rows = curs.fetchall() for row in rows: try: code = row['ret_code'] msg = row['ret_note'] except KeyError: self.log.error("Query does not conform to exec_cmd API:") self.log.error("SQL: %s", skytools.quote_statement(sql, args)) self.log.error("Row: %s", repr(dict(row.items()))) sys.exit(1) level = code // 100 if level == 1: self.log.debug("%s%d %s", _pfx, code, msg) elif level == 2: if quiet: self.log.debug("%s%d %s", _pfx, code, msg) else: self.log.info("%s%s", _pfx, msg) elif level == 3: self.log.warning("%s%s", _pfx, msg) else: self.log.error("%s%s", _pfx, msg) self.log.debug("Query was: %s", skytools.quote_statement(sql, args)) ok = False return (ok, rows) def _exec_cmd_many(self, curs: Cursor, sql: str, baseargs: List[Any], extra_list: Sequence[Any], quiet:bool=False, prefix:Optional[str]=None) -> Tuple[bool, Sequence[DictRow]]: """Internal tool: Run SQL on cursor multiple times.""" ok = True rows: List[DictRow] = [] for a in extra_list: (tmp_ok, tmp_rows) = self._exec_cmd(curs, sql, baseargs + [a], quiet, prefix) if not tmp_ok: ok = False rows += tmp_rows return (ok, rows) def exec_cmd(self, db_or_curs: Union[Connection, Cursor], q: str, args: ExecuteParams, commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]: """Run SQL on db with code/value error handling.""" db: Optional[Connection] curs: Cursor if hasattr(db_or_curs, 'cursor'): db = cast(Connection, db_or_curs) curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd(curs, q, args, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def exec_cmd_many(self, db_or_curs: Union[Connection, Cursor], sql: str, baseargs: List[Any], extra_list: Sequence[Any], commit: bool = True, quiet: bool = False, prefix: Optional[str] = None) -> Sequence[DictRow]: """Run SQL on db multiple times.""" if hasattr(db_or_curs, 'cursor'): db = cast(Connection, db_or_curs) curs = db.cursor() else: db = None curs = db_or_curs (ok, rows) = self._exec_cmd_many(curs, sql, baseargs, extra_list, quiet, prefix) if ok: if commit and db: db.commit() return rows else: if db: db.rollback() if self.options.verbose: raise Exception("db error") # error is already logged sys.exit(1) def execute_with_retry(self, dbname: str, stmt: str, args: List[Any], exceptions: Optional[Sequence[Type[Exception]]] = None) -> Tuple[int, Cursor]: """ Execute SQL and retry if it fails. Return number of retries and current valid cursor, or raise an exception. """ sql_retry = self.cf.getbool("sql_retry", False) sql_retry_max_count = self.cf.getint("sql_retry_max_count", 10) sql_retry_max_time = self.cf.getint("sql_retry_max_time", 300) sql_retry_formula_a = self.cf.getint("sql_retry_formula_a", 1) sql_retry_formula_b = self.cf.getint("sql_retry_formula_b", 5) sql_retry_formula_cap = self.cf.getint("sql_retry_formula_cap", 60) elist = tuple(exceptions) if exceptions else () stime = time.time() tried = 0 dbc: Optional[DBCachedConn] = None while True: try: if dbc is None: if dbname not in self.db_cache: self.get_database(dbname, autocommit=1) dbc = self.db_cache[dbname] if dbc.isolation_level != skytools.I_AUTOCOMMIT: raise skytools.UsageError("execute_with_retry: autocommit required") else: dbc.reset() curs = dbc.get_connection(dbc.isolation_level).cursor() curs.execute(stmt, args) break except elist as e: if not sql_retry or tried >= sql_retry_max_count or time.time() - stime >= sql_retry_max_time: raise self.log.info("Job %s got error on connection %s: %s", self.job_name, dbname, e) except BaseException: raise # y = a + bx , apply cap y = sql_retry_formula_a + sql_retry_formula_b * tried if sql_retry_formula_cap is not None and y > sql_retry_formula_cap: y = sql_retry_formula_cap tried += 1 self.log.info("Retry #%i in %i seconds ...", tried, y) self.sleep(y) return tried, curs def listen(self, dbname: str, channel: str) -> None: """Make connection listen for specific event channel. Listening will be activated on next .get_database() call. Basically this means that DBScript.sleep() will poll for events on that db connection, so when event appears, script will be woken up. """ if dbname not in self._listen_map: self._listen_map[dbname] = [] clist = self._listen_map[dbname] if channel not in clist: clist.append(channel) def unlisten(self, dbname: str, channel: str = '*') -> None: """Stop connection for listening on specific event channel. Listening will stop on next .get_database() call. """ if dbname not in self._listen_map: return if channel == '*': del self._listen_map[dbname] return clist = self._listen_map[dbname] try: clist.remove(channel) except ValueError: pass SetupFunc = Callable[[str, Connection], None] class DBCachedConn: """Cache a db connection.""" name: str loc: str conn: Optional[Connection] conn_time: float max_age: int isolation_level: int verbose: bool setup_func: Optional[SetupFunc] listen_channel_list: Sequence[str] def __init__(self, name: str, loc: str, max_age:int=DEF_CONN_AGE, verbose:bool=False, setup_func: Optional[SetupFunc] = None, channels: Sequence[str] = ()) -> None: self.name = name self.loc = loc self.conn = None self.conn_time = 0 self.max_age = max_age self.isolation_level = -1 self.verbose = verbose self.setup_func = setup_func self.listen_channel_list = [] def fileno(self) -> Optional[int]: if not self.conn: return None return self.conn.cursor().fileno() def get_connection(self, isolation_level:int=-1, listen_channel_list: Sequence[str]=()) -> Connection: # default isolation_level is READ COMMITTED if isolation_level < 0: isolation_level = skytools.I_READ_COMMITTED # new conn? if not self.conn: self.isolation_level = isolation_level conn = skytools.connect_database(self.loc) conn = skytools.connect_database(self.loc) conn.set_isolation_level(isolation_level) self.conn = conn self.conn_time = time.time() if self.setup_func: self.setup_func(self.name, conn) else: if self.isolation_level != isolation_level: raise Exception("Conflict in isolation_level") conn = self.conn self._sync_listen(listen_channel_list) # done return conn def _sync_listen(self, new_clist: Sequence[str]) -> None: if not new_clist and not self.listen_channel_list: return if not self.conn: return curs = self.conn.cursor() for ch in self.listen_channel_list: if ch not in new_clist: curs.execute("UNLISTEN %s" % skytools.quote_ident(ch)) for ch in new_clist: if ch not in self.listen_channel_list: curs.execute("LISTEN %s" % skytools.quote_ident(ch)) if self.isolation_level != skytools.I_AUTOCOMMIT: self.conn.commit() self.listen_channel_list = new_clist[:] def refresh(self) -> None: if not self.conn: return #for row in self.conn.notifies(): # if row[0].lower() == "reload": # self.reset() # return if not self.max_age: return if time.time() - self.conn_time >= self.max_age: self.reset() def reset(self) -> None: if not self.conn: return # drop reference conn = self.conn self.conn = None self.listen_channel_list = [] # close try: conn.close() except BaseException: pass def check_connstr(self, connstr: str) -> None: """Drop connection if connect string has changed. """ if self.loc != connstr: self.reset() python-skytools-3.9.2/skytools/skylog.py000066400000000000000000000264321447265566700205410ustar00rootroot00000000000000"""Our log handlers for Python's logging package. """ import logging import logging.handlers import os import socket import time from logging import LoggerAdapter from typing import Any, Optional, Union, Dict import skytools import skytools.tnetstrings from skytools.basetypes import Buffer __all__ = ['getLogger'] # add TRACE level TRACE = 5 logging.TRACE = TRACE # type: ignore logging.addLevelName(TRACE, 'TRACE') # extra info to be added to each log record _service_name = 'unknown_svc' _job_name = 'unknown_job' _hostname = socket.gethostname() try: _hostaddr = socket.gethostbyname(_hostname) except BaseException: _hostaddr = "0.0.0.0" _log_extra = { 'job_name': _job_name, 'service_name': _service_name, 'hostname': _hostname, 'hostaddr': _hostaddr, } def set_service_name(service_name: str, job_name: str) -> None: """Set info about current script.""" global _service_name, _job_name _service_name = service_name _job_name = job_name _log_extra['job_name'] = _job_name _log_extra['service_name'] = _service_name # Make extra fields available to all log records _old_factory = logging.getLogRecordFactory() def _new_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: record = _old_factory(*args, **kwargs) record.__dict__.update(_log_extra) return record logging.setLogRecordFactory(_new_factory) # configurable file logger class EasyRotatingFileHandler(logging.handlers.RotatingFileHandler): """Easier setup for RotatingFileHandler.""" def __init__(self, filename: str, maxBytes: int = 10 * 1024 * 1024, backupCount: int = 3) -> None: """Args same as for RotatingFileHandler, but in filename '~' is expanded.""" fn = os.path.expanduser(filename) super().__init__(fn, maxBytes=maxBytes, backupCount=backupCount) # send JSON message over UDP class UdpLogServerHandler(logging.handlers.DatagramHandler): """Sends log records over UDP to logserver in JSON format.""" # map logging levels to logserver levels _level_map = { logging.DEBUG: 'DEBUG', logging.INFO: 'INFO', logging.WARNING: 'WARN', logging.ERROR: 'ERROR', logging.CRITICAL: 'FATAL', } # JSON message template _log_template = '{\n\t'\ '"logger": "skytools.UdpLogServer",\n\t'\ '"timestamp": %.0f,\n\t'\ '"level": "%s",\n\t'\ '"thread": null,\n\t'\ '"message": %s,\n\t'\ '"properties": {"application":"%s", "apptype": "%s", "type": "sys", "hostname":"%s", "hostaddr": "%s"}\n'\ '}\n' # cut longer msgs MAXMSG = 1024 def makePickle(self, record: logging.LogRecord) -> bytes: """Create message in JSON format.""" # get & cut msg msg = self.format(record) if len(msg) > self.MAXMSG: msg = msg[:self.MAXMSG] txt_level = self._level_map.get(record.levelno, "ERROR") hostname = _hostname hostaddr = _hostaddr jobname = _job_name svcname = _service_name pkt = self._log_template % ( time.time() * 1000, txt_level, skytools.quote_json(msg), jobname, svcname, hostname, hostaddr ) return pkt.encode("utf8") def send(self, s: Buffer) -> None: """Disable socket caching.""" sock = self.makeSocket() sock.sendto(s, (self.host, self.port)) sock.close() # send TNetStrings message over UDP class UdpTNetStringsHandler(logging.handlers.DatagramHandler): """ Sends log records in TNetStrings format over UDP. """ # LogRecord fields to send send_fields = [ 'created', 'exc_text', 'levelname', 'levelno', 'message', 'msecs', 'name', 'hostaddr', 'hostname', 'job_name', 'service_name'] _udp_reset: float = 0 def makePickle(self, record: logging.LogRecord) -> bytes: """ Create message in TNetStrings format. """ msg = {} self.format(record) # render 'message' attribute and others for k in self.send_fields: msg[k] = record.__dict__[k] tnetstr = skytools.tnetstrings.dumps(msg) return tnetstr def send(self, s: Buffer) -> None: """ Cache socket for a moment, then recreate it. """ now = time.time() if now - 1 > self._udp_reset: if self.sock: self.sock.close() self.sock = self.makeSocket() self._udp_reset = now if self.sock: self.sock.sendto(s, (self.host, self.port)) class LogDBHandler(logging.handlers.SocketHandler): """Sends log records into PostgreSQL server. Additionally, does some statistics aggregating, to avoid overloading log server. It subclasses SocketHandler to get throtthling for failed connections. """ # map codes to string _level_map = { logging.DEBUG: 'DEBUG', logging.INFO: 'INFO', logging.WARNING: 'WARNING', logging.ERROR: 'ERROR', logging.CRITICAL: 'FATAL', } sock: Any closeOnError: bool connect_string: str stat_cache: Dict[str, Union[int, float]] stat_flush_period: float last_stat_flush: float def __init__(self, connect_string: str) -> None: """ Initializes the handler with a specific connection string. """ super().__init__("localhost", 1) self.closeOnError = True self.connect_string = connect_string self.stat_cache = {} self.stat_flush_period = 60 # send first stat line immediately self.last_stat_flush = 0 def createSocket(self) -> None: try: super().createSocket() except BaseException: self.sock = self.makeSocket() def makeSocket(self, timeout: float = 1) -> Any: """Create server connection. In this case its not socket but database connection.""" db = skytools.connect_database(self.connect_string) db.set_isolation_level(0) # autocommit return db def emit(self, record: logging.LogRecord) -> None: """Process log record.""" # we do not want log debug messages if record.levelno < logging.INFO: return try: self.process_rec(record) except (SystemExit, KeyboardInterrupt): raise except BaseException: self.handleError(record) def process_rec(self, record: logging.LogRecord) -> None: """Aggregate stats if needed, and send to logdb.""" # render msg msg = self.format(record) # dont want to send stats too ofter if record.levelno == logging.INFO and msg and msg[0] == "{": self.aggregate_stats(msg) if time.time() - self.last_stat_flush >= self.stat_flush_period: self.flush_stats(_job_name) return if record.levelno < logging.INFO: self.flush_stats(_job_name) # dont send more than one line ln = msg.find('\n') if ln > 0: msg = msg[:ln] txt_level = self._level_map.get(record.levelno, "ERROR") self.send_to_logdb(_job_name, txt_level, msg) def aggregate_stats(self, msg: str) -> None: """Sum stats together, to lessen load on logdb.""" msg = msg[1:-1] for rec in msg.split(", "): k, v = rec.split(": ") agg = self.stat_cache.get(k, 0) if v.find('.') >= 0: agg += float(v) else: agg += int(v) self.stat_cache[k] = agg def flush_stats(self, service: str) -> None: """Send acquired stats to logdb.""" res = [] for k, v in self.stat_cache.items(): res.append("%s: %s" % (k, str(v))) if len(res) > 0: logmsg = "{%s}" % ", ".join(res) self.send_to_logdb(service, "INFO", logmsg) self.stat_cache = {} self.last_stat_flush = time.time() def send_to_logdb(self, service: str, level: str, msg: str) -> None: """Actual sending is done here.""" if self.sock is None: self.createSocket() if self.sock: logcur = self.sock.cursor() query = "select * from log.add(%s, %s, %s)" logcur.execute(query, [level, service, msg]) # fix unicode bug in SysLogHandler class SysLogHandler(logging.handlers.SysLogHandler): """Fixes unicode bug in logging.handlers.SysLogHandler.""" # be compatible with both 2.6 and 2.7 socktype = socket.SOCK_DGRAM _udp_reset: float = 0 def _custom_format(self, record: logging.LogRecord) -> str: msg = self.format(record) + '\000' # We need to convert record level to lowercase, maybe this will # change in the future. prio = '<%d>' % self.encodePriority(self.facility, self.mapPriority(record.levelname)) msg = prio + msg return msg def emit(self, record: logging.LogRecord) -> None: """ Emit a record. The record is formatted, and then sent to the syslog server. If exception information is present, it is NOT sent to the server. """ xmsg = self._custom_format(record) msg = xmsg if isinstance(xmsg, bytes) else xmsg.encode("utf8") try: if self.unixsocket: try: self.socket.send(msg) # type: ignore[has-type] except socket.error: self._connect_unixsocket(self.address) # type: ignore[attr-defined] self.socket.send(msg) # type: ignore[has-type] elif self.socktype == socket.SOCK_DGRAM: now = time.time() if now - 1 > self._udp_reset: self.socket.close() # type: ignore[has-type] self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._udp_reset = now self.socket.sendto(msg, self.address) else: self.socket.sendall(msg) except (KeyboardInterrupt, SystemExit): raise except BaseException: self.handleError(record) class SysLogHostnameHandler(SysLogHandler): """Slightly modified standard SysLogHandler - sends also hostname and service type""" def _custom_format(self, record: logging.LogRecord) -> str: msg = self.format(record) format_string = '<%d> %s %s %s\000' msg = format_string % (self.encodePriority(self.facility, self.mapPriority(record.levelname)), _hostname, _service_name, msg) return msg # add missing aliases (that are in Logger class) if not hasattr(LoggerAdapter, 'fatal'): LoggerAdapter.fatal = LoggerAdapter.critical # type: ignore class SkyLogger(LoggerAdapter): """Adds API to existing Logger. """ def trace(self, msg: str, *args: Any, **kwargs: Any) -> None: """Log message with severity TRACE.""" self.log(TRACE, msg, *args, **kwargs) def getLogger(name: Optional[str] = None, **kwargs_extra: Any) -> SkyLogger: """Get logger with extra functionality. Adds additional log levels, and extra fields to log record. name - name for logging.getLogger() kwargs_extra - extra fields to add to log record """ log = logging.getLogger(name) return SkyLogger(log, kwargs_extra) python-skytools-3.9.2/skytools/sockutil.py000066400000000000000000000074121447265566700210630ustar00rootroot00000000000000"""Various low-level utility functions for sockets.""" import os import socket import sys from typing import Optional, Union try: import fcntl except ImportError: fcntl = None # type: ignore from .basetypes import FileDescriptorLike __all__ = ( 'set_tcp_keepalive', 'set_nonblocking', 'set_cloexec', ) SocketLike = Union[socket.socket, FileDescriptorLike] def set_tcp_keepalive(fd: SocketLike, keepalive: bool = True, tcp_keepidle: int = 4 * 60, tcp_keepcnt: int = 4, tcp_keepintvl: int = 15) -> None: """Turn on TCP keepalive. The fd can be either numeric or socket object with 'fileno' method. OS defaults for SO_KEEPALIVE=1: - Linux: (7200, 9, 75) - can configure all. - MacOS: (7200, 8, 75) - can configure only tcp_keepidle. - Win32: (7200, 5|10, 1) - can configure tcp_keepidle and tcp_keepintvl. Our defaults: (240, 4, 15). """ # usable on this OS? if not hasattr(socket, 'SO_KEEPALIVE') or not hasattr(socket, 'fromfd'): return # need socket object if isinstance(fd, socket.SocketType): s = fd else: if not isinstance(fd, int): fd = fd.fileno() s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) # skip if unix socket if getattr(socket, 'AF_UNIX', None): if not isinstance(s.getsockname(), tuple): return # no keepalive? if not keepalive: s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 0) return # basic keepalive s.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # detect available options TCP_KEEPCNT = getattr(socket, 'TCP_KEEPCNT', None) TCP_KEEPINTVL = getattr(socket, 'TCP_KEEPINTVL', None) TCP_KEEPIDLE = getattr(socket, 'TCP_KEEPIDLE', None) TCP_KEEPALIVE = getattr(socket, 'TCP_KEEPALIVE', None) SIO_KEEPALIVE_VALS = getattr(socket, 'SIO_KEEPALIVE_VALS', None) if TCP_KEEPIDLE is None and TCP_KEEPALIVE is None and sys.platform == 'darwin': TCP_KEEPALIVE = 0x10 # configure if TCP_KEEPCNT is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPCNT, tcp_keepcnt) if TCP_KEEPINTVL is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPINTVL, tcp_keepintvl) if TCP_KEEPIDLE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPIDLE, tcp_keepidle) elif TCP_KEEPALIVE is not None: s.setsockopt(socket.IPPROTO_TCP, TCP_KEEPALIVE, tcp_keepidle) elif SIO_KEEPALIVE_VALS is not None and fcntl: fcntl.ioctl(s.fileno(), SIO_KEEPALIVE_VALS, (1, tcp_keepidle * 1000, tcp_keepintvl * 1000)) # type: ignore def set_nonblocking(fd: SocketLike, onoff: Optional[bool] = True) -> Optional[bool]: """Toggle the O_NONBLOCK flag. If onoff==None then return current setting. Actual sockets from 'socket' module should use .setblocking() method, this is for situations where it is not available. Eg. pipes from 'subprocess' module. """ if fcntl is None: return onoff flags = fcntl.fcntl(fd, fcntl.F_GETFL) if onoff is None: return (flags & os.O_NONBLOCK) > 0 if onoff: flags |= os.O_NONBLOCK else: flags &= ~os.O_NONBLOCK fcntl.fcntl(fd, fcntl.F_SETFL, flags) return onoff def set_cloexec(fd: SocketLike, onoff: Optional[bool] = True) -> bool: """Toggle the FD_CLOEXEC flag. If onoff==None then return current setting. Some libraries do it automatically (eg. libpq). Others do not (Python stdlib). """ if fcntl is None: return True flags = fcntl.fcntl(fd, fcntl.F_GETFD) if onoff is None: return (flags & fcntl.FD_CLOEXEC) > 0 if onoff: flags |= fcntl.FD_CLOEXEC else: flags &= ~fcntl.FD_CLOEXEC fcntl.fcntl(fd, fcntl.F_SETFD, flags) return onoff python-skytools-3.9.2/skytools/sqltools.py000066400000000000000000000503501447265566700211050ustar00rootroot00000000000000"""Database tools. """ import io import logging import os from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union, cast, Callable import skytools from .basetypes import Connection, Cursor __all__ = ( "fq_name_parts", "fq_name", "get_table_oid", "get_table_pkeys", "get_table_columns", "exists_schema", "exists_table", "exists_type", "exists_sequence", "exists_temp_table", "exists_view", "exists_function", "exists_language", "Snapshot", "magic_insert", "CopyPipe", "full_copy", "DBObject", "DBSchema", "DBTable", "DBFunction", "DBLanguage", "db_install", "installer_find_file", "installer_apply_file", "dbdict", "mk_insert_sql", "mk_update_sql", "mk_delete_sql", ) class dbdict(Dict[str, Any]): """Wrapper on actual dict that allows accessing dict keys as attributes. """ # obj.foo access def __getattr__(self, k: str) -> Any: "Return attribute." try: return self[k] except KeyError: raise AttributeError(k) from None def __setattr__(self, k: str, v: Any) -> None: "Set attribute." self[k] = v def __delattr__(self, k: str) -> None: "Remove attribute." del self[k] def merge(self, other: Dict[str, Any]) -> None: for key in other: if key not in self: self[key] = other[key] # # Fully qualified table name # def fq_name_parts(tbl: str) -> Tuple[str, str]: """Return fully qualified name parts. """ tmp = tbl.split('.', 1) if len(tmp) == 1: return ('public', tbl) return (tmp[0], tmp[1]) def fq_name(tbl: str) -> str: """Return fully qualified name. """ return '.'.join(fq_name_parts(tbl)) # # info about table # def get_table_oid(curs: Cursor, table_name: str) -> int: """Find Postgres OID for table.""" schema, name = fq_name_parts(table_name) q = """select c.oid from pg_namespace n, pg_class c where c.relnamespace = n.oid and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchall() if len(res) == 0: raise Exception('Table not found: ' + table_name) return cast(int, res[0][0]) def get_table_pkeys(curs: Cursor, tbl: str) -> List[str]: """Return list of pkey column names.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_index i, pg_attribute k"\ " WHERE i.indrelid = %s AND k.attrelid = i.indexrelid"\ " AND i.indisprimary AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] def get_table_columns(curs: Cursor, tbl: str) -> List[str]: """Return list of column names for table.""" oid = get_table_oid(curs, tbl) q = "SELECT k.attname FROM pg_attribute k"\ " WHERE k.attrelid = %s"\ " AND k.attnum > 0 AND NOT k.attisdropped"\ " ORDER BY k.attnum" curs.execute(q, [oid]) return [row[0] for row in curs.fetchall()] # # exist checks # def exists_schema(curs: Cursor, schema: str) -> bool: """Does schema exists?""" q = "select count(1) from pg_namespace where nspname = %s" curs.execute(q, [schema]) res = curs.fetchone() return bool(res[0]) def exists_table(curs: Cursor, table_name: str) -> bool: """Does table exists?""" schema, name = fq_name_parts(table_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'r' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return bool(res[0]) def exists_sequence(curs: Cursor, seq_name: str) -> bool: """Does sequence exists?""" schema, name = fq_name_parts(seq_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'S' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return bool(res[0]) def exists_view(curs: Cursor, view_name: str) -> bool: """Does view exists?""" schema, name = fq_name_parts(view_name) q = """select count(1) from pg_namespace n, pg_class c where c.relnamespace = n.oid and c.relkind = 'v' and n.nspname = %s and c.relname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return bool(res[0]) def exists_type(curs: Cursor, type_name: str) -> bool: """Does type exists?""" schema, name = fq_name_parts(type_name) q = """select count(1) from pg_namespace n, pg_type t where t.typnamespace = n.oid and n.nspname = %s and t.typname = %s""" curs.execute(q, [schema, name]) res = curs.fetchone() return bool(res[0]) def exists_function(curs: Cursor, function_name: str, nargs: int) -> bool: """Does function exists?""" # this does not check arg types, so may match several functions schema, name = fq_name_parts(function_name) q = """select count(1) from pg_namespace n, pg_proc p where p.pronamespace = n.oid and p.pronargs = %s and n.nspname = %s and p.proname = %s""" curs.execute(q, [nargs, schema, name]) res = curs.fetchone() # if unqualified function, check builtin functions too if not res[0] and function_name.find('.') < 0: name = "pg_catalog." + function_name return exists_function(curs, name, nargs) return bool(res[0]) def exists_language(curs: Cursor, lang_name: str) -> bool: """Does PL exists?""" q = """select count(1) from pg_language where lanname = %s""" curs.execute(q, [lang_name]) res = curs.fetchone() return bool(res[0]) def exists_temp_table(curs: Cursor, tbl: str) -> bool: """Does temp table exists?""" # correct way, works only on 8.2 q = "select 1 from pg_class where relname = %s and relnamespace = pg_my_temp_schema()" curs.execute(q, [tbl]) tmp = curs.fetchall() return len(tmp) > 0 # # Support for PostgreSQL snapshot # class Snapshot: """Represents a PostgreSQL snapshot. """ def __init__(self, str_val: str): "Create snapshot from string." self.sn_str = str_val tmp = str_val.split(':') if len(tmp) != 3: raise ValueError('Unknown format for snapshot') self.xmin = int(tmp[0]) self.xmax = int(tmp[1]) self.txid_list = [] if tmp[2] != "": for s in tmp[2].split(','): self.txid_list.append(int(s)) def contains(self, txid: int) -> bool: "Is txid visible in snapshot." txid = int(txid) if txid < self.xmin: return True if txid >= self.xmax: return False if txid in self.txid_list: return False return True # # Copy helpers # def _gen_dict_copy(tbl: str, row: Mapping[str, Any], fields: Sequence[str], qfields: Sequence[str]) -> str: tmp: List[str] = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_dict_insert(tbl: str, row: Mapping[str, Any], fields: Sequence[str], qfields: Sequence[str]) -> str: tmp: List[str] = [] for f in fields: v = row.get(f) tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) def _gen_list_copy(tbl: str, row: Sequence[Any], fields: Sequence[str], qfields: Sequence[str]) -> str: tmp: List[str] = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_copy(v)) return "\t".join(tmp) def _gen_list_insert(tbl: str, row: Sequence[Any], fields: Sequence[str], qfields: Sequence[str]) -> str: tmp: List[str] = [] for i in range(len(fields)): try: v = row[i] except IndexError: v = None tmp.append(skytools.quote_literal(v)) fmt = "insert into %s (%s) values (%s);" return fmt % (tbl, ",".join(qfields), ",".join(tmp)) DictRow = Mapping[str, Any] ListRow = Sequence[Any] DictRows = Sequence[DictRow] ListRows = Sequence[ListRow] def magic_insert(curs: Optional[Cursor], tablename: str, data: Union[ListRows, DictRows], fields: Optional[Sequence[str]] = None, use_insert: bool = False, quoted_table: bool = False) -> Optional[str]: r"""Copy/insert a list of dict/list data to database. If curs is None, then the copy or insert statements are returned as string. For list of dict the field list is optional, as its possible to guess them from dict keys. """ if len(data) == 0: return None if fields is not None: fields = list(fields) # get rid of iterator if quoted_table: qtablename = tablename else: qtablename = skytools.quote_fqident(tablename) # decide how to process if hasattr(data[0], 'keys'): if fields is None: fields = data[0].keys() # type: ignore if use_insert: row_func = _gen_dict_insert else: row_func = _gen_dict_copy else: if fields is None: raise Exception("Non-dict data needs field list") if use_insert: row_func = _gen_list_insert # type: ignore else: row_func = _gen_list_copy # type: ignore qfields = [skytools.quote_ident(f) for f in fields] # type: ignore # init processing buf = io.StringIO() if curs is None and use_insert == 0: fmt = "COPY %s (%s) FROM STDIN;\n" buf.write(fmt % (qtablename, ",".join(qfields))) # process data for row in data: buf.write(row_func(qtablename, row, fields, qfields)) # type: ignore buf.write("\n") # if user needs only string, return it if curs is None: if use_insert == 0: buf.write("\\.\n") return buf.getvalue() # do the actual copy/inserts if use_insert: curs.execute(buf.getvalue()) else: buf.seek(0) sql = "COPY %s (%s) FROM STDIN" % (qtablename, ",".join(qfields)) curs.copy_expert(sql, buf) return None # # Full COPY of table from one db to another # class CopyPipe(io.TextIOBase): """Splits one big COPY to chunks. """ tablename: Optional[str] sql_from: Optional[str] dstcurs: Cursor buf: io.StringIO limit: int write_hook: Optional[Callable[["CopyPipe", str], str]] flush_hook: Optional[Callable[["CopyPipe"], None]] total_rows: int total_bytes: int def __init__(self, dstcurs: Cursor, tablename: Optional[str] = None, limit: int = 512 * 1024, sql_from: Optional[str] = None): super().__init__() self.tablename = tablename self.sql_from = sql_from self.dstcurs = dstcurs self.buf = io.StringIO() self.limit = limit #hook for new data, hook func should return new data #def write_hook(obj, data): # return data self.write_hook = None #hook for flush, hook func result is discarded # def flush_hook(obj): # return None self.flush_hook = None self.total_rows = 0 self.total_bytes = 0 def write(self, data: str) -> int: """New row from psycopg """ if self.write_hook: data = self.write_hook(self, data) self.total_bytes += len(data) # it's chars now... self.total_rows += 1 n = self.buf.write(data) if self.buf.tell() >= self.limit: self.flush() return n def flush(self) -> None: """Send data out. """ if self.flush_hook: self.flush_hook(self) if self.buf.tell() <= 0: return self.buf.seek(0) if self.sql_from: sql = self.sql_from else: sql = "COPY %s FROM STDIN" % (self.tablename or "missing_table_name",) self.dstcurs.copy_expert(sql, self.buf) self.buf.seek(0) self.buf.truncate() def full_copy(tablename: str, src_curs: Cursor, dst_curs: Cursor, column_list: Sequence[str] = (), condition: Optional[str] = None, dst_tablename: Optional[str] = None, dst_column_list: Optional[Sequence[str]] = None, write_hook: Optional[Callable[[CopyPipe, str], str]] = None, flush_hook: Optional[Callable[[CopyPipe], None]] = None) -> Tuple[int, int]: """COPY table from one db to another.""" # default dst table and dst columns to source ones dst_tablename = dst_tablename or tablename dst_column_list = dst_column_list or column_list[:] if len(dst_column_list) != len(column_list): raise Exception('src and dst column lists must match in length') def build_qfields(cols: Sequence[str]) -> str: if cols: return ",".join([skytools.quote_ident(f) for f in cols]) else: return "*" def build_statement(table: str, cols: Sequence[str]) -> str: qtable = skytools.quote_fqident(table) if cols: qfields = build_qfields(cols) return "%s (%s)" % (qtable, qfields) else: return qtable dst = build_statement(dst_tablename, dst_column_list) if condition: src = "(SELECT %s FROM ONLY %s WHERE %s)" % (build_qfields(column_list), skytools.quote_fqident(tablename), condition) else: src = build_statement(tablename, column_list) sql_to = "COPY %s TO stdout" % src sql_from = "COPY %s FROM stdin" % dst buf = CopyPipe(dst_curs, sql_from=sql_from) buf.write_hook = write_hook buf.flush_hook = flush_hook src_curs.copy_expert(sql_to, buf) buf.flush() return (buf.total_bytes, buf.total_rows) # # SQL installer # class DBObject: """Base class for installable DB objects.""" name: str sql: Optional[str] = None sql_file: Optional[str] = None def __init__(self, name: str, sql: Optional[str] = None, sql_file: Optional[str] = None) -> None: """Generic dbobject init.""" self.name = name self.sql = sql self.sql_file = sql_file def create(self, curs: Cursor, log: Optional[logging.Logger] = None) -> None: """Create a dbobject.""" if log: log.info('Installing %s' % self.name) if self.sql: sql = self.sql elif self.sql_file: fn = self.find_file() if log: log.info(" Reading from %s" % fn) with open(fn, "r", encoding="utf8") as f: sql = f.read() else: raise Exception('object not defined') for stmt in skytools.parse_statements(sql): #if log: log.debug(repr(stmt)) curs.execute(stmt) def find_file(self) -> str: """Find install script file.""" if not self.sql_file: raise ValueError("sql_file not set") return installer_find_file(self.sql_file) def exists(self, curs: Cursor) -> int: raise NotImplementedError class DBSchema(DBObject): """Handles db schema.""" def exists(self, curs: Cursor) -> int: """Does schema exists.""" return exists_schema(curs, self.name) class DBTable(DBObject): """Handles db table.""" def exists(self, curs: Cursor) -> int: """Does table exists.""" return exists_table(curs, self.name) class DBFunction(DBObject): """Handles db function.""" def __init__(self, name: str, nargs: int, sql: Optional[str] = None, sql_file: Optional[str] = None) -> None: """Function object - number of args is significant.""" super().__init__(name, sql, sql_file) self.nargs = nargs def exists(self, curs: Cursor) -> int: """Does function exists.""" return exists_function(curs, self.name, self.nargs) class DBLanguage(DBObject): """Handles db language.""" def __init__(self, name: str) -> None: """PL object - creation happens with CREATE LANGUAGE.""" super().__init__(name, sql="create language %s" % name) def exists(self, curs: Cursor) -> int: """Does PL exists.""" return exists_language(curs, self.name) def db_install(curs: Cursor, obj_list: Sequence[DBObject], log: Optional[logging.Logger] = None) -> None: """Installs list of objects into db.""" for obj in obj_list: if not obj.exists(curs): obj.create(curs, log) else: if log: log.info('%s is installed' % obj.name) def installer_find_file(filename: str) -> str: """Find SQL script from pre-defined paths.""" full_fn = None if filename[0] == "/": if os.path.isfile(filename): full_fn = filename else: from skytools.installer_config import sql_locations dir_list = sql_locations for fdir in dir_list: fn = os.path.join(fdir, filename) if os.path.isfile(fn): full_fn = fn break if not full_fn: raise Exception('File not found: ' + filename) return full_fn def installer_apply_file(db: Connection, filename: str, log: logging.Logger) -> None: """Find SQL file and apply it to db, statement-by-statement.""" fn = installer_find_file(filename) with open(fn, "r", encoding="utf8") as f: sql = f.read() if log: log.info("applying %s" % fn) curs = db.cursor() for stmt in skytools.parse_statements(sql): #log.debug(repr(stmt)) curs.execute(stmt) # # Generate INSERT/UPDATE/DELETE statement # def mk_insert_sql(row: DictRow, tbl: str, pkey_list: Optional[Sequence[str]] = None, field_map: Optional[Mapping[str, str]] = None) -> str: """Generate INSERT statement from dict data. """ col_list = [] val_list = [] if field_map: for src, dst in field_map.items(): col_list.append(skytools.quote_ident(dst)) val_list.append(skytools.quote_literal(row[src])) else: for c, v in row.items(): col_list.append(skytools.quote_ident(c)) val_list.append(skytools.quote_literal(v)) col_str = ", ".join(col_list) val_str = ", ".join(val_list) return "insert into %s (%s) values (%s);" % ( skytools.quote_fqident(tbl), col_str, val_str) def mk_update_sql(row: DictRow, tbl: str, pkey_list: Sequence[str], field_map: Optional[Mapping[str, str]] = None) -> str: """Generate UPDATE statement from dict data. """ if len(pkey_list) < 1: raise Exception("update needs pkeys") set_list = [] whe_list = [] pkmap = {} for k in pkey_list: pkmap[k] = 1 new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) if field_map: for src, dst in field_map.items(): if src not in pkmap: col = skytools.quote_ident(dst) val = skytools.quote_literal(row[src]) set_list.append("%s = %s" % (col, val)) else: for col, val in row.items(): if col not in pkmap: col = skytools.quote_ident(col) val = skytools.quote_literal(val) set_list.append("%s = %s" % (col, val)) return "update only %s set %s where %s;" % (skytools.quote_fqident(tbl), ", ".join(set_list), " and ".join(whe_list)) def mk_delete_sql(row: DictRow, tbl: str, pkey_list: Sequence[str], field_map: Optional[Mapping[str, str]] = None) -> str: """Generate DELETE statement from dict data. """ if len(pkey_list) < 1: raise Exception("delete needs pkeys") whe_list = [] for k in pkey_list: new_k = field_map and field_map[k] or k col = skytools.quote_ident(new_k) val = skytools.quote_literal(row[k]) whe_list.append("%s = %s" % (col, val)) whe_str = " and ".join(whe_list) return "delete from only %s where %s;" % (skytools.quote_fqident(tbl), whe_str) python-skytools-3.9.2/skytools/timeutil.py000066400000000000000000000072231447265566700210620ustar00rootroot00000000000000"""Fill gaps in Python time API-s. parse_iso_timestamp: Parse reasonable subset of ISO_8601 timestamp formats. [ http://en.wikipedia.org/wiki/ISO_8601 ] datetime_to_timestamp: Get POSIX timestamp from datetime() object. """ import re import time from datetime import datetime, timedelta, tzinfo from typing import Optional, Pattern __all__ = ( 'parse_iso_timestamp', 'FixedOffsetTimezone', 'datetime_to_timestamp', ) class FixedOffsetTimezone(tzinfo): """Fixed offset in minutes east from UTC.""" __slots__ = ('__offset', '__name') __offset: timedelta __name: str def __init__(self, offset: int) -> None: super().__init__() self.__offset = timedelta(minutes=offset) # numeric tz name h, m = divmod(abs(offset), 60) if offset < 0: h = -h if m: self.__name = "%+03d:%02d" % (h, m) else: self.__name = "%+03d" % h def utcoffset(self, dt: Optional[datetime]) -> Optional[timedelta]: return self.__offset def tzname(self, dt: Optional[datetime]) -> Optional[str]: return self.__name def dst(self, dt: Optional[datetime]) -> Optional[timedelta]: return ZERO ZERO = timedelta(0) # # Parse ISO_8601 timestamps. # """ TODO: - support more combinations from ISO 8601 (only reasonable ones) - cache TZ objects - make it faster? """ _iso_regex = r""" \s* (?P \d\d\d\d) [-] (?P \d\d) [-] (?P \d\d) [ T] (?P \d\d) [:] (?P \d\d) (?: [:] (?P \d\d ) (?: [.,] (?P \d+))? )? (?: \s* (?P [-+]) (?P \d\d) (?: [:]? (?P \d\d))? | (?P Z ) )? \s* $ """ _iso_rc: Optional[Pattern[str]] = None def parse_iso_timestamp(s: str, default_tz: Optional[tzinfo] = None) -> datetime: """Parse ISO timestamp to datetime object. YYYY-MM-DD[ T]HH:MM[:SS[.ss]][-+HH[:MM]] Assumes that second fractions are zero-trimmed from the end, so '.15' means 150000 microseconds. If the timezone offset is not present, use default_tz as tzinfo. By default its None, meaning the datetime object will be without tz. Only fixed offset timezones are supported. """ global _iso_rc if _iso_rc is None: _iso_rc = re.compile(_iso_regex, re.X) m = _iso_rc.match(s) if not m: raise ValueError('Date not in ISO format: %s' % repr(s)) tz = default_tz if m.group('tzsign'): tzofs = int(m.group('tzhr')) * 60 if m.group('tzmin'): tzofs += int(m.group('tzmin')) if m.group('tzsign') == '-': tzofs = -tzofs tz = FixedOffsetTimezone(tzofs) elif m.group('tzname'): tz = UTC return datetime( int(m.group('year')), int(m.group('month')), int(m.group('day')), int(m.group('hour')), int(m.group('min')), m.group('sec') and int(m.group('sec')) or 0, m.group('ss') and int(m.group('ss').ljust(6, '0')) or 0, tz ) # # POSIX timestamp from datetime() # UTC = FixedOffsetTimezone(0) TZ_EPOCH = datetime.fromtimestamp(0, UTC) UTC_NOTZ_EPOCH = datetime.utcfromtimestamp(0) def datetime_to_timestamp(dt: datetime, local_time: bool = True) -> float: """Get posix timestamp from datetime() object. if dt is without timezone, then local_time specifies whether it's UTC or local time. Returns seconds since epoch as float. """ if dt.tzinfo: delta = dt - TZ_EPOCH return delta.total_seconds() elif local_time: s = time.mktime(dt.timetuple()) return s + (dt.microsecond / 1000000.0) else: delta = dt - UTC_NOTZ_EPOCH return delta.total_seconds() python-skytools-3.9.2/skytools/tnetstrings.py000066400000000000000000000070231447265566700216100ustar00rootroot00000000000000"""TNetStrings. """ import codecs from typing import Any, List __all__ = ['loads', 'dumps'] _memstr_types = (str, bytes, memoryview) _struct_types = (list, tuple, dict) _inttypes = (int,) _decode_utf8 = codecs.getdecoder('utf8') def _dumps(dst: List[bytes], val: Any) -> int: if isinstance(val, _struct_types): tlenpos = len(dst) tlen = 0 dst.append(b'') if isinstance(val, dict): for k in val: tlen += _dumps(dst, k) tlen += _dumps(dst, val[k]) dst.append(b'}') else: for v in val: tlen += _dumps(dst, v) dst.append(b']') dst[tlenpos] = b'%d:' % tlen return len(dst[tlenpos]) + tlen + 1 elif isinstance(val, _memstr_types): if isinstance(val, str): bval = val.encode('utf8') elif isinstance(val, bytes): bval = val else: bval = memoryview(val).tobytes() tval = b'%d:%s,' % (len(bval), bval) elif isinstance(val, bool): tval = val and b'4:true!' or b'5:false!' elif isinstance(val, _inttypes): bval = b'%d' % val tval = b'%d:%s#' % (len(bval), bval) elif isinstance(val, float): bval = b'%r' % val tval = b'%d:%s^' % (len(bval), bval) elif val is None: tval = b'0:~' else: raise TypeError("Object type not supported: %r" % val) dst.append(tval) return len(tval) def _loads(buf: memoryview) -> Any: pos = 0 maxlen = min(len(buf), 9) while buf[pos:pos + 1] != b':': pos += 1 if pos > maxlen: raise ValueError("Too large length") lenbytes = buf[: pos].tobytes() tlen = int(lenbytes) ofs = len(lenbytes) + 1 endofs = ofs + tlen val = buf[ofs: endofs] code = buf[endofs: endofs + 1] rest = buf[endofs + 1:] if len(val) + 1 != tlen + len(code): raise ValueError("failed to load value, invalid length") if code == b',': return _decode_utf8(val)[0], rest elif code == b'#': return int(val.tobytes(), 10), rest elif code == b'^': return float(val.tobytes()), rest elif code == b']': listobj = [] while val: elem, val = _loads(val) listobj.append(elem) return listobj, rest elif code == b'}': dictobj = {} while val: k, val = _loads(val) if not isinstance(k, str): raise ValueError("failed to load value, invalid key type") dictobj[k], val = _loads(val) return dictobj, rest elif code == b'!': if val == b'true': return True, rest if val == b'false': return False, rest raise ValueError("failed to load value, invalid boolean value") elif code == b'~': if val == b'': return None, rest raise ValueError("failed to load value, invalid null value") else: raise ValueError("failed to load value, invalid value code") # # Public API # def dumps(val: Any) -> bytes: """Dump object tree as TNetString value. """ dst: List[bytes] = [] _dumps(dst, val) return b''.join(dst) def loads(binval: bytes) -> Any: """Parse TNetstring from byte string. """ if not isinstance(binval, (bytes, memoryview)): raise TypeError("Bytes or memoryview required") obj, rest = _loads(memoryview(binval)) if rest: raise ValueError("Not all data processed") return obj # old compat? parse = loads dump = dumps python-skytools-3.9.2/skytools/utf8.py000066400000000000000000000045411447265566700201140ustar00rootroot00000000000000r"""UTF-8 sanitizer. Python's UTF-8 parser is quite relaxed, this creates problems when talking with other software that uses stricter parsers. """ import codecs import re from typing import Match, Optional, Pattern, Tuple __all__ = ('safe_utf8_decode', 'sanitize_unicode') # by default, use same symbol as 'replace' REPLACEMENT_SYMBOL = chr(0xFFFD) # 65533 _urc: Optional[Pattern[str]] = None def _fix_utf8(m: Match[str]) -> str: """Merge UTF16 surrogates, replace others""" u = m.group() if len(u) == 2: # merge into single symbol c1 = ord(u[0]) c2 = ord(u[1]) c = 0x10000 + ((c1 & 0x3FF) << 10) + (c2 & 0x3FF) return chr(c) else: # use replacement symbol return REPLACEMENT_SYMBOL def sanitize_unicode(u: str) -> str: """Fix invalid symbols in unicode string.""" global _urc if not isinstance(u, str): raise TypeError('Need unicode string') # regex for finding invalid chars, works on unicode string if not _urc: rx = "[\uD800-\uDBFF] [\uDC00-\uDFFF]? | [\0\uDC00-\uDFFF]" _urc = re.compile(rx, re.X) # now find and fix UTF16 surrogates m = _urc.search(u) if m: u = _urc.sub(_fix_utf8, u) return u def safe_replace(exc: UnicodeError) -> Tuple[str, int]: """Replace only one symbol at a time. Builtin .decode('xxx', 'replace') replaces several symbols together, which is unsafe. """ c2 = REPLACEMENT_SYMBOL # we could assume latin1 #if 0: # c1 = exc.object[exc.start] # c2 = chr(ord(c1)) assert isinstance(exc, UnicodeDecodeError) return c2, exc.start + 1 # register, it will be globally available codecs.register_error("safe_replace", safe_replace) def safe_utf8_decode(s: bytes) -> Tuple[bool, str]: """Decode UTF-8 safely. Acts like str.decode('utf8', 'replace') but also fixes UTF16 surrogates and NUL bytes, which Python's default decoder does not do. @param s: utf8-encoded byte string @return: tuple of (was_valid_utf8, unicode_string) """ # decode with error detection ok = True try: # expect no errors by default u = s.decode('utf8') except UnicodeDecodeError: u = s.decode('utf8', 'safe_replace') ok = False u2 = sanitize_unicode(u) if u is not u2: ok = False return (ok, u2) python-skytools-3.9.2/tests/000077500000000000000000000000001447265566700161235ustar00rootroot00000000000000python-skytools-3.9.2/tests/config.ini000066400000000000000000000007561447265566700201010ustar00rootroot00000000000000[base] foo = 1 bar = %(foo)s bool-true1 = 1 bool-true2 = true bool-false1 = 0 bool-false2 = false float-val = 2.0 list-val1 = list-val2 = a, 1, asd, ppp dict-val1 = dict-val2 = a : 1, b : 2, z file-val1 = - file-val2 = ~/foo bytes-val1 = 4 bytes-val2 = 2k wild-*-* = w2 wild-a-* = w.a wild-a-b = w.a.b vars1 = V2=%(vars2)s vars2 = V3=%(vars3)s vars3 = Q3 bad1 = B2=%(bad2)s bad2 = %(missing1)s %(missing2)s [DEFAULT] all = yes [other] test = try [testscript] opt = test db = python-skytools-3.9.2/tests/config2.ini000066400000000000000000000001771447265566700201600ustar00rootroot00000000000000[fmt1] foo = 1 bar = %(foo)s [fmt2] config_format = 2 foo = 1 bar1 = %(foo)s bar2 = ${foo} [fmt3] config_format = 3 foo = 1 python-skytools-3.9.2/tests/test_api.py000066400000000000000000000002621447265566700203050ustar00rootroot00000000000000 import skytools def test_version() -> None: a = skytools.natsort_key(getattr(skytools, "__version__")) assert a b = skytools.natsort_key('3.3') assert a >= b python-skytools-3.9.2/tests/test_config.py000066400000000000000000000135511447265566700210060ustar00rootroot00000000000000 import io import os.path import sys import pytest from skytools.config import ( Config, ConfigError, ExtendedCompatConfigParser, InterpolationError, NoOptionError, NoSectionError, ) TOP = os.path.dirname(__file__) CONFIG = os.path.join(TOP, 'config.ini') CONFIG2 = os.path.join(TOP, 'config2.ini') def test_config_str() -> None: cf = Config('base', CONFIG) assert cf.get('foo') == '1' assert cf.get('missing', 'q') == 'q' with pytest.raises(NoOptionError): cf.get('missing') def test_config_int() -> None: cf = Config('base', CONFIG) assert cf.getint('foo') == 1 assert cf.getint('missing', 2) == 2 with pytest.raises(NoOptionError): cf.getint('missing') def test_config_float() -> None: cf = Config('base', CONFIG) assert cf.getfloat('float-val') == 2.0 assert cf.getfloat('missing', 3.0) == 3.0 with pytest.raises(NoOptionError): cf.getfloat('missing') def test_config_bool() -> None: cf = Config('base', CONFIG) assert cf.getboolean('bool-true1') == True assert cf.getboolean('bool-true2') == True assert cf.getboolean('missing', True) == True with pytest.raises(NoOptionError): cf.getboolean('missing') assert cf.getboolean('bool-false1') == False assert cf.getboolean('bool-false2') == False assert cf.getboolean('missing', False) == False with pytest.raises(NoOptionError): cf.getbool('missing') def test_config_list() -> None: cf = Config('base', CONFIG) assert cf.getlist('list-val1') == [] assert cf.getlist('list-val2') == ['a', '1', 'asd', 'ppp'] assert cf.getlist('missing', ["a"]) == ["a"] with pytest.raises(NoOptionError): cf.getlist('missing') def test_config_dict() -> None: cf = Config('base', CONFIG) assert cf.getdict('dict-val1') == {} assert cf.getdict('dict-val2') == {'a': '1', 'b': '2', 'z': 'z'} assert cf.getdict('missing', {'a': '1'}) == {'a': '1'} with pytest.raises(NoOptionError): cf.getdict('missing') def test_config_file() -> None: cf = Config('base', CONFIG) assert cf.getfile('file-val1') == '-' if sys.platform != 'win32': assert cf.getfile('file-val2')[0] == '/' assert cf.getfile('missing', 'qwe') == 'qwe' with pytest.raises(NoOptionError): cf.getfile('missing') def test_config_bytes() -> None: cf = Config('base', CONFIG) assert cf.getbytes('bytes-val1') == 4 assert cf.getbytes('bytes-val2') == 2048 assert cf.getbytes('missing', '3k') == 3072 with pytest.raises(NoOptionError): cf.getbytes('missing') def test_config_wildcard() -> None: cf = Config('base', CONFIG) assert cf.get_wildcard('wild-*-*', ['a', 'b']) == 'w.a.b' assert cf.get_wildcard('wild-*-*', ['a', 'x']) == 'w.a' assert cf.get_wildcard('wild-*-*', ['q', 'b']) == 'w2' assert cf.get_wildcard('missing-*-*', ['1', '2'], 'def') == 'def' with pytest.raises(NoOptionError): cf.get_wildcard('missing-*-*', ['1', '2']) def test_config_default() -> None: cf = Config('base', CONFIG) assert cf.get('all') == 'yes' def test_config_other() -> None: cf = Config('base', CONFIG) assert sorted(cf.sections()) == ['base', 'other', 'testscript'] assert cf.has_section('base') == True assert cf.has_section('other') == True assert cf.has_section('missing') == False assert cf.has_section('DEFAULT') == False assert cf.has_option('missing') == False assert cf.has_option('all') == True assert cf.has_option('foo') == True cf2 = cf.clone('other') opts = list(sorted(cf2.options())) assert opts == [ 'all', 'config_dir', 'config_file', 'host_name', 'job_name', 'service_name', 'test' ] assert len(cf2.items()) == len(cf2.options()) def test_loading() -> None: with pytest.raises(NoSectionError): Config('random', CONFIG) with pytest.raises(NoSectionError): Config('random', CONFIG) with pytest.raises(ConfigError): Config('random', 'random.ini') def test_nofile() -> None: cf = Config('base', None, user_defs={'a': '1'}) assert cf.sections() == ['base'] assert cf.get('a') == '1' cf = Config('base', None, user_defs={'a': '1'}, ignore_defs=True) assert cf.get('a', '2') == '2' def test_override() -> None: cf = Config('base', CONFIG, override={'foo': 'overrided'}) assert cf.get('foo') == 'overrided' def test_vars() -> None: cf = Config('base', CONFIG) assert cf.get('vars1') == 'V2=V3=Q3' with pytest.raises(InterpolationError): cf.get('bad1') def test_extended_compat() -> None: config = u'[foo]\nkey = ${sub} $${nosub}\nsub = 2\n[bar]\nkey = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') assert cf.get('bar', 'key') == '2 ${nosub}' config = u'[foo]\nloop1= ${loop1}\nloop2 = ${loop3}\nloop3 = ${loop2}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') with pytest.raises(InterpolationError): cf.get('foo', 'loop1') with pytest.raises(InterpolationError): cf.get('foo', 'loop2') config = u'[foo]\nkey = %(sub)s ${sub}\nsub = 2\n[bar]\nkey = %(foo:key)s\nkey2 = ${foo:key}\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') assert cf.get('bar', 'key2') == '2 2' with pytest.raises(NoOptionError): cf.get('bar', 'key') config = u'[foo]\nkey = ${bad:xxx}\n[bad]\nsub = 1\n' cf = ExtendedCompatConfigParser() cf.read_file(io.StringIO(config), 'conf.ini') with pytest.raises(NoOptionError): cf.get('foo', 'key') def test_config_format() -> None: cf1 = Config("fmt1", CONFIG2) cf2 = Config("fmt2", CONFIG2) with pytest.raises(ConfigError): Config("fmt3", CONFIG2) assert cf1.get("bar") == "1" assert cf2.get("bar1") == "%(foo)s" assert cf2.get("bar2") == "1" python-skytools-3.9.2/tests/test_dbservice.py000066400000000000000000000010611447265566700215000ustar00rootroot00000000000000 from skytools.dbservice import transform_fields def test_transform_fields() -> None: rows = [] rows.append({'time': '22:00', 'metric': 'count', 'value': 100}) rows.append({'time': '22:00', 'metric': 'dur', 'value': 7}) rows.append({'time': '23:00', 'metric': 'count', 'value': 200}) rows.append({'time': '23:00', 'metric': 'dur', 'value': 5}) res = list(transform_fields(rows, ['time'], 'metric', 'value')) assert res[0] == {'count': 100, 'dur': 7, 'time': '22:00'} assert res[1] == {'count': 200, 'dur': 5, 'time': '23:00'} python-skytools-3.9.2/tests/test_fileutil.py000066400000000000000000000006101447265566700213460ustar00rootroot00000000000000 import os import tempfile from skytools.fileutil import write_atomic def test_write_atomic() -> None: pidfn = tempfile.mktemp('.pid') write_atomic(pidfn, "1") write_atomic(pidfn, "2") os.remove(pidfn) def test_write_atomic_bak() -> None: pidfn = tempfile.mktemp('.pid') write_atomic(pidfn, "1", '.bak') write_atomic(pidfn, "2", '.bak') os.remove(pidfn) python-skytools-3.9.2/tests/test_gzlog.py000066400000000000000000000011751447265566700206620ustar00rootroot00000000000000 import gzip import os import tempfile from skytools.gzlog import gzip_append def test_gzlog() -> None: fd, tmpname = tempfile.mkstemp(suffix='.gz') os.close(fd) try: blk = b'1234567890' * 100 write_total = 0 for i in range(5): gzip_append(tmpname, blk) write_total += len(blk) read_total = 0 with gzip.open(tmpname) as rfd: while True: blk = rfd.read(512) if not blk: break read_total += len(blk) finally: os.remove(tmpname) assert read_total == write_total python-skytools-3.9.2/tests/test_hashtext.py000066400000000000000000000047011447265566700213660ustar00rootroot00000000000000from skytools.hashtext import ( hashtext_new, hashtext_new_py, hashtext_old, hashtext_old_py, ) def test_hashtext_new_const() -> None: c0 = [hashtext_new_py(b'x' * (0 * 5 + j)) for j in range(5)] c1 = [hashtext_new_py(b'x' * (1 * 5 + j)) for j in range(5)] c2 = [hashtext_new_py(b'x' * (2 * 5 + j)) for j in range(5)] assert c0 == [-1477818771, 1074944137, -1086392228, -1992236649, -1379736791] assert c1 == [-370454118, 1489915569, -66683019, -2126973000, 1651296771] assert c2 == [755764456, -1494243903, 631527812, 28686851, -9498641] def test_hashtext_old_const() -> None: c0 = [hashtext_old_py(b'x' * (0 * 5 + j)) for j in range(5)] c1 = [hashtext_old_py(b'x' * (1 * 5 + j)) for j in range(5)] c2 = [hashtext_old_py(b'x' * (2 * 5 + j)) for j in range(5)] assert c0 == [-863449762, 37835117, 294739542, -320432768, 1007638138] assert c1 == [1422906842, -261065348, 59863994, -162804943, 1736144510] assert c2 == [-682756517, 317827663, -495599455, -1411793989, 1739997714] def test_hashtext_new_impl() -> None: data = b'HypficUjFitraxlumCitcemkiOkIkthi' p = [hashtext_new_py(data[:l]) for l in range(len(data) + 1)] c = [hashtext_new(data[:l]) for l in range(len(data) + 1)] assert p == c, '%s <> %s' % (p, c) def test_hashtext_old_impl() -> None: data = b'HypficUjFitraxlumCitcemkiOkIkthi' p = [hashtext_old_py(data[:l]) for l in range(len(data) + 1)] c = [hashtext_old(data[:l]) for l in range(len(data) + 1)] assert p == c, '%s <> %s' % (p, c) def test_hashtext_new_input_types() -> None: data = b'HypficUjFitraxlumCitcemkiOkIkthi' exp = hashtext_new(data) assert hashtext_new(data.decode("utf8")) == exp #assert hashtext_new(memoryview(data)) == exp #assert hashtext_new(bytearray(data)) == exp assert hashtext_new_py(data) == exp assert hashtext_new_py(data.decode("utf8")) == exp #assert hashtext_new_py(memoryview(data)) == exp #assert hashtext_new_py(bytearray(data)) == exp def test_hashtext_old_input_types() -> None: data = b'HypficUjFitraxlumCitcemkiOkIkthi' exp = hashtext_old(data) assert hashtext_old(data.decode("utf8")) == exp #assert hashtext_old(memoryview(data)) == exp #assert hashtext_old(bytearray(data)) == exp assert hashtext_old_py(data) == exp assert hashtext_old_py(data.decode("utf8")) == exp #assert hashtext_old_py(memoryview(data)) == exp #assert hashtext_old_py(bytearray(data)) == exp python-skytools-3.9.2/tests/test_kwcheck.py000066400000000000000000000046551447265566700211650ustar00rootroot00000000000000"""Check if SQL keywords are up-to-date. """ from typing import Dict, List, Tuple import os.path import re import skytools.quoting versions = [ "94", "95", "96", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19" ] locations = [ #"/usr/include/postgresql/{VER}/server/parser/kwlist.h", "~/src/pgsql/pg{VER}/src/include/parser/kwlist.h", #"~/src/pgsql/postgresql/src/include/parser/kwlist.h", ] known_cats = [ "COL_NAME_KEYWORD", "UNRESERVED_KEYWORD", "RESERVED_KEYWORD", "TYPE_FUNC_NAME_KEYWORD", ] def addkw(kmap: Dict[str, str], kw: str, cat: str) -> None: assert cat in known_cats if kw not in kmap: kmap[kw] = cat elif kmap[kw] != cat and cat not in kmap[kw]: kmap[kw] += "," + cat def _load_kwlist(fn: str, full_map: Dict[str, str], cur_map: Dict[str, str]) -> None: fn = os.path.expanduser(fn) if not os.path.isfile(fn): return with open(fn, 'rt') as f: data = f.read() rc = re.compile(r'PG_KEYWORD[(]"(.*)" , \s* \w+ , \s* (\w+) [)]', re.X) for kw, cat in rc.findall(data): addkw(full_map, kw, cat) if cat == 'UNRESERVED_KEYWORD': continue #if cat == 'COL_NAME_KEYWORD': # continue addkw(cur_map, kw, cat) def test_kwcheck() -> None: """Compare keyword list in quoting.py to the one in postgres sources """ kwset = set(skytools.quoting._ident_kwmap) full_map: Dict[str, str] = {} # all types from kwlist.h cur_map: Dict[str, str] = {} # only kwlist.h new_list: List[Tuple[str, str]] = [] # missing from kwset obsolete_list: List[Tuple[str, str]] = [] # in kwset, but not in cur_map done = set() for loc in locations: for ver in versions: fn = loc.format(VER=ver) if fn not in done: _load_kwlist(fn, full_map, cur_map) done.add(fn) if not full_map: return for kw in sorted(cur_map): if kw not in kwset: new_list.append((kw, cur_map[kw])) kwset.add(kw) for k in sorted(kwset): if k not in full_map: # especially obsolete obsolete_list.append((k, '!FULL')) elif k not in cur_map: # slightly obsolete obsolete_list.append((k, '!CUR')) assert new_list == [] # here we need to keep older keywords around longer #assert obsolete_list == [] python-skytools-3.9.2/tests/test_natsort.py000066400000000000000000000031101447265566700212210ustar00rootroot00000000000000 from skytools.natsort import ( natsort, natsort_icase, natsort_key, natsorted, natsorted_icase, ) def test_natsorted() -> None: res = natsorted(['1', 'ver-1.11', '', 'ver-1.0']) assert res == ['', '1', 'ver-1.0', 'ver-1.11'] def test_natsort() -> None: res = ['a1', '2a', '.1'] natsort(res) assert res == ['.1', '2a', 'a1'] def test_natsorted_icase() -> None: res = natsorted_icase(['Ver-1.1', 'vEr-1.11', '', 'veR-1.0']) assert res == ['', 'veR-1.0', 'Ver-1.1', 'vEr-1.11'] def test_natsort_icase() -> None: res = ['Ver-1.1', 'vEr-1.11', '', 'veR-1.0'] natsort_icase(res) assert res == ['', 'veR-1.0', 'Ver-1.1', 'vEr-1.11'] def _natcmp(a: str, b: str) -> str: k1 = natsort_key(a) k2 = natsort_key(b) if k1 < k2: return 'ok' return f"fail: a='{a}' > b='{b}'" def test_natsort_order() -> None: assert _natcmp('1', '2') == 'ok' assert _natcmp('2', '11') == 'ok' assert _natcmp('.', '1') == 'ok' assert _natcmp('1', 'a') == 'ok' assert _natcmp('a~1', 'ab') == 'ok' assert _natcmp('a~1', 'a') == 'ok' assert _natcmp('a~1', 'a1') == 'ok' assert _natcmp('00', '0') == 'ok' assert _natcmp('001', '0') == 'ok' assert _natcmp('0', '01') == 'ok' assert _natcmp('011', '02') == 'ok' assert _natcmp('00~1', '0~1') == 'ok' assert _natcmp('~~~', '~~') == 'ok' assert _natcmp('1~beta0', '1') == 'ok' assert _natcmp('1', '1.0') == 'ok' assert _natcmp('~', '') == 'ok' assert _natcmp('', '0') == 'ok' assert _natcmp('', '1') == 'ok' assert _natcmp('', 'a') == 'ok' python-skytools-3.9.2/tests/test_parsing.py000066400000000000000000000163551447265566700212110ustar00rootroot00000000000000 import pytest from skytools.parsing import ( dedent, hsize_to_bytes, merge_connect_string, parse_acl, parse_connect_string, parse_logtriga_sql, parse_pgarray, parse_sqltriga_sql, parse_statements, parse_tabbed_table, sql_tokenizer, ) def test_parse_pgarray() -> None: assert parse_pgarray('{}') == [] assert parse_pgarray('{a,b,null,"null"}') == ['a', 'b', None, 'null'] assert parse_pgarray(r'{"a,a","b\"b","c\\c"}') == ['a,a', 'b"b', 'c\\c'] assert parse_pgarray("[0,3]={1,2,3}") == ['1', '2', '3'] assert parse_pgarray(None) is None with pytest.raises(ValueError): parse_pgarray('}{') with pytest.raises(ValueError): parse_pgarray('[1]=}') with pytest.raises(ValueError): parse_pgarray('{"..." , }') with pytest.raises(ValueError): parse_pgarray('{"..." ; }') with pytest.raises(ValueError): parse_pgarray('{"}') with pytest.raises(ValueError): parse_pgarray('{"..."}zzz') with pytest.raises(ValueError): parse_pgarray('{"..."}z') def test_parse_sqltriga_sql() -> None: # Insert event row1 = parse_logtriga_sql('I', '(id, data) values (1, null)') assert row1 == {'data': None, 'id': '1'} row2 = parse_sqltriga_sql('I', '(id, data) values (1, null)', pklist=["id"]) assert row2 == {'data': None, 'id': '1'} # Update event row3 = parse_logtriga_sql('U', "data='foo' where id = 1") assert row3 == {'data': 'foo', 'id': '1'} # Delete event row4 = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'") assert row4 == {'id': '1', 'id2': "str'val"} # Insert event, splitkeys keys5, row5 = parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True) assert keys5 == {} assert row5 == {'data': None, 'id': '1'} keys6, row6 = parse_logtriga_sql('I', '(id, data) values (1, null)', splitkeys=True) assert keys6 == {} assert row6 == {'data': None, 'id': '1'} # Update event, splitkeys keys7, row7 = parse_logtriga_sql('U', "data='foo' where id = 1", splitkeys=True) assert keys7 == {'id': '1'} assert row7 == {'data': 'foo'} keys8, row8 = parse_logtriga_sql('U', "data='foo',type=3 where id = 1", splitkeys=True) assert keys8 == {'id': '1'} assert row8 == {'data': 'foo', 'type': '3'} # Delete event, splitkeys keys9, row9 = parse_logtriga_sql('D', "id = 1 and id2 = 'str''val'", splitkeys=True) assert keys9 == {'id': '1', 'id2': "str'val"} # generic with pytest.raises(ValueError): parse_logtriga_sql('J', "(id, data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id) values (1, null)") # insert errors with pytest.raises(ValueError): parse_logtriga_sql('I', "insert (id, data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id; data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) select (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values of (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values (1; null)") with pytest.raises(ValueError): parse_logtriga_sql('I', "(id, data) values (1, null) ;") with pytest.raises(ValueError, match="EOF"): parse_logtriga_sql('I', "(id, data) values (1, null) , insert") # update errors with pytest.raises(ValueError): parse_logtriga_sql('U', "(id,data) values (1, null)") with pytest.raises(ValueError): parse_logtriga_sql('U', "id,data") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo';type=3 where id = 1") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo' where id>1") with pytest.raises(ValueError): parse_logtriga_sql('U', "data='foo' where id=1 or true") # delete errors with pytest.raises(ValueError): parse_logtriga_sql('D', "foo,1") with pytest.raises(ValueError): parse_logtriga_sql('D', "foo = 1 ,") def test_parse_tabbed_table() -> None: assert parse_tabbed_table('col1\tcol2\nval1\tval2\n') == [ {'col1': 'val1', 'col2': 'val2'} ] # skip rows with different size assert parse_tabbed_table('col1\tcol2\nval1\tval2\ntmp\n') == [ {'col1': 'val1', 'col2': 'val2'} ] def test_sql_tokenizer() -> None: res = sql_tokenizer("select * from a.b", ignore_whitespace=True) assert list(res) == [ ('ident', 'select'), ('sym', '*'), ('ident', 'from'), ('ident', 'a'), ('sym', '.'), ('ident', 'b') ] res = sql_tokenizer("\"c olumn\",'str''val'") assert list(res) == [ ('ident', '"c olumn"'), ('sym', ','), ('str', "'str''val'") ] res = sql_tokenizer('a.b a."b "" c" a.1', fqident=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1') ] res = sql_tokenizer(r"set 'a''\' + E'\''", standard_quoting=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'set'), ('str', "'a''\\'"), ('sym', '+'), ('str', "E'\\''") ] res = sql_tokenizer('a.b a."b "" c" a.1', fqident=True, standard_quoting=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a.b'), ('ident', 'a."b "" c"'), ('ident', 'a'), ('sym', '.'), ('num', '1') ] res = sql_tokenizer('a.b\nc;', show_location=True, ignore_whitespace=True) assert list(res) == [ ('ident', 'a', 1), ('sym', '.', 2), ('ident', 'b', 3), ('ident', 'c', 5), ('sym', ';', 6) ] def test_parse_statements() -> None: res = parse_statements("begin; select 1; select 'foo'; end;") assert list(res) == ['begin;', 'select 1;', "select 'foo';", 'end;'] res = parse_statements("select (select 2+(select 3;);) ; select 4;") assert list(res) == ['select (select 2+(select 3;);) ;', 'select 4;'] with pytest.raises(ValueError): list(parse_statements('select ());')) with pytest.raises(ValueError): list(parse_statements('copy from stdin;')) def test_parse_acl() -> None: assert parse_acl('user=rwx/owner') == ('user', 'rwx', 'owner') assert parse_acl('" ""user"=rwx/" ""owner"') == (' "user', 'rwx', ' "owner') assert parse_acl('user=rwx') == ('user', 'rwx', None) assert parse_acl('=/f') == (None, '', 'f') # is this ok? assert parse_acl('?') is None def test_dedent() -> None: assert dedent(' Line1:\n Line 2\n') == 'Line1:\n Line 2\n' res = dedent(' \nLine1:\n Line 2\n Line 3\n Line 4') assert res == 'Line1:\nLine 2\n Line 3\n Line 4\n' def test_hsize_to_bytes() -> None: assert hsize_to_bytes('10G') == 10737418240 assert hsize_to_bytes('12k') == 12288 with pytest.raises(ValueError): hsize_to_bytes("x") def test_parse_connect_string() -> None: assert parse_connect_string("host=foo") == [('host', 'foo')] res = parse_connect_string(r" host = foo password = ' f\\\o\'o ' ") assert res == [('host', 'foo'), ('password', "' f\\o'o '")] with pytest.raises(ValueError): parse_connect_string(r" host = ") def test_merge_connect_string() -> None: res = merge_connect_string([('host', 'ip'), ('pass', ''), ('x', ' ')]) assert res == "host=ip pass='' x=' '" python-skytools-3.9.2/tests/test_querybuilder.py000066400000000000000000000074701447265566700222600ustar00rootroot00000000000000 from typing import Dict, Any import pytest from skytools.querybuilder import ( # type: ignore[attr-defined] PARAM_DBAPI, PARAM_INLINE, PARAM_PLPY, PlanCache, QueryBuilder, plpy, plpy_exec, ) def test_cached_plan() -> None: cache = PlanCache(3) p1 = cache.get_plan('sql1', ['text']) assert p1 is cache.get_plan('sql1', ['text']) p2 = cache.get_plan('sql1', ['int']) assert p2 is cache.get_plan('sql1', ['int']) assert p1 is not p2 p3 = cache.get_plan('sql3', ['text']) assert p3 is cache.get_plan('sql3', ['text']) p4 = cache.get_plan('sql4', ['text']) assert p4 is cache.get_plan('sql4', ['text']) p1x = cache.get_plan('sql1', ['text']) assert p1 is not p1x def test_querybuilder_core() -> None: args = {'success': 't', 'total': 45, 'ccy': 'EEK', 'id': 556} q = QueryBuilder("update orders set total = {total} where id = {id}", args) q.add(" and optional = {non_exist}") q.add(" and final = {success}") exp = "update orders set total = '45' where id = '556' and final = 't'" assert q.get_sql(PARAM_INLINE) == exp exp = "update orders set total = %s where id = %s and final = %s" assert q.get_sql(PARAM_DBAPI) == exp exp = "update orders set total = $1 where id = $2 and final = $3" assert q.get_sql(PARAM_PLPY) == exp def test_querybuilder_parse_errors() -> None: args = {'id': 1} with pytest.raises(ValueError): QueryBuilder("values ({{id)", args) with pytest.raises(ValueError): QueryBuilder("values ({id)", args) with pytest.raises(ValueError): QueryBuilder("values ({id::})", args) with pytest.raises(ValueError): QueryBuilder("values ({id||})", args) def test_querybuilder_inline() -> None: from decimal import Decimal args = { 'list': [1, 2], 'tup': ('s', 'x'), 'dict': {'a': 1}, 'none': None, 'bin': b'bin', 'str': 's', 'dec': Decimal('1.1'), } q = QueryBuilder("values ({list}, {dict}, {tup}, {none}, {str}, {bin}, {dec})", args) exp = r"""values ('{1,2}', '{"a": 1}', '{s,x}', null, 's', E'\\x62696e', '1.1')""" assert q.get_sql(PARAM_INLINE) == exp def test_querybuilder_altsql() -> None: args = {'id': 1} q = QueryBuilder("values ({id|XX}, {missing|DEFAULT})", args) exp = "values ('1', DEFAULT)" assert q.get_sql(PARAM_INLINE) == exp with pytest.raises(ValueError): QueryBuilder("values ({missing|DEFAULT})", None) def test_plpy_exec() -> None: GD: Dict[str, Any] = {} plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) assert plpy.log == [ "DBG: plpy.prepare('select $1, $2, $3', ['text', 'int4', 'text'])", "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['1', '2', '1'])", ] plpy.log.clear() plpy_exec(None, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '1', 'arg2': '2'}) assert plpy.log == [ "DBG: plpy.prepare('select $1, $2, $3', ['text', 'int4', 'text'])", "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['1', '2', '1'])", ] plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3', 'arg2': '4'}) assert plpy.log == [ "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', '4', '3'])" ] plpy.log.clear() with pytest.raises(Exception): plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}) assert plpy.log == [ """DBG: plpy.error("Missing arguments: [arg2] QUERY: 'select {arg1}, {arg2:int4}, {arg1}'")""" ] plpy.log.clear() plpy_exec(GD, "select {arg1}, {arg2:int4}, {arg1}", {'arg1': '3'}, False) assert plpy.log == [ "DBG: plpy.execute(('PLAN', 'select $1, $2, $3', ['text', 'int4', 'text']), ['3', None, '3'])" ] python-skytools-3.9.2/tests/test_quoting.py000066400000000000000000000223401447265566700212230ustar00rootroot00000000000000"""Extra tests for quoting module. """ from typing import Dict, Callable, Any, List, Type, Tuple, Union, Optional from decimal import Decimal try: import psycopg2.extras _have_psycopg2 = True except ImportError: _have_psycopg2 = False import pytest import skytools._cquoting import skytools._pyquoting from skytools.quoting import ( json_decode, json_encode, make_pgarray, quote_fqident, unescape_copy, unquote_fqident, ) QuoteTests = List[Tuple[Any, str]] QuoteFunc = Callable[[Any], str] UnquoteTests = List[Tuple[str, Optional[str]]] UnquoteFunc = Callable[[str], Optional[str]] class fake_cursor: """create a DictCursor row""" index = {'id': 0, 'data': 1} description = ['x', 'x'] if _have_psycopg2: dbrow = psycopg2.extras.DictRow(fake_cursor()) dbrow[0] = '123' dbrow[1] = 'value' def try_quote(func: QuoteFunc, data_list: QuoteTests) -> None: for val, exp in data_list: got = func(val) assert got == exp def try_unquote(func: UnquoteFunc, data_list: UnquoteTests) -> None: for val, exp in data_list: got = func(val) assert got == exp def test_quote_literal() -> None: sql_literal = [ (None, "null"), ("", "''"), ("a'b", "'a''b'"), (r"a\'b", r"E'a\\''b'"), (1, "'1'"), (True, "'True'"), (Decimal(1), "'1'"), (u'qwe', "'qwe'") ] def quote_literal_c(x: Any) -> str: return skytools._cquoting.quote_literal(x) def quote_literal_py(x: Any) -> str: return skytools._pyquoting.quote_literal(x) def quote_literal_default(x: Any) -> str: return skytools.quote_literal(x) try_quote(quote_literal_default, sql_literal) try_quote(quote_literal_default, sql_literal) try_quote(quote_literal_default, sql_literal) qliterals_common = [ (r"""null""", None), (r"""NULL""", None), (r"""123""", "123"), (r"""''""", r""""""), (r"""'a''b''c'""", r"""a'b'c"""), (r"""'foo'""", r"""foo"""), (r"""E'foo'""", r"""foo"""), (r"""E'a\n\t\a\b\0\z\'b'""", "a\n\t\x07\x08\x00z'b"), (r"""$$$$""", r""), (r"""$$qw$e$z$$""", r"qw$e$z"), (r"""$qq$$aa$$$'"\\$qq$""", '$aa$$$\'"\\\\'), (u"'qwe'", 'qwe'), ] bad_dol_literals = [ ('$$', '$$'), #('$$q', '$$q'), ('$$q$', '$$q$'), ('$q$q$', '$q$q$'), ('$q$q$x$', '$q$q$x$'), ] def test_unquote_literal() -> None: qliterals_nonstd = qliterals_common + [ (r"""'a\\b\\c'""", r"""a\b\c"""), (r"""e'a\\b\\c'""", r"""a\b\c"""), ] def unquote_literal_c(x: str) -> Optional[str]: return skytools._cquoting.unquote_literal(x) def unquote_literal_py(x: str) -> Optional[str]: return skytools._pyquoting.unquote_literal(x) def unquote_literal_default(x: str) -> Optional[str]: return skytools.unquote_literal(x) try_unquote(unquote_literal_c, qliterals_nonstd) try_unquote(unquote_literal_py, qliterals_nonstd) try_unquote(unquote_literal_default, qliterals_nonstd) for v1, v2 in bad_dol_literals: with pytest.raises(ValueError): skytools._pyquoting.unquote_literal(v1) with pytest.raises(ValueError): skytools._cquoting.unquote_literal(v1) with pytest.raises(ValueError): skytools.unquote_literal(v1) def test_unquote_literal_std() -> None: qliterals_std = qliterals_common + [ (r"''", r""), (r"'foo'", r"foo"), (r"E'foo'", r"foo"), (r"'\\''z'", r"\\'z"), ] for val, exp in qliterals_std: assert skytools._cquoting.unquote_literal(val, True) == exp assert skytools._pyquoting.unquote_literal(val, True) == exp assert skytools.unquote_literal(val, True) == exp def test_quote_copy() -> None: sql_copy = [ (None, "\\N"), ("", ""), ("a'\tb", "a'\\tb"), (r"a\'b", r"a\\'b"), (1, "1"), (True, "True"), (u"qwe", "qwe"), (Decimal(1), "1"), ] try_quote(skytools._cquoting.quote_copy, sql_copy) try_quote(skytools._pyquoting.quote_copy, sql_copy) try_quote(skytools.quote_copy, sql_copy) def test_quote_bytea_raw() -> None: sql_bytea_raw: List[Tuple[Optional[bytes], Optional[str]]] = [ (None, None), (b"", ""), (b"a'\tb", "a'\\011b"), (b"a\\'b", r"a\\'b"), (b"\t\344", r"\011\344"), ] for val, exp in sql_bytea_raw: assert skytools._cquoting.quote_bytea_raw(val) == exp assert skytools._pyquoting.quote_bytea_raw(val) == exp assert skytools.quote_bytea_raw(val) == exp def test_quote_bytea_raw_fail() -> None: with pytest.raises(TypeError): skytools._pyquoting.quote_bytea_raw(u'qwe') # type: ignore[arg-type] #assert_raises(TypeError, skytools._cquoting.quote_bytea_raw, u'qwe') #assert_raises(TypeError, skytools.quote_bytea_raw, 'qwe') def test_quote_ident() -> None: sql_ident = [ ('', '""'), ("a'\t\\\"b", '"a\'\t\\""b"'), ('abc_19', 'abc_19'), ('from', '"from"'), ('0foo', '"0foo"'), ('mixCase', '"mixCase"'), (u'utf', 'utf'), ] try_quote(skytools.quote_ident, sql_ident) def test_fqident() -> None: assert quote_fqident('tbl') == 'public.tbl' assert quote_fqident('Baz.Foo.Bar') == '"Baz"."Foo.Bar"' def _sort_urlenc(func: Callable[[Any], str]) -> Callable[[Any], str]: def wrapper(data: Any) -> str: res = func(data) return '&'.join(sorted(res.split('&'))) return wrapper def test_db_urlencode() -> None: t_urlenc = [ ({}, ""), ({'a': 1}, "a=1"), ({'a': None}, "a"), ({'qwe': 1, u'zz': u"qwe"}, 'qwe=1&zz=qwe'), ({'qwe': 1, u'zz': u"qwe"}, 'qwe=1&zz=qwe'), ({'a': '\000%&'}, "a=%00%25%26"), ({'a': Decimal("1")}, "a=1"), ] if _have_psycopg2: t_urlenc.append((dbrow, 'data=value&id=123')) try_quote(_sort_urlenc(skytools._cquoting.db_urlencode), t_urlenc) try_quote(_sort_urlenc(skytools._pyquoting.db_urlencode), t_urlenc) try_quote(_sort_urlenc(skytools.db_urlencode), t_urlenc) def test_db_urldecode() -> None: t_urldec = [ ("", {}), ("a=b&c", {'a': 'b', 'c': None}), ("&&b=f&&", {'b': 'f'}), (u"abc=qwe", {'abc': 'qwe'}), ("b=", {'b': ''}), ("b=%00%45", {'b': '\x00E'}), ] for val, exp in t_urldec: assert skytools._cquoting.db_urldecode(val) == exp assert skytools._pyquoting.db_urldecode(val) == exp assert skytools.db_urldecode(val) == exp def test_unescape() -> None: t_unesc: List[Tuple[str, str]] = [ ("", ""), ("\\N", "N"), ("abc", "abc"), (r"\0\000\001\01\1", "\0\000\001\001\001"), (r"a\001b\tc\r\n", "a\001b\tc\r\n"), ] for val, exp in t_unesc: assert skytools._cquoting.unescape(val) == exp assert skytools._pyquoting.unescape(val) == exp assert skytools.unescape(val) == exp def test_quote_bytea_literal() -> None: bytea_raw = [ (None, "null"), (b"", "''"), (b"a'\tb", "E'a''\\\\011b'"), (b"a\\'b", r"E'a\\\\''b'"), (b"\t\344", r"E'\\011\\344'"), ] try_quote(skytools.quote_bytea_literal, bytea_raw) def test_quote_bytea_copy() -> None: bytea_raw = [ (None, "\\N"), (b"", ""), (b"a'\tb", "a'\\\\011b"), (b"a\\'b", r"a\\\\'b"), (b"\t\344", r"\\011\\344"), ] try_quote(skytools.quote_bytea_copy, bytea_raw) def test_quote_statement() -> None: sql = "set a=%s, b=%s, c=%s" args = [None, u"qwe'qwe", 6.6] assert skytools.quote_statement(sql, args) == "set a=null, b='qwe''qwe', c='6.6'" sql2 = "set a=%(a)s, b=%(b)s, c=%(c)s" args2 = dict(a=None, b="qwe'qwe", c=6.6) assert skytools.quote_statement(sql2, args2) == "set a=null, b='qwe''qwe', c='6.6'" def test_quote_json() -> None: json_string_vals = [ (None, "null"), ('', '""'), (u'xx', '"xx"'), ('qwe"qwe\t', '"qwe\\"qwe\\t"'), ('\x01', '"\\u0001"'), ] try_quote(skytools.quote_json, json_string_vals) def test_unquote_ident() -> None: idents = [ ('qwe', 'qwe'), ('"qwe"', 'qwe'), ('"q""w\\\\e"', 'q"w\\\\e'), ('Foo', 'foo'), ('"Wei "" rd"', 'Wei " rd'), ] for val, exp in idents: assert skytools.unquote_ident(val) == exp def test_unquote_ident_fail() -> None: with pytest.raises(Exception): skytools.unquote_ident('asd"asd') def test_unescape_copy() -> None: assert unescape_copy(r'baz\tfo\'o') == "baz\tfo'o" assert unescape_copy(r'\N') is None def test_unquote_fqident() -> None: assert unquote_fqident('Foo') == 'foo' assert unquote_fqident('"Foo"."Bar "" z"') == 'Foo.Bar " z' def test_json_encode() -> None: assert json_encode({'a': 1}) == '{"a": 1}' assert json_encode('a') == '"a"' assert json_encode(['a']) == '["a"]' assert json_encode(a=1) == '{"a": 1}' def test_json_decode() -> None: assert json_decode('[1]') == [1] def test_make_pgarray() -> None: assert make_pgarray([]) == '{}' assert make_pgarray(['foo_3', 1, '', None]) == '{foo_3,1,"",NULL}' res = make_pgarray([None, ',', '\\', "'", '"', "{", "}", '_']) exp = '{NULL,",","\\\\","\'","\\"","{","}",_}' assert res == exp python-skytools-3.9.2/tests/test_scripting.py000066400000000000000000000107121447265566700215370ustar00rootroot00000000000000 import os import signal import sys import time import pathlib import pytest import skytools from skytools.scripting import run_single_process WIN32 = sys.platform == "win32" CONF = os.path.join(os.path.dirname(__file__), "config.ini") TEST_DB = os.environ.get("TEST_DB") def checklog(log: str, word: str) -> bool: with open(log, 'r') as f: return word in f.read() class Runner: def __init__(self, logfile: str, word: str, sleep: int = 0) -> None: self.logfile = logfile self.word = word self.sleep = sleep def run(self) -> None: with open(self.logfile, "a") as f: f.write(self.word + "\n") time.sleep(self.sleep) @pytest.mark.skipif(WIN32, reason="cannot daemonize on win32") def test_bg_process(tmp_path: pathlib.Path) -> None: pidfile = str(tmp_path / "proc.pid") logfile = str(tmp_path / "proc.log") run_single_process(Runner(logfile, "STEP1"), False, pidfile) while skytools.signal_pidfile(pidfile, 0): time.sleep(1) assert checklog(logfile, "STEP1") # daemonize from other process pid = os.fork() if pid == 0: run_single_process(Runner(logfile, "STEP2", 10), True, pidfile) time.sleep(2) with pytest.raises(SystemExit): run_single_process(Runner(logfile, "STEP3"), False, pidfile) skytools.signal_pidfile(pidfile, signal.SIGTERM) while skytools.signal_pidfile(pidfile, 0): time.sleep(1) assert checklog(logfile, "STEP2") assert not checklog(logfile, "STEP3") class OptScript(skytools.BaseScript): ARGPARSE = False looping = 0 def send_signal(self, code: int) -> None: print("signal: %s" % code) sys.exit(0) def work(self) -> None: print("opt=%s" % self.cf.get("opt")) class ArgScript(OptScript): ARGPARSE = True def test_optparse_script(capsys: pytest.CaptureFixture[str]) -> None: with pytest.raises(SystemExit): OptScript("testscript", ["-h"]) res = capsys.readouterr() assert "display" in res.out def test_argparse_script(capsys: pytest.CaptureFixture[str]) -> None: with pytest.raises(SystemExit): ArgScript("testscript", ["-h"]) res = capsys.readouterr() assert "display" in res.out @pytest.mark.skipif(WIN32, reason="use signals on win32") def test_optparse_signals(capsys: pytest.CaptureFixture[str]) -> None: with pytest.raises(SystemExit): OptScript("testscript", ["-s", CONF]) res = capsys.readouterr() assert "SIGINT" in res.out or f"signal: {signal.SIGINT}" in res.out with pytest.raises(SystemExit): OptScript("testscript", ["-r", CONF]) res = capsys.readouterr() assert "SIGHUP" in res.out or f"signal: {signal.SIGHUP}" in res.out with pytest.raises(SystemExit): OptScript("testscript", ["-k", CONF]) res = capsys.readouterr() assert "SIGTERM" in res.out or f"signal: {signal.SIGTERM}" in res.out @pytest.mark.skipif(WIN32, reason="need to use signals") def test_argparse_signals(capsys: pytest.CaptureFixture[str]) -> None: with pytest.raises(SystemExit): ArgScript("testscript", ["-s", CONF]) res = capsys.readouterr() assert "SIGINT" in res.out or f"signal: {signal.SIGINT}" in res.out with pytest.raises(SystemExit): ArgScript("testscript", ["-r", CONF]) res = capsys.readouterr() assert "SIGHUP" in res.out or f"signal: {signal.SIGHUP}" in res.out with pytest.raises(SystemExit): ArgScript("testscript", ["-k", CONF]) res = capsys.readouterr() assert "SIGTERM" in res.out or f"signal: {signal.SIGTERM}" in res.out def test_optparse_confopt(capsys: pytest.CaptureFixture[str]) -> None: s = ArgScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "opt=test" in res.out def test_argparse_confopt(capsys: pytest.CaptureFixture[str]) -> None: s = ArgScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "opt=test" in res.out class DBScript(skytools.DBScript): ARGPARSE = True looping = 0 def work(self) -> None: db = self.get_database("db", connstr=TEST_DB) curs = db.cursor() curs.execute("select 1") curs.fetchall() self.close_database("db") print("OK") @pytest.mark.skipif(not TEST_DB, reason="need database config") def test_get_database(capsys: pytest.CaptureFixture[str]) -> None: s = DBScript("testscript", [CONF]) s.start() res = capsys.readouterr() assert "OK" in res.out python-skytools-3.9.2/tests/test_skylog.py000066400000000000000000000006371447265566700210520ustar00rootroot00000000000000 import logging import skytools from skytools import skylog def test_trace_setup() -> None: assert skylog.TRACE < logging.DEBUG assert skylog.TRACE == logging.TRACE # type: ignore assert logging.getLevelName(skylog.TRACE) == "TRACE" def test_skylog() -> None: log = skytools.getLogger("test.skylog") log.trace("tracemsg") assert not log.isEnabledFor(logging.TRACE) # type: ignore python-skytools-3.9.2/tests/test_sockutil.py000066400000000000000000000017311447265566700213730ustar00rootroot00000000000000import os import socket import sys import pytest from skytools.sockutil import set_cloexec, set_nonblocking, set_tcp_keepalive def test_set_tcp_keepalive() -> None: with socket.socket() as s: set_tcp_keepalive(s) @pytest.mark.skipif( sys.platform == 'win32', reason="set_nonblocking on fd does not work on win32" ) def test_set_nonblocking() -> None: with socket.socket() as s: assert set_nonblocking(s, None) == False assert set_nonblocking(s, True) == True assert set_nonblocking(s, None) == True def test_set_cloexec_file() -> None: with open(os.devnull, 'rb') as f: assert set_cloexec(f, None) in (True, False) assert set_cloexec(f, True) == True assert set_cloexec(f, None) == True def test_set_cloexec_socket() -> None: with socket.socket() as s: assert set_cloexec(s, None) in (True, False) assert set_cloexec(s, True) == True assert set_cloexec(s, None) == True python-skytools-3.9.2/tests/test_sqltools.py000066400000000000000000000106701447265566700214200ustar00rootroot00000000000000 import pytest from skytools.sqltools import ( Snapshot, dbdict, fq_name, fq_name_parts, magic_insert, mk_delete_sql, mk_insert_sql, mk_update_sql, ) def test_dbdict() -> None: row = dbdict(a=1, b=2) assert (row.a, row.b, row['a'], row['b']) == (1, 2, 1, 2) row.c = 3 assert row['c'] == 3 del row.c with pytest.raises(AttributeError): assert row.c with pytest.raises(KeyError): assert row['c'] row.merge({'q': 4}) assert row.q == 4 def test_fq_name_parts() -> None: assert fq_name_parts('tbl') == ('public', 'tbl') assert fq_name_parts('foo.tbl') == ('foo', 'tbl') assert fq_name_parts('foo.tbl.baz') == ('foo', 'tbl.baz') def test_fq_name() -> None: assert fq_name('tbl') == 'public.tbl' assert fq_name('foo.tbl') == 'foo.tbl' assert fq_name('foo.tbl.baz') == 'foo.tbl.baz' def test_snapshot() -> None: sn = Snapshot('11:20:11,12,15') assert sn.contains(9) assert not sn.contains(11) assert sn.contains(17) assert not sn.contains(20) with pytest.raises(ValueError): Snapshot(':') def test_magic_insert() -> None: res = magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2']) exp = 'COPY public.tbl (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [[1, '1'], [2, '2']], ['col1', 'col2'], use_insert=True) exp = "insert into public.tbl (col1,col2) values ('1','1');\ninsert into public.tbl (col1,col2) values ('2','2');\n" assert res == exp assert magic_insert(None, 'tbl', [], ['col1', 'col2']) is None res = magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']], ['col1', 'col2'], quoted_table=True) exp = 'COPY tbl."1" (col1,col2) FROM STDIN;\n1\t1\n2\t2\n\\.\n' assert res == exp with pytest.raises(Exception): magic_insert(None, 'tbl."1"', [[1, '1'], [2, '2']]) res = magic_insert(None, 'a.tbl', [{'a': 1}, {'a': 2}]) exp = 'COPY a.tbl (a) FROM STDIN;\n1\n2\n\\.\n' assert res == exp res = magic_insert(None, 'a.tbl', [{'a': 1}, {'a': 2}], use_insert=True) exp = "insert into a.tbl (a) values ('1');\ninsert into a.tbl (a) values ('2');\n" assert res == exp # More fields than data res = magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3']) exp = 'COPY public.tbl (col1,col2,col3) FROM STDIN;\n1\ta\t\\N\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [[1, 'a']], ['col1', 'col2', 'col3'], use_insert=True) exp = "insert into public.tbl (col1,col2,col3) values ('1','a',null);\n" assert res == exp res = magic_insert(None, 'tbl', [{'a': 1}, {'b': 2}], ['a', 'b'], use_insert=False) exp = 'COPY public.tbl (a,b) FROM STDIN;\n1\t\\N\n\\N\t2\n\\.\n' assert res == exp res = magic_insert(None, 'tbl', [{'a': 1}, {'b': 2}], ['a', 'b'], use_insert=True) exp = "insert into public.tbl (a,b) values ('1',null);\ninsert into public.tbl (a,b) values (null,'2');\n" assert res == exp def test_mk_insert_sql() -> None: row = {'id': 1, 'data': None} res = mk_insert_sql(row, 'tbl') exp = "insert into public.tbl (id, data) values ('1', null);" assert res == exp fmap = {'id': 'id_', 'data': 'data_'} res = mk_insert_sql(row, 'tbl', ['x'], fmap) exp = "insert into public.tbl (id_, data_) values ('1', null);" assert res == exp def test_mk_update_sql() -> None: res = mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2']) exp = 'update only public."Table" set data = E\'str\\\\\' where id = \'0\' and id2 = \'2\';' assert res == exp res = mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', ['id', 'id2'], {'id': '_id', 'id2': '_id2', 'data': '_data'}) exp = 'update only public."Table" set _data = E\'str\\\\\' where _id = \'0\' and _id2 = \'2\';' assert res == exp with pytest.raises(Exception): mk_update_sql({'id': 0, 'id2': '2', 'data': 'str\\'}, 'Table', []) def test_mk_delete_sql() -> None: res = mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', ['a', 'b']) exp = "delete from only public.tablename where a = '1' and b = '2';" assert res == exp res = mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', ['a', 'b'], {'a': 'aa', 'b': 'bb'}) exp = "delete from only public.tablename where aa = '1' and bb = '2';" assert res == exp with pytest.raises(Exception): mk_delete_sql({'a': 1, 'b': 2, 'c': 3}, 'tablename', []) python-skytools-3.9.2/tests/test_timeutil.py000066400000000000000000000034731447265566700213770ustar00rootroot00000000000000 from datetime import datetime import pytest from skytools.timeutil import UTC, datetime_to_timestamp, parse_iso_timestamp def test_parse_iso_timestamp() -> None: res = str(parse_iso_timestamp('2005-06-01 15:00')) assert res == '2005-06-01 15:00:00' res = str(parse_iso_timestamp(' 2005-06-01T15:00 +02 ')) assert res == '2005-06-01 15:00:00+02:00' res = str(parse_iso_timestamp('2005-06-01 15:00:33+02:00')) assert res == '2005-06-01 15:00:33+02:00' d = parse_iso_timestamp('2005-06-01 15:00:59.33 +02') assert d.strftime("%z %Z") == '+0200 +02' assert str(parse_iso_timestamp(str(d))) == '2005-06-01 15:00:59.330000+02:00' res = parse_iso_timestamp('2005-06-01 15:00-0530').strftime('%Y-%m-%d %H:%M %z %Z') assert res == '2005-06-01 15:00 -0530 -05:30' res = parse_iso_timestamp('2014-10-27T11:59:13Z').strftime('%Y-%m-%d %H:%M:%S %z %Z') assert res == '2014-10-27 11:59:13 +0000 +00' with pytest.raises(ValueError): parse_iso_timestamp('2014.10.27') def test_datetime_to_timestamp() -> None: res = datetime_to_timestamp(parse_iso_timestamp("2005-06-01 15:00:59.5 +02")) assert res == 1117630859.5 res = datetime_to_timestamp(datetime.fromtimestamp(1117630859.5, UTC)) assert res == 1117630859.5 res = datetime_to_timestamp(datetime.fromtimestamp(1117630859.5)) assert res == 1117630859.5 now = datetime.utcnow() now2 = datetime.utcfromtimestamp(datetime_to_timestamp(now, False)) assert abs(now2.microsecond - now.microsecond) < 100 now2 = now2.replace(microsecond=now.microsecond) assert now == now2 now = datetime.now() now2 = datetime.fromtimestamp(datetime_to_timestamp(now)) assert abs(now2.microsecond - now.microsecond) < 100 now2 = now2.replace(microsecond=now.microsecond) assert now == now2 python-skytools-3.9.2/tests/test_tnetstrings.py000066400000000000000000000037351447265566700221300ustar00rootroot00000000000000 from typing import Any import pytest from skytools.tnetstrings import dumps, loads def ustr(v: Any) -> str: return repr(v).replace("u'", "'") def nstr(b: bytes) -> str: return b.decode('utf8') def test_dumps_simple_values() -> None: vals = (None, False, True, 222, 333.0, b"foo", u"bar") tnvals = [nstr(dumps(v)) for v in vals] assert tnvals == [ '0:~', '5:false!', '4:true!', '3:222#', '5:333.0^', '3:foo,', '3:bar,' ] def test_dumps_complex_values() -> None: vals2 = [[], (), {}, [1, 2], {'a': 'b'}] tnvals2 = [nstr(dumps(v)) for v in vals2] assert tnvals2 == ['0:]', '0:]', '0:}', '8:1:1#1:2#]', '8:1:a,1:b,}'] def test_loads_simple() -> None: vals = (None, False, True, 222, 333.0, b"foo", u"bar") res = ustr([loads(dumps(v)) for v in vals]) assert res == "[None, False, True, 222, 333.0, 'foo', 'bar']" def test_loads_complex() -> None: vals2 = [[], (), {}, [1, 2], {'a': 'b'}] res = ustr([loads(dumps(v)) for v in vals2]) assert res == "[[], [], {}, [1, 2], {'a': 'b'}]" def test_dumps_mview() -> None: res = nstr(dumps([memoryview(b'zzz'), b'qqq'])) assert res == '12:3:zzz,3:qqq,]' def test_loads_errors() -> None: with pytest.raises(ValueError): loads(b'4:qwez!') with pytest.raises(ValueError): loads(b'4:') with pytest.raises(ValueError): loads(b'4:qwez') with pytest.raises(ValueError): loads(b'4') with pytest.raises(ValueError): loads(b'') with pytest.raises(ValueError): loads(b'999999999999999999:z,') with pytest.raises(TypeError): loads(u'qweqwe') # type: ignore[arg-type] with pytest.raises(ValueError): loads(b'4:true!0:~') with pytest.raises(ValueError): loads(b'1:X~') with pytest.raises(ValueError): loads(b'8:1:1#1:2#}') with pytest.raises(ValueError): loads(b'1:Xz') def test_dumps_errors() -> None: with pytest.raises(TypeError): dumps(open) python-skytools-3.9.2/tests/test_utf8.py000066400000000000000000000014511447265566700204230ustar00rootroot00000000000000 import pytest from skytools.utf8 import safe_utf8_decode, sanitize_unicode def test_safe_decode() -> None: assert safe_utf8_decode(b"foobar") == (True, "foobar") assert safe_utf8_decode(b'X\0Z') == (False, "X\uFFFDZ") assert safe_utf8_decode(b"OK") == (True, "OK") assert safe_utf8_decode(b'X\xF1Y') == (False, "X\uFFFDY") assert sanitize_unicode(u'\uD801\uDC01') == "\U00010401" with pytest.raises(TypeError): sanitize_unicode(b'qwe') # type: ignore[arg-type] ## these give different results in py27 and py35 # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80Y\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 'Y', 65533, 65533, 65533, 'Z']) # >>> _norm(safe_utf8_decode(b'X\xed\xa0\x80\xed\xb0\x89Z')) # (False, ['X', 65533, 65533, 65533, 65533, 65533, 65533, 'Z']) python-skytools-3.9.2/tox.ini000066400000000000000000000026351447265566700163020ustar00rootroot00000000000000 [tox] envlist = lint,xlint,py3 [package] name = skytools deps = psycopg2-binary==2.9.7; platform_python_implementation != 'PyPy' test_deps = #coverage==7.3.0 coverage==7.2.7 pytest==7.4.0 pytest-cov==4.1.0 lint_deps = mypy==1.5.1 pylint==2.17.5 pyflakes==3.1.0 #pytype==2023.8.14 typing-extensions==4.7.1 types-setuptools==68.1.0.0 types-psycopg2==2.9.21.11; platform_python_implementation != 'PyPy' doc_deps = sphinx==7.2.2 docutils==0.20.1 [testenv] changedir = {toxinidir} deps = {[package]deps} {[package]test_deps} setenv = PYTHONDEVMODE=1 COVERAGE_RCFILE={toxinidir}/.coveragerc passenv = TEST_DB commands = python --version pytest \ --cov \ --cov-report=term \ --cov-report=html:{toxinidir}/cover/{envname} \ --rootdir={toxinidir} \ {posargs} [testenv:lint] #changedir = {toxinidir} basepython = python3 deps = {[package]deps} {[package]lint_deps} {[package]test_deps} commands = mypy skytools tests [testenv:xlint] #changedir = {envsitepackagesdir} basepython = python3 deps = {[package]deps} {[package]lint_deps} {[package]test_deps} commands = pylint skytools #pytype skytools [testenv:docs] basepython = python3 deps = {[package]deps} {[package]doc_deps} changedir = doc commands = sphinx-build -q -W -b html -d {envtmpdir}/doctrees . ../tmp/dochtml