pax_global_header00006660000000000000000000000064136000775540014521gustar00rootroot0000000000000052 comment=8d535bdc0e902f5154a8102263a1e9f67b183330 sqlalchemy-utils-0.36.1/000077500000000000000000000000001360007755400151105ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/.editorconfig000066400000000000000000000004321360007755400175640ustar00rootroot00000000000000# EditorConfig helps developers define and maintain consistent # coding styles between different editors and IDEs # editorconfig.org root = true [*] indent_style = space end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true indent_size = 4 sqlalchemy-utils-0.36.1/.gitignore000066400000000000000000000006551360007755400171060ustar00rootroot00000000000000*.py[cod] # C extensions *.so # Packages *.egg *.egg-info dist build eggs parts bin var sdist develop-eggs .installed.cfg .cache .eggs lib lib64 docs/_build # Installer logs pip-log.txt # Unit test / coverage reports .coverage .tox nosetests.xml # Translations *.mo # Mr Developer .mr.developer.cfg .project .pydevproject # vim [._]*.s[a-w][a-z] [._]s[a-w][a-z] *.un~ Session.vim .netrwhist *~ # Sublime Text *.sublime-* sqlalchemy-utils-0.36.1/.isort.cfg000066400000000000000000000002251360007755400170060ustar00rootroot00000000000000[settings] known_first_party=sqlalchemy_utils known_third_party=flexmock line_length=79 multi_line_output=3 not_skip=__init__.py order_by_type=false sqlalchemy-utils-0.36.1/.travis.yml000066400000000000000000000012411360007755400172170ustar00rootroot00000000000000language: python sudo: required dist: xenial addons: postgresql: "9.4" services: - docker - mysql before_script: - psql -c 'create database sqlalchemy_utils_test;' -U postgres - psql -c 'create extension hstore;' -U postgres -d sqlalchemy_utils_test - mysql -e 'create database sqlalchemy_utils_test;' matrix: include: - python: 2.7 env: - "TOXENV=py27" - python: 3.5 env: - "TOXENV=py35" - python: 3.6 env: - "TOXENV=py36" - python: 3.7 env: - "TOXENV=py37" - python: 3.7 env: - "TOXENV=lint" install: - source $TRAVIS_BUILD_DIR/.travis/install_mssql.sh - pip install tox script: - tox sqlalchemy-utils-0.36.1/.travis/000077500000000000000000000000001360007755400164765ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/.travis/install_mssql.sh000066400000000000000000000013231360007755400217160ustar00rootroot00000000000000#!/usr/bin/env bash wget http://www.unixodbc.org/unixODBC-2.3.1.tar.gz tar xvf unixODBC-2.3.1.tar.gz cd unixODBC-2.3.1/ ./configure --disable-gui \ --disable-drivers \ --enable-iconv \ --with-iconv-char-enc=UTF8 \ --with-iconv-ucode-enc=UTF16LE make sudo make install sudo ldconfig sudo su < /etc/apt/sources.list.d/mssql-release.list EOF sudo apt-get update sudo ACCEPT_EULA=Y apt-get install msodbcsql17 docker run -e 'ACCEPT_EULA=Y' -e 'SA_PASSWORD=Strong!Passw0rd' -p 1433:1433 -d mcr.microsoft.com/mssql/server:2017-latest sqlalchemy-utils-0.36.1/CHANGES.rst000066400000000000000000000667371360007755400167350ustar00rootroot00000000000000Changelog --------- Here you can see the full list of changes between each SQLAlchemy-Utils release. 0.36.1 (2019-12-23) ^^^^^^^^^^^^^^^^^^^ - Added support for CASCADE option when dropping views (#406, pull request courtesy of amicks) - Added `aliases` parameter to create_materialized_view function. 0.36.0 (2019-12-08) ^^^^^^^^^^^^^^^^^^^ - Removed explain and explain_analyze due to the internal changes in SQLAlchemy version 1.3. 0.35.0 (2019-11-01) ^^^^^^^^^^^^^^^^^^^ - Removed some deprecation warnings - Added Int8RangeType (#401, pull request courtesy of lpsinger) 0.34.2 (2019-08-20) ^^^^^^^^^^^^^^^^^^^ - Remove ABC deprecation warnings (#386, pull request courtesy of VizualAbstract) 0.34.1 (2019-07-15) ^^^^^^^^^^^^^^^^^^^ - Remove deprecation warnings (#379, pull request courtesy of Le-Stagiaire) - Drop py34 support 0.34.0 (2019-06-09) ^^^^^^^^^^^^^^^^^^^ - Removed array_agg compilation which was never a good idea and collided with the latest version of SA. (#374) - Removed deprecation warnings (#373, pull request courtesy of pbasista) 0.33.12 (2019-02-02) ^^^^^^^^^^^^^^^^^^^^ - Added ordering support for Country primitive (#361, pull request courtesy of TrilceAC) 0.33.11 (2019-01-13) ^^^^^^^^^^^^^^^^^^^^ - Added support for creating and dropping a PostgreSQL database when using pg8000 driver (#303, pull request courtesy of mohaseeb) 0.33.10 (2018-12-27) ^^^^^^^^^^^^^^^^^^^^ - Removed optional dependency to Flask-Babel. Now using Babel instead. (#333, pull request courtesy of aveuiller) 0.33.9 (2018-11-19) ^^^^^^^^^^^^^^^^^^^ - Fixed SQLite database_exists to check for correct file format (#306, pull request courtesy of jomasti) 0.33.8 (2018-11-19) ^^^^^^^^^^^^^^^^^^^ - Added support of short-code in PhoneNumberType (#348, pull request courtesy of anandtripathi5) 0.33.7 (2018-11-19) ^^^^^^^^^^^^^^^^^^^ - Added MSSQL support for create_database and drop_database (#337, pull request courtesy of jomasti) 0.33.6 (2018-10-14) ^^^^^^^^^^^^^^^^^^^ - Fixed passlib compatibility issue (again) (#342) - Added support for SQL VIEWs 0.33.5 (2018-09-19) ^^^^^^^^^^^^^^^^^^^ - Added optional attr parameter for locale calleble in TranslationHybrid - Fixed an issue with PasswordType so that it is compatible with older versions of passlib (#342) 0.33.4 (2018-09-11) ^^^^^^^^^^^^^^^^^^^ - Made PasswordType use `hash` function instead of deprecated `encrypt` function (#341, pull request courtesy of libre-man) 0.33.3 (2018-04-29) ^^^^^^^^^^^^^^^^^^^ - Added new AesGcmEngine (#322, pull request courtesy of manishahluwalia) 0.33.2 (2018-04-02) ^^^^^^^^^^^^^^^^^^^ - Added support for universal wheels (#312, pull request courtesy of nsoranzo) - Fixed usage of template0 and template1 with postgres database functions. (#286, pull request courtesy of funkybob) 0.33.1 (2018-03-19) ^^^^^^^^^^^^^^^^^^^ - Fixed EncryptedType for Oracle padding attack (#316, pull request courtesy of manishahluwalia) 0.33.0 (2018-02-18) ^^^^^^^^^^^^^^^^^^^ - Added support for materialized views in PostgreSQL - Added Ltree.descendant_of and Ltree.ancestor_of (#311, pull request courtesy of kageurufu) - Dropped Python 3.3 support - Fixed EncryptedType padding (#301, pull request courtesy of konstantinoskostis) 0.32.21 (2017-11-11) ^^^^^^^^^^^^^^^^^^^^ - Close connections on exists, create and drop database functions (#295, pull request courtesy of Terseus) 0.32.20 (2017-11-04) ^^^^^^^^^^^^^^^^^^^^ - Added `__hash__` method for choice objects (#294, pull request courtesy of havelock) 0.32.19 (2017-10-17) ^^^^^^^^^^^^^^^^^^^^ - Fixed select_correlated_expression order by for intermediate table aliases 0.32.18 (2017-10-06) ^^^^^^^^^^^^^^^^^^^^ - Made aggregated attributes to work with subclass objects (#287, pull request courtesy of fayazkhan) 0.32.17 (2017-09-29) ^^^^^^^^^^^^^^^^^^^^ - Added support for MSSQL uniqueidentifier type (#283, pull request courtesy of nHurD) 0.32.16 (2017-09-01) ^^^^^^^^^^^^^^^^^^^^ - Added more hints when decrypting AES with an invalid key (#275, pull request courtesy of xrmx) 0.32.15 (2017-08-31) ^^^^^^^^^^^^^^^^^^^^ - Added better handling of date types for EncryptedType (#184, pull request courtesy of konstantinoskostis) 0.32.14 (2017-03-27) ^^^^^^^^^^^^^^^^^^^^ - Fixed drop_database version comparison 0.32.13 (2017-03-12) ^^^^^^^^^^^^^^^^^^^^ - Fixed a DeprecationWarning by using LargeBinary instead of Binary (#263, pull request courtesy of jacquerie) 0.32.12 (2016-12-18) ^^^^^^^^^^^^^^^^^^^^ - Added generic_repr decorator 0.32.11 (2016-11-19) ^^^^^^^^^^^^^^^^^^^^ - TimeZoneType support for static timezones (#244, pull request courtesy of fuhrysteve) - Added SQLite support for PasswordType (#254, pull request courtesy of frol) 0.32.10 (2016-10-20) ^^^^^^^^^^^^^^^^^^^^ - Added PhoneNumber as the python_type for PhoneNumberType (#248) - Made auto_delete_orphans support backref tuples (#234, pull request courtesy of vToMy) 0.32.9 (2016-07-17) ^^^^^^^^^^^^^^^^^^^ - Added support for multi-column observers (#231, pull request courtesy of quantus) 0.32.8 (2016-05-20) ^^^^^^^^^^^^^^^^^^^ - Fixed EmailType to respect constructor args (#230, pull request courtesy of quantus) 0.32.7 (2016-05-20) ^^^^^^^^^^^^^^^^^^^ - Made PhoneNumber exceptions inherit SQLAlchemy's DontWrapMixin (#219, pull request courtesy of JackWink) 0.32.6 (2016-05-11) ^^^^^^^^^^^^^^^^^^^ - Added support for timezones with ArrowType (#218, pull request courtesy of jmagnusson) 0.32.5 (2016-04-29) ^^^^^^^^^^^^^^^^^^^ - Fixed import issue with latest version of SQLAlchemy (#215) 0.32.4 (2016-04-25) ^^^^^^^^^^^^^^^^^^^ - Added LtreeType for PostgreSQL ltree extension - Added Ltree primitive data type 0.32.3 (2016-04-20) ^^^^^^^^^^^^^^^^^^^ - Added support for PhoneNumber objects as composites 0.32.2 (2016-04-20) ^^^^^^^^^^^^^^^^^^^ - Fixed PasswordType to not access LazyCryptContext on type init (#211, pull request courtesy of olegpidsadnyi) 0.32.1 (2016-03-30) ^^^^^^^^^^^^^^^^^^^ - Fixed database helpers for sqlite (#208, pull request courtesy of RobertDeRose) - Fixed TranslationHybrid aliased entities handling (#198, pull request courtesy of jmagnusson) 0.32.0 (2016-03-17) ^^^^^^^^^^^^^^^^^^^ - Dropped py26 support - Made PasswordType to use LazyCryptContext by default (#204, courtesy of olegpidsadnyi) 0.31.6 (2016-01-21) ^^^^^^^^^^^^^^^^^^^ - Added literal parameter processing for ArrowType (#182, pull request courtesy of jmagnusson) 0.31.5 (2016-01-14) ^^^^^^^^^^^^^^^^^^^ - Fixed scalar parsing of LocaleType (#173) 0.31.4 (2015-12-06) ^^^^^^^^^^^^^^^^^^^ - Fixed column alias handling with assert_* functions (#175) 0.31.3 (2015-11-09) ^^^^^^^^^^^^^^^^^^^ - Fixed non-ascii string handling in composite types (#170) 0.31.2 (2015-10-30) ^^^^^^^^^^^^^^^^^^^ - Fixed observes crashing when observable root_obj is ``None`` (#168) 0.31.1 (2015-10-26) ^^^^^^^^^^^^^^^^^^^ - Column observers only notified when actual changes have been made to underlying columns (#138) 0.31.0 (2015-09-17) ^^^^^^^^^^^^^^^^^^^ - Made has_index allow fk constraint as parameter - Made has_unique_index allow fk constraint as parameter - Made the extra packages in setup.py to be returned in deterministic order (courtesy of thomasgoirand) - Removed is_indexed_foreign_key (superceded by more versatile has_index) - Fixed LocaleType territory parsing (courtesy of dahlia) 0.30.17 (2015-08-16) ^^^^^^^^^^^^^^^^^^^^ - Added correlate parameter to select_correlated_expression function 0.30.16 (2015-08-04) ^^^^^^^^^^^^^^^^^^^^ - Fixed sort_query handling of aliased classes with hybrid properties 0.30.15 (2015-07-28) ^^^^^^^^^^^^^^^^^^^^ - Added support for aliased classes in get_hybrid_properties utility function 0.30.14 (2015-07-23) ^^^^^^^^^^^^^^^^^^^^ - Added cast_if utility function 0.30.13 (2015-07-21) ^^^^^^^^^^^^^^^^^^^^ - Added support for InstrumentedAttributes, ColumnProperties and Columns in get_columns function 0.30.12 (2015-07-05) ^^^^^^^^^^^^^^^^^^^^ - Added support for PhoneNumber extensions (#121) 0.30.11 (2015-06-18) ^^^^^^^^^^^^^^^^^^^^ - Fix None type handling of ChoiceType - Make locale casting for translation hybrid expressions cast locales on compilation phase. This extra lazy locale casting is needed in some cases where translation hybrid expressions are used before get_locale function is available. 0.30.10 (2015-06-17) ^^^^^^^^^^^^^^^^^^^^ - Added better support for dynamic locales in translation_hybrid - Make babel dependent primitive types to use Locale('en') for data validation instead of current locale. Using current locale leads to infinite recursion in cases where the loaded data has dependency to the loaded object's locale. 0.30.9 (2015-06-09) ^^^^^^^^^^^^^^^^^^^ - Added get_type utility function - Added default parameter for array_agg function 0.30.8 (2015-06-05) ^^^^^^^^^^^^^^^^^^^ - Added Asterisk compiler - Added row_to_json GenericFunction - Added array_agg GenericFunction - Made quote function accept dialect object as the first paremeter - Made has_index work with tables without primary keys (#148) 0.30.7 (2015-05-28) ^^^^^^^^^^^^^^^^^^^ - Fixed CompositeType null handling 0.30.6 (2015-05-28) ^^^^^^^^^^^^^^^^^^^ - Made psycopg2 requirement optional (#145, #146) - Made CompositeArray work with tuples given as bind parameters 0.30.5 (2015-05-27) ^^^^^^^^^^^^^^^^^^^ - Fixed CompositeType bind parameter processing when one of the fields is of TypeDecorator type and CompositeType is used inside ARRAY type. 0.30.4 (2015-05-27) ^^^^^^^^^^^^^^^^^^^ - Fixed CompositeType bind parameter processing when one of the fields is of TypeDecorator type. 0.30.3 (2015-05-27) ^^^^^^^^^^^^^^^^^^^ - Added length property to range types - Added CompositeType for PostgreSQL 0.30.2 (2015-05-21) ^^^^^^^^^^^^^^^^^^^ - Fixed ``assert_max_length``, ``assert_non_nullable``, ``assert_min_value`` and ``assert_max_value`` not properly raising an ``AssertionError`` when the assertion failed. 0.30.1 (2015-05-06) ^^^^^^^^^^^^^^^^^^^ - Drop undocumented batch fetch feature. Let's wait until the inner workings of SQLAlchemy loading API is well-documented. - Fixed GenericRelationshipProperty comparator to work with SA 1.0.x (#139) - Make all foreign key helpers SA 1.0 compliant - Make translation_hybrid expression work the same way as SQLAlchemy-i18n translation expressions - Update SQLAlchemy dependency to 1.0 0.30.0 (2015-04-15) ^^^^^^^^^^^^^^^^^^^ - Added __hash__ method to Country class - Made Country validate itself during object initialization - Made Country string coercible - Removed deprecated function generates - Fixed observes function to work with simple column properties 0.29.9 (2015-04-07) ^^^^^^^^^^^^^^^^^^^ - Added CurrencyType (#19) and Currency class 0.29.8 (2015-03-03) ^^^^^^^^^^^^^^^^^^^ - Added get_class_by_table ORM utility function 0.29.7 (2015-03-01) ^^^^^^^^^^^^^^^^^^^ - Added Enum representation support for ChoiceType 0.29.6 (2015-02-03) ^^^^^^^^^^^^^^^^^^^ - Added customizable TranslationHybrid default value 0.29.5 (2015-02-03) ^^^^^^^^^^^^^^^^^^^ - Made assert_max_length support PostgreSQL array type 0.29.4 (2015-01-31) ^^^^^^^^^^^^^^^^^^^ - Made CaseInsensitiveComparator not cast already lowercased types to lowercase 0.29.3 (2015-01-24) ^^^^^^^^^^^^^^^^^^^ - Fixed analyze function runtime property handling for PostgreSQL >= 9.4 - Fixed drop_database and create_database identifier quoting (#122) 0.29.2 (2015-01-08) ^^^^^^^^^^^^^^^^^^^ - Removed deprecated defer_except (SQLAlchemy's own load_only should be used from now on) - Added json_sql PostgreSQL helper function 0.29.1 (2015-01-03) ^^^^^^^^^^^^^^^^^^^ - Added assert_min_value and assert_max_value testing functions 0.29.0 (2015-01-02) ^^^^^^^^^^^^^^^^^^^ - Removed TSVectorType.match_tsquery (now replaced by TSVectorType.match to be compatible with SQLAlchemy) - Removed undocumented function tsvector_concat - Added support for TSVectorType concatenation through OR operator - Added documentation for TSVectorType (#102) 0.28.3 (2014-12-17) ^^^^^^^^^^^^^^^^^^^ - Made aggregated fully support column aliases - Changed test matrix to run all tests without any optional dependencies (as well as with all optional dependencies) 0.28.2 (2014-12-13) ^^^^^^^^^^^^^^^^^^^ - Fixed issue with Color importing (#104) 0.28.1 (2014-12-13) ^^^^^^^^^^^^^^^^^^^ - Improved EncryptedType to support more underlying_type's; now supports: Integer, Boolean, Date, Time, DateTime, ColorType, PhoneNumberType, Unicode(Text), String(Text), Enum - Allow a callable to be used to lookup the key for EncryptedType 0.28.0 (2014-12-12) ^^^^^^^^^^^^^^^^^^^ - Fixed PhoneNumber string coercion (#93) - Added observes decorator (generates decorator will be deprecated later) 0.27.11 (2014-12-06) ^^^^^^^^^^^^^^^^^^^^ - Added loose typed column checking support for get_column_key - Made get_column_key throw UnmappedColumnError to be consistent with SQLAlchemy 0.27.10 (2014-12-03) ^^^^^^^^^^^^^^^^^^^^ - Fixed column alias handling in dependent_objects 0.27.9 (2014-12-01) ^^^^^^^^^^^^^^^^^^^ - Fixed aggregated decorator many-to-many relationship handling - Fixed aggregated column alias handling 0.27.8 (2014-11-13) ^^^^^^^^^^^^^^^^^^^ - Added is_loaded utility function - Removed deprecated has_any_changes 0.27.7 (2014-11-03) ^^^^^^^^^^^^^^^^^^^ - Added support for Column and ColumnEntity objects in get_mapper - Made make_order_by_deterministic add deterministic column more aggressively 0.27.6 (2014-10-29) ^^^^^^^^^^^^^^^^^^^ - Fixed assert_max_length not working with non nullable columns - Add PostgreSQL < 9.2 support for drop_database 0.27.5 (2014-10-24) ^^^^^^^^^^^^^^^^^^^ - Made assert_* functions automatically rollback session - Changed make_order_by_deterministic attach order by primary key for queries without order by - Fixed alias handling in has_unique_index - Fixed alias handling in has_index - Fixed alias handling in make_order_by_deterministic 0.27.4 (2014-10-23) ^^^^^^^^^^^^^^^^^^^ - Added assert_non_nullable, assert_nullable and assert_max_length testing functions 0.27.3 (2014-10-22) ^^^^^^^^^^^^^^^^^^^ - Added supported for various SQLAlchemy objects in make_order_by_deterministic (previosly this function threw exceptions for other than Column objects) 0.27.2 (2014-10-21) ^^^^^^^^^^^^^^^^^^^ - Fixed MapperEntity handling in get_mapper and get_tables utility functions - Fixed make_order_by_deterministic handling for queries without order by (no just silently ignores those rather than throws exception) - Made make_order_by_deterministic if given query uses strings as order by args 0.27.1 (2014-10-20) ^^^^^^^^^^^^^^^^^^^ - Added support for more SQLAlchemy based objects and classes in get_tables function - Added has_unique_index utility function - Added make_order_by_deterministic utility function 0.27.0 (2014-10-14) ^^^^^^^^^^^^^^^^^^^ - Added EncryptedType 0.26.17 (2014-10-07) ^^^^^^^^^^^^^^^^^^^^ - Added explain and explain_analyze expressions - Added analyze function 0.26.16 (2014-09-09) ^^^^^^^^^^^^^^^^^^^^ - Fix aggregate value handling for cascade deleted objects - Fix ambiguous column sorting with join table inheritance in sort_query 0.26.15 (2014-08-28) ^^^^^^^^^^^^^^^^^^^^ - Fix sort_query support for queries using mappers (not declarative classes) with calculated column properties 0.26.14 (2014-08-26) ^^^^^^^^^^^^^^^^^^^^ - Added count method to QueryChain class 0.26.13 (2014-08-23) ^^^^^^^^^^^^^^^^^^^^ - Added template parameter to create_database function 0.26.12 (2014-08-22) ^^^^^^^^^^^^^^^^^^^^ - Added quote utility function 0.26.11 (2014-08-21) ^^^^^^^^^^^^^^^^^^^^ - Fixed dependent_objects support for single table inheritance 0.26.10 (2014-08-13) ^^^^^^^^^^^^^^^^^^^^ - Fixed dependent_objects support for multiple dependencies 0.26.9 (2014-08-06) ^^^^^^^^^^^^^^^^^^^ - Fixed PasswordType with Oracle dialect - Added support for sort_query and attributes on mappers using with_polymorphic 0.26.8 (2014-07-30) ^^^^^^^^^^^^^^^^^^^ - Fixed order by column property handling in sort_query when using polymorphic inheritance - Added support for synonym properties in sort_query 0.26.7 (2014-07-29) ^^^^^^^^^^^^^^^^^^^ - Made sort_query support hybrid properties where function name != property name - Made get_hybrid_properties return a dictionary of property keys and hybrid properties - Added documentation for get_hybrid_properties 0.26.6 (2014-07-22) ^^^^^^^^^^^^^^^^^^^ - Added exclude parameter to has_changes - Made has_changes accept multiple attributes as second parameter 0.26.5 (2014-07-11) ^^^^^^^^^^^^^^^^^^^ - Added get_column_key - Added Timestamp model mixin 0.26.4 (2014-06-25) ^^^^^^^^^^^^^^^^^^^ - Added auto_delete_orphans 0.26.3 (2014-06-25) ^^^^^^^^^^^^^^^^^^^ - Added has_any_changes 0.26.2 (2014-05-29) ^^^^^^^^^^^^^^^^^^^ - Added various fixes for bugs found in use of psycopg2 - Added has_index 0.26.1 (2014-05-14) ^^^^^^^^^^^^^^^^^^^ - Added get_bind - Added group_foreign_keys - Added get_mapper - Added merge_references 0.26.0 (2014-05-07) ^^^^^^^^^^^^^^^^^^^ - Added get_referencing_foreign_keys - Added get_tables - Added QueryChain - Added dependent_objects 0.25.4 (2014-04-22) ^^^^^^^^^^^^^^^^^^^ - Added ExpressionParser 0.25.3 (2014-04-21) ^^^^^^^^^^^^^^^^^^^ - Added support for primary key aliases in get_primary_keys function - Added get_columns utility function 0.25.2 (2014-03-25) ^^^^^^^^^^^^^^^^^^^ - Fixed sort_query handling of regular properties (no longer throws exceptions) 0.25.1 (2014-03-20) ^^^^^^^^^^^^^^^^^^^ - Added more import json as a fallback if anyjson package is not installed for JSONType - Fixed query_entities labeled select handling 0.25.0 (2014-03-05) ^^^^^^^^^^^^^^^^^^^ - Added single table inheritance support for generic_relationship - Added support for comparing class super types with generic relationships - BC break: In order to support different inheritance strategies generic_relationship now uses class names as discriminators instead of table names. 0.24.4 (2014-03-05) ^^^^^^^^^^^^^^^^^^^ - Added hybrid_property support for generic_relationship 0.24.3 (2014-03-05) ^^^^^^^^^^^^^^^^^^^ - Added string argument support for generic_relationship - Added composite primary key support for generic_relationship 0.24.2 (2014-03-04) ^^^^^^^^^^^^^^^^^^^ - Remove toolz from dependencies - Add step argument support for all range types - Optional intervals dependency updated to 0.2.4 0.24.1 (2014-02-21) ^^^^^^^^^^^^^^^^^^^ - Made identity return a tuple in all cases - Added support for declarative model classes as identity function's first argument 0.24.0 (2014-02-18) ^^^^^^^^^^^^^^^^^^^ - Added getdotattr - Added Path and AttrPath classes - SQLAlchemy dependency updated to 0.9.3 - Optional intervals dependency updated to 0.2.2 0.23.5 (2014-02-15) ^^^^^^^^^^^^^^^^^^^ - Fixed ArrowType timezone handling 0.23.4 (2014-01-30) ^^^^^^^^^^^^^^^^^^^ - Added force_instant_defaults function - Added force_auto_coercion function - Added source paramater for generates function 0.23.3 (2014-01-21) ^^^^^^^^^^^^^^^^^^^ - Fixed backref handling for aggregates - Added support for many-to-many aggregates 0.23.2 (2014-01-21) ^^^^^^^^^^^^^^^^^^^ - Fixed issues with ColorType and ChoiceType string bound parameter processing - Fixed inheritance handling with aggregates - Fixed generic relationship nullifying 0.23.1 (2014-01-14) ^^^^^^^^^^^^^^^^^^^ - Added support for membership operators 'in' and 'not in' in range types - Added support for contains and contained_by operators in range types - Added range types to main module import 0.23.0 (2014-01-14) ^^^^^^^^^^^^^^^^^^^ - Deprecated NumberRangeType, NumberRange - Added IntRangeType, NumericRangeType, DateRangeType, DateTimeRangeType - Moved NumberRange functionality to intervals package 0.22.1 (2014-01-06) ^^^^^^^^^^^^^^^^^^^ - Fixed in issue where NumberRange would not always raise RangeBoundsException with object initialization 0.22.0 (2014-01-04) ^^^^^^^^^^^^^^^^^^^ - Added SQLAlchemy 0.9 support - Made JSONType use sqlalchemy.dialects.postgresql.JSON if available - Updated psycopg requirement to 2.5.1 - Deprecated NumberRange classmethod constructors 0.21.0 (2013-11-11) ^^^^^^^^^^^^^^^^^^^ - Added support for cached aggregates 0.20.0 (2013-10-24) ^^^^^^^^^^^^^^^^^^^ - Added JSONType - NumberRangeType now supports coercing of integer values 0.19.0 (2013-10-24) ^^^^^^^^^^^^^^^^^^^ - Added ChoiceType 0.18.0 (2013-10-24) ^^^^^^^^^^^^^^^^^^^ - Added LocaleType 0.17.1 (2013-10-23) ^^^^^^^^^^^^^^^^^^^ - Removed compat module, added total_ordering package to Python 2.6 requirements - Enhanced render_statement function 0.17.0 (2013-10-23) ^^^^^^^^^^^^^^^^^^^ - Added URLType 0.16.25 (2013-10-18) ^^^^^^^^^^^^^^^^^^^^ - Added __ne__ operator implementation for Country object - New utility function: naturally_equivalent 0.16.24 (2013-10-04) ^^^^^^^^^^^^^^^^^^^^ - Renamed match operator of TSVectorType to match_tsquery in order to avoid confusion with existing match operator - Added catalog parameter support for match_tsquery operator 0.16.23 (2013-10-04) ^^^^^^^^^^^^^^^^^^^^ - Added match operator for TSVectorType 0.16.22 (2013-10-03) ^^^^^^^^^^^^^^^^^^^^ - Added optional columns and options parameter for TSVectorType 0.16.21 (2013-09-29) ^^^^^^^^^^^^^^^^^^^^ - Fixed an issue with sort_query where sort by relationship property would cause an exception. 0.16.20 (2013-09-26) ^^^^^^^^^^^^^^^^^^^^ - Fixed an issue with sort_query where sort by main entity's attribute would fail if joins where applied. 0.16.19 (2013-09-21) ^^^^^^^^^^^^^^^^^^^^ - Added configuration for silent mode in sort_query - Added support for aliased entity hybrid properties in sort_query 0.16.18 (2013-09-19) ^^^^^^^^^^^^^^^^^^^^ - Fixed sort_query hybrid property handling (again) 0.16.17 (2013-09-19) ^^^^^^^^^^^^^^^^^^^^ - Added support for relation hybrid property sorting in sort_query 0.16.16 (2013-09-18) ^^^^^^^^^^^^^^^^^^^^ - Fixed fatal bug in batch fetch join table inheritance handling (not handling one-to-many relations properly) 0.16.15 (2013-09-17) ^^^^^^^^^^^^^^^^^^^^ - Fixed sort_query hybrid property handling (now supports both ascending and descending sorting) 0.16.14 (2013-09-17) ^^^^^^^^^^^^^^^^^^^^ - More pythonic __init__ for Country allowing Country(Country('fi')) == Country('fi') - Better equality operator for Country 0.16.13 (2013-09-17) ^^^^^^^^^^^^^^^^^^^^ - Added i18n module for configuration of locale dependant types 0.16.12 (2013-09-17) ^^^^^^^^^^^^^^^^^^^^ - Fixed remaining Python 3 issues with WeekDaysType - Better bound method handling for WeekDay get_locale 0.16.11 (2013-09-17) ^^^^^^^^^^^^^^^^^^^^ - Python 3 support for WeekDaysType - Fixed a bug in batch fetch for situations where joined paths contain zero entitites 0.16.10 (2013-09-16) ^^^^^^^^^^^^^^^^^^^^ - Added WeekDaysType 0.16.9 (2013-08-21) ^^^^^^^^^^^^^^^^^^^ - Support for many-to-one directed relationship properties batch fetching 0.16.8 (2013-08-21) ^^^^^^^^^^^^^^^^^^^ - PasswordType support for PostgreSQL - Hybrid property for sort_query 0.16.7 (2013-08-18) ^^^^^^^^^^^^^^^^^^^ - Added better handling of local column names in batch_fetch - PasswordType gets default length even if no crypt context schemes provided 0.16.6 (2013-08-16) ^^^^^^^^^^^^^^^^^^^ - Rewritten batch_fetch schematics, new syntax for backref population 0.16.5 (2013-08-08) ^^^^^^^^^^^^^^^^^^^ - Initial backref population forcing support for batch_fetch 0.16.4 (2013-08-08) ^^^^^^^^^^^^^^^^^^^ - Initial many-to-many relations support for batch_fetch 0.16.3 (2013-08-05) ^^^^^^^^^^^^^^^^^^^ - Added batch_fetch function 0.16.2 (2013-08-01) ^^^^^^^^^^^^^^^^^^^ - Added to_tsquery and plainto_tsquery sql function expressions 0.16.1 (2013-08-01) ^^^^^^^^^^^^^^^^^^^ - Added tsvector_concat and tsvector_match sql function expressions 0.16.0 (2013-07-25) ^^^^^^^^^^^^^^^^^^^ - Added ArrowType 0.15.1 (2013-07-22) ^^^^^^^^^^^^^^^^^^^ - Added utility functions declarative_base, identity and is_auto_assigned_date_column 0.15.0 (2013-07-22) ^^^^^^^^^^^^^^^^^^^ - Added PasswordType 0.14.7 (2013-07-22) ^^^^^^^^^^^^^^^^^^^ - Lazy import for ipaddress package 0.14.6 (2013-07-22) ^^^^^^^^^^^^^^^^^^^ - Fixed UUID import issues 0.14.5 (2013-07-22) ^^^^^^^^^^^^^^^^^^^ - Added UUID type 0.14.4 (2013-07-03) ^^^^^^^^^^^^^^^^^^^ - Added TSVector type 0.14.3 (2013-07-03) ^^^^^^^^^^^^^^^^^^^ - Added non_indexed_foreign_keys utility function 0.14.2 (2013-07-02) ^^^^^^^^^^^^^^^^^^^ - Fixed py3 bug introduced in 0.14.1 0.14.1 (2013-07-02) ^^^^^^^^^^^^^^^^^^^ - Made sort_query support column_property selects with labels 0.14.0 (2013-07-02) ^^^^^^^^^^^^^^^^^^^ - Python 3 support, dropped python 2.5 support 0.13.3 (2013-06-11) ^^^^^^^^^^^^^^^^^^^ - Initial support for psycopg 2.5 NumericRange objects 0.13.2 (2013-06-11) ^^^^^^^^^^^^^^^^^^^ - QuerySorter now threadsafe. 0.13.1 (2013-06-11) ^^^^^^^^^^^^^^^^^^^ - Made sort_query function support multicolumn sorting. 0.13.0 (2013-06-05) ^^^^^^^^^^^^^^^^^^^ - Added table_name utility function. 0.12.5 (2013-06-05) ^^^^^^^^^^^^^^^^^^^ - ProxyDict now contains None values in cache - more efficient contains method. 0.12.4 (2013-06-01) ^^^^^^^^^^^^^^^^^^^ - Fixed ProxyDict contains method 0.12.3 (2013-05-30) ^^^^^^^^^^^^^^^^^^^ - Proxy dict expiration listener from function scope to global scope 0.12.2 (2013-05-29) ^^^^^^^^^^^^^^^^^^^ - Added automatic expiration of proxy dicts 0.12.1 (2013-05-18) ^^^^^^^^^^^^^^^^^^^ - Added utility functions remove_property and primary_keys 0.12.0 (2013-05-17) ^^^^^^^^^^^^^^^^^^^ - Added ProxyDict 0.11.0 (2013-05-08) ^^^^^^^^^^^^^^^^^^^ - Added coercion_listener 0.10.0 (2013-04-29) ^^^^^^^^^^^^^^^^^^^ - Added ColorType 0.9.1 (2013-04-15) ^^^^^^^^^^^^^^^^^^ - Renamed Email to EmailType and ScalarList to ScalarListType (unified type class naming convention) 0.9.0 (2013-04-11) ^^^^^^^^^^^^^^^^^^ - Added CaseInsensitiveComparator - Added Email type 0.8.4 (2013-04-08) ^^^^^^^^^^^^^^^^^^ - Added sort by aliased and joined entity 0.8.3 (2013-04-03) ^^^^^^^^^^^^^^^^^^ - sort_query now supports labeled and subqueried scalars 0.8.2 (2013-04-03) ^^^^^^^^^^^^^^^^^^ - Fixed empty ScalarList handling 0.8.1 (2013-04-03) ^^^^^^^^^^^^^^^^^^ - Removed unnecessary print statement form ScalarList - Documentation for ScalarList and NumberRange 0.8.0 (2013-04-02) ^^^^^^^^^^^^^^^^^^ - Added ScalarList type - Fixed NumberRange bind param and result value processing 0.7.7 (2013-03-27) ^^^^^^^^^^^^^^^^^^ - Changed PhoneNumber string representation to the national phone number format 0.7.6 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - NumberRange now wraps ValueErrors as NumberRangeExceptions 0.7.5 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Fixed defer_except - Better string representations for NumberRange 0.7.4 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Fixed NumberRange upper bound parsing 0.7.3 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Enabled PhoneNumberType None value storing 0.7.2 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Enhanced string parsing for NumberRange 0.7.1 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Fixed requirements (now supports SQLAlchemy 0.8) 0.7.0 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Added NumberRange type 0.6.0 (2013-03-26) ^^^^^^^^^^^^^^^^^^ - Extended PhoneNumber class from python-phonenumbers library 0.5.0 (2013-03-20) ^^^^^^^^^^^^^^^^^^ - Added PhoneNumberType type decorator 0.4.0 (2013-03-01) ^^^^^^^^^^^^^^^^^^ - Renamed SmartList to InstrumentedList - Added instrumented_list decorator 0.3.0 (2013-03-01) ^^^^^^^^^^^^^^^^^^ - Added new collection class SmartList 0.2.0 (2013-03-01) ^^^^^^^^^^^^^^^^^^ - Added new function defer_except() 0.1.0 (2013-01-12) ^^^^^^^^^^^^^^^^^^ - Initial public release sqlalchemy-utils-0.36.1/LICENSE000066400000000000000000000026351360007755400161230ustar00rootroot00000000000000Copyright (c) 2012, Konsta Vesterinen All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * The names of the contributors may not be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. sqlalchemy-utils-0.36.1/MANIFEST.in000066400000000000000000000003301360007755400166420ustar00rootroot00000000000000include CHANGES.rst LICENSE README.rst conftest.py .isort.cfg recursive-include tests * recursive-exclude tests *.pyc recursive-include docs * recursive-exclude docs *.pyc prune docs/_build exclude docs/_themes/.git sqlalchemy-utils-0.36.1/README.rst000066400000000000000000000014261360007755400166020ustar00rootroot00000000000000SQLAlchemy-Utils ================ |Build Status| |Version Status| |Downloads| Various utility functions, new data types and helpers for SQLAlchemy. Resources --------- - `Documentation `_ - `Issue Tracker `_ - `Code `_ .. |Build Status| image:: https://travis-ci.org/kvesteri/sqlalchemy-utils.svg?branch=master :target: https://travis-ci.org/kvesteri/sqlalchemy-utils .. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Utils.svg :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ .. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Utils.svg :target: https://pypi.python.org/pypi/SQLAlchemy-Utils/ sqlalchemy-utils-0.36.1/ROADMAP.rst000066400000000000000000000006721360007755400167320ustar00rootroot00000000000000* Add efficient pagination support http://www.depesz.com/2011/05/20/pagination-with-fixed-order/ http://stackoverflow.com/questions/6618366/improving-offset-performance-in-postgresql * Generic file model https://github.com/jpvanhal/silo * Query to Postgres JSON converter http://hashrocket.com/blog/posts/faster-json-generation-with-postgresql * Postgres Cube datatype: http://www.postgresql.org/docs/9.4/static/cube.html sqlalchemy-utils-0.36.1/conftest.py000066400000000000000000000134211360007755400173100ustar00rootroot00000000000000import os import warnings import pytest import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base, synonym_for from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import close_all_sessions from sqlalchemy_utils import ( aggregates, coercion_listener, i18n, InstrumentedList ) from sqlalchemy_utils.types.pg_composite import remove_composite_listeners @sa.event.listens_for(sa.engine.Engine, 'before_cursor_execute') def count_sql_calls(conn, cursor, statement, parameters, context, executemany): try: conn.query_count += 1 except AttributeError: conn.query_count = 0 warnings.simplefilter('error', sa.exc.SAWarning) sa.event.listen(sa.orm.mapper, 'mapper_configured', coercion_listener) def get_locale(): class Locale(): territories = {'FI': 'Finland'} return Locale() @pytest.fixture(scope='session') def db_name(): return os.environ.get('SQLALCHEMY_UTILS_TEST_DB', 'sqlalchemy_utils_test') @pytest.fixture(scope='session') def postgresql_db_user(): return os.environ.get('SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER', 'postgres') @pytest.fixture(scope='session') def mysql_db_user(): return os.environ.get('SQLALCHEMY_UTILS_TEST_MYSQL_USER', 'root') @pytest.fixture def postgresql_dsn(postgresql_db_user, db_name): return 'postgresql://{0}@localhost/{1}'.format(postgresql_db_user, db_name) @pytest.fixture def mysql_dsn(mysql_db_user, db_name): return 'mysql+pymysql://{0}@localhost/{1}'.format(mysql_db_user, db_name) @pytest.fixture def sqlite_memory_dsn(): return 'sqlite:///:memory:' @pytest.fixture def sqlite_none_database_dsn(): return 'sqlite://' @pytest.fixture def sqlite_file_dsn(db_name): return 'sqlite:///{0}.db'.format(db_name) @pytest.fixture def mssql_db_user(): return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_USER', 'sa') @pytest.fixture def mssql_db_password(): return os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_PASSWORD', 'Strong!Passw0rd') @pytest.fixture def mssql_db_driver(): driver = os.environ.get('SQLALCHEMY_UTILS_TEST_MSSQL_DRIVER', 'ODBC Driver 17 for SQL Server') return driver.replace(' ', '+') @pytest.fixture def mssql_dsn(mssql_db_user, mssql_db_password, mssql_db_driver, db_name): return 'mssql+pyodbc://{0}:{1}@localhost/{2}?driver={3}'\ .format(mssql_db_user, mssql_db_password, db_name, mssql_db_driver) @pytest.fixture def dsn(request): if 'postgresql_dsn' in request.fixturenames: return request.getfixturevalue('postgresql_dsn') elif 'mysql_dsn' in request.fixturenames: return request.getfixturevalue('mysql_dsn') elif 'mssql_dsn' in request.fixturenames: return request.getfixturevalue('mssql_dsn') elif 'sqlite_file_dsn' in request.fixturenames: return request.getfixturevalue('sqlite_file_dsn') elif 'sqlite_memory_dsn' in request.fixturenames: pass # Return default return request.getfixturevalue('sqlite_memory_dsn') @pytest.fixture def engine(dsn): engine = create_engine(dsn) # engine.echo = True return engine @pytest.fixture def connection(engine): return engine.connect() @pytest.fixture def Base(): return declarative_base() @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) title = sa.Column(sa.Unicode(255)) @hybrid_property def full_name(self): return u'%s %s' % (self.title, self.name) @full_name.expression def full_name(self): return sa.func.concat(self.title, ' ', self.name) @hybrid_property def articles_count(self): return len(self.articles) @articles_count.expression def articles_count(cls): Article = Base._decl_class_registry['Article'] return ( sa.select([sa.func.count(Article.id)]) .where(Article.category_id == cls.id) .correlate(Article.__table__) .label('article_count') ) @property def name_alias(self): return self.name @synonym_for('name') @property def name_synonym(self): return self.name return Category @pytest.fixture def Article(Base, Category): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255), index=True) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship( Category, primaryjoin=category_id == Category.id, backref=sa.orm.backref( 'articles', collection_class=InstrumentedList ) ) return Article @pytest.fixture def init_models(User, Category, Article): pass @pytest.fixture def session(request, engine, connection, Base, init_models): sa.orm.configure_mappers() Base.metadata.create_all(connection) Session = sessionmaker(bind=connection) session = Session() i18n.get_locale = get_locale def teardown(): aggregates.manager.reset() close_all_sessions() Base.metadata.drop_all(connection) remove_composite_listeners() connection.close() engine.dispose() request.addfinalizer(teardown) return session sqlalchemy-utils-0.36.1/docs/000077500000000000000000000000001360007755400160405ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/docs/Makefile000066400000000000000000000127441360007755400175100ustar00rootroot00000000000000# Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: -rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SQLAlchemy-Utils.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SQLAlchemy-Utils.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/SQLAlchemy-Utils" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SQLAlchemy-Utils" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." sqlalchemy-utils-0.36.1/docs/aggregates.rst000066400000000000000000000001671360007755400207070ustar00rootroot00000000000000Aggregated attributes ===================== .. automodule:: sqlalchemy_utils.aggregates .. autofunction:: aggregated sqlalchemy-utils-0.36.1/docs/conf.py000066400000000000000000000177561360007755400173570ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # SQLAlchemy-Utils documentation build configuration file, created by # sphinx-quickstart on Tue Feb 19 11:16:09 2013. # # This file is execfile()d with the current directory set to its containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys, os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath('..')) from sqlalchemy_utils import __version__ # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'SQLAlchemy-Utils' copyright = u'2013, Konsta Vesterinen' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = __version__ # The full version, including alpha/beta/rc tags. release = version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. #language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ['_build'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'default' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. #html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. #html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = 'SQLAlchemy-Utilsdoc' # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ('index', 'SQLAlchemy-Utils.tex', u'SQLAlchemy-Utils Documentation', u'Konsta Vesterinen', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'sqlalchemy-utils', u'SQLAlchemy-Utils Documentation', [u'Konsta Vesterinen'], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ('index', 'SQLAlchemy-Utils', u'SQLAlchemy-Utils Documentation', u'Konsta Vesterinen', 'SQLAlchemy-Utils', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'http://docs.python.org/': None} sqlalchemy-utils-0.36.1/docs/data_types.rst000066400000000000000000000047161360007755400207370ustar00rootroot00000000000000Data types ========== SQLAlchemy-Utils provides various new data types for SQLAlchemy. In order to gain full advantage of these datatypes you should use automatic data coercion. See :func:`force_auto_coercion` for how to set up this feature. .. module:: sqlalchemy_utils.types ArrowType --------- .. module:: sqlalchemy_utils.types.arrow .. autoclass:: ArrowType ChoiceType ---------- .. module:: sqlalchemy_utils.types.choice .. autoclass:: ChoiceType ColorType --------- .. module:: sqlalchemy_utils.types.color .. autoclass:: ColorType CompositeType ------------- .. automodule:: sqlalchemy_utils.types.pg_composite .. autoclass:: CompositeType CountryType ----------- .. module:: sqlalchemy_utils.types.country .. autoclass:: CountryType .. module:: sqlalchemy_utils.primitives.country .. autoclass:: Country CurrencyType ------------ .. module:: sqlalchemy_utils.types.currency .. autoclass:: CurrencyType .. module:: sqlalchemy_utils.primitives.currency .. autoclass:: Currency EmailType --------- .. automodule:: sqlalchemy_utils.types.email .. autoclass:: EmailType EncryptedType ------------- .. module:: sqlalchemy_utils.types.encrypted.encrypted_type .. autoclass:: EncryptedType JSONType -------- .. module:: sqlalchemy_utils.types.json .. autoclass:: JSONType LocaleType ---------- .. module:: sqlalchemy_utils.types.locale .. autoclass:: LocaleType LtreeType --------- .. module:: sqlalchemy_utils.types.ltree .. autoclass:: LtreeType .. module:: sqlalchemy_utils.primitives.ltree .. autoclass:: Ltree IPAddressType ------------- .. module:: sqlalchemy_utils.types.ip_address .. autoclass:: IPAddressType PasswordType ------------ .. module:: sqlalchemy_utils.types.password .. autoclass:: PasswordType PhoneNumberType --------------- .. module:: sqlalchemy_utils.types.phone_number .. autoclass:: PhoneNumber .. autoclass:: PhoneNumberType ScalarListType -------------- .. module:: sqlalchemy_utils.types.scalar_list .. autoclass:: ScalarListType TimezoneType ------------ .. module:: sqlalchemy_utils.types.timezone .. autoclass:: TimezoneType TSVectorType ------------ .. module:: sqlalchemy_utils.types.ts_vector .. autoclass:: TSVectorType URLType ------- .. module:: sqlalchemy_utils.types.url .. autoclass:: URLType UUIDType -------- .. module:: sqlalchemy_utils.types.uuid .. autoclass:: UUIDType WeekDaysType ------------ .. module:: sqlalchemy_utils.types.weekdays .. autoclass:: WeekDaysType sqlalchemy-utils-0.36.1/docs/database_helpers.rst000066400000000000000000000011771360007755400220660ustar00rootroot00000000000000Database helpers ================ .. module:: sqlalchemy_utils.functions analyze ------- .. autofunction:: analyze database_exists --------------- .. autofunction:: database_exists create_database --------------- .. autofunction:: create_database drop_database ------------- .. autofunction:: drop_database has_index --------- .. autofunction:: has_index has_unique_index ---------------- .. autofunction:: has_unique_index json_sql -------- .. autofunction:: json_sql render_expression ----------------- .. autofunction:: render_expression render_statement ---------------- .. autofunction:: render_statement sqlalchemy-utils-0.36.1/docs/foreign_key_helpers.rst000066400000000000000000000011271360007755400226160ustar00rootroot00000000000000Foreign key helpers =================== .. module:: sqlalchemy_utils.functions dependent_objects ----------------- .. autofunction:: dependent_objects get_referencing_foreign_keys ---------------------------- .. autofunction:: get_referencing_foreign_keys group_foreign_keys ------------------ .. autofunction:: group_foreign_keys is_indexed_foreign_key ---------------------- .. autofunction:: is_indexed_foreign_key merge_references ---------------- .. autofunction:: merge_references non_indexed_foreign_keys ------------------------ .. autofunction:: non_indexed_foreign_keys sqlalchemy-utils-0.36.1/docs/generic_relationship.rst000066400000000000000000000106131360007755400227700ustar00rootroot00000000000000Generic relationships ===================== Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. :: from sqlalchemy_utils import generic_relationship class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) class Customer(Base): __tablename__ = 'customer' id = sa.Column(sa.Integer, primary_key=True) class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) # This is used to discriminate between the linked tables. object_type = sa.Column(sa.Unicode(255)) # This is used to point to the primary key of the linked row. object_id = sa.Column(sa.Integer) object = generic_relationship(object_type, object_id) # Some general usage to attach an event to a user. user = User() customer = Customer() session.add_all([user, customer]) session.commit() ev = Event() ev.object = user session.add(ev) session.commit() # Find the event we just made. session.query(Event).filter_by(object=user).first() # Find any events that are bound to users. session.query(Event).filter(Event.object.is_type(User)).all() Inheritance ----------- :: class Employee(self.Base): __tablename__ = 'employee' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(50)) type = sa.Column(sa.String(20)) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'employee' } class Manager(Employee): __mapper_args__ = { 'polymorphic_identity': 'manager' } class Engineer(Employee): __mapper_args__ = { 'polymorphic_identity': 'engineer' } class Activity(self.Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) object = generic_relationship(object_type, object_id) Now same as before we can add some objects:: manager = Manager() session.add(manager) session.commit() activity = Activity() activity.object = manager session.add(activity) session.commit() # Find the activity we just made. session.query(Event).filter_by(object=manager).first() We can even test super types:: session.query(Activity).filter(Event.object.is_type(Employee)).all() Abstract base classes --------------------- Generic relationships also allows using string arguments. When using generic_relationship with abstract base classes you need to set up the relationship using declared_attr decorator and string arguments. :: class Building(self.Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) class EventBase(self.Base): __abstract__ = True object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) @declared_attr def object(cls): return generic_relationship('object_type', 'object_id') class Event(EventBase): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) Composite keys -------------- For some very rare cases you may need to use generic_relationships with composite primary keys. There is a limitation here though: you can only set up generic_relationship for similar composite primary key types. In other words you can't mix generic relationship to both composite keyed objects and single keyed objects. :: from sqlalchemy_utils import generic_relationship class Customer(Base): __tablename__ = 'customer' code1 = sa.Column(sa.Integer, primary_key=True) code2 = sa.Column(sa.Integer, primary_key=True) class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) # This is used to discriminate between the linked tables. object_type = sa.Column(sa.Unicode(255)) object_code1 = sa.Column(sa.Integer) object_code2 = sa.Column(sa.Integer) object = generic_relationship( object_type, (object_code1, object_code2) ) sqlalchemy-utils-0.36.1/docs/index.rst000066400000000000000000000006351360007755400177050ustar00rootroot00000000000000SQLAlchemy-Utils ================ SQLAlchemy-Utils provides custom data types and various utility functions for SQLAlchemy. .. toctree:: :maxdepth: 2 installation listeners data_types range_data_types aggregates observers internationalization generic_relationship database_helpers foreign_key_helpers orm_helpers utility_classes models view testing license sqlalchemy-utils-0.36.1/docs/installation.rst000066400000000000000000000025661360007755400213040ustar00rootroot00000000000000Installation ============ This part of the documentation covers the installation of SQLAlchemy-Utils. Supported platforms ------------------- SQLAlchemy-Utils has been tested against the following Python platforms. - cPython 2.6 (unsupported since 0.32) - cPython 2.7 - cPython 3.3 - cPython 3.4 - cPython 3.5 - cPython 3.6 Installing an official release ------------------------------ You can install the most recent official SQLAlchemy-Utils version using pip_:: pip install sqlalchemy-utils # Use `pip3` instead of `pip` for Python 3.x .. _pip: http://www.pip-installer.org/ Installing the development version ---------------------------------- To install the latest version of SQLAlchemy-Utils, you need first obtain a copy of the source. You can do that by cloning the git_ repository:: git clone git://github.com/kvesteri/sqlalchemy-utils.git Then you can install the source distribution using pip:: cd sqlalchemy-utils pip install -e . # Use `pip3` instead of `pip` for Python 3.x .. _git: http://git-scm.org/ Checking the installation ------------------------- To check that SQLAlchemy-Utils has been properly installed, type ``python`` from your shell. Then at the Python prompt, try to import SQLAlchemy-Utils, and check the installed version: .. parsed-literal:: >>> import sqlalchemy_utils >>> sqlalchemy_utils.__version__ |release| sqlalchemy-utils-0.36.1/docs/internationalization.rst000066400000000000000000000105701360007755400230420ustar00rootroot00000000000000Internationalization ==================== SQLAlchemy-Utils provides a way for modeling translatable models. Model is translatable if one or more of its columns can be displayed in various languages. .. note:: The implementation is currently highly PostgreSQL specific since it needs a dict-compatible column type (PostgreSQL HSTORE and JSON are such types). If you want database-agnostic way of modeling i18n see `SQLAlchemy-i18n`_. TranslationHybrid vs SQLAlchemy-i18n ------------------------------------ Compared to SQLAlchemy-i18n the TranslationHybrid has the following pros and cons: * Usually faster since no joins are needed for fetching the data * Less magic * Easier to understand data model * Only PostgreSQL supported for now Quickstart ---------- Let's say we have an Article model with translatable name and content. First we need to define the TranslationHybrid. :: from sqlalchemy_utils import TranslationHybrid # For testing purposes we define this as simple function which returns # locale 'fi'. Usually you would define this function as something that # returns the user's current locale. def get_locale(): return 'fi' translation_hybrid = TranslationHybrid( current_locale=get_locale, default_locale='en' ) Then we can define the model.:: from sqlalchemy import * from sqlalchemy.dialects.postgresql import HSTORE class Article(Base): __tablename__ = 'article' id = Column(Integer, primary_key=True) name_translations = Column(HSTORE) content_translations = Column(HSTORE) name = translation_hybrid(name_translations) content = translation_hybrid(content_translations) Now we can start using our translatable model. By assigning things to translatable hybrids you are assigning them to the locale returned by the `current_locale`. :: article = Article(name='Joku artikkeli') article.name_translations['fi'] # Joku artikkeli article.name # Joku artikkeli If you access the hybrid with a locale that doesn't exist the hybrid tries to fetch a the locale returned by `default_locale`. :: article = Article(name_translations={'en': 'Some article'}) article.name # Some article article.name_translations['fi'] = 'Joku artikkeli' article.name # Joku artikkeli Translation hybrids can also be used as expressions. :: session.query(Article).filter(Article.name_translations['en'] == 'Some article') By default if no value is found for either current or default locale the translation hybrid returns `None`. You can customize this value with `default_value` parameter of translation_hybrid. In the following example we make translation hybrid fallback to empty string instead of `None`. :: translation_hybrid = TranslationHybrid( current_locale=get_locale, default_locale='en', default_value='' ) class Article(Base): __tablename__ = 'article' id = Column(Integer, primary_key=True) name_translations = Column(HSTORE) name = translation_hybrid(name_translations, default) Article().name # '' Dynamic locales --------------- Sometimes locales need to be dynamic. The following example illustrates how to setup dynamic locales. You can pass a callable of either 0, 1 or 2 args as a constructor parameter for TranslationHybrid. The first argument should be the associated object and second parameter the name of the translations attribute. :: translation_hybrid = TranslationHybrid( current_locale=get_locale, default_locale=lambda obj: obj.locale, ) class Article(Base): __tablename__ = 'article' id = Column(Integer, primary_key=True) name_translations = Column(HSTORE) name = translation_hybrid(name_translations, default) locale = Column(String) article = Article(name_translations={'en': 'Some article'}) article.locale = 'en' session.add(article) session.commit() article.name # Some article (even if current locale is other than 'en') The locales can also be attribute dependent so you can set up translation hybrid in a way that it is guaranteed to return a translation. :: translation_hybrid.default_locale = lambda obj, attr: sorted(getattr(obj, attr).keys())[0] article.name # Some article .. _SQLAlchemy-i18n: https://github.com/kvesteri/sqlalchemy-i18n sqlalchemy-utils-0.36.1/docs/license.rst000066400000000000000000000000511360007755400202100ustar00rootroot00000000000000License ======= .. include:: ../LICENSE sqlalchemy-utils-0.36.1/docs/listeners.rst000066400000000000000000000005061360007755400206030ustar00rootroot00000000000000Listeners ========= .. module:: sqlalchemy_utils.listeners Automatic data coercion ----------------------- .. autofunction:: force_auto_coercion Instant defaults ---------------- .. autofunction:: force_instant_defaults Many-to-many orphan deletion ---------------------------- .. autofunction:: auto_delete_orphans sqlalchemy-utils-0.36.1/docs/make.bat000066400000000000000000000117741360007755400174570ustar00rootroot00000000000000@ECHO OFF REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set BUILDDIR=_build set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set I18NSPHINXOPTS=%SPHINXOPTS% . if NOT "%PAPER%" == "" ( set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% ) if "%1" == "" goto help if "%1" == "help" ( :help echo.Please use `make ^` where ^ is one of echo. html to make standalone HTML files echo. dirhtml to make HTML files named index.html in directories echo. singlehtml to make a single large HTML file echo. pickle to make pickle files echo. json to make JSON files echo. htmlhelp to make HTML files and a HTML help project echo. qthelp to make HTML files and a qthelp project echo. devhelp to make HTML files and a Devhelp project echo. epub to make an epub echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. text to make text files echo. man to make manual pages echo. texinfo to make Texinfo files echo. gettext to make PO message catalogs echo. changes to make an overview over all changed/added/deprecated items echo. linkcheck to check all external links for integrity echo. doctest to run all doctests embedded in the documentation if enabled goto end ) if "%1" == "clean" ( for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i del /q /s %BUILDDIR%\* goto end ) if "%1" == "html" ( %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/html. goto end ) if "%1" == "dirhtml" ( %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. goto end ) if "%1" == "singlehtml" ( %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. goto end ) if "%1" == "pickle" ( %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the pickle files. goto end ) if "%1" == "json" ( %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the JSON files. goto end ) if "%1" == "htmlhelp" ( %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run HTML Help Workshop with the ^ .hhp project file in %BUILDDIR%/htmlhelp. goto end ) if "%1" == "qthelp" ( %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run "qcollectiongenerator" with the ^ .qhcp project file in %BUILDDIR%/qthelp, like this: echo.^> qcollectiongenerator %BUILDDIR%\qthelp\SQLAlchemy-Utils.qhcp echo.To view the help file: echo.^> assistant -collectionFile %BUILDDIR%\qthelp\SQLAlchemy-Utils.ghc goto end ) if "%1" == "devhelp" ( %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp if errorlevel 1 exit /b 1 echo. echo.Build finished. goto end ) if "%1" == "epub" ( %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub if errorlevel 1 exit /b 1 echo. echo.Build finished. The epub file is in %BUILDDIR%/epub. goto end ) if "%1" == "latex" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex if errorlevel 1 exit /b 1 echo. echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. goto end ) if "%1" == "text" ( %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text if errorlevel 1 exit /b 1 echo. echo.Build finished. The text files are in %BUILDDIR%/text. goto end ) if "%1" == "man" ( %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man if errorlevel 1 exit /b 1 echo. echo.Build finished. The manual pages are in %BUILDDIR%/man. goto end ) if "%1" == "texinfo" ( %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo if errorlevel 1 exit /b 1 echo. echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. goto end ) if "%1" == "gettext" ( %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale if errorlevel 1 exit /b 1 echo. echo.Build finished. The message catalogs are in %BUILDDIR%/locale. goto end ) if "%1" == "changes" ( %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes if errorlevel 1 exit /b 1 echo. echo.The overview file is in %BUILDDIR%/changes. goto end ) if "%1" == "linkcheck" ( %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck if errorlevel 1 exit /b 1 echo. echo.Link check complete; look for any errors in the above output ^ or in %BUILDDIR%/linkcheck/output.txt. goto end ) if "%1" == "doctest" ( %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest if errorlevel 1 exit /b 1 echo. echo.Testing of doctests in the sources finished, look at the ^ results in %BUILDDIR%/doctest/output.txt. goto end ) :end sqlalchemy-utils-0.36.1/docs/models.rst000066400000000000000000000003201360007755400200500ustar00rootroot00000000000000Model mixins ============ Timestamp --------- .. module:: sqlalchemy_utils.models .. autoclass:: Timestamp generic_repr ------------ .. module:: sqlalchemy_utils.models .. autofunction:: generic_repr sqlalchemy-utils-0.36.1/docs/observers.rst000066400000000000000000000001331360007755400206010ustar00rootroot00000000000000Observers ========= .. automodule:: sqlalchemy_utils.observer .. autofunction:: observes sqlalchemy-utils-0.36.1/docs/orm_helpers.rst000066400000000000000000000024621360007755400211150ustar00rootroot00000000000000ORM helpers =========== .. module:: sqlalchemy_utils.functions cast_if ------- .. autofunction:: cast_if escape_like ----------- .. autofunction:: escape_like get_bind -------- .. autofunction:: get_bind get_class_by_table ------------------ .. autofunction:: get_class_by_table get_column_key -------------- .. autofunction:: get_column_key get_columns ----------- .. autofunction:: get_columns get_declarative_base -------------------- .. autofunction:: get_declarative_base get_hybrid_properties --------------------- .. autofunction:: get_hybrid_properties get_mapper ---------- .. autofunction:: get_mapper get_query_entities ------------------ .. autofunction:: get_query_entities get_primary_keys ---------------- .. autofunction:: get_primary_keys get_tables ---------- .. autofunction:: get_tables get_type -------- .. autofunction:: get_type has_changes ----------- .. autofunction:: has_changes identity -------- .. autofunction:: identity is_loaded --------- .. autofunction:: is_loaded make_order_by_deterministic --------------------------- .. autofunction:: make_order_by_deterministic naturally_equivalent -------------------- .. autofunction:: naturally_equivalent quote ----- .. autofunction:: quote sort_query ---------- .. autofunction:: sort_query sqlalchemy-utils-0.36.1/docs/range_data_types.rst000066400000000000000000000006441360007755400221070ustar00rootroot00000000000000Range data types ================ .. automodule:: sqlalchemy_utils.types.range DateRangeType ------------- .. autoclass:: DateRangeType DateTimeRangeType ----------------- .. autoclass:: DateTimeRangeType IntRangeType ------------ .. autoclass:: IntRangeType NumericRangeType ---------------- .. autoclass:: NumericRangeType RangeComparator --------------- .. autoclass:: RangeComparator :members: sqlalchemy-utils-0.36.1/docs/testing.rst000066400000000000000000000006471360007755400202560ustar00rootroot00000000000000Testing ======= .. automodule:: sqlalchemy_utils.asserts assert_min_value ---------------- .. autofunction:: assert_min_value assert_max_length ----------------- .. autofunction:: assert_max_length assert_max_value ---------------- .. autofunction:: assert_max_value assert_nullable --------------- .. autofunction:: assert_nullable assert_non_nullable ------------------- .. autofunction:: assert_non_nullable sqlalchemy-utils-0.36.1/docs/utility_classes.rst000066400000000000000000000002271360007755400220130ustar00rootroot00000000000000Utility classes =============== QueryChain ---------- .. automodule:: sqlalchemy_utils.query_chain API --- .. autoclass:: QueryChain :members: sqlalchemy-utils-0.36.1/docs/view.rst000066400000000000000000000004701360007755400175450ustar00rootroot00000000000000View utilities ============== .. module:: sqlalchemy_utils create_view ----------- .. autofunction:: create_view create_materialized_view ------------------------ .. autofunction:: create_materialized_view refresh_materialized_view ------------------------- .. autofunction:: refresh_materialized_view sqlalchemy-utils-0.36.1/setup.cfg000066400000000000000000000000341360007755400167260ustar00rootroot00000000000000[bdist_wheel] universal = 1 sqlalchemy-utils-0.36.1/setup.py000066400000000000000000000055111360007755400166240ustar00rootroot00000000000000""" SQLAlchemy-Utils ---------------- Various utility functions and custom data types for SQLAlchemy. """ from setuptools import setup, find_packages import os import re import sys HERE = os.path.dirname(os.path.abspath(__file__)) PY3 = sys.version_info[0] == 3 def get_version(): filename = os.path.join(HERE, 'sqlalchemy_utils', '__init__.py') with open(filename) as f: contents = f.read() pattern = r"^__version__ = '(.*?)'$" return re.search(pattern, contents, re.MULTILINE).group(1) extras_require = { 'test': [ 'pytest>=2.7.1', 'Pygments>=1.2', 'Jinja2>=2.3', 'docutils>=0.10', 'flexmock>=0.9.7', 'mock==2.0.0', 'psycopg2>=2.5.1', 'pg8000>=1.12.4', 'pytz>=2014.2', 'python-dateutil>=2.6', 'pymysql', 'flake8>=2.4.0', 'isort>=4.2.2', 'pyodbc', ], 'anyjson': ['anyjson>=0.3.3'], 'babel': ['Babel>=1.3'], 'arrow': ['arrow>=0.3.4'], 'intervals': ['intervals>=0.7.1'], 'phone': ['phonenumbers>=5.9.2'], 'password': ['passlib >= 1.6, < 2.0'], 'color': ['colour>=0.0.4'], 'ipaddress': ['ipaddr'] if not PY3 else [], 'enum': ['enum34'] if sys.version_info < (3, 4) else [], 'timezone': ['python-dateutil'], 'url': ['furl >= 0.4.1'], 'encrypted': ['cryptography>=0.6'] } # Add all optional dependencies to testing requirements. test_all = [] for name, requirements in sorted(extras_require.items()): test_all += requirements extras_require['test_all'] = test_all setup( name='SQLAlchemy-Utils', version=get_version(), url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen', author_email='konsta@fastmonkeys.com', description=( 'Various utility functions for SQLAlchemy.' ), long_description=__doc__, packages=find_packages('.', exclude=['tests', 'tests.*']), zip_safe=False, include_package_data=True, platforms='any', install_requires=[ 'six', 'SQLAlchemy>=1.0' ], extras_require=extras_require, classifiers=[ 'Environment :: Web Environment', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Software Development :: Libraries :: Python Modules' ] ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/000077500000000000000000000000001360007755400204725ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/__init__.py000066400000000000000000000044371360007755400226130ustar00rootroot00000000000000from .aggregates import aggregated # noqa from .asserts import ( # noqa assert_max_length, assert_max_value, assert_min_value, assert_non_nullable, assert_nullable ) from .exceptions import ImproperlyConfigured # noqa from .expressions import Asterisk, row_to_json # noqa from .functions import ( # noqa cast_if, create_database, create_mock_engine, database_exists, dependent_objects, drop_database, escape_like, get_bind, get_class_by_table, get_column_key, get_columns, get_declarative_base, get_fk_constraint_for_columns, get_hybrid_properties, get_mapper, get_primary_keys, get_query_entities, get_referencing_foreign_keys, get_tables, get_type, group_foreign_keys, has_changes, has_index, has_unique_index, identity, is_loaded, json_sql, merge_references, mock_engine, naturally_equivalent, render_expression, render_statement, sort_query, table_name ) from .generic import generic_relationship # noqa from .i18n import TranslationHybrid # noqa from .listeners import ( # noqa auto_delete_orphans, coercion_listener, force_auto_coercion, force_instant_defaults ) from .models import generic_repr, Timestamp # noqa from .observer import observes # noqa from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa from .proxy_dict import proxy_dict, ProxyDict # noqa from .query_chain import QueryChain # noqa from .types import ( # noqa ArrowType, Choice, ChoiceType, ColorType, CompositeArray, CompositeType, CountryType, CurrencyType, DateRangeType, DateTimeRangeType, EmailType, EncryptedType, instrumented_list, InstrumentedList, Int8RangeType, IntRangeType, IPAddressType, JSONType, LocaleType, LtreeType, NumericRangeType, Password, PasswordType, PhoneNumber, PhoneNumberParseException, PhoneNumberType, register_composites, remove_composite_listeners, ScalarListException, ScalarListType, TimezoneType, TSVectorType, URLType, UUIDType, WeekDaysType ) from .view import ( # noqa create_materialized_view, create_view, refresh_materialized_view ) __version__ = '0.36.1' sqlalchemy-utils-0.36.1/sqlalchemy_utils/aggregates.py000066400000000000000000000360221360007755400231600ustar00rootroot00000000000000""" SQLAlchemy-Utils provides way of automatically calculating aggregate values of related models and saving them to parent model. This solution is inspired by RoR counter cache, `counter_culture`_ and `stackoverflow reply by Michael Bayer`_. Why? ---- Many times you may have situations where you need to calculate dynamically some aggregate value for given model. Some simple examples include: - Number of products in a catalog - Average rating for movie - Latest forum post - Total price of orders for given customer Now all these aggregates can be elegantly implemented with SQLAlchemy column_property_ function. However when your data grows calculating these values on the fly might start to hurt the performance of your application. The more aggregates you are using the more performance penalty you get. This module provides way of calculating these values automatically and efficiently at the time of modification rather than on the fly. Features -------- * Automatically updates aggregate columns when aggregated values change * Supports aggregate values through arbitrary number levels of relations * Highly optimized: uses single query per transaction per aggregate column * Aggregated columns can be of any data type and use any selectable scalar expression .. _column_property: http://docs.sqlalchemy.org/en/latest/orm/mapper_config.html#using-column-property .. _counter_culture: https://github.com/magnusvk/counter_culture .. _stackoverflow reply by Michael Bayer: http://stackoverflow.com/questions/13693872/ Simple aggregates ----------------- :: from sqlalchemy_utils import aggregated class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('comments', sa.Column(sa.Integer)) def comment_count(self): return sa.func.count('1') comments = sa.orm.relationship( 'Comment', backref='thread' ) class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.UnicodeText) thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id)) thread = Thread(name=u'SQLAlchemy development') thread.comments.append(Comment(u'Going good!')) thread.comments.append(Comment(u'Great new features!')) session.add(thread) session.commit() thread.comment_count # 2 Custom aggregate expressions ---------------------------- Aggregate expression can be virtually any SQL expression not just a simple function taking one parameter. You can try things such as subqueries and different kinds of functions. In the following example we have a Catalog of products where each catalog knows the net worth of its products. :: from sqlalchemy_utils import aggregated class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('products', sa.Column(sa.Integer)) def net_worth(self): return sa.func.sum(Product.price) products = sa.orm.relationship('Product') class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) Now the net_worth column of Catalog model will be automatically whenever: * A new product is added to the catalog * A product is deleted from the catalog * The price of catalog product is changed :: from decimal import Decimal product1 = Product(name='Some product', price=Decimal(1000)) product2 = Product(name='Some other product', price=Decimal(500)) catalog = Catalog( name=u'My first catalog', products=[ product1, product2 ] ) session.add(catalog) session.commit() session.refresh(catalog) catalog.net_worth # 1500 session.delete(product2) session.commit() session.refresh(catalog) catalog.net_worth # 1000 product1.price = 2000 session.commit() session.refresh(catalog) catalog.net_worth # 2000 Multiple aggregates per class ----------------------------- Sometimes you may need to define multiple aggregate values for same class. If you need to define lots of relationships pointing to same class, remember to define the relationships as viewonly when possible. :: from sqlalchemy_utils import aggregated class Customer(Base): __tablename__ = 'customer' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('orders', sa.Column(sa.Integer)) def orders_sum(self): return sa.func.sum(Order.price) @aggregated('invoiced_orders', sa.Column(sa.Integer)) def invoiced_orders_sum(self): return sa.func.sum(Order.price) orders = sa.orm.relationship('Order') invoiced_orders = sa.orm.relationship( 'Order', primaryjoin= 'sa.and_(Order.customer_id == Customer.id, Order.invoiced)', viewonly=True ) class Order(Base): __tablename__ = 'order' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) invoiced = sa.Column(sa.Boolean, default=False) customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id)) Many-to-Many aggregates ----------------------- Aggregate expressions also support many-to-many relationships. The usual use scenarios includes things such as: 1. Friend count of a user 2. Group count where given user belongs to :: user_group = sa.Table('user_group', Base.metadata, sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) ) class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('groups', sa.Column(sa.Integer, default=0)) def group_count(self): return sa.func.count('1') groups = sa.orm.relationship( 'Group', backref='users', secondary=user_group ) class Group(Base): __tablename__ = 'group' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) user = User(name=u'John Matrix') user.groups = [Group(name=u'Group A'), Group(name=u'Group B')] session.add(user) session.commit() session.refresh(user) user.group_count # 2 Multi-level aggregates ---------------------- Aggregates can span across multiple relationships. In the following example each Catalog has a net_worth which is the sum of all products in all categories. :: from sqlalchemy_utils import aggregated class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('categories.products', sa.Column(sa.Integer)) def net_worth(self): return sa.func.sum(Product.price) categories = sa.orm.relationship('Category') class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id)) products = sa.orm.relationship('Product') class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) Examples -------- Average movie rating ^^^^^^^^^^^^^^^^^^^^ :: from sqlalchemy_utils import aggregated class Movie(Base): __tablename__ = 'movie' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('ratings', sa.Column(sa.Numeric)) def avg_rating(self): return sa.func.avg(Rating.stars) ratings = sa.orm.relationship('Rating') class Rating(Base): __tablename__ = 'rating' id = sa.Column(sa.Integer, primary_key=True) stars = sa.Column(sa.Integer) movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id)) movie = Movie('Terminator 2') movie.ratings.append(Rating(stars=5)) movie.ratings.append(Rating(stars=4)) movie.ratings.append(Rating(stars=3)) session.add(movie) session.commit() movie.avg_rating # 4 TODO ---- * Special consideration should be given to `deadlocks`_. .. _deadlocks: http://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html """ from collections import defaultdict from weakref import WeakKeyDictionary import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.sql.functions import _FunctionGenerator from .functions.orm import get_column_key from .relationships import ( chained_join, path_to_relationships, select_correlated_expression ) aggregated_attrs = WeakKeyDictionary() class AggregatedAttribute(declared_attr): def __init__( self, fget, relationship, column, *args, **kwargs ): super(AggregatedAttribute, self).__init__(fget, *args, **kwargs) self.__doc__ = fget.__doc__ self.column = column self.relationship = relationship def __get__(desc, self, cls): value = (desc.fget, desc.relationship, desc.column) if cls not in aggregated_attrs: aggregated_attrs[cls] = [value] else: aggregated_attrs[cls].append(value) return desc.column def local_condition(prop, objects): pairs = prop.local_remote_pairs if prop.secondary is not None: parent_column = pairs[1][0] fetched_column = pairs[1][0] else: parent_column = pairs[0][0] fetched_column = pairs[0][1] key = get_column_key(prop.mapper, fetched_column) values = [] for obj in objects: try: values.append(getattr(obj, key)) except sa.orm.exc.ObjectDeletedError: pass if values: return parent_column.in_(values) def aggregate_expression(expr, class_): if isinstance(expr, sa.sql.visitors.Visitable): return expr elif isinstance(expr, _FunctionGenerator): return expr(sa.sql.text('1')) else: return expr(class_) class AggregatedValue(object): def __init__(self, class_, attr, path, expr): self.class_ = class_ self.attr = attr self.path = path self.relationships = list( reversed(path_to_relationships(path, class_)) ) self.expr = aggregate_expression(expr, class_) @property def aggregate_query(self): query = select_correlated_expression( self.class_, self.expr, self.path, self.relationships[0].mapper.class_ ) return query.as_scalar() def update_query(self, objects): table = self.class_.__table__ query = table.update().values( {self.attr: self.aggregate_query} ) if len(self.relationships) == 1: prop = self.relationships[-1].property condition = local_condition(prop, objects) if condition is not None: return query.where(condition) else: # Builds query such as: # # UPDATE catalog SET product_count = (aggregate_query) # WHERE id IN ( # SELECT catalog_id # FROM category # INNER JOIN sub_category # ON category.id = sub_category.category_id # WHERE sub_category.id IN (product_sub_category_ids) # ) property_ = self.relationships[-1].property remote_pairs = property_.local_remote_pairs local = remote_pairs[0][0] remote = remote_pairs[0][1] condition = local_condition( self.relationships[0].property, objects ) if condition is not None: return query.where( local.in_( sa.select( [remote], from_obj=[ chained_join(*reversed(self.relationships)) ] ).where( condition ) ) ) class AggregationManager(object): def __init__(self): self.reset() def reset(self): self.generator_registry = defaultdict(list) def register_listeners(self): sa.event.listen( sa.orm.mapper, 'after_configured', self.update_generator_registry ) sa.event.listen( sa.orm.session.Session, 'after_flush', self.construct_aggregate_queries ) def update_generator_registry(self): for class_, attrs in aggregated_attrs.items(): for expr, path, column in attrs: value = AggregatedValue( class_=class_, attr=column, path=path, expr=expr(class_) ) key = value.relationships[0].mapper.class_ self.generator_registry[key].append( value ) def construct_aggregate_queries(self, session, ctx): object_dict = defaultdict(list) for obj in session: for class_ in self.generator_registry: if isinstance(obj, class_): object_dict[class_].append(obj) for class_, objects in object_dict.items(): for aggregate_value in self.generator_registry[class_]: query = aggregate_value.update_query(objects) if query is not None: session.execute(query) manager = AggregationManager() manager.register_listeners() def aggregated( relationship, column ): """ Decorator that generates an aggregated attribute. The decorated function should return an aggregate select expression. :param relationship: Defines the relationship of which the aggregate is calculated from. The class needs to have given relationship in order to calculate the aggregate. :param column: SQLAlchemy Column object. The column definition of this aggregate attribute. """ def wraps(func): return AggregatedAttribute( func, relationship, column ) return wraps sqlalchemy-utils-0.36.1/sqlalchemy_utils/asserts.py000066400000000000000000000124021360007755400225270ustar00rootroot00000000000000""" The functions in this module can be used for testing that the constraints of your models. Each assert function runs SQL UPDATEs that check for the existence of given constraint. Consider the following model:: class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(200), nullable=True) email = sa.Column(sa.String(255), nullable=False) user = User(name='John Doe', email='john@example.com') session.add(user) session.commit() We can easily test the constraints by assert_* functions:: from sqlalchemy_utils import ( assert_nullable, assert_non_nullable, assert_max_length ) assert_nullable(user, 'name') assert_non_nullable(user, 'email') assert_max_length(user, 'name', 200) # raises AssertionError because the max length of email is 255 assert_max_length(user, 'email', 300) """ from decimal import Decimal import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.exc import DataError, IntegrityError def _update_field(obj, field, value): session = sa.orm.object_session(obj) column = sa.inspect(obj.__class__).columns[field] query = column.table.update().values(**{column.key: value}) session.execute(query) session.flush() def _expect_successful_update(obj, field, value, reraise_exc): try: _update_field(obj, field, value) except (reraise_exc) as e: session = sa.orm.object_session(obj) session.rollback() assert False, str(e) def _expect_failing_update(obj, field, value, expected_exc): try: _update_field(obj, field, value) except expected_exc: pass else: raise AssertionError('Expected update to raise %s' % expected_exc) finally: session = sa.orm.object_session(obj) session.rollback() def _repeated_value(type_): if isinstance(type_, ARRAY): if isinstance(type_.item_type, sa.Integer): return [0] elif isinstance(type_.item_type, sa.String): return [u'a'] elif isinstance(type_.item_type, sa.Numeric): return [Decimal('0')] else: raise TypeError('Unknown array item type') else: return u'a' def _expected_exception(type_): if isinstance(type_, ARRAY): return IntegrityError else: return DataError def assert_nullable(obj, column): """ Assert that given column is nullable. This is checked by running an SQL update that assigns given column as None. :param obj: SQLAlchemy declarative model object :param column: Name of the column """ _expect_successful_update(obj, column, None, IntegrityError) def assert_non_nullable(obj, column): """ Assert that given column is not nullable. This is checked by running an SQL update that assigns given column as None. :param obj: SQLAlchemy declarative model object :param column: Name of the column """ _expect_failing_update(obj, column, None, IntegrityError) def assert_max_length(obj, column, max_length): """ Assert that the given column is of given max length. This function supports string typed columns as well as PostgreSQL array typed columns. In the following example we add a check constraint that user can have a maximum of 5 favorite colors and then test this.:: class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) favorite_colors = sa.Column(ARRAY(sa.String), nullable=False) __table_args__ = ( sa.CheckConstraint( sa.func.array_length(favorite_colors, 1) <= 5 ) ) user = User(name='John Doe', favorite_colors=['red', 'blue']) session.add(user) session.commit() assert_max_length(user, 'favorite_colors', 5) :param obj: SQLAlchemy declarative model object :param column: Name of the column :param max_length: Maximum length of given column """ type_ = sa.inspect(obj.__class__).columns[column].type _expect_successful_update( obj, column, _repeated_value(type_) * max_length, _expected_exception(type_) ) _expect_failing_update( obj, column, _repeated_value(type_) * (max_length + 1), _expected_exception(type_) ) def assert_min_value(obj, column, min_value): """ Assert that the given column must have a minimum value of `min_value`. :param obj: SQLAlchemy declarative model object :param column: Name of the column :param min_value: The minimum allowed value for given column """ _expect_successful_update(obj, column, min_value, IntegrityError) _expect_failing_update(obj, column, min_value - 1, IntegrityError) def assert_max_value(obj, column, min_value): """ Assert that the given column must have a minimum value of `max_value`. :param obj: SQLAlchemy declarative model object :param column: Name of the column :param max_value: The maximum allowed value for given column """ _expect_successful_update(obj, column, min_value, IntegrityError) _expect_failing_update(obj, column, min_value + 1, IntegrityError) sqlalchemy-utils-0.36.1/sqlalchemy_utils/exceptions.py000066400000000000000000000003451360007755400232270ustar00rootroot00000000000000""" Global SQLAlchemy-Utils exception classes. """ class ImproperlyConfigured(Exception): """ SQLAlchemy-Utils is improperly configured; normally due to usage of a utility that depends on a missing library. """ sqlalchemy-utils-0.36.1/sqlalchemy_utils/expressions.py000066400000000000000000000031221360007755400234240ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import ColumnElement, FunctionElement from sqlalchemy.sql.functions import GenericFunction from .functions.orm import quote class array_get(FunctionElement): name = 'array_get' @compiles(array_get) def compile_array_get(element, compiler, **kw): args = list(element.clauses) if len(args) != 2: raise Exception( "Function 'array_get' expects two arguments (%d given)." % len(args) ) if not hasattr(args[1], 'value') or not isinstance(args[1].value, int): raise Exception( "Second argument should be an integer." ) return '(%s)[%s]' % ( compiler.process(args[0]), sa.text(str(args[1].value + 1)) ) class row_to_json(GenericFunction): name = 'row_to_json' type = postgresql.JSON @compiles(row_to_json, 'postgresql') def compile_row_to_json(element, compiler, **kw): return "%s(%s)" % (element.name, compiler.process(element.clauses)) class json_array_length(GenericFunction): name = 'json_array_length' type = sa.Integer @compiles(json_array_length, 'postgresql') def compile_json_array_length(element, compiler, **kw): return "%s(%s)" % (element.name, compiler.process(element.clauses)) class Asterisk(ColumnElement): def __init__(self, selectable): self.selectable = selectable @compiles(Asterisk) def compile_asterisk(element, compiler, **kw): return '%s.*' % quote(compiler.dialect, element.selectable.name) sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/000077500000000000000000000000001360007755400225025ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/__init__.py000066400000000000000000000017611360007755400246200ustar00rootroot00000000000000from .database import ( # noqa create_database, database_exists, drop_database, escape_like, has_index, has_unique_index, is_auto_assigned_date_column, json_sql ) from .foreign_keys import ( # noqa dependent_objects, get_fk_constraint_for_columns, get_referencing_foreign_keys, group_foreign_keys, merge_references, non_indexed_foreign_keys ) from .mock import create_mock_engine, mock_engine # noqa from .orm import ( # noqa cast_if, get_bind, get_class_by_table, get_column_key, get_columns, get_declarative_base, get_hybrid_properties, get_mapper, get_primary_keys, get_query_entities, get_tables, get_type, getdotattr, has_changes, identity, is_loaded, naturally_equivalent, quote, table_name ) from .render import render_expression, render_statement # noqa from .sort_query import ( # noqa make_order_by_deterministic, QuerySorterException, sort_query ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/database.py000066400000000000000000000415551360007755400246320ustar00rootroot00000000000000try: from collections.abc import Mapping, Sequence except ImportError: # For python 2.7 support from collections import Mapping, Sequence import itertools import os from copy import copy import sqlalchemy as sa from sqlalchemy.engine.url import make_url from sqlalchemy.exc import OperationalError, ProgrammingError from ..utils import starts_with from .orm import quote def escape_like(string, escape_char='*'): """ Escape the string paremeter used in SQL LIKE expressions. :: from sqlalchemy_utils import escape_like query = session.query(User).filter( User.name.ilike(escape_like('John')) ) :param string: a string to escape :param escape_char: escape character """ return ( string .replace(escape_char, escape_char * 2) .replace('%', escape_char + '%') .replace('_', escape_char + '_') ) def json_sql(value, scalars_to_json=True): """ Convert python data structures to PostgreSQL specific SQLAlchemy JSON constructs. This function is extremly useful if you need to build PostgreSQL JSON on python side. .. note:: This function needs PostgreSQL >= 9.4 Scalars are converted to to_json SQLAlchemy function objects :: json_sql(1) # Equals SQL: to_json(1) json_sql('a') # to_json('a') Mappings are converted to json_build_object constructs :: json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5) Sequences (other than strings) are converted to json_build_array constructs :: json_sql([1, 2, 3]) # json_build_array(1, 2, 3) You can also nest these data structures :: json_sql({'a': [1, 2, 3]}) # json_build_object('a', json_build_array[1, 2, 3]) :param value: value to be converted to SQLAlchemy PostgreSQL function constructs """ scalar_convert = sa.text if scalars_to_json: def scalar_convert(a): return sa.func.to_json(sa.text(a)) if isinstance(value, Mapping): return sa.func.json_build_object( *( json_sql(v, scalars_to_json=False) for v in itertools.chain(*value.items()) ) ) elif isinstance(value, str): return scalar_convert("'{0}'".format(value)) elif isinstance(value, Sequence): return sa.func.json_build_array( *( json_sql(v, scalars_to_json=False) for v in value ) ) elif isinstance(value, (int, float)): return scalar_convert(str(value)) return value def has_index(column_or_constraint): """ Return whether or not given column or the columns of given foreign key constraint have an index. A column has an index if it has a single column index or it is the first column in compound column index. A foreign key constraint has an index if the constraint columns are the first columns in compound column index. :param column_or_constraint: SQLAlchemy Column object or SA ForeignKeyConstraint object .. versionadded: 0.26.2 .. versionchanged: 0.30.18 Added support for foreign key constaints. :: from sqlalchemy_utils import has_index class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.String(100)) is_published = sa.Column(sa.Boolean, index=True) is_deleted = sa.Column(sa.Boolean) is_archived = sa.Column(sa.Boolean) __table_args__ = ( sa.Index('my_index', is_deleted, is_archived), ) table = Article.__table__ has_index(table.c.is_published) # True has_index(table.c.is_deleted) # True has_index(table.c.is_archived) # False Also supports primary key indexes :: from sqlalchemy_utils import has_index class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) locale = sa.Column(sa.String(10), primary_key=True) title = sa.Column(sa.String(100)) table = ArticleTranslation.__table__ has_index(table.c.locale) # False has_index(table.c.id) # True This function supports foreign key constraints as well :: class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name ) ) table = Article.__table__ constraint = list(table.foreign_keys)[0].constraint has_index(constraint) # True """ table = column_or_constraint.table if not isinstance(table, sa.Table): raise TypeError( 'Only columns belonging to Table objects are supported. Given ' 'column belongs to %r.' % table ) primary_keys = table.primary_key.columns.values() if isinstance(column_or_constraint, sa.ForeignKeyConstraint): columns = list(column_or_constraint.columns.values()) else: columns = [column_or_constraint] return ( (primary_keys and starts_with(primary_keys, columns)) or any( starts_with(index.columns.values(), columns) for index in table.indexes ) ) def has_unique_index(column_or_constraint): """ Return whether or not given column or given foreign key constraint has a unique index. A column has a unique index if it has a single column primary key index or it has a single column UniqueConstraint. A foreign key constraint has a unique index if the columns of the constraint are the same as the columns of table primary key or the coluns of any unique index or any unique constraint of the given table. :param column: SQLAlchemy Column object .. versionadded: 0.27.1 .. versionchanged: 0.30.18 Added support for foreign key constaints. Fixed support for unique indexes (previously only worked for unique constraints) :: from sqlalchemy_utils import has_unique_index class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.String(100)) is_published = sa.Column(sa.Boolean, unique=True) is_deleted = sa.Column(sa.Boolean) is_archived = sa.Column(sa.Boolean) table = Article.__table__ has_unique_index(table.c.is_published) # True has_unique_index(table.c.is_deleted) # False has_unique_index(table.c.id) # True This function supports foreign key constraints as well :: class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name, unique=True ) ) table = Article.__table__ constraint = list(table.foreign_keys)[0].constraint has_unique_index(constraint) # True :raises TypeError: if given column does not belong to a Table object """ table = column_or_constraint.table if not isinstance(table, sa.Table): raise TypeError( 'Only columns belonging to Table objects are supported. Given ' 'column belongs to %r.' % table ) primary_keys = list(table.primary_key.columns.values()) if isinstance(column_or_constraint, sa.ForeignKeyConstraint): columns = list(column_or_constraint.columns.values()) else: columns = [column_or_constraint] return ( (columns == primary_keys) or any( columns == list(constraint.columns.values()) for constraint in table.constraints if isinstance(constraint, sa.sql.schema.UniqueConstraint) ) or any( columns == list(index.columns.values()) for index in table.indexes if index.unique ) ) def is_auto_assigned_date_column(column): """ Returns whether or not given SQLAlchemy Column object's is auto assigned DateTime or Date. :param column: SQLAlchemy Column object """ return ( ( isinstance(column.type, sa.DateTime) or isinstance(column.type, sa.Date) ) and ( column.default or column.server_default or column.onupdate or column.server_onupdate ) ) def database_exists(url): """Check if a database exists. :param url: A SQLAlchemy engine URL. Performs backend-specific testing to quickly determine if a database exists on the server. :: database_exists('postgresql://postgres@localhost/name') #=> False create_database('postgresql://postgres@localhost/name') database_exists('postgresql://postgres@localhost/name') #=> True Supports checking against a constructed URL as well. :: engine = create_engine('postgresql://postgres@localhost/name') database_exists(engine.url) #=> False create_database(engine.url) database_exists(engine.url) #=> True """ def get_scalar_result(engine, sql): result_proxy = engine.execute(sql) result = result_proxy.scalar() result_proxy.close() engine.dispose() return result def sqlite_file_exists(database): if not os.path.isfile(database) or os.path.getsize(database) < 100: return False with open(database, 'rb') as f: header = f.read(100) return header[:16] == b'SQLite format 3\x00' url = copy(make_url(url)) database = url.database if url.drivername.startswith('postgres'): url.database = 'postgres' else: url.database = None engine = sa.create_engine(url) if engine.dialect.name == 'postgresql': text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database return bool(get_scalar_result(engine, text)) elif engine.dialect.name == 'mysql': text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA " "WHERE SCHEMA_NAME = '%s'" % database) return bool(get_scalar_result(engine, text)) elif engine.dialect.name == 'sqlite': if database: return database == ':memory:' or sqlite_file_exists(database) else: # The default SQLAlchemy database is in memory, # and :memory is not required, thus we should support that use-case return True else: engine.dispose() engine = None text = 'SELECT 1' try: url.database = database engine = sa.create_engine(url) result = engine.execute(text) result.close() return True except (ProgrammingError, OperationalError): return False finally: if engine is not None: engine.dispose() def create_database(url, encoding='utf8', template=None): """Issue the appropriate CREATE DATABASE statement. :param url: A SQLAlchemy engine URL. :param encoding: The encoding to create the database as. :param template: The name of the template from which to create the new database. At the moment only supported by PostgreSQL driver. To create a database, you can pass a simple URL that would have been passed to ``create_engine``. :: create_database('postgresql://postgres@localhost/name') You may also pass the url from an existing engine. :: create_database(engine.url) Has full support for mysql, postgres, and sqlite. In theory, other database engines should be supported. """ url = copy(make_url(url)) database = url.database if url.drivername.startswith('postgres'): url.database = 'postgres' elif url.drivername.startswith('mssql'): url.database = 'master' elif not url.drivername.startswith('sqlite'): url.database = None if url.drivername == 'mssql+pyodbc': engine = sa.create_engine(url, connect_args={'autocommit': True}) elif url.drivername == 'postgresql+pg8000': engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') else: engine = sa.create_engine(url) result_proxy = None if engine.dialect.name == 'postgresql': if engine.driver == 'psycopg2': from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT engine.raw_connection().set_isolation_level( ISOLATION_LEVEL_AUTOCOMMIT ) if not template: template = 'template1' text = "CREATE DATABASE {0} ENCODING '{1}' TEMPLATE {2}".format( quote(engine, database), encoding, quote(engine, template) ) result_proxy = engine.execute(text) elif engine.dialect.name == 'mysql': text = "CREATE DATABASE {0} CHARACTER SET = '{1}'".format( quote(engine, database), encoding ) result_proxy = engine.execute(text) elif engine.dialect.name == 'sqlite' and database != ':memory:': if database: engine.execute("CREATE TABLE DB(id int);") engine.execute("DROP TABLE DB;") else: text = 'CREATE DATABASE {0}'.format(quote(engine, database)) result_proxy = engine.execute(text) if result_proxy is not None: result_proxy.close() engine.dispose() def drop_database(url): """Issue the appropriate DROP DATABASE statement. :param url: A SQLAlchemy engine URL. Works similar to the :ref:`create_database` method in that both url text and a constructed url are accepted. :: drop_database('postgresql://postgres@localhost/name') drop_database(engine.url) """ url = copy(make_url(url)) database = url.database if url.drivername.startswith('postgres'): url.database = 'postgres' elif url.drivername.startswith('mssql'): url.database = 'master' elif not url.drivername.startswith('sqlite'): url.database = None if url.drivername == 'mssql+pyodbc': engine = sa.create_engine(url, connect_args={'autocommit': True}) elif url.drivername == 'postgresql+pg8000': engine = sa.create_engine(url, isolation_level='AUTOCOMMIT') else: engine = sa.create_engine(url) conn_resource = None if engine.dialect.name == 'sqlite' and database != ':memory:': if database: os.remove(database) elif engine.dialect.name == 'postgresql' and engine.driver == 'psycopg2': from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT connection = engine.connect() connection.connection.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) # Disconnect all users from the database we are dropping. version = connection.dialect.server_version_info pid_column = ( 'pid' if (version >= (9, 2)) else 'procpid' ) text = ''' SELECT pg_terminate_backend(pg_stat_activity.%(pid_column)s) FROM pg_stat_activity WHERE pg_stat_activity.datname = '%(database)s' AND %(pid_column)s <> pg_backend_pid(); ''' % {'pid_column': pid_column, 'database': database} connection.execute(text) # Drop the database. text = 'DROP DATABASE {0}'.format(quote(connection, database)) connection.execute(text) conn_resource = connection else: text = 'DROP DATABASE {0}'.format(quote(engine, database)) conn_resource = engine.execute(text) if conn_resource is not None: conn_resource.close() engine.dispose() sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/foreign_keys.py000066400000000000000000000243171360007755400255470ustar00rootroot00000000000000from collections import defaultdict from itertools import groupby import sqlalchemy as sa from sqlalchemy.exc import NoInspectionAvailable from sqlalchemy.orm import object_session from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table from ..query_chain import QueryChain from .database import has_index from .orm import get_column_key, get_mapper, get_tables def get_foreign_key_values(fk, obj): return dict( ( fk.constraint.columns.values()[index].key, getattr(obj, element.column.key) ) for index, element in enumerate(fk.constraint.elements) ) def group_foreign_keys(foreign_keys): """ Return a groupby iterator that groups given foreign keys by table. :param foreign_keys: a sequence of foreign keys :: foreign_keys = get_referencing_foreign_keys(User) for table, fks in group_foreign_keys(foreign_keys): # do something pass .. seealso:: :func:`get_referencing_foreign_keys` .. versionadded: 0.26.1 """ foreign_keys = sorted( foreign_keys, key=lambda key: key.constraint.table.name ) return groupby(foreign_keys, lambda key: key.constraint.table) def get_referencing_foreign_keys(mixed): """ Returns referencing foreign keys for given Table object or declarative class. :param mixed: SA Table object or SA declarative class :: get_referencing_foreign_keys(User) # set([ForeignKey('user.id')]) get_referencing_foreign_keys(User.__table__) This function also understands inheritance. This means it returns all foreign keys that reference any table in the class inheritance tree. Let's say you have three classes which use joined table inheritance, namely TextItem, Article and BlogPost with Article and BlogPost inheriting TextItem. :: # This will check all foreign keys that reference either article table # or textitem table. get_referencing_foreign_keys(Article) .. seealso:: :func:`get_tables` """ if isinstance(mixed, sa.Table): tables = [mixed] else: tables = get_tables(mixed) referencing_foreign_keys = set() for table in mixed.metadata.tables.values(): if table not in tables: for constraint in table.constraints: if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): for fk in constraint.elements: if any(fk.references(t) for t in tables): referencing_foreign_keys.add(fk) return referencing_foreign_keys def merge_references(from_, to, foreign_keys=None): """ Merge the references of an entity into another entity. Consider the following models:: class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(255)) def __repr__(self): return 'User(name=%r)' % self.name class BlogPost(self.Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.String(255)) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) Now lets add some data:: john = self.User(name='John') jack = self.User(name='Jack') post = self.BlogPost(title='Some title', author=john) post2 = self.BlogPost(title='Other title', author=jack) self.session.add_all([ john, jack, post, post2 ]) self.session.commit() If we wanted to merge all John's references to Jack it would be as easy as :: merge_references(john, jack) self.session.commit() post.author # User(name='Jack') post2.author # User(name='Jack') :param from_: an entity to merge into another entity :param to: an entity to merge another entity into :param foreign_keys: A sequence of foreign keys. By default this is None indicating all referencing foreign keys should be used. .. seealso: :func:`dependent_objects` .. versionadded: 0.26.1 """ if from_.__tablename__ != to.__tablename__: raise TypeError('The tables of given arguments do not match.') session = object_session(from_) foreign_keys = get_referencing_foreign_keys(from_) for fk in foreign_keys: old_values = get_foreign_key_values(fk, from_) new_values = get_foreign_key_values(fk, to) criteria = ( getattr(fk.constraint.table.c, key) == value for key, value in old_values.items() ) try: mapper = get_mapper(fk.constraint.table) except ValueError: query = ( fk.constraint.table .update() .where(sa.and_(*criteria)) .values(new_values) ) session.execute(query) else: ( session.query(mapper.class_) .filter_by(**old_values) .update( new_values, 'evaluate' ) ) def dependent_objects(obj, foreign_keys=None): """ Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates through all dependent objects for given SQLAlchemy object. Consider a User object is referenced in various articles and also in various orders. Getting all these dependent objects is as easy as:: from sqlalchemy_utils import dependent_objects dependent_objects(user) If you expect an object to have lots of dependent_objects it might be good to limit the results:: dependent_objects(user).limit(5) The common use case is checking for all restrict dependent objects before deleting parent object and inform the user if there are dependent objects with ondelete='RESTRICT' foreign keys. If this kind of checking is not used it will lead to nasty IntegrityErrors being raised. In the following example we delete given user if it doesn't have any foreign key restricted dependent objects:: from sqlalchemy_utils import get_referencing_foreign_keys user = session.query(User).get(some_user_id) deps = list( dependent_objects( user, ( fk for fk in get_referencing_foreign_keys(User) # On most databases RESTRICT is the default mode hence we # check for None values also if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) ) if deps: # Do something to inform the user pass else: session.delete(user) :param obj: SQLAlchemy declarative model object :param foreign_keys: A sequence of foreign keys to use for searching the dependent_objects for given object. By default this is None, indicating that all foreign keys referencing the object will be used. .. note:: This function does not support exotic mappers that use multiple tables .. seealso:: :func:`get_referencing_foreign_keys` .. seealso:: :func:`merge_references` .. versionadded: 0.26.0 """ if foreign_keys is None: foreign_keys = get_referencing_foreign_keys(obj) session = object_session(obj) chain = QueryChain([]) classes = obj.__class__._decl_class_registry for table, keys in group_foreign_keys(foreign_keys): keys = list(keys) for class_ in classes.values(): try: mapper = sa.inspect(class_) except NoInspectionAvailable: continue parent_mapper = mapper.inherits if ( table in mapper.tables and not (parent_mapper and table in parent_mapper.tables) ): query = session.query(class_).filter( sa.or_(*_get_criteria(keys, class_, obj)) ) chain.queries.append(query) return chain def _get_criteria(keys, class_, obj): criteria = [] visited_constraints = [] for key in keys: if key.constraint in visited_constraints: continue visited_constraints.append(key.constraint) subcriteria = [] for index, column in enumerate(key.constraint.columns): foreign_column = ( key.constraint.elements[index].column ) subcriteria.append( getattr(class_, get_column_key(class_, column)) == getattr( obj, sa.inspect(type(obj)) .get_property_by_column( foreign_column ).key ) ) criteria.append(sa.and_(*subcriteria)) return criteria def non_indexed_foreign_keys(metadata, engine=None): """ Finds all non indexed foreign keys from all tables of given MetaData. Very useful for optimizing postgresql database and finding out which foreign keys need indexes. :param metadata: MetaData object to inspect tables from """ reflected_metadata = MetaData() if metadata.bind is None and engine is None: raise Exception( 'Either pass a metadata object with bind or ' 'pass engine as a second parameter' ) constraints = defaultdict(list) for table_name in metadata.tables.keys(): table = Table( table_name, reflected_metadata, autoload=True, autoload_with=metadata.bind or engine ) for constraint in table.constraints: if not isinstance(constraint, ForeignKeyConstraint): continue if not has_index(constraint): constraints[table.name].append(constraint) return dict(constraints) def get_fk_constraint_for_columns(table, *columns): for constraint in table.constraints: if list(constraint.columns.values()) == list(columns): return constraint sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/mock.py000066400000000000000000000060321360007755400240060ustar00rootroot00000000000000import contextlib import datetime import inspect import re import six import sqlalchemy as sa def create_mock_engine(bind, stream=None): """Create a mock SQLAlchemy engine from the passed engine or bind URL. :param bind: A SQLAlchemy engine or bind URL to mock. :param stream: Render all DDL operations to the stream. """ if not isinstance(bind, six.string_types): bind_url = str(bind.url) else: bind_url = bind if stream is not None: def dump(sql, *args, **kwargs): class Compiler(type(sql._compiler(engine.dialect))): def visit_bindparam(self, bindparam, *args, **kwargs): return self.render_literal_value( bindparam.value, bindparam.type) def render_literal_value(self, value, type_): if isinstance(value, six.integer_types): return str(value) elif isinstance(value, (datetime.date, datetime.datetime)): return "'%s'" % value return super(Compiler, self).render_literal_value( value, type_) text = str(Compiler(engine.dialect, sql).process(sql)) text = re.sub(r'\n+', '\n', text) text = text.strip('\n').strip() stream.write('\n%s;' % text) else: def dump(*args, **kw): return None engine = sa.create_engine(bind_url, strategy='mock', executor=dump) return engine @contextlib.contextmanager def mock_engine(engine, stream=None): """Mocks out the engine specified in the passed bind expression. Note this function is meant for convenience and protected usage. Do NOT blindly pass user input to this function as it uses exec. :param engine: A python expression that represents the engine to mock. :param stream: Render all DDL operations to the stream. """ # Create a stream if not present. if stream is None: stream = six.moves.cStringIO() # Navigate the stack and find the calling frame that allows the # expression to execuate. for frame in inspect.stack()[1:]: try: frame = frame[0] expression = '__target = %s' % engine six.exec_(expression, frame.f_globals, frame.f_locals) target = frame.f_locals['__target'] break except Exception: pass else: raise ValueError('Not a valid python expression', engine) # Evaluate the expression and get the target engine. frame.f_locals['__mock'] = create_mock_engine(target, stream) # Replace the target with our mock. six.exec_('%s = __mock' % engine, frame.f_globals, frame.f_locals) # Give control back. yield stream # Put the target engine back. frame.f_locals['__target'] = target six.exec_('%s = __target' % engine, frame.f_globals, frame.f_locals) six.exec_('del __target', frame.f_globals, frame.f_locals) six.exec_('del __mock', frame.f_globals, frame.f_locals) sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/orm.py000066400000000000000000000643011360007755400236550ustar00rootroot00000000000000from collections import OrderedDict from functools import partial from inspect import isclass from operator import attrgetter import six import sqlalchemy as sa from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import mapperlib from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.exc import UnmappedInstanceError from sqlalchemy.orm.properties import ColumnProperty, RelationshipProperty from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp from ..utils import is_sequence def get_class_by_table(base, table, data=None): """ Return declarative class associated with given table. If no class is found this function returns `None`. If multiple classes were found (polymorphic cases) additional `data` parameter can be given to hint which class to return. :: class User(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) get_class_by_table(Base, User.__table__) # User class This function also supports models using single table inheritance. Additional data paratemer should be provided in these case. :: class Entity(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) type = sa.Column(sa.String) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'entity' } class User(Entity): __mapper_args__ = { 'polymorphic_identity': 'user' } # Entity class get_class_by_table(Base, Entity.__table__, {'type': 'entity'}) # User class get_class_by_table(Base, Entity.__table__, {'type': 'user'}) :param base: Declarative model base :param table: SQLAlchemy Table object :param data: Data row to determine the class in polymorphic scenarios :return: Declarative class or None. """ found_classes = set( c for c in base._decl_class_registry.values() if hasattr(c, '__table__') and c.__table__ is table ) if len(found_classes) > 1: if not data: raise ValueError( "Multiple declarative classes found for table '{0}'. " "Please provide data parameter for this function to be able " "to determine polymorphic scenarios.".format( table.name ) ) else: for cls in found_classes: mapper = sa.inspect(cls) polymorphic_on = mapper.polymorphic_on.name if polymorphic_on in data: if data[polymorphic_on] == mapper.polymorphic_identity: return cls raise ValueError( "Multiple declarative classes found for table '{0}'. Given " "data row does not match any polymorphic identity of the " "found classes.".format( table.name ) ) elif found_classes: return found_classes.pop() return None def get_type(expr): """ Return the associated type with given Column, InstrumentedAttribute, ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct. For constructs wrapping columns this is the column type. For relationships this function returns the relationship mapper class. :param expr: SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other similar SA construct. :: class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship(User) get_type(User.__table__.c.name) # sa.String() get_type(User.name) # sa.String() get_type(User.name.property) # sa.String() get_type(Article.author) # User .. versionadded: 0.30.9 """ if hasattr(expr, 'type'): return expr.type elif isinstance(expr, InstrumentedAttribute): expr = expr.property if isinstance(expr, ColumnProperty): return expr.columns[0].type elif isinstance(expr, RelationshipProperty): return expr.mapper.class_ raise TypeError("Couldn't inspect type.") def cast_if(expression, type_): """ Produce a CAST expression but only if given expression is not of given type already. Assume we have a model with two fields id (Integer) and name (String). :: import sqlalchemy as sa from sqlalchemy_utils import cast_if cast_if(User.id, sa.Integer) # "user".id cast_if(User.name, sa.String) # "user".name cast_if(User.id, sa.String) # CAST("user".id AS TEXT) This function supports scalar values as well. :: cast_if(1, sa.Integer) # 1 cast_if('text', sa.String) # 'text' cast_if(1, sa.String) # CAST(1 AS TEXT) :param expression: A SQL expression, such as a ColumnElement expression or a Python string which will be coerced into a bound literal value. :param type_: A TypeEngine class or instance indicating the type to which the CAST should apply. .. versionadded: 0.30.14 """ try: expr_type = get_type(expression) except TypeError: expr_type = expression check_type = type_().python_type else: check_type = type_ return ( sa.cast(expression, type_) if not isinstance(expr_type, check_type) else expression ) def get_column_key(model, column): """ Return the key for given column in given model. :param model: SQLAlchemy declarative model object :: class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column('_name', sa.String) get_column_key(User, User.__table__.c._name) # 'name' .. versionadded: 0.26.5 .. versionchanged: 0.27.11 Throws UnmappedColumnError instead of ValueError when no property was found for given column. This is consistent with how SQLAlchemy works. """ mapper = sa.inspect(model) try: return mapper.get_property_by_column(column).key except sa.orm.exc.UnmappedColumnError: for key, c in mapper.columns.items(): if c.name == column.name and c.table is column.table: return key raise sa.orm.exc.UnmappedColumnError( 'No column %s is configured on mapper %s...' % (column, mapper) ) def get_mapper(mixed): """ Return related SQLAlchemy Mapper for given SQLAlchemy object. :param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object :: from sqlalchemy_utils import get_mapper get_mapper(User) get_mapper(User()) get_mapper(User.__table__) get_mapper(User.__mapper__) get_mapper(sa.orm.aliased(User)) get_mapper(sa.orm.aliased(User.__table__)) Raises: ValueError: if multiple mappers were found for given argument .. versionadded: 0.26.1 """ if isinstance(mixed, sa.orm.query._MapperEntity): mixed = mixed.expr elif isinstance(mixed, sa.Column): mixed = mixed.table elif isinstance(mixed, sa.orm.query._ColumnEntity): mixed = mixed.expr if isinstance(mixed, sa.orm.Mapper): return mixed if isinstance(mixed, sa.orm.util.AliasedClass): return sa.inspect(mixed).mapper if isinstance(mixed, sa.sql.selectable.Alias): mixed = mixed.element if isinstance(mixed, AliasedInsp): return mixed.mapper if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): mixed = mixed.class_ if isinstance(mixed, sa.Table): mappers = [ mapper for mapper in mapperlib._mapper_registry if mixed in mapper.tables ] if len(mappers) > 1: raise ValueError( "Multiple mappers found for table '%s'." % mixed.name ) elif not mappers: raise ValueError( "Could not get mapper for table '%s'." % mixed.name ) else: return mappers[0] if not isclass(mixed): mixed = type(mixed) return sa.inspect(mixed) def get_bind(obj): """ Return the bind for given SQLAlchemy Engine / Connection / declarative model object. :param obj: SQLAlchemy Engine / Connection / declarative model object :: from sqlalchemy_utils import get_bind get_bind(session) # Connection object get_bind(user) """ if hasattr(obj, 'bind'): conn = obj.bind else: try: conn = object_session(obj).bind except UnmappedInstanceError: conn = obj if not hasattr(conn, 'execute'): raise TypeError( 'This method accepts only Session, Engine, Connection and ' 'declarative model objects.' ) return conn def get_primary_keys(mixed): """ Return an OrderedDict of all primary keys for given Table object, declarative class or declarative class instance. :param mixed: SA Table object, SA declarative class or SA declarative class instance :: get_primary_keys(User) get_primary_keys(User()) get_primary_keys(User.__table__) get_primary_keys(User.__mapper__) get_primary_keys(sa.orm.aliased(User)) get_primary_keys(sa.orm.aliased(User.__table__)) .. versionchanged: 0.25.3 Made the function return an ordered dictionary instead of generator. This change was made to support primary key aliases. Renamed this function to 'get_primary_keys', formerly 'primary_keys' .. seealso:: :func:`get_columns` """ return OrderedDict( ( (key, column) for key, column in get_columns(mixed).items() if column.primary_key ) ) def get_tables(mixed): """ Return a set of tables associated with given SQLAlchemy object. Let's say we have three classes which use joined table inheritance TextItem, Article and BlogPost. Article and BlogPost inherit TextItem. :: get_tables(Article) # set([Table('article', ...), Table('text_item')]) get_tables(Article()) get_tables(Article.__mapper__) If the TextItem entity is using with_polymorphic='*' then this function returns all child tables (article and blog_post) as well. :: get_tables(TextItem) # set([Table('text_item', ...)], ...]) .. versionadded: 0.26.0 :param mixed: SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or a SA Alias object wrapping any of these objects. """ if isinstance(mixed, sa.Table): return [mixed] elif isinstance(mixed, sa.Column): return [mixed.table] elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): return mixed.parent.tables elif isinstance(mixed, sa.orm.query._ColumnEntity): mixed = mixed.expr mapper = get_mapper(mixed) polymorphic_mappers = get_polymorphic_mappers(mapper) if polymorphic_mappers: tables = sum((m.tables for m in polymorphic_mappers), []) else: tables = mapper.tables return tables def get_columns(mixed): """ Return a collection of all Column objects for given SQLAlchemy object. The type of the collection depends on the type of the object to return the columns from. :: get_columns(User) get_columns(User()) get_columns(User.__table__) get_columns(User.__mapper__) get_columns(sa.orm.aliased(User)) get_columns(sa.orm.alised(User.__table__)) :param mixed: SA Table object, SA Mapper, SA declarative class, SA declarative class instance or an alias of any of these objects """ if isinstance(mixed, sa.sql.selectable.Selectable): return mixed.c if isinstance(mixed, sa.orm.util.AliasedClass): return sa.inspect(mixed).mapper.columns if isinstance(mixed, sa.orm.Mapper): return mixed.columns if isinstance(mixed, InstrumentedAttribute): return mixed.property.columns if isinstance(mixed, ColumnProperty): return mixed.columns if isinstance(mixed, sa.Column): return [mixed] if not isclass(mixed): mixed = mixed.__class__ return sa.inspect(mixed).columns def table_name(obj): """ Return table name of given target, declarative class or the table name where the declarative attribute is bound to. """ class_ = getattr(obj, 'class_', obj) try: return class_.__tablename__ except AttributeError: pass try: return class_.__table__.name except AttributeError: pass def getattrs(obj, attrs): return map(partial(getattr, obj), attrs) def quote(mixed, ident): """ Conditionally quote an identifier. :: from sqlalchemy_utils import quote engine = create_engine('sqlite:///:memory:') quote(engine, 'order') # '"order"' quote(engine, 'some_other_identifier') # 'some_other_identifier' :param mixed: SQLAlchemy Session / Connection / Engine / Dialect object. :param ident: identifier to conditionally quote """ if isinstance(mixed, Dialect): dialect = mixed else: dialect = get_bind(mixed).dialect return dialect.preparer(dialect).quote(ident) def query_labels(query): """ Return all labels for given SQLAlchemy query object. Example:: query = session.query( Category, db.func.count(Article.id).label('articles') ) query_labels(query) # ['articles'] :param query: SQLAlchemy Query object """ return [ entity._label_name for entity in query._entities if isinstance(entity, _ColumnEntity) and entity._label_name ] def get_query_entities(query): """ Return a list of all entities present in given SQLAlchemy query object. Examples:: from sqlalchemy_utils import get_query_entities query = session.query(Category) get_query_entities(query) # [] query = session.query(Category.id) get_query_entities(query) # [] This function also supports queries with joins. :: query = session.query(Category).join(Article) get_query_entities(query) # [,
] .. versionchanged: 0.26.7 This function now returns a list instead of generator :param query: SQLAlchemy Query object """ exprs = [ d['expr'] if is_labeled_query(d['expr']) or isinstance(d['expr'], sa.Column) else d['entity'] for d in query.column_descriptions ] return [ get_query_entity(expr) for expr in exprs ] + [ get_query_entity(entity) for entity in query._join_entities ] def is_labeled_query(expr): return ( isinstance(expr, sa.sql.elements.Label) and isinstance( list(expr.base_columns)[0], (sa.sql.selectable.Select, sa.sql.selectable.ScalarSelect) ) ) def get_query_entity(expr): if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): return expr.parent.class_ elif isinstance(expr, sa.Column): return expr.table elif isinstance(expr, AliasedInsp): return expr.entity return expr def get_query_entity_by_alias(query, alias): entities = get_query_entities(query) if not alias: return entities[0] for entity in entities: if isinstance(entity, sa.orm.util.AliasedClass): name = sa.inspect(entity).name else: name = get_mapper(entity).tables[0].name if name == alias: return entity def get_polymorphic_mappers(mixed): if isinstance(mixed, AliasedInsp): return mixed.with_polymorphic_mappers else: return mixed.polymorphic_map.values() def get_query_descriptor(query, entity, attr): if attr in query_labels(query): return attr else: entity = get_query_entity_by_alias(query, entity) if entity: descriptor = get_descriptor(entity, attr) if ( hasattr(descriptor, 'property') and isinstance(descriptor.property, sa.orm.RelationshipProperty) ): return return descriptor def get_descriptor(entity, attr): mapper = sa.inspect(entity) for key, descriptor in get_all_descriptors(mapper).items(): if attr == key: prop = ( descriptor.property if hasattr(descriptor, 'property') else None ) if isinstance(prop, ColumnProperty): if isinstance(entity, sa.orm.util.AliasedClass): for c in mapper.selectable.c: if c.key == attr: return c else: # If the property belongs to a class that uses # polymorphic inheritance we have to take into account # situations where the attribute exists in child class # but not in parent class. return getattr(prop.parent.class_, attr) else: # Handle synonyms, relationship properties and hybrid # properties if isinstance(entity, sa.orm.util.AliasedClass): return getattr(entity, attr) try: return getattr(mapper.class_, attr) except AttributeError: pass def get_all_descriptors(expr): if isinstance(expr, sa.sql.selectable.Selectable): return expr.c insp = sa.inspect(expr) try: polymorphic_mappers = get_polymorphic_mappers(insp) except sa.exc.NoInspectionAvailable: return get_mapper(expr).all_orm_descriptors else: attrs = dict(get_mapper(expr).all_orm_descriptors) for submapper in polymorphic_mappers: for key, descriptor in submapper.all_orm_descriptors.items(): if key not in attrs: attrs[key] = descriptor return attrs def get_hybrid_properties(model): """ Returns a dictionary of hybrid property keys and hybrid properties for given SQLAlchemy declarative model / mapper. Consider the following model :: from sqlalchemy.ext.hybrid import hybrid_property class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @hybrid_property def lowercase_name(self): return self.name.lower() @lowercase_name.expression def lowercase_name(cls): return sa.func.lower(cls.name) You can now easily get a list of all hybrid property names :: from sqlalchemy_utils import get_hybrid_properties get_hybrid_properties(Category).keys() # ['lowercase_name'] This function also supports aliased classes :: get_hybrid_properties( sa.orm.aliased(Category) ).keys() # ['lowercase_name'] .. versionchanged: 0.26.7 This function now returns a dictionary instead of generator .. versionchanged: 0.30.15 Added support for aliased classes :param model: SQLAlchemy declarative model or mapper """ return dict( (key, prop) for key, prop in get_mapper(model).all_orm_descriptors.items() if isinstance(prop, hybrid_property) ) def get_declarative_base(model): """ Returns the declarative base for given model class. :param model: SQLAlchemy declarative model """ for parent in model.__bases__: try: parent.metadata return get_declarative_base(parent) except AttributeError: pass return model def getdotattr(obj_or_class, dot_path, condition=None): """ Allow dot-notated strings to be passed to `getattr`. :: getdotattr(SubSection, 'section.document') getdotattr(subsection, 'section.document') :param obj_or_class: Any object or class :param dot_path: Attribute path with dot mark as separator """ last = obj_or_class for path in str(dot_path).split('.'): getter = attrgetter(path) if is_sequence(last): tmp = [] for element in last: value = getter(element) if is_sequence(value): tmp.extend(value) else: tmp.append(value) last = tmp elif isinstance(last, InstrumentedAttribute): last = getter(last.property.mapper.class_) elif last is None: return None else: last = getter(last) if condition is not None: if is_sequence(last): last = [v for v in last if condition(v)] else: if not condition(last): return None return last def is_deleted(obj): return obj in sa.orm.object_session(obj).deleted def has_changes(obj, attrs=None, exclude=None): """ Simple shortcut function for checking if given attributes of given declarative model object have changed during the session. Without parameters this checks if given object has any modificiations. Additionally exclude parameter can be given to check if given object has any changes in any attributes other than the ones given in exclude. :: from sqlalchemy_utils import has_changes user = User() has_changes(user, 'name') # False user.name = u'someone' has_changes(user, 'name') # True has_changes(user) # True You can check multiple attributes as well. :: has_changes(user, ['age']) # True has_changes(user, ['name', 'age']) # True This function also supports excluding certain attributes. :: has_changes(user, exclude=['name']) # False has_changes(user, exclude=['age']) # True .. versionchanged: 0.26.6 Added support for multiple attributes and exclude parameter. :param obj: SQLAlchemy declarative model object :param attrs: Names of the attributes :param exclude: Names of the attributes to exclude """ if attrs: if isinstance(attrs, six.string_types): return ( sa.inspect(obj) .attrs .get(attrs) .history .has_changes() ) else: return any(has_changes(obj, attr) for attr in attrs) else: if exclude is None: exclude = [] return any( attr.history.has_changes() for key, attr in sa.inspect(obj).attrs.items() if key not in exclude ) def is_loaded(obj, prop): """ Return whether or not given property of given object has been loaded. :: class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) content = sa.orm.deferred(sa.Column(sa.String)) article = session.query(Article).get(5) # name gets loaded since its not a deferred property assert is_loaded(article, 'name') # content has not yet been loaded since its a deferred property assert not is_loaded(article, 'content') .. versionadded: 0.27.8 :param obj: SQLAlchemy declarative model object :param prop: Name of the property or InstrumentedAttribute """ return not isinstance( getattr(sa.inspect(obj).attrs, prop).loaded_value, sa.util.langhelpers._symbol ) def identity(obj_or_class): """ Return the identity of given sqlalchemy declarative model class or instance as a tuple. This differs from obj._sa_instance_state.identity in a way that it always returns the identity even if object is still in transient state ( new object that is not yet persisted into database). Also for classes it returns the identity attributes. :: from sqlalchemy import inspect from sqlalchemy_utils import identity user = User(name=u'John Matrix') session.add(user) identity(user) # None inspect(user).identity # None session.flush() # User now has id but is still in transient state identity(user) # (1,) inspect(user).identity # None session.commit() identity(user) # (1,) inspect(user).identity # (1, ) You can also use identity for classes:: identity(User) # (User.id, ) .. versionadded: 0.21.0 :param obj: SQLAlchemy declarative model object """ return tuple( getattr(obj_or_class, column_key) for column_key in get_primary_keys(obj_or_class).keys() ) def naturally_equivalent(obj, obj2): """ Returns whether or not two given SQLAlchemy declarative instances are naturally equivalent (all their non primary key properties are equivalent). :: from sqlalchemy_utils import naturally_equivalent user = User(name=u'someone') user2 = User(name=u'someone') user == user2 # False naturally_equivalent(user, user2) # True :param obj: SQLAlchemy declarative model object :param obj2: SQLAlchemy declarative model object to compare with `obj` """ for column_key, column in sa.inspect(obj.__class__).columns.items(): if column.primary_key: continue if not (getattr(obj, column_key) == getattr(obj2, column_key)): return False return True sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/render.py000066400000000000000000000036751360007755400243460ustar00rootroot00000000000000import inspect import six import sqlalchemy as sa from .mock import create_mock_engine def render_expression(expression, bind, stream=None): """Generate a SQL expression from the passed python expression. Only the global variable, `engine`, is available for use in the expression. Additional local variables may be passed in the context parameter. Note this function is meant for convenience and protected usage. Do NOT blindly pass user input to this function as it uses exec. :param bind: A SQLAlchemy engine or bind URL. :param stream: Render all DDL operations to the stream. """ # Create a stream if not present. if stream is None: stream = six.moves.cStringIO() engine = create_mock_engine(bind, stream) # Navigate the stack and find the calling frame that allows the # expression to execuate. for frame in inspect.stack()[1:]: try: frame = frame[0] local = dict(frame.f_locals) local['engine'] = engine six.exec_(expression, frame.f_globals, local) break except Exception: pass else: raise ValueError('Not a valid python expression', engine) return stream def render_statement(statement, bind=None): """ Generate an SQL expression string with bound parameters rendered inline for the given SQLAlchemy statement. :param statement: SQLAlchemy Query object. :param bind: Optional SQLAlchemy bind, if None uses the bind of the given query object. """ if isinstance(statement, sa.orm.query.Query): if bind is None: bind = statement.session.get_bind(statement._mapper_zero()) statement = statement.statement elif bind is None: bind = statement.bind stream = six.moves.cStringIO() engine = create_mock_engine(bind.engine, stream=stream) engine.execute(statement) return stream.getvalue() sqlalchemy-utils-0.36.1/sqlalchemy_utils/functions/sort_query.py000066400000000000000000000125271360007755400252770ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.sql.expression import asc, desc from .database import has_unique_index from .orm import get_query_descriptor, get_tables class QuerySorterException(Exception): pass class QuerySorter(object): def __init__(self, silent=True, separator='-'): self.separator = separator self.silent = silent def assign_order_by(self, entity, attr, func): expr = get_query_descriptor(self.query, entity, attr) if expr is not None: return self.query.order_by(func(expr)) if not self.silent: raise QuerySorterException( "Could not sort query with expression '%s'" % attr ) return self.query def parse_sort_arg(self, arg): if arg[0] == self.separator: func = desc arg = arg[1:] else: func = asc parts = arg.split(self.separator) return { 'entity': parts[0] if len(parts) > 1 else None, 'attr': parts[1] if len(parts) > 1 else arg, 'func': func } def __call__(self, query, *args): self.query = query for sort in args: if not sort: continue self.query = self.assign_order_by( **self.parse_sort_arg(sort) ) return self.query def sort_query(query, *args, **kwargs): """ Applies an sql ORDER BY for given query. This function can be easily used with user-defined sorting. The examples use the following model definition: :: import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import sort_query engine = create_engine( 'sqlite:///' ) Base = declarative_base() Session = sessionmaker(bind=engine) session = Session() class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship( Category, primaryjoin=category_id == Category.id ) 1. Applying simple ascending sort :: query = session.query(Article) query = sort_query(query, 'name') 2. Applying descending sort :: query = sort_query(query, '-name') 3. Applying sort to custom calculated label :: query = session.query( Category, sa.func.count(Article.id).label('articles') ) query = sort_query(query, 'articles') 4. Applying sort to joined table column :: query = session.query(Article).join(Article.category) query = sort_query(query, 'category-name') :param query: query to be modified :param sort: string that defines the label or column to sort the query by :param silent: Whether or not to raise exceptions if unknown sort column is passed. By default this is `True` indicating that no errors should be raised for unknown columns. """ return QuerySorter(**kwargs)(query, *args) def make_order_by_deterministic(query): """ Make query order by deterministic (if it isn't already). Order by is considered deterministic if it contains column that is unique index ( either it is a primary key or has a unique index). Many times it is design flaw to order by queries in nondeterministic manner. Consider a User model with three fields: id (primary key), favorite color and email (unique).:: from sqlalchemy_utils import make_order_by_deterministic query = session.query(User).order_by(User.favorite_color) query = make_order_by_deterministic(query) print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id' query = session.query(User).order_by(User.email) query = make_order_by_deterministic(query) print query # 'SELECT ... ORDER BY "user".email' query = session.query(User).order_by(User.id) query = make_order_by_deterministic(query) print query # 'SELECT ... ORDER BY "user".id' .. versionadded: 0.27.1 """ order_by_func = sa.asc if not query._order_by: column = None else: order_by = query._order_by[0] if isinstance(order_by, sa.sql.expression.UnaryExpression): if order_by.modifier == sa.sql.operators.desc_op: order_by_func = sa.desc else: order_by_func = sa.asc column = order_by.get_children()[0] else: column = order_by # Skip queries that are ordered by an already deterministic column if isinstance(column, sa.Column): try: if has_unique_index(column): return query except TypeError: pass base_table = get_tables(query._entities[0])[0] query = query.order_by( *(order_by_func(c) for c in base_table.c if c.primary_key) ) return query sqlalchemy-utils-0.36.1/sqlalchemy_utils/generic.py000066400000000000000000000143701360007755400224650ustar00rootroot00000000000000try: from collections.abc import Iterable except ImportError: # For python 2.7 support from collections import Iterable import six import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import attributes, class_mapper, ColumnProperty from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.util import set_creation_order from .exceptions import ImproperlyConfigured from .functions import identity class GenericAttributeImpl(attributes.ScalarAttributeImpl): def get(self, state, dict_, passive=attributes.PASSIVE_OFF): if self.key in dict_: return dict_[self.key] # Retrieve the session bound to the state in order to perform # a lazy query for the attribute. session = _state_session(state) if session is None: # State is not bound to a session; we cannot proceed. return None # Find class for discriminator. # TODO: Perhaps optimize with some sort of lookup? discriminator = self.get_state_discriminator(state) target_class = state.class_._decl_class_registry.get(discriminator) if target_class is None: # Unknown discriminator; return nothing. return None id = self.get_state_id(state) target = session.query(target_class).get(id) # Return found (or not found) target. return target def get_state_discriminator(self, state): discriminator = self.parent_token.discriminator if isinstance(discriminator, hybrid_property): return getattr(state.obj(), discriminator.__name__) else: return state.attrs[discriminator.key].value def get_state_id(self, state): # Lookup row with the discriminator and id. return tuple(state.attrs[id.key].value for id in self.parent_token.id) def set(self, state, dict_, initiator, passive=attributes.PASSIVE_OFF, check_old=None, pop=False): # Set us on the state. dict_[self.key] = initiator if initiator is None: # Nullify relationship args for id in self.parent_token.id: dict_[id.key] = None dict_[self.parent_token.discriminator.key] = None else: # Get the primary key of the initiator and ensure we # can support this assignment. class_ = type(initiator) mapper = class_mapper(class_) pk = mapper.identity_key_from_instance(initiator)[1] # Set the identifier and the discriminator. discriminator = six.text_type(class_.__name__) for index, id in enumerate(self.parent_token.id): dict_[id.key] = pk[index] dict_[self.parent_token.discriminator.key] = discriminator class GenericRelationshipProperty(MapperProperty): """A generic form of the relationship property. Creates a 1 to many relationship between the parent model and any other models using a descriminator (the table name). :param discriminator Field to discriminate which model we are referring to. :param id: Field to point to the model we are referring to. """ def __init__(self, discriminator, id, doc=None): super(GenericRelationshipProperty, self).__init__() self._discriminator_col = discriminator self._id_cols = id self._id = None self._discriminator = None self.doc = doc set_creation_order(self) def _column_to_property(self, column): if isinstance(column, hybrid_property): attr_key = column.__name__ for key, attr in self.parent.all_orm_descriptors.items(): if key == attr_key: return attr else: for key, attr in self.parent.attrs.items(): if isinstance(attr, ColumnProperty): if attr.columns[0].name == column.name: return attr def init(self): def convert_strings(column): if isinstance(column, six.string_types): return self.parent.columns[column] return column self._discriminator_col = convert_strings(self._discriminator_col) self._id_cols = convert_strings(self._id_cols) if isinstance(self._id_cols, Iterable): self._id_cols = list(map(convert_strings, self._id_cols)) else: self._id_cols = [self._id_cols] self.discriminator = self._column_to_property(self._discriminator_col) if self.discriminator is None: raise ImproperlyConfigured( 'Could not find discriminator descriptor.' ) self.id = list(map(self._column_to_property, self._id_cols)) class Comparator(PropComparator): def __init__(self, prop, parentmapper): self.property = prop self._parententity = parentmapper def __eq__(self, other): discriminator = six.text_type(type(other).__name__) q = self.property._discriminator_col == discriminator other_id = identity(other) for index, id in enumerate(self.property._id_cols): q &= id == other_id[index] return q def __ne__(self, other): return ~(self == other) def is_type(self, other): mapper = sa.inspect(other) # Iterate through the weak sequence in order to get the actual # mappers class_names = [six.text_type(other.__name__)] class_names.extend([ six.text_type(submapper.class_.__name__) for submapper in mapper._inheriting_mappers ]) return self.property._discriminator_col.in_(class_names) def instrument_class(self, mapper): attributes.register_attribute( mapper.class_, self.key, comparator=self.Comparator(self, mapper), parententity=mapper, doc=self.doc, impl_class=GenericAttributeImpl, parent_token=self ) def generic_relationship(*args, **kwargs): return GenericRelationshipProperty(*args, **kwargs) sqlalchemy-utils-0.36.1/sqlalchemy_utils/i18n.py000066400000000000000000000104531360007755400216260ustar00rootroot00000000000000import inspect import six import sqlalchemy as sa from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.sql.expression import ColumnElement from .exceptions import ImproperlyConfigured try: import babel import babel.dates except ImportError: babel = None def get_locale(): try: return babel.Locale('en') except AttributeError: # As babel is optional, we may raise an AttributeError accessing it raise ImproperlyConfigured( 'Could not load get_locale function using Babel. Either ' 'install Babel or make a similar function and override it ' 'in this module.' ) if six.PY2: def get_args_count(func): if ( callable(func) and not inspect.isfunction(func) and not inspect.ismethod(func) ): func = func.__call__ args = inspect.getargspec(func).args return len(args) - 1 if inspect.ismethod(func) else len(args) else: def get_args_count(func): return len(inspect.signature(func).parameters) def cast_locale(obj, locale, attr): """ Cast given locale to string. Supports also callbacks that return locales. :param obj: Object or class to use as a possible parameter to locale callable :param locale: Locale object or string or callable that returns a locale. """ if callable(locale): args_count = get_args_count(locale) if args_count == 0: locale = locale() elif args_count == 1: locale = locale(obj) elif args_count == 2: locale = locale(obj, attr.key) if isinstance(locale, babel.Locale): return str(locale) return locale class cast_locale_expr(ColumnElement): def __init__(self, cls, locale, attr): self.cls = cls self.locale = locale self.attr = attr @compiles(cast_locale_expr) def compile_cast_locale_expr(element, compiler, **kw): locale = cast_locale(element.cls, element.locale, element.attr) if isinstance(locale, six.string_types): return "'{0}'".format(locale) return compiler.process(locale) class TranslationHybrid(object): def __init__(self, current_locale, default_locale, default_value=None): if babel is None: raise ImproperlyConfigured( 'You need to install babel in order to use TranslationHybrid.' ) self.current_locale = current_locale self.default_locale = default_locale self.default_value = default_value def getter_factory(self, attr): """ Return a hybrid_property getter function for given attribute. The returned getter first checks if object has translation for current locale. If not it tries to get translation for default locale. If there is no translation found for default locale it returns None. """ def getter(obj): current_locale = cast_locale(obj, self.current_locale, attr) try: return getattr(obj, attr.key)[current_locale] except (TypeError, KeyError): default_locale = cast_locale(obj, self.default_locale, attr) try: return getattr(obj, attr.key)[default_locale] except (TypeError, KeyError): return self.default_value return getter def setter_factory(self, attr): def setter(obj, value): if getattr(obj, attr.key) is None: setattr(obj, attr.key, {}) locale = cast_locale(obj, self.current_locale, attr) getattr(obj, attr.key)[locale] = value return setter def expr_factory(self, attr): def expr(cls): cls_attr = getattr(cls, attr.key) current_locale = cast_locale_expr(cls, self.current_locale, attr) default_locale = cast_locale_expr(cls, self.default_locale, attr) return sa.func.coalesce( cls_attr[current_locale], cls_attr[default_locale] ) return expr def __call__(self, attr): return hybrid_property( fget=self.getter_factory(attr), fset=self.setter_factory(attr), expr=self.expr_factory(attr) ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/listeners.py000066400000000000000000000164751360007755400230710ustar00rootroot00000000000000import sqlalchemy as sa from .exceptions import ImproperlyConfigured def coercion_listener(mapper, class_): """ Auto assigns coercing listener for all class properties which are of coerce capable type. """ for prop in mapper.iterate_properties: try: listener = prop.columns[0].type.coercion_listener except AttributeError: continue sa.event.listen( getattr(class_, prop.key), 'set', listener, retval=True ) def instant_defaults_listener(target, args, kwargs): for key, column in sa.inspect(target.__class__).columns.items(): if hasattr(column, 'default') and column.default is not None: if callable(column.default.arg): setattr(target, key, column.default.arg(target)) else: setattr(target, key, column.default.arg) def force_auto_coercion(mapper=None): """ Function that assigns automatic data type coercion for all classes which are of type of given mapper. The coercion is applied to all coercion capable properties. By default coercion is applied to all SQLAlchemy mappers. Before initializing your models you need to call force_auto_coercion. :: from sqlalchemy_utils import force_auto_coercion force_auto_coercion() Then define your models the usual way:: class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(50)) background_color = sa.Column(ColorType) Now scalar values for coercion capable data types will convert to appropriate value objects:: document = Document() document.background_color = 'F5F5F5' document.background_color # Color object session.commit() A useful side-effect of this is that additional validation of data will be done on the moment it is being assigned to model objects. For example without auto coerction set, an invalid :class:`sqlalchemy_utils.types.IPAddressType` (eg. ``10.0.0 255.255``) would get through without an exception being raised. The database wouldn't notice this (as most databases don't have a native type for an IP address, so they're usually just stored as a string), and the ``ipaddress/ipaddr`` package uses a string field as well. :param mapper: The mapper which the automatic data type coercion should be applied to """ if mapper is None: mapper = sa.orm.mapper sa.event.listen(mapper, 'mapper_configured', coercion_listener) def force_instant_defaults(mapper=None): """ Function that assigns object column defaults on object initialization time. By default calling this function applies instant defaults to all your models. Setting up instant defaults:: from sqlalchemy_utils import force_instant_defaults force_instant_defaults() Example usage:: class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(50)) created_at = sa.Column(sa.DateTime, default=datetime.now) document = Document() document.created_at # datetime object :param mapper: The mapper which the automatic instant defaults forcing should be applied to """ if mapper is None: mapper = sa.orm.mapper sa.event.listen(mapper, 'init', instant_defaults_listener) def auto_delete_orphans(attr): """ Delete orphans for given SQLAlchemy model attribute. This function can be used for deleting many-to-many associated orphans easily. For more information see https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/ManyToManyOrphan. Consider the following model definition: :: from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy import * from sqlalchemy.orm import * from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import event Base = declarative_base() tagging = Table( 'tagging', Base.metadata, Column( 'tag_id', Integer, ForeignKey('tag.id', ondelete='CASCADE'), primary_key=True ), Column( 'entry_id', Integer, ForeignKey('entry.id', ondelete='CASCADE'), primary_key=True ) ) class Tag(Base): __tablename__ = 'tag' id = Column(Integer, primary_key=True) name = Column(String(100), unique=True, nullable=False) def __init__(self, name=None): self.name = name class Entry(Base): __tablename__ = 'entry' id = Column(Integer, primary_key=True) tags = relationship( 'Tag', secondary=tagging, backref='entries' ) Now lets say we want to delete the tags if all their parents get deleted ( all Entry objects get deleted). This can be achieved as follows: :: from sqlalchemy_utils import auto_delete_orphans auto_delete_orphans(Entry.tags) After we've set up this listener we can see it in action. :: e = create_engine('sqlite://') Base.metadata.create_all(e) s = Session(e) r1 = Entry() r2 = Entry() r3 = Entry() t1, t2, t3, t4 = Tag('t1'), Tag('t2'), Tag('t3'), Tag('t4') r1.tags.extend([t1, t2]) r2.tags.extend([t2, t3]) r3.tags.extend([t4]) s.add_all([r1, r2, r3]) assert s.query(Tag).count() == 4 r2.tags.remove(t2) assert s.query(Tag).count() == 4 r1.tags.remove(t2) assert s.query(Tag).count() == 3 r1.tags.remove(t1) assert s.query(Tag).count() == 2 .. versionadded: 0.26.4 :param attr: Association relationship attribute to auto delete orphans from """ parent_class = attr.parent.class_ target_class = attr.property.mapper.class_ backref = attr.property.backref if not backref: raise ImproperlyConfigured( 'The relationship argument given for auto_delete_orphans needs to ' 'have a backref relationship set.' ) if isinstance(backref, tuple): backref = backref[0] @sa.event.listens_for(sa.orm.Session, 'after_flush') def delete_orphan_listener(session, ctx): # Look through Session state to see if we want to emit a DELETE for # orphans orphans_found = ( any( isinstance(obj, parent_class) and sa.orm.attributes.get_history(obj, attr.key).deleted for obj in session.dirty ) or any( isinstance(obj, parent_class) for obj in session.deleted ) ) if orphans_found: # Emit a DELETE for all orphans ( session.query(target_class) .filter( ~getattr(target_class, backref).any() ) .delete(synchronize_session=False) ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/models.py000066400000000000000000000056121360007755400223330ustar00rootroot00000000000000from datetime import datetime import sqlalchemy as sa from sqlalchemy.util.langhelpers import symbol class Timestamp(object): """Adds `created` and `updated` columns to a derived declarative model. The `created` column is handled through a default and the `updated` column is handled through a `before_update` event that propagates for all derived declarative models. :: import sqlalchemy as sa from sqlalchemy_utils import Timestamp class SomeModel(Base, Timestamp): __tablename__ = 'somemodel' id = sa.Column(sa.Integer, primary_key=True) """ created = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False) updated = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False) @sa.event.listens_for(Timestamp, 'before_update', propagate=True) def timestamp_before_update(mapper, connection, target): # When a model with a timestamp is updated; force update the updated # timestamp. target.updated = datetime.utcnow() NO_VALUE = symbol('NO_VALUE') NOT_LOADED_REPR = '' def _generic_repr_method(self, fields): state = sa.inspect(self) field_reprs = [] if not fields: fields = state.mapper.columns.keys() for key in fields: value = state.attrs[key].loaded_value if value == NO_VALUE: value = NOT_LOADED_REPR else: value = repr(value) field_reprs.append('='.join((key, value))) return '%s(%s)' % (self.__class__.__name__, ', '.join(field_reprs)) def generic_repr(*fields): """Adds generic ``__repr__()`` method to a declarative SQLAlchemy model. In case if some fields are not loaded from a database, it doesn't force their loading and instead repesents them as ````. In addition, user can provide field names as arguments to the decorator to specify what fields should present in the string representation and in what order. Example:: import sqlalchemy as sa from sqlalchemy_utils import generic_repr @generic_repr class MyModel(Base): __tablename__ = 'mymodel' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) category = sa.Column(sa.String) session.add(MyModel(name='Foo', category='Bar')) session.commit() foo = session.query(MyModel).options(sa.orm.defer('category')).one(s) assert repr(foo) == 'MyModel(id=1, name='Foo', category=)' """ if len(fields) == 1 and callable(fields[0]): target = fields[0] target.__repr__ = lambda self: _generic_repr_method(self, fields=None) return target else: def decorator(cls): cls.__repr__ = lambda self: _generic_repr_method( self, fields=fields ) return cls return decorator sqlalchemy-utils-0.36.1/sqlalchemy_utils/observer.py000066400000000000000000000276651360007755400227130ustar00rootroot00000000000000""" This module provides a decorator function for observing changes in a given property. Internally the decorator is implemented using SQLAlchemy event listeners. Both column properties and relationship properties can be observed. Property observers can be used for pre-calculating aggregates and automatic real-time data denormalization. Simple observers ---------------- At the heart of the observer extension is the :func:`observes` decorator. You mark some property path as being observed and the marked method will get notified when any changes are made to given path. Consider the following model structure: :: class Director(Base): __tablename__ = 'director' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) date_of_birth = sa.Column(sa.Date) class Movie(Base): __tablename__ = 'movie' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id)) director = sa.orm.relationship(Director, backref='movies') Now consider we want to show movies in some listing ordered by director id first and movie id secondly. If we have many movies then using joins and ordering by Director.name will be very slow. Here is where denormalization and :func:`observes` comes to rescue the day. Let's add a new column called director_name to Movie which will get automatically copied from associated Director. :: from sqlalchemy_utils import observes class Movie(Base): # same as before.. director_name = sa.Column(sa.String) @observes('director') def director_observer(self, director): self.director_name = director.name .. note:: This example could be done much more efficiently using a compound foreign key from director_name, director_id to Director.name, Director.id but for the sake of simplicity we added this as an example. Observes vs aggregated ---------------------- :func:`observes` and :func:`.aggregates.aggregated` can be used for similar things. However performance wise you should take the following things into consideration: * :func:`observes` works always inside transaction and deals with objects. If the relationship observer is observing has a large number of objects it's better to use :func:`.aggregates.aggregated`. * :func:`.aggregates.aggregated` always executes one additional query per aggregate so in scenarios where the observed relationship has only a handful of objects it's better to use :func:`observes` instead. Example 1. Movie with many ratings Let's say we have a Movie object with potentially thousands of ratings. In this case we should always use :func:`.aggregates.aggregated` since iterating through thousands of objects is slow and very memory consuming. Example 2. Product with denormalized catalog name Each product belongs to one catalog. Here it is natural to use :func:`observes` for data denormalization. Deeply nested observing ----------------------- Consider the following model structure where Catalog has many Categories and Category has many Products. :: class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) product_count = sa.Column(sa.Integer, default=0) @observes('categories.products') def product_observer(self, products): self.product_count = len(products) categories = sa.orm.relationship('Category', backref='catalog') class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) products = sa.orm.relationship('Product', backref='category') class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) :func:`observes` is smart enough to: * Notify catalog objects of any changes in associated Product objects * Notify catalog objects of any changes in Category objects that affect products (for example if Category gets deleted, or a new Category is added to Catalog with any number of Products) :: category = Category( products=[Product(), Product()] ) category2 = Category( product=[Product()] ) catalog = Catalog( categories=[category, category2] ) session.add(catalog) session.commit() catalog.product_count # 2 session.delete(category) session.commit() catalog.product_count # 1 Observing multiple columns ----------------------- You can also observe multiple columns by specifying all the observable columns in the decorator. :: class Order(Base): __tablename__ = 'order' id = sa.Column(sa.Integer, primary_key=True) unit_price = sa.Column(sa.Integer) amount = sa.Column(sa.Integer) total_price = sa.Column(sa.Integer) @observes('amount', 'unit_price') def total_price_observer(self, amount, unit_price): self.total_price = amount * unit_price """ import itertools from collections import defaultdict, namedtuple import sqlalchemy as sa from .functions import getdotattr, has_changes from .path import AttrPath from .utils import is_sequence try: from collections.abc import Iterable except ImportError: # For python 2.7 support from collections import Iterable Callback = namedtuple('Callback', ['func', 'backref', 'fullpath']) class PropertyObserver(object): def __init__(self): self.listener_args = [ ( sa.orm.mapper, 'mapper_configured', self.update_generator_registry ), ( sa.orm.mapper, 'after_configured', self.gather_paths ), ( sa.orm.session.Session, 'before_flush', self.invoke_callbacks ) ] self.callback_map = defaultdict(list) # TODO: make the registry a WeakKey dict self.generator_registry = defaultdict(list) def remove_listeners(self): for args in self.listener_args: sa.event.remove(*args) def register_listeners(self): for args in self.listener_args: if not sa.event.contains(*args): sa.event.listen(*args) def __repr__(self): return '' def update_generator_registry(self, mapper, class_): """ Adds generator functions to generator_registry. """ for generator in class_.__dict__.values(): if hasattr(generator, '__observes__'): self.generator_registry[class_].append( generator ) def gather_paths(self): for class_, generators in self.generator_registry.items(): for callback in generators: full_paths = [] for call_path in callback.__observes__: full_paths.append(AttrPath(class_, call_path)) for path in full_paths: self.callback_map[class_].append( Callback( func=callback, backref=None, fullpath=full_paths ) ) for index in range(len(path)): i = index + 1 prop = path[index].property if isinstance(prop, sa.orm.RelationshipProperty): prop_class = path[index].property.mapper.class_ self.callback_map[prop_class].append( Callback( func=callback, backref=~ (path[:i]), fullpath=full_paths ) ) def gather_callback_args(self, obj, callbacks): for callback in callbacks: backref = callback.backref root_objs = getdotattr(obj, backref) if backref else obj if root_objs: if not isinstance(root_objs, Iterable): root_objs = [root_objs] for root_obj in root_objs: if root_obj: args = self.get_callback_args(root_obj, callback) if args: yield args def get_callback_args(self, root_obj, callback): session = sa.orm.object_session(root_obj) objects = [getdotattr( root_obj, path, lambda obj: obj not in session.deleted ) for path in callback.fullpath] paths = [str(path) for path in callback.fullpath] for path in paths: if '.' in path or has_changes(root_obj, path): return ( root_obj, callback.func, objects ) def iterate_objects_and_callbacks(self, session): objs = itertools.chain(session.new, session.dirty, session.deleted) for obj in objs: for class_, callbacks in self.callback_map.items(): if isinstance(obj, class_): yield obj, callbacks def invoke_callbacks(self, session, ctx, instances): callback_args = defaultdict(lambda: defaultdict(set)) for obj, callbacks in self.iterate_objects_and_callbacks(session): args = self.gather_callback_args(obj, callbacks) for (root_obj, func, objects) in args: if not callback_args[root_obj][func]: callback_args[root_obj][func] = {} for i, object_ in enumerate(objects): if is_sequence(object_): callback_args[root_obj][func][i] = ( callback_args[root_obj][func].get(i, set()) | set(object_) ) else: callback_args[root_obj][func][i] = object_ for root_obj, callback_objs in callback_args.items(): for callback, objs in callback_objs.items(): callback(root_obj, *[objs[i] for i in range(len(objs))]) observer = PropertyObserver() def observes(*paths, **observer_kw): """ Mark method as property observer for the given property path. Inside transaction observer gathers all changes made in given property path and feeds the changed objects to observer-marked method at the before flush phase. :: from sqlalchemy_utils import observes class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) category_count = sa.Column(sa.Integer, default=0) @observes('categories') def category_observer(self, categories): self.category_count = len(categories) class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) catalog = Catalog(categories=[Category(), Category()]) session.add(catalog) session.commit() catalog.category_count # 2 .. versionadded: 0.28.0 :param *paths: One or more dot-notated property paths, eg. 'categories.products.price' :param **observer: A dictionary where value for key 'observer' contains :meth:`PropertyObserver` object """ observer_ = observer_kw.pop('observer', observer) observer_.register_listeners() def wraps(func): def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) wrapper.__observes__ = paths return wrapper return wraps sqlalchemy-utils-0.36.1/sqlalchemy_utils/operators.py000066400000000000000000000036561360007755400230740ustar00rootroot00000000000000import sqlalchemy as sa def inspect_type(mixed): if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): return mixed.property.columns[0].type elif isinstance(mixed, sa.orm.ColumnProperty): return mixed.columns[0].type elif isinstance(mixed, sa.Column): return mixed.type def is_case_insensitive(mixed): try: return isinstance( inspect_type(mixed).comparator, CaseInsensitiveComparator ) except AttributeError: try: return issubclass( inspect_type(mixed).comparator_factory, CaseInsensitiveComparator ) except AttributeError: return False class CaseInsensitiveComparator(sa.Unicode.Comparator): @classmethod def lowercase_arg(cls, func): def operation(self, other, **kwargs): operator = getattr(sa.Unicode.Comparator, func) if other is None: return operator(self, other, **kwargs) if not is_case_insensitive(other): other = sa.func.lower(other) return operator(self, other, **kwargs) return operation def in_(self, other): if isinstance(other, list) or isinstance(other, tuple): other = map(sa.func.lower, other) return sa.Unicode.Comparator.in_(self, other) def notin_(self, other): if isinstance(other, list) or isinstance(other, tuple): other = map(sa.func.lower, other) return sa.Unicode.Comparator.notin_(self, other) string_operator_funcs = [ '__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__', 'concat', 'contains', 'ilike', 'like', 'notlike', 'notilike', 'startswith', 'endswith', ] for func in string_operator_funcs: setattr( CaseInsensitiveComparator, func, CaseInsensitiveComparator.lowercase_arg(func) ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/path.py000066400000000000000000000101331360007755400217760ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.util.langhelpers import symbol from .utils import str_coercible @str_coercible class Path(object): def __init__(self, path, separator='.'): if isinstance(path, Path): self.path = path.path else: self.path = path self.separator = separator @property def parts(self): return self.path.split(self.separator) def __iter__(self): for part in self.parts: yield part def __len__(self): return len(self.parts) def __repr__(self): return "%s('%s')" % (self.__class__.__name__, self.path) def index(self, element): return self.parts.index(element) def __getitem__(self, slice): result = self.parts[slice] if isinstance(result, list): return self.__class__( self.separator.join(result), separator=self.separator ) return result def __eq__(self, other): return self.path == other.path and self.separator == other.separator def __ne__(self, other): return not (self == other) def __unicode__(self): return self.path def get_attr(mixed, attr): if isinstance(mixed, InstrumentedAttribute): return getattr( mixed.property.mapper.class_, attr ) else: return getattr(mixed, attr) @str_coercible class AttrPath(object): def __init__(self, class_, path): self.class_ = class_ self.path = Path(path) self.parts = [] last_attr = class_ for value in self.path: last_attr = get_attr(last_attr, value) self.parts.append(last_attr) def __iter__(self): for part in self.parts: yield part def __invert__(self): def get_backref(part): prop = part.property backref = prop.backref or prop.back_populates if backref is None: raise Exception( "Invert failed because property '%s' of class " "%s has no backref." % ( prop.key, prop.parent.class_.__name__ ) ) if isinstance(backref, tuple): return backref[0] else: return backref if isinstance(self.parts[-1].property, sa.orm.ColumnProperty): class_ = self.parts[-1].class_ else: class_ = self.parts[-1].mapper.class_ return self.__class__( class_, '.'.join(map(get_backref, reversed(self.parts))) ) def index(self, element): for index, el in enumerate(self.parts): if el is element: return index @property def direction(self): symbols = [part.property.direction for part in self.parts] if symbol('MANYTOMANY') in symbols: return symbol('MANYTOMANY') elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols: return symbol('MANYTOMANY') return symbols[0] @property def uselist(self): return any(part.property.uselist for part in self.parts) def __getitem__(self, slice): result = self.parts[slice] if isinstance(result, list) and result: if result[0] is self.parts[0]: class_ = self.class_ else: class_ = result[0].parent.class_ return self.__class__( class_, self.path[slice] ) else: return result def __len__(self): return len(self.path) def __repr__(self): return "%s(%s, %r)" % ( self.__class__.__name__, self.class_.__name__, self.path.path ) def __eq__(self, other): return self.path == other.path and self.class_ == other.class_ def __ne__(self, other): return not (self == other) def __unicode__(self): return str(self.path) sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/000077500000000000000000000000001360007755400226655ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/__init__.py000066400000000000000000000002711360007755400247760ustar00rootroot00000000000000from .country import Country # noqa from .currency import Currency # noqa from .ltree import Ltree # noqa from .weekday import WeekDay # noqa from .weekdays import WeekDays # noqa sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/country.py000066400000000000000000000053751360007755400247540ustar00rootroot00000000000000from functools import total_ordering import six from .. import i18n from ..utils import str_coercible @total_ordering @str_coercible class Country(object): """ Country class wraps a 2 to 3 letter country code. It provides various convenience properties and methods. :: from babel import Locale from sqlalchemy_utils import Country, i18n # First lets add a locale getter for testing purposes i18n.get_locale = lambda: Locale('en') Country('FI').name # Finland Country('FI').code # FI Country(Country('FI')).code # 'FI' Country always validates the given code if you use at least the optional dependency list 'babel', otherwise no validation are performed. :: Country(None) # raises TypeError Country('UnknownCode') # raises ValueError Country supports equality operators. :: Country('FI') == Country('FI') Country('FI') != Country('US') Country objects are hashable. :: assert hash(Country('FI')) == hash('FI') """ def __init__(self, code_or_country): if isinstance(code_or_country, Country): self.code = code_or_country.code elif isinstance(code_or_country, six.string_types): self.validate(code_or_country) self.code = code_or_country else: raise TypeError( "Country() argument must be a string or a country, not '{0}'" .format( type(code_or_country).__name__ ) ) @property def name(self): return i18n.get_locale().territories[self.code] @classmethod def validate(self, code): try: i18n.babel.Locale('en').territories[code] except KeyError: raise ValueError( 'Could not convert string to country code: {0}'.format(code) ) except AttributeError: # As babel is optional, we may raise an AttributeError accessing it pass def __eq__(self, other): if isinstance(other, Country): return self.code == other.code elif isinstance(other, six.string_types): return self.code == other else: return NotImplemented def __hash__(self): return hash(self.code) def __ne__(self, other): return not (self == other) def __lt__(self, other): if isinstance(other, Country): return self.code < other.code elif isinstance(other, six.string_types): return self.code < other return NotImplemented def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.code) def __unicode__(self): return self.name sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/currency.py000066400000000000000000000054001360007755400250700ustar00rootroot00000000000000# -*- coding: utf-8 -*- import six from .. import i18n, ImproperlyConfigured from ..utils import str_coercible @str_coercible class Currency(object): """ Currency class wraps a 3-letter currency code. It provides various convenience properties and methods. :: from babel import Locale from sqlalchemy_utils import Currency, i18n # First lets add a locale getter for testing purposes i18n.get_locale = lambda: Locale('en') Currency('USD').name # US Dollar Currency('USD').symbol # $ Currency(Currency('USD')).code # 'USD' Currency always validates the given code if you use at least the optional dependency list 'babel', otherwise no validation are performed. :: Currency(None) # raises TypeError Currency('UnknownCode') # raises ValueError Currency supports equality operators. :: Currency('USD') == Currency('USD') Currency('USD') != Currency('EUR') Currencies are hashable. :: len(set([Currency('USD'), Currency('USD')])) # 1 """ def __init__(self, code): if i18n.babel is None: raise ImproperlyConfigured( "'babel' package is required in order to use Currency class." ) if isinstance(code, Currency): self.code = code elif isinstance(code, six.string_types): self.validate(code) self.code = code else: raise TypeError( 'First argument given to Currency constructor should be ' 'either an instance of Currency or valid three letter ' 'currency code.' ) @classmethod def validate(self, code): try: i18n.babel.Locale('en').currencies[code] except KeyError: raise ValueError("'{0}' is not valid currency code.".format(code)) except AttributeError: # As babel is optional, we may raise an AttributeError accessing it pass @property def symbol(self): return i18n.babel.numbers.get_currency_symbol( self.code, i18n.get_locale() ) @property def name(self): return i18n.get_locale().currencies[self.code] def __eq__(self, other): if isinstance(other, Currency): return self.code == other.code elif isinstance(other, six.string_types): return self.code == other else: return NotImplemented def __ne__(self, other): return not (self == other) def __hash__(self): return hash(self.code) def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.code) def __unicode__(self): return self.code sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/ltree.py000066400000000000000000000116471360007755400243630ustar00rootroot00000000000000from __future__ import absolute_import import re import six from ..utils import str_coercible path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$') @str_coercible class Ltree(object): """ Ltree class wraps a valid string label path. It provides various convenience properties and methods. :: from sqlalchemy_utils import Ltree Ltree('1.2.3').path # '1.2.3' Ltree always validates the given path. :: Ltree(None) # raises TypeError Ltree('..') # raises ValueError Validator is also available as class method. :: Ltree.validate('1.2.3') Ltree.validate(None) # raises TypeError Ltree supports equality operators. :: Ltree('Countries.Finland') == Ltree('Countries.Finland') Ltree('Countries.Germany') != Ltree('Countries.Finland') Ltree objects are hashable. :: assert hash(Ltree('Finland')) == hash('Finland') Ltree objects have length. :: assert len(Ltree('1.2')) == 2 assert len(Ltree('some.one.some.where')) # 4 You can easily find subpath indexes. :: assert Ltree('1.2.3').index('2.3') == 1 assert Ltree('1.2.3.4.5').index('3.4') == 2 Ltree objects can be sliced. :: assert Ltree('1.2.3')[0:2] == Ltree('1.2') assert Ltree('1.2.3')[1:] == Ltree('2.3') Finding longest common ancestor. :: assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2' assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1' Ltree objects can be concatenated. :: assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2') """ def __init__(self, path_or_ltree): if isinstance(path_or_ltree, Ltree): self.path = path_or_ltree.path elif isinstance(path_or_ltree, six.string_types): self.validate(path_or_ltree) self.path = path_or_ltree else: raise TypeError( "Ltree() argument must be a string or an Ltree, not '{0}'" .format( type(path_or_ltree).__name__ ) ) @classmethod def validate(cls, path): if path_matcher.match(path) is None: raise ValueError( "'{0}' is not a valid ltree path.".format(path) ) def __len__(self): return len(self.path.split('.')) def index(self, other): subpath = Ltree(other).path.split('.') parts = self.path.split('.') for index, _ in enumerate(parts): if parts[index:len(subpath) + index] == subpath: return index raise ValueError('subpath not found') def descendant_of(self, other): """ is left argument a descendant of right (or equal)? :: assert Ltree('1.2.3.4.5').descendant_of('1.2.3') """ subpath = self[:len(Ltree(other))] return subpath == other def ancestor_of(self, other): """ is left argument an ancestor of right (or equal)? :: assert Ltree('1.2.3').ancestor_of('1.2.3.4.5') """ subpath = Ltree(other)[:len(self)] return subpath == self def __getitem__(self, key): if isinstance(key, int): return Ltree(self.path.split('.')[key]) elif isinstance(key, slice): return Ltree('.'.join(self.path.split('.')[key])) raise TypeError( 'Ltree indices must be integers, not {0}'.format( key.__class__.__name__ ) ) def lca(self, *others): """ Lowest common ancestor, i.e., longest common prefix of paths :: assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2' """ other_parts = [Ltree(other).path.split('.') for other in others] parts = self.path.split('.') for index, element in enumerate(parts): if any(( other[index] != element or len(other) <= index + 1 for other in other_parts )): if index == 0: return None return Ltree('.'.join(parts[0:index])) def __add__(self, other): return Ltree(self.path + '.' + Ltree(other).path) def __radd__(self, other): return Ltree(other) + self def __eq__(self, other): if isinstance(other, Ltree): return self.path == other.path elif isinstance(other, six.string_types): return self.path == other else: return NotImplemented def __hash__(self): return hash(self.path) def __ne__(self, other): return not (self == other) def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.path) def __unicode__(self): return self.path def __contains__(self, label): return label in self.path.split('.') sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/weekday.py000066400000000000000000000024171360007755400246740ustar00rootroot00000000000000# -*- coding: utf-8 -*- from functools import total_ordering from .. import i18n from ..utils import str_coercible @str_coercible @total_ordering class WeekDay(object): NUM_WEEK_DAYS = 7 def __init__(self, index): if not (0 <= index < self.NUM_WEEK_DAYS): raise ValueError( "index must be between 0 and %d" % self.NUM_WEEK_DAYS ) self.index = index def __eq__(self, other): if isinstance(other, WeekDay): return self.index == other.index else: return NotImplemented def __hash__(self): return hash(self.index) def __lt__(self, other): return self.position < other.position def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.index) def __unicode__(self): return self.name def get_name(self, width='wide', context='format'): names = i18n.babel.dates.get_day_names( width, context, i18n.get_locale() ) return names[self.index] @property def name(self): return self.get_name() @property def position(self): return ( self.index - i18n.get_locale().first_week_day ) % self.NUM_WEEK_DAYS sqlalchemy-utils-0.36.1/sqlalchemy_utils/primitives/weekdays.py000066400000000000000000000034721360007755400250610ustar00rootroot00000000000000import six from ..utils import str_coercible from .weekday import WeekDay @str_coercible class WeekDays(object): def __init__(self, bit_string_or_week_days): if isinstance(bit_string_or_week_days, six.string_types): self._days = set() if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS: raise ValueError( 'Bit string must be {0} characters long.'.format( WeekDay.NUM_WEEK_DAYS ) ) for index, bit in enumerate(bit_string_or_week_days): if bit not in '01': raise ValueError( 'Bit string may only contain zeroes and ones.' ) if bit == '1': self._days.add(WeekDay(index)) elif isinstance(bit_string_or_week_days, WeekDays): self._days = bit_string_or_week_days._days else: self._days = set(bit_string_or_week_days) def __eq__(self, other): if isinstance(other, WeekDays): return self._days == other._days elif isinstance(other, six.string_types): return self.as_bit_string() == other else: return NotImplemented def __iter__(self): for day in sorted(self._days): yield day def __contains__(self, value): return value in self._days def __repr__(self): return '%s(%r)' % ( self.__class__.__name__, self.as_bit_string() ) def __unicode__(self): return u', '.join(six.text_type(day) for day in self) def as_bit_string(self): return ''.join( '1' if WeekDay(index) in self._days else '0' for index in six.moves.xrange(WeekDay.NUM_WEEK_DAYS) ) sqlalchemy-utils-0.36.1/sqlalchemy_utils/proxy_dict.py000066400000000000000000000045111360007755400232310ustar00rootroot00000000000000import sqlalchemy as sa class ProxyDict(object): def __init__(self, parent, collection_name, mapping_attr): self.parent = parent self.collection_name = collection_name self.child_class = mapping_attr.class_ self.key_name = mapping_attr.key self.cache = {} @property def collection(self): return getattr(self.parent, self.collection_name) def keys(self): descriptor = getattr(self.child_class, self.key_name) return [x[0] for x in self.collection.values(descriptor)] def __contains__(self, key): if key in self.cache: return self.cache[key] is not None return self.fetch(key) is not None def has_key(self, key): return self.__contains__(key) def fetch(self, key): session = sa.orm.object_session(self.parent) if session and sa.orm.util.has_identity(self.parent): obj = self.collection.filter_by(**{self.key_name: key}).first() self.cache[key] = obj return obj def create_new_instance(self, key): value = self.child_class(**{self.key_name: key}) self.collection.append(value) self.cache[key] = value return value def __getitem__(self, key): if key in self.cache: if self.cache[key] is not None: return self.cache[key] else: value = self.fetch(key) if value: return value return self.create_new_instance(key) def __setitem__(self, key, value): try: existing = self[key] self.collection.remove(existing) except KeyError: pass self.collection.append(value) self.cache[key] = value def proxy_dict(parent, collection_name, mapping_attr): try: parent._proxy_dicts except AttributeError: parent._proxy_dicts = {} try: return parent._proxy_dicts[collection_name] except KeyError: parent._proxy_dicts[collection_name] = ProxyDict( parent, collection_name, mapping_attr ) return parent._proxy_dicts[collection_name] def expire_proxy_dicts(target, context): if hasattr(target, '_proxy_dicts'): target._proxy_dicts = {} sa.event.listen(sa.orm.mapper, 'expire', expire_proxy_dicts) sqlalchemy-utils-0.36.1/sqlalchemy_utils/query_chain.py000066400000000000000000000076031360007755400233610ustar00rootroot00000000000000""" QueryChain is a wrapper for sequence of queries. Features: * Easy iteration for sequence of queries * Limit, offset and count which are applied to all queries in the chain * Smart __getitem__ support Initialization ^^^^^^^^^^^^^^ QueryChain takes iterable of queries as first argument. Additionally limit and offset parameters can be given :: chain = QueryChain([session.query(User), session.query(Article)]) chain = QueryChain( [session.query(User), session.query(Article)], limit=4 ) Simple iteration ^^^^^^^^^^^^^^^^ :: chain = QueryChain([session.query(User), session.query(Article)]) for obj in chain: print obj Limit and offset ^^^^^^^^^^^^^^^^ Lets say you have 5 blog posts, 5 articles and 5 news items in your database. :: chain = QueryChain( [ session.query(BlogPost), session.query(Article), session.query(NewsItem) ], limit=5 ) list(chain) # all blog posts but not articles and news items chain = chain.offset(4) list(chain) # last blog post, and first four articles Just like with original query object the limit and offset can be chained to return a new QueryChain. :: chain = chain.limit(5).offset(7) Chain slicing ^^^^^^^^^^^^^ :: chain = QueryChain( [ session.query(BlogPost), session.query(Article), session.query(NewsItem) ] ) chain[3:6] # New QueryChain with offset=3 and limit=6 Count ^^^^^ Let's assume that there are five blog posts, five articles and five news items in the database, and you have the following query chain:: chain = QueryChain( [ session.query(BlogPost), session.query(Article), session.query(NewsItem) ] ) You can then get the total number rows returned by the query chain with :meth:`~QueryChain.count`:: >>> chain.count() 15 """ from copy import copy class QueryChain(object): """ QueryChain can be used as a wrapper for sequence of queries. :param queries: A sequence of SQLAlchemy Query objects :param limit: Similar to normal query limit this parameter can be used for limiting the number of results for the whole query chain. :param offset: Similar to normal query offset this parameter can be used for offsetting the query chain as a whole. .. versionadded: 0.26.0 """ def __init__(self, queries, limit=None, offset=None): self.queries = queries self._limit = limit self._offset = offset def __iter__(self): consumed = 0 skipped = 0 for query in self.queries: query_copy = copy(query) if self._limit: query = query.limit(self._limit - consumed) if self._offset: query = query.offset(self._offset - skipped) obj_count = 0 for obj in query: consumed += 1 obj_count += 1 yield obj if not obj_count: skipped += query_copy.count() else: skipped += obj_count def limit(self, value): return self[:value] def offset(self, value): return self[value:] def count(self): """ Return the total number of rows this QueryChain's queries would return. """ return sum(q.count() for q in self.queries) def __getitem__(self, key): if isinstance(key, slice): return self.__class__( queries=self.queries, limit=key.stop if key.stop is not None else self._limit, offset=key.start if key.start is not None else self._offset ) else: for obj in self[key:1]: return obj def __repr__(self): return '' % id(self) sqlalchemy-utils-0.36.1/sqlalchemy_utils/relationships/000077500000000000000000000000001360007755400233565ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/relationships/__init__.py000066400000000000000000000066761360007755400255060ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.sql.util import ClauseAdapter from .chained_join import chained_join # noqa def path_to_relationships(path, cls): relationships = [] for path_name in path.split('.'): rel = getattr(cls, path_name) relationships.append(rel) cls = rel.mapper.class_ return relationships def adapt_expr(expr, *selectables): for selectable in selectables: expr = ClauseAdapter(selectable).traverse(expr) return expr def inverse_join(selectable, left_alias, right_alias, relationship): if relationship.property.secondary is not None: secondary_alias = sa.alias(relationship.property.secondary) return selectable.join( secondary_alias, adapt_expr( relationship.property.secondaryjoin, sa.inspect(left_alias).selectable, secondary_alias ) ).join( right_alias, adapt_expr( relationship.property.primaryjoin, sa.inspect(right_alias).selectable, secondary_alias ) ) else: join = sa.orm.join(right_alias, left_alias, relationship) onclause = join.onclause return selectable.join(right_alias, onclause) def relationship_to_correlation(relationship, alias): if relationship.property.secondary is not None: return adapt_expr( relationship.property.primaryjoin, alias, ) else: return sa.orm.join( relationship.parent, alias, relationship ).onclause def chained_inverse_join(relationships, leaf_model): selectable = sa.inspect(leaf_model).selectable aliases = [leaf_model] for index, relationship in enumerate(relationships[1:]): aliases.append(sa.orm.aliased(relationship.mapper.class_)) selectable = inverse_join( selectable, aliases[index], aliases[index + 1], relationships[index] ) if relationships[-1].property.secondary is not None: secondary_alias = sa.alias(relationships[-1].property.secondary) selectable = selectable.join( secondary_alias, adapt_expr( relationships[-1].property.secondaryjoin, secondary_alias, sa.inspect(aliases[-1]).selectable ) ) aliases.append(secondary_alias) return selectable, aliases def select_correlated_expression( root_model, expr, path, leaf_model, from_obj=None, order_by=None, correlate=True ): relationships = list(reversed(path_to_relationships(path, root_model))) query = sa.select([expr]) join_expr, aliases = chained_inverse_join(relationships, leaf_model) if order_by: query = query.order_by( *[ adapt_expr( o, *(sa.inspect(alias).selectable for alias in aliases) ) for o in order_by ] ) condition = relationship_to_correlation( relationships[-1], aliases[-1] ) if from_obj is not None: condition = adapt_expr(condition, from_obj) query = query.select_from(join_expr.selectable) if correlate: query = query.correlate( from_obj if from_obj is not None else root_model ) return query.where(condition) sqlalchemy-utils-0.36.1/sqlalchemy_utils/relationships/chained_join.py000066400000000000000000000015601360007755400263440ustar00rootroot00000000000000def chained_join(*relationships): """ Return a chained Join object for given relationships. """ property_ = relationships[0].property if property_.secondary is not None: from_ = property_.secondary.join( property_.mapper.class_.__table__, property_.secondaryjoin ) else: from_ = property_.mapper.class_.__table__ for relationship in relationships[1:]: prop = relationship.property if prop.secondary is not None: from_ = from_.join( prop.secondary, prop.primaryjoin ) from_ = from_.join( prop.mapper.class_, prop.secondaryjoin ) else: from_ = from_.join( prop.mapper.class_, prop.primaryjoin ) return from_ sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/000077500000000000000000000000001360007755400216365ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/__init__.py000066400000000000000000000033101360007755400237440ustar00rootroot00000000000000from functools import wraps from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from .arrow import ArrowType # noqa from .choice import Choice, ChoiceType # noqa from .color import ColorType # noqa from .country import CountryType # noqa from .currency import CurrencyType # noqa from .email import EmailType # noqa from .encrypted.encrypted_type import EncryptedType # noqa from .ip_address import IPAddressType # noqa from .json import JSONType # noqa from .locale import LocaleType # noqa from .ltree import LtreeType # noqa from .password import Password, PasswordType # noqa from .pg_composite import ( # noqa CompositeArray, CompositeType, register_composites, remove_composite_listeners ) from .phone_number import ( # noqa PhoneNumber, PhoneNumberParseException, PhoneNumberType ) from .range import ( # noqa DateRangeType, DateTimeRangeType, Int8RangeType, IntRangeType, NumericRangeType ) from .scalar_list import ScalarListException, ScalarListType # noqa from .timezone import TimezoneType # noqa from .ts_vector import TSVectorType # noqa from .url import URLType # noqa from .uuid import UUIDType # noqa from .weekdays import WeekDaysType # noqa class InstrumentedList(_InstrumentedList): """Enhanced version of SQLAlchemy InstrumentedList. Provides some additional functionality.""" def any(self, attr): return any(getattr(item, attr) for item in self) def all(self, attr): return all(getattr(item, attr) for item in self) def instrumented_list(f): @wraps(f) def wrapper(*args, **kwargs): return InstrumentedList([item for item in f(*args, **kwargs)]) return wrapper sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/arrow.py000066400000000000000000000047221360007755400233470ustar00rootroot00000000000000from __future__ import absolute_import from datetime import datetime import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible try: from collections.abc import Iterable except ImportError: # For python 2.7 support from collections import Iterable arrow = None try: import arrow except ImportError: pass class ArrowType(types.TypeDecorator, ScalarCoercible): """ ArrowType provides way of saving Arrow_ objects into database. It automatically changes Arrow_ objects to datetime objects on the way in and datetime objects back to Arrow_ objects on the way out (when querying database). ArrowType needs Arrow_ library installed. .. _Arrow: http://crsmithdev.com/arrow/ :: from datetime import datetime from sqlalchemy_utils import ArrowType import arrow class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) created_at = sa.Column(ArrowType) article = Article(created_at=arrow.utcnow()) As you may expect all the arrow goodies come available: :: article.created_at = article.created_at.replace(hours=-1) article.created_at.humanize() # 'an hour ago' """ impl = types.DateTime def __init__(self, *args, **kwargs): if not arrow: raise ImproperlyConfigured( "'arrow' package is required to use 'ArrowType'" ) super(ArrowType, self).__init__(*args, **kwargs) def process_bind_param(self, value, dialect): if value: utc_val = self._coerce(value).to('UTC') return utc_val.datetime if self.impl.timezone else utc_val.naive return value def process_result_value(self, value, dialect): if value: return arrow.get(value) return value def process_literal_param(self, value, dialect): return str(value) def _coerce(self, value): if value is None: return None elif isinstance(value, six.string_types): value = arrow.get(value) elif isinstance(value, Iterable): value = arrow.get(*value) elif isinstance(value, datetime): value = arrow.get(value) return value @property def python_type(self): return self.impl.type.python_type sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/bit.py000066400000000000000000000013371360007755400227720ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.dialects.postgresql import BIT class BitType(sa.types.TypeDecorator): """ BitType offers way of saving BITs into database. """ impl = sa.types.BINARY def __init__(self, length=1, **kwargs): self.length = length sa.types.TypeDecorator.__init__(self, **kwargs) def load_dialect_impl(self, dialect): # Use the native BIT type for drivers that has it. if dialect.name == 'postgresql': return dialect.type_descriptor(BIT(self.length)) elif dialect.name == 'sqlite': return dialect.type_descriptor(sa.String(self.length)) else: return dialect.type_descriptor(type(self.impl)(self.length)) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/choice.py000066400000000000000000000137611360007755400234520ustar00rootroot00000000000000import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible try: from enum import Enum except ImportError: Enum = None class Choice(object): def __init__(self, code, value): self.code = code self.value = value def __eq__(self, other): if isinstance(other, Choice): return self.code == other.code return other == self.code def __hash__(self): return hash(self.code) def __ne__(self, other): return not (self == other) def __unicode__(self): return six.text_type(self.value) def __str__(self): return six.ensure_str(self.__unicode__()) def __repr__(self): return 'Choice(code={code}, value={value})'.format( code=self.code, value=self.value ) class ChoiceType(types.TypeDecorator, ScalarCoercible): """ ChoiceType offers way of having fixed set of choices for given column. It could work with a list of tuple (a collection of key-value pairs), or integrate with :mod:`enum` in the standard library of Python 3.4+ (the enum34_ backported package on PyPI is compatible too for ``< 3.4``). .. _enum34: https://pypi.python.org/pypi/enum34 Columns with ChoiceTypes are automatically coerced to Choice objects while a list of tuple been passed to the constructor. If a subclass of :class:`enum.Enum` is passed, columns will be coerced to :class:`enum.Enum` objects instead. :: class User(Base): TYPES = [ (u'admin', u'Admin'), (u'regular-user', u'Regular user') ] __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(ChoiceType(TYPES)) user = User(type=u'admin') user.type # Choice(code='admin', value=u'Admin') Or:: import enum class UserType(enum.Enum): admin = 1 regular = 2 class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(ChoiceType(UserType, impl=sa.Integer())) user = User(type=1) user.type # ChoiceType is very useful when the rendered values change based on user's locale: :: from babel import lazy_gettext as _ class User(Base): TYPES = [ (u'admin', _(u'Admin')), (u'regular-user', _(u'Regular user')) ] __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(ChoiceType(TYPES)) user = User(type=u'admin') user.type # Choice(code='admin', value=u'Admin') print user.type # u'Admin' Or:: from enum import Enum from babel import lazy_gettext as _ class UserType(Enum): admin = 1 regular = 2 UserType.admin.label = _(u'Admin') UserType.regular.label = _(u'Regular user') class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(ChoiceType(UserType, impl=sa.Integer())) user = User(type=UserType.admin) user.type # print user.type.label # u'Admin' """ impl = types.Unicode(255) def __init__(self, choices, impl=None): self.choices = choices if ( Enum is not None and isinstance(choices, type) and issubclass(choices, Enum) ): self.type_impl = EnumTypeImpl(enum_class=choices) else: self.type_impl = ChoiceTypeImpl(choices=choices) if impl: self.impl = impl @property def python_type(self): return self.impl.python_type def _coerce(self, value): return self.type_impl._coerce(value) def process_bind_param(self, value, dialect): return self.type_impl.process_bind_param(value, dialect) def process_result_value(self, value, dialect): return self.type_impl.process_result_value(value, dialect) class ChoiceTypeImpl(object): """The implementation for the ``Choice`` usage.""" def __init__(self, choices): if not choices: raise ImproperlyConfigured( 'ChoiceType needs list of choices defined.' ) self.choices_dict = dict(choices) def _coerce(self, value): if value is None: return value if isinstance(value, Choice): return value return Choice(value, self.choices_dict[value]) def process_bind_param(self, value, dialect): if value and isinstance(value, Choice): return value.code return value def process_result_value(self, value, dialect): if value: return Choice(value, self.choices_dict[value]) return value class EnumTypeImpl(object): """The implementation for the ``Enum`` usage.""" def __init__(self, enum_class): if Enum is None: raise ImproperlyConfigured( "'enum34' package is required to use 'EnumType' in Python " "< 3.4" ) if not issubclass(enum_class, Enum): raise ImproperlyConfigured( "EnumType needs a class of enum defined." ) self.enum_class = enum_class def _coerce(self, value): if value is None: return None return self.enum_class(value) def process_bind_param(self, value, dialect): if value is None: return None return self.enum_class(value).value def process_result_value(self, value, dialect): return self._coerce(value) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/color.py000066400000000000000000000041101360007755400233220ustar00rootroot00000000000000import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible colour = None try: import colour python_colour_type = colour.Color except ImportError: python_colour_type = None class ColorType(types.TypeDecorator, ScalarCoercible): """ ColorType provides a way for saving Color (from colour_ package) objects into database. ColorType saves Color objects as strings on the way in and converts them back to objects when querying the database. :: from colour import Color from sqlalchemy_utils import ColorType class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(50)) background_color = sa.Column(ColorType) document = Document() document.background_color = Color('#F5F5F5') session.commit() Querying the database returns Color objects: :: document = session.query(Document).first() document.background_color.hex # '#f5f5f5' .. _colour: https://github.com/vaab/colour """ STORE_FORMAT = u'hex' impl = types.Unicode(20) python_type = python_colour_type def __init__(self, max_length=20, *args, **kwargs): # Fail if colour is not found. if colour is None: raise ImproperlyConfigured( "'colour' package is required to use 'ColorType'" ) super(ColorType, self).__init__(*args, **kwargs) self.impl = types.Unicode(max_length) def process_bind_param(self, value, dialect): if value and isinstance(value, colour.Color): return six.text_type(getattr(value, self.STORE_FORMAT)) return value def process_result_value(self, value, dialect): if value: return colour.Color(value) return value def _coerce(self, value): if value is not None and not isinstance(value, colour.Color): return colour.Color(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/country.py000066400000000000000000000030561360007755400237170ustar00rootroot00000000000000import six from sqlalchemy import types from ..primitives import Country from .scalar_coercible import ScalarCoercible class CountryType(types.TypeDecorator, ScalarCoercible): """ Changes :class:`.Country` objects to a string representation on the way in and changes them back to :class:`.Country objects on the way out. In order to use CountryType you need to install Babel_ first. .. _Babel: http://babel.pocoo.org/ :: from sqlalchemy_utils import CountryType, Country class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) country = sa.Column(CountryType) user = User() user.country = Country('FI') session.add(user) session.commit() user.country # Country('FI') user.country.name # Finland print user.country # Finland CountryType is scalar coercible:: user.country = 'US' user.country # Country('US') """ impl = types.String(2) python_type = Country def process_bind_param(self, value, dialect): if isinstance(value, Country): return value.code if isinstance(value, six.string_types): return value def process_result_value(self, value, dialect): if value is not None: return Country(value) def _coerce(self, value): if value is not None and not isinstance(value, Country): return Country(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/currency.py000066400000000000000000000036421360007755400240470ustar00rootroot00000000000000import six from sqlalchemy import types from .. import i18n, ImproperlyConfigured from ..primitives import Currency from .scalar_coercible import ScalarCoercible class CurrencyType(types.TypeDecorator, ScalarCoercible): """ Changes :class:`.Currency` objects to a string representation on the way in and changes them back to :class:`.Currency` objects on the way out. In order to use CurrencyType you need to install Babel_ first. .. _Babel: http://babel.pocoo.org/ :: from sqlalchemy_utils import CurrencyType, Currency class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) currency = sa.Column(CurrencyType) user = User() user.currency = Currency('USD') session.add(user) session.commit() user.currency # Currency('USD') user.currency.name # US Dollar str(user.currency) # US Dollar user.currency.symbol # $ CurrencyType is scalar coercible:: user.currency = 'US' user.currency # Currency('US') """ impl = types.String(3) python_type = Currency def __init__(self, *args, **kwargs): if i18n.babel is None: raise ImproperlyConfigured( "'babel' package is required in order to use CurrencyType." ) super(CurrencyType, self).__init__(*args, **kwargs) def process_bind_param(self, value, dialect): if isinstance(value, Currency): return value.code elif isinstance(value, six.string_types): return value def process_result_value(self, value, dialect): if value is not None: return Currency(value) def _coerce(self, value): if value is not None and not isinstance(value, Currency): return Currency(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/email.py000066400000000000000000000023461360007755400233040ustar00rootroot00000000000000import sqlalchemy as sa from ..operators import CaseInsensitiveComparator class EmailType(sa.types.TypeDecorator): """ Provides a way for storing emails in a lower case. Example:: from sqlalchemy_utils import EmailType class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) email = sa.Column(EmailType) user = User() user.email = 'John.Smith@foo.com' user.name = 'John Smith' session.add(user) session.commit() # Notice - email in filter() is lowercase. user = (session.query(User) .filter(User.email == 'john.smith@foo.com') .one()) assert user.name == 'John Smith' """ impl = sa.Unicode comparator_factory = CaseInsensitiveComparator def __init__(self, length=255, *args, **kwargs): super(EmailType, self).__init__(length=length, *args, **kwargs) def process_bind_param(self, value, dialect): if value is not None: return value.lower() return value @property def python_type(self): return self.impl.type.python_type sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/encrypted/000077500000000000000000000000001360007755400236335ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/encrypted/__init__.py000066400000000000000000000000341360007755400257410ustar00rootroot00000000000000# Module for encrypted type sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/encrypted/encrypted_type.py000066400000000000000000000361561360007755400272560ustar00rootroot00000000000000# -*- coding: utf-8 -*- import base64 import datetime import os import six from sqlalchemy.types import LargeBinary, String, TypeDecorator from sqlalchemy_utils.exceptions import ImproperlyConfigured from sqlalchemy_utils.types.encrypted.padding import PADDING_MECHANISM from sqlalchemy_utils.types.scalar_coercible import ScalarCoercible cryptography = None try: import cryptography from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers import ( Cipher, algorithms, modes ) from cryptography.fernet import Fernet from cryptography.exceptions import InvalidTag except ImportError: pass dateutil = None try: import dateutil from dateutil.parser import parse as datetime_parse except ImportError: pass class InvalidCiphertextError(Exception): pass class EncryptionDecryptionBaseEngine(object): """A base encryption and decryption engine. This class must be sub-classed in order to create new engines. """ def _update_key(self, key): if isinstance(key, six.string_types): key = key.encode() digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) digest.update(key) engine_key = digest.finalize() self._initialize_engine(engine_key) def encrypt(self, value): raise NotImplementedError('Subclasses must implement this!') def decrypt(self, value): raise NotImplementedError('Subclasses must implement this!') class AesEngine(EncryptionDecryptionBaseEngine): """Provide AES encryption and decryption methods. You may also consider using the AesGcmEngine instead -- that may be a better fit for some cases. You should NOT use the AesGcmEngine if you want to be able to search for a row based on the value of an encrypted column. Use AesEngine instead, since that allows you to perform such searches. If you don't need to search by the value of an encypted column, the AesGcmEngine provides better security. """ BLOCK_SIZE = 16 def _initialize_engine(self, parent_class_key): self.secret_key = parent_class_key self.iv = self.secret_key[:16] self.cipher = Cipher( algorithms.AES(self.secret_key), modes.CBC(self.iv), backend=default_backend() ) def _set_padding_mechanism(self, padding_mechanism=None): """Set the padding mechanism.""" if isinstance(padding_mechanism, six.string_types): if padding_mechanism not in PADDING_MECHANISM.keys(): raise ImproperlyConfigured( "There is not padding mechanism with name {}".format( padding_mechanism ) ) if padding_mechanism is None: padding_mechanism = 'naive' padding_class = PADDING_MECHANISM[padding_mechanism] self.padding_engine = padding_class(self.BLOCK_SIZE) def encrypt(self, value): if not isinstance(value, six.string_types): value = repr(value) if isinstance(value, six.text_type): value = str(value) value = value.encode() value = self.padding_engine.pad(value) encryptor = self.cipher.encryptor() encrypted = encryptor.update(value) + encryptor.finalize() encrypted = base64.b64encode(encrypted) return encrypted def decrypt(self, value): if isinstance(value, six.text_type): value = str(value) decryptor = self.cipher.decryptor() decrypted = base64.b64decode(value) decrypted = decryptor.update(decrypted) + decryptor.finalize() decrypted = self.padding_engine.unpad(decrypted) if not isinstance(decrypted, six.string_types): try: decrypted = decrypted.decode('utf-8') except UnicodeDecodeError: raise ValueError('Invalid decryption key') return decrypted class AesGcmEngine(EncryptionDecryptionBaseEngine): """Provide AES/GCM encryption and decryption methods. You may also consider using the AesEngine instead -- that may be a better fit for some cases. You should NOT use this AesGcmEngine if you want to be able to search for a row based on the value of an encrypted column. Use AesEngine instead, since that allows you to perform such searches. If you don't need to search by the value of an encypted column, the AesGcmEngine provides better security. """ BLOCK_SIZE = 16 IV_BYTES_NEEDED = 12 TAG_SIZE_BYTES = BLOCK_SIZE def _initialize_engine(self, parent_class_key): self.secret_key = parent_class_key def encrypt(self, value): if not isinstance(value, six.string_types): value = repr(value) if isinstance(value, six.text_type): value = str(value) value = value.encode() iv = os.urandom(self.IV_BYTES_NEEDED) cipher = Cipher( algorithms.AES(self.secret_key), modes.GCM(iv), backend=default_backend() ) encryptor = cipher.encryptor() encrypted = encryptor.update(value) + encryptor.finalize() assert len(encryptor.tag) == self.TAG_SIZE_BYTES encrypted = base64.b64encode(iv + encryptor.tag + encrypted) return encrypted def decrypt(self, value): if isinstance(value, six.text_type): value = str(value) decrypted = base64.b64decode(value) if len(decrypted) < self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES: raise InvalidCiphertextError() iv = decrypted[:self.IV_BYTES_NEEDED] tag = decrypted[self.IV_BYTES_NEEDED: self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES] decrypted = decrypted[self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES:] cipher = Cipher( algorithms.AES(self.secret_key), modes.GCM(iv, tag), backend=default_backend() ) decryptor = cipher.decryptor() try: decrypted = decryptor.update(decrypted) + decryptor.finalize() except InvalidTag: raise InvalidCiphertextError() if not isinstance(decrypted, six.string_types): try: decrypted = decrypted.decode('utf-8') except UnicodeDecodeError: raise InvalidCiphertextError() return decrypted class FernetEngine(EncryptionDecryptionBaseEngine): """Provide Fernet encryption and decryption methods.""" def _initialize_engine(self, parent_class_key): self.secret_key = base64.urlsafe_b64encode(parent_class_key) self.fernet = Fernet(self.secret_key) def encrypt(self, value): if not isinstance(value, six.string_types): value = repr(value) if isinstance(value, six.text_type): value = str(value) value = value.encode() encrypted = self.fernet.encrypt(value) return encrypted def decrypt(self, value): if isinstance(value, six.text_type): value = str(value) decrypted = self.fernet.decrypt(value) if not isinstance(decrypted, six.string_types): decrypted = decrypted.decode('utf-8') return decrypted class EncryptedType(TypeDecorator, ScalarCoercible): """ EncryptedType provides a way to encrypt and decrypt values, to and from databases, that their type is a basic SQLAlchemy type. For example Unicode, String or even Boolean. On the way in, the value is encrypted and on the way out the stored value is decrypted. EncryptedType needs Cryptography_ library in order to work. When declaring a column which will be of type EncryptedType it is better to be as precise as possible and follow the pattern below. .. _Cryptography: https://cryptography.io/en/latest/ :: a_column = sa.Column(EncryptedType(sa.Unicode, secret_key, FernetEngine)) another_column = sa.Column(EncryptedType(sa.Unicode, secret_key, AesEngine, 'pkcs5')) A more complete example is given below. :: import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy_utils import EncryptedType from sqlalchemy_utils.types.encrypted.encrypted_type import AesEngine secret_key = 'secretkey1234' # setup engine = create_engine('sqlite:///:memory:') connection = engine.connect() Base = declarative_base() class User(Base): __tablename__ = "user" id = sa.Column(sa.Integer, primary_key=True) username = sa.Column(EncryptedType(sa.Unicode, secret_key, AesEngine, 'pkcs5')) access_token = sa.Column(EncryptedType(sa.String, secret_key, AesEngine, 'pkcs5')) is_active = sa.Column(EncryptedType(sa.Boolean, secret_key, AesEngine, 'zeroes')) number_of_accounts = sa.Column(EncryptedType(sa.Integer, secret_key, AesEngine, 'oneandzeroes')) sa.orm.configure_mappers() Base.metadata.create_all(connection) # create a configured "Session" class Session = sessionmaker(bind=connection) # create a Session session = Session() # example user_name = u'secret_user' test_token = 'atesttoken' active = True num_of_accounts = 2 user = User(username=user_name, access_token=test_token, is_active=active, number_of_accounts=num_of_accounts) session.add(user) session.commit() user_id = user.id session.expunge_all() user_instance = session.query(User).get(user_id) print('id: {}'.format(user_instance.id)) print('username: {}'.format(user_instance.username)) print('token: {}'.format(user_instance.access_token)) print('active: {}'.format(user_instance.is_active)) print('accounts: {}'.format(user_instance.number_of_accounts)) # teardown session.close_all() Base.metadata.drop_all(connection) connection.close() engine.dispose() The key parameter accepts a callable to allow for the key to change per-row instead of being fixed for the whole table. :: def get_key(): return 'dynamic-key' class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) username = sa.Column(EncryptedType( sa.Unicode, get_key)) """ impl = LargeBinary def __init__(self, type_in=None, key=None, engine=None, padding=None, **kwargs): """Initialization.""" if not cryptography: raise ImproperlyConfigured( "'cryptography' is required to use EncryptedType" ) super(EncryptedType, self).__init__(**kwargs) # set the underlying type if type_in is None: type_in = String() elif isinstance(type_in, type): type_in = type_in() self.underlying_type = type_in self._key = key if not engine: engine = AesEngine self.engine = engine() if isinstance(self.engine, AesEngine): self.engine._set_padding_mechanism(padding) @property def key(self): return self._key @key.setter def key(self, value): self._key = value def _update_key(self): key = self._key() if callable(self._key) else self._key self.engine._update_key(key) def process_bind_param(self, value, dialect): """Encrypt a value on the way in.""" if value is not None: self._update_key() try: value = self.underlying_type.process_bind_param( value, dialect ) except AttributeError: # Doesn't have 'process_bind_param' # Handle 'boolean' and 'dates' type_ = self.underlying_type.python_type if issubclass(type_, bool): value = 'true' if value else 'false' elif issubclass(type_, (datetime.date, datetime.time)): value = value.isoformat() return self.engine.encrypt(value) def process_result_value(self, value, dialect): """Decrypt value on the way out.""" if value is not None: self._update_key() decrypted_value = self.engine.decrypt(value) try: return self.underlying_type.process_result_value( decrypted_value, dialect ) except AttributeError: # Doesn't have 'process_result_value' # Handle 'boolean' and 'dates' type_ = self.underlying_type.python_type date_types = [datetime.datetime, datetime.time, datetime.date] if issubclass(type_, bool): return decrypted_value == 'true' elif type_ in date_types: return DatetimeHandler.process_value( decrypted_value, type_ ) # Handle all others return self.underlying_type.python_type(decrypted_value) def _coerce(self, value): if isinstance(self.underlying_type, ScalarCoercible): return self.underlying_type._coerce(value) return value class DatetimeHandler(object): """ DatetimeHandler is responsible for parsing strings and returning the appropriate date, datetime or time objects. """ @classmethod def process_value(cls, value, python_type): """ process_value returns a datetime, date or time object according to a given string value and a python type. """ if not dateutil: raise ImproperlyConfigured( "'python-dateutil' is required to process datetimes" ) return_value = datetime_parse(value) if issubclass(python_type, datetime.datetime): return return_value elif issubclass(python_type, datetime.time): return return_value.time() elif issubclass(python_type, datetime.date): return return_value.date() sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/encrypted/padding.py000066400000000000000000000107711360007755400256210ustar00rootroot00000000000000import six class InvalidPaddingError(Exception): pass class Padding(object): """Base class for padding and unpadding.""" def __init__(self, block_size): self.block_size = block_size def pad(value): raise NotImplementedError('Subclasses must implement this!') def unpad(value): raise NotImplementedError('Subclasses must implement this!') class PKCS5Padding(Padding): """Provide PKCS5 padding and unpadding.""" def pad(self, value): if not isinstance(value, six.binary_type): value = value.encode() padding_length = (self.block_size - len(value) % self.block_size) padding_sequence = padding_length * six.b(chr(padding_length)) value_with_padding = value + padding_sequence return value_with_padding def unpad(self, value): # Perform some input validations. # In case of error, we throw a generic InvalidPaddingError() if not value or len(value) < self.block_size: # PKCS5 padded output will always be at least 1 block size raise InvalidPaddingError() if len(value) % self.block_size != 0: # PKCS5 padded output will be a multiple of the block size raise InvalidPaddingError() if isinstance(value, six.binary_type): padding_length = value[-1] if isinstance(value, six.string_types): padding_length = ord(value[-1]) if padding_length == 0 or padding_length > self.block_size: raise InvalidPaddingError() def convert_byte_or_char_to_number(x): return ord(x) if isinstance(x, six.string_types) else x if any([padding_length != convert_byte_or_char_to_number(x) for x in value[-padding_length:]]): raise InvalidPaddingError() value_without_padding = value[0:-padding_length] return value_without_padding class OneAndZeroesPadding(Padding): """Provide the one and zeroes padding and unpadding. This mechanism pads with 0x80 followed by zero bytes. For unpadding it strips off all trailing zero bytes and the 0x80 byte. """ BYTE_80 = 0x80 BYTE_00 = 0x00 def pad(self, value): if not isinstance(value, six.binary_type): value = value.encode() padding_length = (self.block_size - len(value) % self.block_size) one_part_bytes = six.b(chr(self.BYTE_80)) zeroes_part_bytes = (padding_length - 1) * six.b(chr(self.BYTE_00)) padding_sequence = one_part_bytes + zeroes_part_bytes value_with_padding = value + padding_sequence return value_with_padding def unpad(self, value): value_without_padding = value.rstrip(six.b(chr(self.BYTE_00))) value_without_padding = value_without_padding.rstrip( six.b(chr(self.BYTE_80))) return value_without_padding class ZeroesPadding(Padding): """Provide zeroes padding and unpadding. This mechanism pads with 0x00 except the last byte equals to the padding length. For unpadding it reads the last byte and strips off that many bytes. """ BYTE_00 = 0x00 def pad(self, value): if not isinstance(value, six.binary_type): value = value.encode() padding_length = (self.block_size - len(value) % self.block_size) zeroes_part_bytes = (padding_length - 1) * six.b(chr(self.BYTE_00)) last_part_bytes = six.b(chr(padding_length)) padding_sequence = zeroes_part_bytes + last_part_bytes value_with_padding = value + padding_sequence return value_with_padding def unpad(self, value): if isinstance(value, six.binary_type): padding_length = value[-1] if isinstance(value, six.string_types): padding_length = ord(value[-1]) value_without_padding = value[0:-padding_length] return value_without_padding class NaivePadding(Padding): """Naive padding and unpadding using '*'. The class is provided only for backwards compatibility. """ CHARACTER = six.b('*') def pad(self, value): num_of_bytes = (self.block_size - len(value) % self.block_size) value_with_padding = value + num_of_bytes * self.CHARACTER return value_with_padding def unpad(self, value): value_without_padding = value.rstrip(self.CHARACTER) return value_without_padding PADDING_MECHANISM = { 'pkcs5': PKCS5Padding, 'oneandzeroes': OneAndZeroesPadding, 'zeroes': ZeroesPadding, 'naive': NaivePadding } sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/ip_address.py000066400000000000000000000036341360007755400243330ustar00rootroot00000000000000import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible ip_address = None try: from ipaddress import ip_address except ImportError: try: from ipaddr import IPAddress as ip_address except ImportError: pass class IPAddressType(types.TypeDecorator, ScalarCoercible): """ Changes IPAddress objects to a string representation on the way in and changes them back to IPAddress objects on the way out. IPAddressType uses ipaddress package on Python >= 3 and ipaddr_ package on Python 2. In order to use IPAddressType with python you need to install ipaddr_ first. .. _ipaddr: https://pypi.python.org/pypi/ipaddr :: from sqlalchemy_utils import IPAddressType class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) ip_address = sa.Column(IPAddressType) user = User() user.ip_address = '123.123.123.123' session.add(user) session.commit() user.ip_address # IPAddress object """ impl = types.Unicode(50) def __init__(self, max_length=50, *args, **kwargs): if not ip_address: raise ImproperlyConfigured( "'ipaddr' package is required to use 'IPAddressType' " "in python 2" ) super(IPAddressType, self).__init__(*args, **kwargs) self.impl = types.Unicode(max_length) def process_bind_param(self, value, dialect): return six.text_type(value) if value else None def process_result_value(self, value, dialect): return ip_address(value) if value else None def _coerce(self, value): return ip_address(value) if value else None @property def python_type(self): return self.impl.type.python_type sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/json.py000066400000000000000000000045701360007755400231670ustar00rootroot00000000000000from __future__ import absolute_import import six import sqlalchemy as sa from sqlalchemy.dialects.postgresql.base import ischema_names from ..exceptions import ImproperlyConfigured json = None try: import anyjson as json except ImportError: import json as json try: from sqlalchemy.dialects.postgresql import JSON has_postgres_json = True except ImportError: class PostgresJSONType(sa.types.UserDefinedType): """ Text search vector type for postgresql. """ def get_col_spec(self): return 'json' ischema_names['json'] = PostgresJSONType has_postgres_json = False class JSONType(sa.types.TypeDecorator): """ JSONType offers way of saving JSON data structures to database. On PostgreSQL the underlying implementation of this data type is 'json' while on other databases its simply 'text'. :: from sqlalchemy_utils import JSONType class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(50)) details = sa.Column(JSONType) product = Product() product.details = { 'color': 'red', 'type': 'car', 'max-speed': '400 mph' } session.commit() """ impl = sa.UnicodeText def __init__(self, *args, **kwargs): if json is None: raise ImproperlyConfigured( 'JSONType needs anyjson package installed.' ) super(JSONType, self).__init__(*args, **kwargs) def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': # Use the native JSON type. if has_postgres_json: return dialect.type_descriptor(JSON()) else: return dialect.type_descriptor(PostgresJSONType()) else: return dialect.type_descriptor(self.impl) def process_bind_param(self, value, dialect): if dialect.name == 'postgresql' and has_postgres_json: return value if value is not None: value = six.text_type(json.dumps(value)) return value def process_result_value(self, value, dialect): if dialect.name == 'postgresql': return value if value is not None: value = json.loads(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/locale.py000066400000000000000000000034211360007755400234470ustar00rootroot00000000000000import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible babel = None try: import babel except ImportError: pass class LocaleType(types.TypeDecorator, ScalarCoercible): """ LocaleType saves Babel_ Locale objects into database. The Locale objects are converted to string on the way in and back to object on the way out. In order to use LocaleType you need to install Babel_ first. .. _Babel: http://babel.pocoo.org/ :: from sqlalchemy_utils import LocaleType from babel import Locale class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(50)) locale = sa.Column(LocaleType) user = User() user.locale = Locale('en_US') session.add(user) session.commit() Like many other types this type also supports scalar coercion: :: user.locale = 'de_DE' user.locale # Locale('de', territory='DE') """ impl = types.Unicode(10) def __init__(self): if babel is None: raise ImproperlyConfigured( 'Babel packaged is required with LocaleType.' ) def process_bind_param(self, value, dialect): if isinstance(value, babel.Locale): return six.text_type(value) if isinstance(value, six.string_types): return value def process_result_value(self, value, dialect): if value is not None: return babel.Locale.parse(value) def _coerce(self, value): if value is not None and not isinstance(value, babel.Locale): return babel.Locale.parse(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/ltree.py000066400000000000000000000065171360007755400233340ustar00rootroot00000000000000from __future__ import absolute_import from sqlalchemy import types from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql.base import ischema_names, PGTypeCompiler from sqlalchemy.sql import expression from ..primitives import Ltree from .scalar_coercible import ScalarCoercible class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible): """Postgresql LtreeType type. The LtreeType datatype can be used for representing labels of data stored in hierarchial tree-like structure. For more detailed information please refer to http://www.postgresql.org/docs/current/static/ltree.html :: from sqlalchemy_utils import LtreeType, Ltree class DocumentSection(Base): __tablename__ = 'document_section' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) path = sa.Column(LtreeType) section = DocumentSection(path=Ltree('Countries.Finland')) session.add(section) session.commit() section.path # Ltree('Countries.Finland') .. note:: Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types may require installation of Postgresql ltree extension on the server side. Please visit http://www.postgres.org for details. """ class comparator_factory(types.Concatenable.Comparator): def ancestor_of(self, other): if isinstance(other, list): return self.op('@>')(expression.cast(other, ARRAY(LtreeType))) else: return self.op('@>')(other) def descendant_of(self, other): if isinstance(other, list): return self.op('<@')(expression.cast(other, ARRAY(LtreeType))) else: return self.op('<@')(other) def lquery(self, other): if isinstance(other, list): return self.op('?')(expression.cast(other, ARRAY(LQUERY))) else: return self.op('~')(other) def ltxtquery(self, other): return self.op('@')(other) def bind_processor(self, dialect): def process(value): if value: return value.path return process def result_processor(self, dialect, coltype): def process(value): return self._coerce(value) return process def literal_processor(self, dialect): def process(value): value = value.replace("'", "''") return "'%s'" % value return process __visit_name__ = 'LTREE' def _coerce(self, value): if value: return Ltree(value) class LQUERY(types.TypeEngine): """Postresql LQUERY type. See :class:`LTREE` for details. """ __visit_name__ = 'LQUERY' class LTXTQUERY(types.TypeEngine): """Postresql LTXTQUERY type. See :class:`LTREE` for details. """ __visit_name__ = 'LTXTQUERY' ischema_names['ltree'] = LtreeType ischema_names['lquery'] = LQUERY ischema_names['ltxtquery'] = LTXTQUERY def visit_LTREE(self, type_, **kw): return 'LTREE' def visit_LQUERY(self, type_, **kw): return 'LQUERY' def visit_LTXTQUERY(self, type_, **kw): return 'LTXTQUERY' PGTypeCompiler.visit_LTREE = visit_LTREE PGTypeCompiler.visit_LQUERY = visit_LQUERY PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/password.py000066400000000000000000000174411360007755400240610ustar00rootroot00000000000000import weakref import six from sqlalchemy import types from sqlalchemy.dialects import oracle, postgresql, sqlite from sqlalchemy.ext.mutable import Mutable from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible passlib = None try: import passlib from passlib.context import LazyCryptContext except ImportError: pass class Password(Mutable, object): @classmethod def coerce(cls, key, value): if isinstance(value, Password): return value if isinstance(value, (six.string_types, six.binary_type)): return cls(value, secret=True) super(Password, cls).coerce(key, value) def __init__(self, value, context=None, secret=False): # Store the hash (if it is one). self.hash = value if not secret else None # Store the secret if we have one. self.secret = value if secret else None # The hash should be bytes. if isinstance(self.hash, six.text_type): self.hash = self.hash.encode('utf8') # Save weakref of the password context (if we have one) self.context = weakref.proxy(context) if context is not None else None def __eq__(self, value): if self.hash is None or value is None: # Ensure that we don't continue comparison if one of us is None. return self.hash is value if isinstance(value, Password): # Comparing 2 hashes isn't very useful; but this equality # method breaks otherwise. return value.hash == self.hash if self.context is None: # Compare 2 hashes again as we don't know how to validate. return value == self if isinstance(value, (six.string_types, six.binary_type)): valid, new = self.context.verify_and_update(value, self.hash) if valid and new: # New hash was calculated due to various reasons; stored one # wasn't optimal, etc. self.hash = new # The hash should be bytes. if isinstance(self.hash, six.string_types): self.hash = self.hash.encode('utf8') self.changed() return valid return False def __ne__(self, value): return not (self == value) class PasswordType(types.TypeDecorator, ScalarCoercible): """ PasswordType hashes passwords as they come into the database and allows verifying them using a Pythonic interface. This Pythonic interface relies on setting up automatic data type coercion using the :func:`~sqlalchemy_utils.listeners.force_auto_coercion` function. All keyword arguments (aside from max_length) are forwarded to the construction of a `passlib.context.LazyCryptContext` object, which also supports deferred configuration via the `onload` callback. The following usage will create a password column that will automatically hash new passwords as `pbkdf2_sha512` but still compare passwords against pre-existing `md5_crypt` hashes. As passwords are compared; the password hash in the database will be updated to be `pbkdf2_sha512`. :: class Model(Base): password = sa.Column(PasswordType( schemes=[ 'pbkdf2_sha512', 'md5_crypt' ], deprecated=['md5_crypt'] )) Verifying password is as easy as: :: target = Model() target.password = 'b' # '$5$rounds=80000$H.............' target.password == 'b' # True Lazy configuration of the type with Flask config: :: import flask from sqlalchemy_utils import PasswordType, force_auto_coercion force_auto_coercion() class User(db.Model): __tablename__ = 'user' password = db.Column( PasswordType( # The returned dictionary is forwarded to the CryptContext onload=lambda **kwargs: dict( schemes=flask.current_app.config['PASSWORD_SCHEMES'], **kwargs ), ), unique=False, nullable=False, ) """ impl = types.VARBINARY(1024) python_type = Password def __init__(self, max_length=None, **kwargs): # Fail if passlib is not found. if passlib is None: raise ImproperlyConfigured( "'passlib' is required to use 'PasswordType'" ) # Construct the passlib crypt context. self.context = LazyCryptContext(**kwargs) self._max_length = max_length @property def hashing_method(self): return ( 'hash' if hasattr(self.context, 'hash') else 'encrypt' ) @property def length(self): """Get column length.""" if self._max_length is None: self._max_length = self.calculate_max_length() return self._max_length def calculate_max_length(self): # Calculate the largest possible encoded password. # name + rounds + salt + hash + ($ * 4) of largest hash max_lengths = [1024] for name in self.context.schemes(): scheme = getattr(__import__('passlib.hash').hash, name) length = 4 + len(scheme.name) length += len(str(getattr(scheme, 'max_rounds', ''))) length += (getattr(scheme, 'max_salt_size', 0) or 0) length += getattr( scheme, 'encoded_checksum_size', scheme.checksum_size ) max_lengths.append(length) # Return the maximum calculated max length. return max(max_lengths) def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': # Use a BYTEA type for postgresql. impl = postgresql.BYTEA(self.length) elif dialect.name == 'oracle': # Use a RAW type for oracle. impl = oracle.RAW(self.length) elif dialect.name == 'sqlite': # Use a BLOB type for sqlite impl = sqlite.BLOB(self.length) else: # Use a VARBINARY for all other dialects. impl = types.VARBINARY(self.length) return dialect.type_descriptor(impl) def process_bind_param(self, value, dialect): if isinstance(value, Password): # If were given a password secret; hash it. if value.secret is not None: return self._hash(value.secret).encode('utf8') # Value has already been hashed. return value.hash if isinstance(value, six.string_types): # Assume value has not been hashed. return self._hash(value).encode('utf8') def process_result_value(self, value, dialect): if value is not None: return Password(value, self.context) def _hash(self, value): return getattr(self.context, self.hashing_method)(value) def _coerce(self, value): if value is None: return if not isinstance(value, Password): # Hash the password using the default scheme. value = self._hash(value).encode('utf8') return Password(value, context=self.context) else: # If were given a password object; ensure the context is right. value.context = weakref.proxy(self.context) # If were given a password secret; hash it. if value.secret is not None: value.hash = self._hash(value.secret).encode('utf8') value.secret = None return value @property def python_type(self): return self.impl.type.python_type Password.associate_with(PasswordType) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/pg_composite.py000066400000000000000000000235241360007755400247060ustar00rootroot00000000000000""" CompositeType provides means to interact with `PostgreSQL composite types`_. Currently this type features: * Easy attribute access to composite type fields * Supports SQLAlchemy TypeDecorator types * Ability to include composite types as part of PostgreSQL arrays * Type creation and dropping Installation ^^^^^^^^^^^^ CompositeType automatically attaches `before_create` and `after_drop` DDL listeners. These listeners create and drop the composite type in the database. This means it works out of the box in your test environment where you create the tables on each test run. When you already have your database set up you should call :func:`register_composites` after you've set up all models. :: register_composites(conn) Usage ^^^^^ :: from collections import OrderedDict import sqlalchemy as sa from sqlalchemy_utils import CompositeType, CurrencyType class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column( CompositeType( 'money_type', [ sa.Column('currency', CurrencyType), sa.Column('amount', sa.Integer) ] ) ) Accessing fields ^^^^^^^^^^^^^^^^ CompositeType provides attribute access to underlying fields. In the following example we find all accounts with balance amount more than 5000. :: session.query(Account).filter(Account.balance.amount > 5000) Arrays of composites ^^^^^^^^^^^^^^^^^^^^ :: from sqlalchemy_utils import CompositeArray class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balances = sa.Column( CompositeArray( CompositeType( 'money_type', [ sa.Column('currency', CurrencyType), sa.Column('amount', sa.Integer) ] ) ) ) .. _PostgreSQL composite types: http://www.postgresql.org/docs/current/static/rowtypes.html Related links: http://schinckel.net/2014/09/24/using-postgres-composite-types-in-django/ """ from collections import namedtuple import six import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.ext.compiler import compiles from sqlalchemy.schema import _CreateDropBase from sqlalchemy.sql.expression import FunctionElement from sqlalchemy.types import ( SchemaType, to_instance, TypeDecorator, UserDefinedType ) from .. import ImproperlyConfigured psycopg2 = None CompositeCaster = None adapt = None AsIs = None register_adapter = None try: import psycopg2 from psycopg2.extras import CompositeCaster from psycopg2.extensions import adapt, AsIs, register_adapter except ImportError: pass class CompositeElement(FunctionElement): """ Instances of this class wrap a Postgres composite type. """ def __init__(self, base, field, type_): self.name = field self.type = to_instance(type_) super(CompositeElement, self).__init__(base) @compiles(CompositeElement) def _compile_pgelem(expr, compiler, **kw): return '(%s).%s' % (compiler.process(expr.clauses, **kw), expr.name) class CompositeArray(ARRAY): def _proc_array(self, arr, itemproc, dim, collection): if dim is None: if isinstance(self.item_type, CompositeType): arr = [itemproc(a) for a in arr] return arr return ARRAY._proc_array(self, arr, itemproc, dim, collection) # TODO: Make the registration work on connection level instead of global level registered_composites = {} class CompositeType(UserDefinedType, SchemaType): """ Represents a PostgreSQL composite type. :param name: Name of the composite type. :param columns: List of columns that this composite type consists of """ python_type = tuple class comparator_factory(UserDefinedType.Comparator): def __getattr__(self, key): try: type_ = self.type.typemap[key] except KeyError: raise KeyError( "Type '%s' doesn't have an attribute: '%s'" % ( self.name, key ) ) return CompositeElement(self.expr, key, type_) def __init__(self, name, columns): if psycopg2 is None: raise ImproperlyConfigured( "'psycopg2' package is required in order to use CompositeType." ) SchemaType.__init__(self) self.name = name self.columns = columns if name in registered_composites: self.type_cls = registered_composites[name].type_cls else: self.type_cls = namedtuple( self.name, [c.name for c in columns] ) registered_composites[name] = self class Caster(CompositeCaster): def make(obj, values): return self.type_cls(*values) self.caster = Caster attach_composite_listeners() def get_col_spec(self): return self.name def bind_processor(self, dialect): def process(value): if value is None: return None processed_value = [] for i, column in enumerate(self.columns): if isinstance(column.type, TypeDecorator): processed_value.append( column.type.process_bind_param( value[i], dialect ) ) else: processed_value.append(value[i]) return self.type_cls(*processed_value) return process def result_processor(self, dialect, coltype): def process(value): if value is None: return None cls = value.__class__ kwargs = {} for column in self.columns: if isinstance(column.type, TypeDecorator): kwargs[column.name] = column.type.process_result_value( getattr(value, column.name), dialect ) else: kwargs[column.name] = getattr(value, column.name) return cls(**kwargs) return process def create(self, bind=None, checkfirst=None): if ( not checkfirst or not bind.dialect.has_type(bind, self.name, schema=self.schema) ): bind.execute(CreateCompositeType(self)) def drop(self, bind=None, checkfirst=True): if ( checkfirst and bind.dialect.has_type(bind, self.name, schema=self.schema) ): bind.execute(DropCompositeType(self)) def register_psycopg2_composite(dbapi_connection, composite): psycopg2.extras.register_composite( composite.name, dbapi_connection, globally=True, factory=composite.caster ) def adapt_composite(value): adapted = [ adapt( getattr(value, column.name) if not isinstance(column.type, TypeDecorator) else column.type.process_bind_param( getattr(value, column.name), PGDialect_psycopg2() ) ) for column in composite.columns ] for value in adapted: if hasattr(value, 'prepare'): value.prepare(dbapi_connection) values = [ value.getquoted().decode(dbapi_connection.encoding) if six.PY3 else value.getquoted() for value in adapted ] return AsIs("(%s)::%s" % (', '.join(values), composite.name)) register_adapter(composite.type_cls, adapt_composite) def before_create(target, connection, **kw): for name, composite in registered_composites.items(): composite.create(connection, checkfirst=True) register_psycopg2_composite( connection.connection.connection, composite ) def after_drop(target, connection, **kw): for name, composite in registered_composites.items(): composite.drop(connection, checkfirst=True) def register_composites(connection): for name, composite in registered_composites.items(): register_psycopg2_composite( connection.connection.connection, composite ) def attach_composite_listeners(): listeners = [ (sa.MetaData, 'before_create', before_create), (sa.MetaData, 'after_drop', after_drop), ] for listener in listeners: if not sa.event.contains(*listener): sa.event.listen(*listener) def remove_composite_listeners(): listeners = [ (sa.MetaData, 'before_create', before_create), (sa.MetaData, 'after_drop', after_drop), ] for listener in listeners: if sa.event.contains(*listener): sa.event.remove(*listener) class CreateCompositeType(_CreateDropBase): pass @compiles(CreateCompositeType) def _visit_create_composite_type(create, compiler, **kw): type_ = create.element fields = ', '.join( '{name} {type}'.format( name=column.name, type=compiler.dialect.type_compiler.process( to_instance(column.type) ) ) for column in type_.columns ) return 'CREATE TYPE {name} AS ({fields})'.format( name=compiler.preparer.format_type(type_), fields=fields ) class DropCompositeType(_CreateDropBase): pass @compiles(DropCompositeType) def _visit_drop_composite_type(drop, compiler, **kw): type_ = drop.element return 'DROP TYPE {name}'.format(name=compiler.preparer.format_type(type_)) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/phone_number.py000066400000000000000000000147311360007755400246770ustar00rootroot00000000000000import six from sqlalchemy import exc, types from ..exceptions import ImproperlyConfigured from ..utils import str_coercible from .scalar_coercible import ScalarCoercible try: import phonenumbers from phonenumbers.phonenumber import PhoneNumber as BasePhoneNumber from phonenumbers.phonenumberutil import NumberParseException except ImportError: phonenumbers = None BasePhoneNumber = object NumberParseException = Exception class PhoneNumberParseException(NumberParseException, exc.DontWrapMixin): ''' Wraps exceptions from phonenumbers with SQLAlchemy's DontWrapMixin so we get more meaningful exceptions on validation failure instead of the StatementException Clients can catch this as either a PhoneNumberParseException or NumberParseException from the phonenumbers library. ''' pass @str_coercible class PhoneNumber(BasePhoneNumber): """ Extends a PhoneNumber class from `Python phonenumbers library`_. Adds different phone number formats to attributes, so they can be easily used in templates. Phone number validation method is also implemented. Takes the raw phone number and country code as params and parses them into a PhoneNumber object. .. _Python phonenumbers library: https://github.com/daviddrysdale/python-phonenumbers :: from sqlalchemy_utils import PhoneNumber class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) _phone_number = sa.Column(sa.Unicode(20)) country_code = sa.Column(sa.Unicode(8)) phonenumber = sa.orm.composite( PhoneNumber, _phone_number, country_code ) user = User(phone_number=PhoneNumber('0401234567', 'FI')) user.phone_number.e164 # u'+358401234567' user.phone_number.international # u'+358 40 1234567' user.phone_number.national # u'040 1234567' user.country_code # 'FI' :param raw_number: String representation of the phone number. :param region: Region of the phone number. :param check_region: Whether to check the supplied region parameter; should always be True for external callers. Can be useful for short codes or toll free """ def __init__(self, raw_number, region=None, check_region=True): # Bail if phonenumbers is not found. if phonenumbers is None: raise ImproperlyConfigured( "'phonenumbers' is required to use 'PhoneNumber'" ) try: self._phone_number = phonenumbers.parse( raw_number, region, _check_region=check_region ) except NumberParseException as e: # Wrap exception so SQLAlchemy doesn't swallow it as a # StatementError # # Worth noting that if -1 shows up as the error_type # it's likely because the API has changed upstream and these # bindings need to be updated. raise PhoneNumberParseException( getattr(e, 'error_type', -1), six.text_type(e) ) super(PhoneNumber, self).__init__( country_code=self._phone_number.country_code, national_number=self._phone_number.national_number, extension=self._phone_number.extension, italian_leading_zero=self._phone_number.italian_leading_zero, raw_input=self._phone_number.raw_input, country_code_source=self._phone_number.country_code_source, preferred_domestic_carrier_code=( self._phone_number.preferred_domestic_carrier_code ) ) self.region = region self.national = phonenumbers.format_number( self._phone_number, phonenumbers.PhoneNumberFormat.NATIONAL ) self.international = phonenumbers.format_number( self._phone_number, phonenumbers.PhoneNumberFormat.INTERNATIONAL ) self.e164 = phonenumbers.format_number( self._phone_number, phonenumbers.PhoneNumberFormat.E164 ) def __composite_values__(self): return self.national, self.region def is_valid_number(self): return phonenumbers.is_valid_number(self._phone_number) def __unicode__(self): return self.national class PhoneNumberType(types.TypeDecorator, ScalarCoercible): """ Changes PhoneNumber objects to a string representation on the way in and changes them back to PhoneNumber objects on the way out. If E164 is used as storing format, no country code is needed for parsing the database value to PhoneNumber object. :: class User(self.Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) phone_number = sa.Column(PhoneNumberType()) user = User(phone_number='+358401234567') user.phone_number.e164 # u'+358401234567' user.phone_number.international # u'+358 40 1234567' user.phone_number.national # u'040 1234567' """ STORE_FORMAT = 'e164' impl = types.Unicode(20) python_type = PhoneNumber def __init__(self, region='US', max_length=20, *args, **kwargs): # Bail if phonenumbers is not found. if phonenumbers is None: raise ImproperlyConfigured( "'phonenumbers' is required to use 'PhoneNumberType'" ) super(PhoneNumberType, self).__init__(*args, **kwargs) self.region = region self.impl = types.Unicode(max_length) def process_bind_param(self, value, dialect): if value: if not isinstance(value, PhoneNumber): value = PhoneNumber(value, region=self.region) if self.STORE_FORMAT == 'e164' and value.extension: return '%s;ext=%s' % (value.e164, value.extension) return getattr(value, self.STORE_FORMAT) return value def process_result_value(self, value, dialect): if value: return PhoneNumber(value, self.region) return value def _coerce(self, value): if value and not isinstance(value, PhoneNumber): value = PhoneNumber(value, region=self.region) return value or None sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/range.py000066400000000000000000000276571360007755400233250ustar00rootroot00000000000000""" SQLAlchemy-Utils provides wide variety of range data types. All range data types return Interval objects of intervals_ package. In order to use range data types you need to install intervals_ with: :: pip install intervals Intervals package provides good chunk of additional interval operators that for example psycopg2 range objects do not support. Some good reading for practical interval implementations: http://wiki.postgresql.org/images/f/f0/Range-types.pdf Range type initialization ------------------------- :: from sqlalchemy_utils import IntRangeType class Event(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) estimated_number_of_persons = sa.Column(IntRangeType) You can also set a step parameter for range type. The values that are not multipliers of given step will be rounded up to nearest step multiplier. :: from sqlalchemy_utils import IntRangeType class Event(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) estimated_number_of_persons = sa.Column(IntRangeType(step=1000)) event = Event(estimated_number_of_persons=[100, 1200]) event.estimated_number_of_persons.lower # 0 event.estimated_number_of_persons.upper # 1000 Range type operators -------------------- SQLAlchemy-Utils supports many range type operators. These operators follow the `intervals` package interval coercion rules. So for example when we make a query such as: :: session.query(Car).filter(Car.price_range == 300) It is essentially the same as: :: session.query(Car).filter(Car.price_range == DecimalInterval([300, 300])) Comparison operators ^^^^^^^^^^^^^^^^^^^^ All range types support all comparison operators (>, >=, ==, !=, <=, <). :: Car.price_range < [12, 300] Car.price_range == [12, 300] Car.price_range < 300 Car.price_range > (300, 500) # Whether or not range is strictly left of another range Car.price_range << [300, 500] # Whether or not range is strictly right of another range Car.price_range >> [300, 500] Membership operators ^^^^^^^^^^^^^^^^^^^^ :: Car.price_range.contains([300, 500]) Car.price_range.contained_by([300, 500]) Car.price_range.in_([[300, 500], [800, 900]]) ~ Car.price_range.in_([[300, 400], [700, 800]]) Length ^^^^^^ SQLAlchemy-Utils provides length property for all range types. The implementation of this property varies on different range types. In the following example we find all cars whose price range's length is more than 500. :: session.query(Car).filter( Car.price_range.length > 500 ) .. _intervals: https://github.com/kvesteri/intervals """ try: from collections.abc import Iterable except ImportError: # For python 2.7 support from collections import Iterable from datetime import timedelta import six import sqlalchemy as sa from sqlalchemy import types from sqlalchemy.dialects.postgresql import ( DATERANGE, INT4RANGE, INT8RANGE, NUMRANGE, TSRANGE ) from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible intervals = None try: import intervals except ImportError: pass class RangeComparator(types.TypeEngine.Comparator): @classmethod def coerced_func(cls, func): def operation(self, other, **kwargs): other = self.coerce_arg(other) return getattr(types.TypeEngine.Comparator, func)( self, other, **kwargs ) return operation def coerce_arg(self, other): coerced_types = ( self.type.interval_class.type, tuple, list, ) + six.string_types if isinstance(other, coerced_types): return self.type.interval_class(other) return other def in_(self, other): if ( isinstance(other, Iterable) and not isinstance(other, six.string_types) ): other = map(self.coerce_arg, other) return super(RangeComparator, self).in_(other) def notin_(self, other): if ( isinstance(other, Iterable) and not isinstance(other, six.string_types) ): other = map(self.coerce_arg, other) return super(RangeComparator, self).notin_(other) def __rshift__(self, other, **kwargs): """ Returns whether or not given interval is strictly right of another interval. [a, b] >> [c, d] True, if a > d """ other = self.coerce_arg(other) return self.op('>>')(other) def __lshift__(self, other, **kwargs): """ Returns whether or not given interval is strictly left of another interval. [a, b] << [c, d] True, if b < c """ other = self.coerce_arg(other) return self.op('<<')(other) def contains(self, other, **kwargs): other = self.coerce_arg(other) return self.op('@>')(other) def contained_by(self, other, **kwargs): other = self.coerce_arg(other) return self.op('<@')(other) class DiscreteRangeComparator(RangeComparator): @property def length(self): return sa.func.upper(self.expr) - self.step - sa.func.lower(self.expr) class IntRangeComparator(DiscreteRangeComparator): step = 1 class DateRangeComparator(DiscreteRangeComparator): step = timedelta(days=1) class ContinuousRangeComparator(RangeComparator): @property def length(self): return sa.func.upper(self.expr) - sa.func.lower(self.expr) funcs = [ '__eq__', '__ne__', '__lt__', '__le__', '__gt__', '__ge__', ] for func in funcs: setattr( RangeComparator, func, RangeComparator.coerced_func(func) ) class RangeType(types.TypeDecorator, ScalarCoercible): comparator_factory = RangeComparator def __init__(self, *args, **kwargs): if intervals is None: raise ImproperlyConfigured( 'RangeType needs intervals package installed.' ) self.step = kwargs.pop('step', None) super(RangeType, self).__init__(*args, **kwargs) def load_dialect_impl(self, dialect): if dialect.name == 'postgresql': # Use the native range type for postgres. return dialect.type_descriptor(self.impl) else: # Other drivers don't have native types. return dialect.type_descriptor(sa.String(255)) def process_bind_param(self, value, dialect): if value is not None: return str(value) return value def process_result_value(self, value, dialect): if isinstance(value, six.string_types): factory_func = self.interval_class.from_string else: factory_func = self.interval_class if value is not None: if self.interval_class.step is not None: return self.canonicalize_result_value( factory_func(value, step=self.step) ) else: return factory_func(value, step=self.step) return value def canonicalize_result_value(self, value): return intervals.canonicalize(value, True, True) def _coerce(self, value): if value is None: return None return self.interval_class(value, step=self.step) class IntRangeType(RangeType): """ IntRangeType provides way for saving ranges of integers into database. On PostgreSQL this type maps to native INT4RANGE type while on other drivers this maps to simple string column. Example:: from sqlalchemy_utils import IntRangeType class Event(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) estimated_number_of_persons = sa.Column(IntRangeType) party = Event(name=u'party') # we estimate the party to contain minium of 10 persons and at max # 100 persons party.estimated_number_of_persons = [10, 100] print party.estimated_number_of_persons # '10-100' IntRangeType returns the values as IntInterval objects. These objects support many arithmetic operators:: meeting = Event(name=u'meeting') meeting.estimated_number_of_persons = [20, 40] total = ( meeting.estimated_number_of_persons + party.estimated_number_of_persons ) print total # '30-140' """ impl = INT4RANGE comparator_factory = IntRangeComparator def __init__(self, *args, **kwargs): super(IntRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.IntInterval class Int8RangeType(RangeType): """ Int8RangeType provides way for saving ranges of 8-byte integers into database. On PostgreSQL this type maps to native INT8RANGE type while on other drivers this maps to simple string column. Example:: from sqlalchemy_utils import IntRangeType class Event(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255)) estimated_number_of_persons = sa.Column(Int8RangeType) party = Event(name=u'party') # we estimate the party to contain minium of 10 persons and at max # 100 persons party.estimated_number_of_persons = [10, 100] print party.estimated_number_of_persons # '10-100' Int8RangeType returns the values as IntInterval objects. These objects support many arithmetic operators:: meeting = Event(name=u'meeting') meeting.estimated_number_of_persons = [20, 40] total = ( meeting.estimated_number_of_persons + party.estimated_number_of_persons ) print total # '30-140' """ impl = INT8RANGE comparator_factory = IntRangeComparator def __init__(self, *args, **kwargs): super(Int8RangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.IntInterval class DateRangeType(RangeType): """ DateRangeType provides way for saving ranges of dates into database. On PostgreSQL this type maps to native DATERANGE type while on other drivers this maps to simple string column. Example:: from sqlalchemy_utils import DateRangeType class Reservation(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) room_id = sa.Column(sa.Integer)) during = sa.Column(DateRangeType) """ impl = DATERANGE comparator_factory = DateRangeComparator def __init__(self, *args, **kwargs): super(DateRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.DateInterval class NumericRangeType(RangeType): """ NumericRangeType provides way for saving ranges of decimals into database. On PostgreSQL this type maps to native NUMRANGE type while on other drivers this maps to simple string column. Example:: from sqlalchemy_utils import NumericRangeType class Car(Base): __tablename__ = 'car' id = sa.Column(sa.Integer, autoincrement=True) name = sa.Column(sa.Unicode(255))) price_range = sa.Column(NumericRangeType) """ impl = NUMRANGE comparator_factory = ContinuousRangeComparator def __init__(self, *args, **kwargs): super(NumericRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.DecimalInterval class DateTimeRangeType(RangeType): impl = TSRANGE comparator_factory = ContinuousRangeComparator def __init__(self, *args, **kwargs): super(DateTimeRangeType, self).__init__(*args, **kwargs) self.interval_class = intervals.DateTimeInterval sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/scalar_coercible.py000066400000000000000000000003101360007755400254560ustar00rootroot00000000000000class ScalarCoercible(object): def _coerce(self, value): raise NotImplementedError def coercion_listener(self, target, value, oldvalue, initiator): return self._coerce(value) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/scalar_list.py000066400000000000000000000044151360007755400245140ustar00rootroot00000000000000import six import sqlalchemy as sa from sqlalchemy import types class ScalarListException(Exception): pass class ScalarListType(types.TypeDecorator): """ ScalarListType type provides convenient way for saving multiple scalar values in one column. ScalarListType works like list on python side and saves the result as comma-separated list in the database (custom separators can also be used). Example :: from sqlalchemy_utils import ScalarListType class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True) hobbies = sa.Column(ScalarListType()) user = User() user.hobbies = [u'football', u'ice_hockey'] session.commit() You can easily set up integer lists too: :: from sqlalchemy_utils import ScalarListType class Player(Base): __tablename__ = 'player' id = sa.Column(sa.Integer, autoincrement=True) points = sa.Column(ScalarListType(int)) player = Player() player.points = [11, 12, 8, 80] session.commit() """ impl = sa.UnicodeText() def __init__(self, coerce_func=six.text_type, separator=u','): self.separator = six.text_type(separator) self.coerce_func = coerce_func def process_bind_param(self, value, dialect): # Convert list of values to unicode separator-separated list # Example: [1, 2, 3, 4] -> u'1, 2, 3, 4' if value is not None: if any(self.separator in six.text_type(item) for item in value): raise ScalarListException( "List values can't contain string '%s' (its being used as " "separator. If you wish for scalar list values to contain " "these strings, use a different separator string.)" % self.separator ) return self.separator.join( map(six.text_type, value) ) def process_result_value(self, value, dialect): if value is not None: if value == u'': return [] # coerce each value return list(map( self.coerce_func, value.split(self.separator) )) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/timezone.py000066400000000000000000000051651360007755400240510ustar00rootroot00000000000000import six from sqlalchemy import types from ..exceptions import ImproperlyConfigured from .scalar_coercible import ScalarCoercible class TimezoneType(types.TypeDecorator, ScalarCoercible): """ TimezoneType provides a way for saving timezones (from either the pytz or the dateutil package) objects into database. TimezoneType saves timezone objects as strings on the way in and converts them back to objects when querying the database. :: from sqlalchemy_utils import TimezoneType class User(Base): __tablename__ = 'user' # Pass backend='pytz' to change it to use pytz (dateutil by # default) timezone = sa.Column(TimezoneType(backend='pytz')) """ impl = types.Unicode(50) python_type = None def __init__(self, backend='dateutil'): """ :param backend: Whether to use 'dateutil' or 'pytz' for timezones. """ self.backend = backend if backend == 'dateutil': try: from dateutil.tz import tzfile from dateutil.zoneinfo import get_zonefile_instance self.python_type = tzfile self._to = get_zonefile_instance().zones.get self._from = lambda x: six.text_type(x._filename) except ImportError: raise ImproperlyConfigured( "'python-dateutil' is required to use the " "'dateutil' backend for 'TimezoneType'" ) elif backend == 'pytz': try: from pytz import timezone from pytz.tzinfo import BaseTzInfo self.python_type = BaseTzInfo self._to = timezone self._from = six.text_type except ImportError: raise ImproperlyConfigured( "'pytz' is required to use the 'pytz' backend " "for 'TimezoneType'" ) else: raise ImproperlyConfigured( "'pytz' or 'dateutil' are the backends supported for " "'TimezoneType'" ) def _coerce(self, value): if value is not None and not isinstance(value, self.python_type): obj = self._to(value) if obj is None: raise ValueError("unknown time zone '%s'" % value) return obj return value def process_bind_param(self, value, dialect): return self._from(self._coerce(value)) if value else None def process_result_value(self, value, dialect): return self._to(value) if value else None sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/ts_vector.py000066400000000000000000000061111360007755400242170ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.dialects.postgresql import TSVECTOR class TSVectorType(sa.types.TypeDecorator): """ .. note:: This type is PostgreSQL specific and is not supported by other dialects. Provides additional functionality for SQLAlchemy PostgreSQL dialect's TSVECTOR_ type. This additional functionality includes: * Vector concatenation * regconfig constructor parameter which is applied to match function if no postgresql_regconfig parameter is given * Provides extensible base for extensions such as SQLAlchemy-Searchable_ .. _TSVECTOR: http://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search .. _SQLAlchemy-Searchable: https://www.github.com/kvesteri/sqlalchemy-searchable :: from sqlalchemy_utils import TSVectorType class Article(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(100)) search_vector = sa.Column(TSVectorType) # Find all articles whose name matches 'finland' session.query(Article).filter(Article.search_vector.match('finland')) TSVectorType also supports vector concatenation. :: class Article(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(100)) name_vector = sa.Column(TSVectorType) content = sa.Column(sa.String) content_vector = sa.Column(TSVectorType) # Find all articles whose name or content matches 'finland' session.query(Article).filter( (Article.name_vector | Article.content_vector).match('finland') ) You can configure TSVectorType to use a specific regconfig. :: class Article(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(100)) search_vector = sa.Column( TSVectorType(regconfig='pg_catalog.simple') ) Now expression such as:: Article.search_vector.match('finland') Would be equivalent to SQL:: search_vector @@ to_tsquery('pg_catalog.simgle', 'finland') """ impl = TSVECTOR class comparator_factory(TSVECTOR.Comparator): def match(self, other, **kwargs): if 'postgresql_regconfig' not in kwargs: if 'regconfig' in self.type.options: kwargs['postgresql_regconfig'] = ( self.type.options['regconfig'] ) return TSVECTOR.Comparator.match(self, other, **kwargs) def __or__(self, other): return self.op('||')(other) def __init__(self, *args, **kwargs): """ Initializes new TSVectorType :param *args: list of column names :param **kwargs: various other options for this TSVectorType """ self.columns = args self.options = kwargs super(TSVectorType, self).__init__() sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/url.py000066400000000000000000000030041360007755400230070ustar00rootroot00000000000000import six from sqlalchemy import types from .scalar_coercible import ScalarCoercible furl = None try: from furl import furl except ImportError: pass class URLType(types.TypeDecorator, ScalarCoercible): """ URLType stores furl_ objects into database. .. _furl: https://github.com/gruns/furl :: from sqlalchemy_utils import URLType from furl import furl class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) website = sa.Column(URLType) user = User(website=u'www.example.com') # website is coerced to furl object, hence all nice furl operations # come available user.website.args['some_argument'] = '12' print user.website # www.example.com?some_argument=12 """ impl = types.UnicodeText def process_bind_param(self, value, dialect): if furl is not None and isinstance(value, furl): return six.text_type(value) if isinstance(value, six.string_types): return value def process_result_value(self, value, dialect): if furl is None: return value if value is not None: return furl(value) def _coerce(self, value): if furl is None: return value if value is not None and not isinstance(value, furl): return furl(value) return value @property def python_type(self): return self.impl.type.python_type sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/uuid.py000066400000000000000000000050111360007755400231530ustar00rootroot00000000000000from __future__ import absolute_import import uuid from sqlalchemy import types from sqlalchemy.dialects import mssql, postgresql from .scalar_coercible import ScalarCoercible class UUIDType(types.TypeDecorator, ScalarCoercible): """ Stores a UUID in the database natively when it can and falls back to a BINARY(16) or a CHAR(32) when it can't. :: from sqlalchemy_utils import UUIDType import uuid class User(Base): __tablename__ = 'user' # Pass `binary=False` to fallback to CHAR instead of BINARY id = sa.Column(UUIDType(binary=False), primary_key=True) """ impl = types.BINARY(16) python_type = uuid.UUID def __init__(self, binary=True, native=True): """ :param binary: Whether to use a BINARY(16) or CHAR(32) fallback. """ self.binary = binary self.native = native def load_dialect_impl(self, dialect): if dialect.name == 'postgresql' and self.native: # Use the native UUID type. return dialect.type_descriptor(postgresql.UUID()) if dialect.name == 'mssql' and self.native: # Use the native UNIQUEIDENTIFIER type. return dialect.type_descriptor(mssql.UNIQUEIDENTIFIER()) else: # Fallback to either a BINARY or a CHAR. kind = self.impl if self.binary else types.CHAR(32) return dialect.type_descriptor(kind) @staticmethod def _coerce(value): if value and not isinstance(value, uuid.UUID): try: value = uuid.UUID(value) except (TypeError, ValueError): value = uuid.UUID(bytes=value) return value def process_bind_param(self, value, dialect): if value is None: return value if not isinstance(value, uuid.UUID): value = self._coerce(value) if self.native and dialect.name in ('postgresql', 'mssql'): return str(value) return value.bytes if self.binary else value.hex def process_result_value(self, value, dialect): if value is None: return value if self.native and dialect.name in ('postgresql', 'mssql'): if isinstance(value, uuid.UUID): # Some drivers convert PostgreSQL's uuid values to # Python's uuid.UUID objects by themselves return value return uuid.UUID(value) return uuid.UUID(bytes=value) if self.binary else uuid.UUID(value) sqlalchemy-utils-0.36.1/sqlalchemy_utils/types/weekdays.py000066400000000000000000000042121360007755400240230ustar00rootroot00000000000000import six from sqlalchemy import types from .. import i18n from ..exceptions import ImproperlyConfigured from ..primitives import WeekDay, WeekDays from .bit import BitType from .scalar_coercible import ScalarCoercible class WeekDaysType(types.TypeDecorator, ScalarCoercible): """ WeekDaysType offers way of saving WeekDays objects into database. The WeekDays objects are converted to bit strings on the way in and back to WeekDays objects on the way out. In order to use WeekDaysType you need to install Babel_ first. .. _Babel: http://babel.pocoo.org/ :: from sqlalchemy_utils import WeekDaysType, WeekDays from babel import Locale class Schedule(Base): __tablename__ = 'schedule' id = sa.Column(sa.Integer, autoincrement=True) working_days = sa.Column(WeekDaysType) schedule = Schedule() schedule.working_days = WeekDays('0001111') session.add(schedule) session.commit() print schedule.working_days # Thursday, Friday, Saturday, Sunday WeekDaysType also supports scalar coercion: :: schedule.working_days = '1110000' schedule.working_days # WeekDays object """ impl = BitType(WeekDay.NUM_WEEK_DAYS) def __init__(self, *args, **kwargs): if i18n.babel is None: raise ImproperlyConfigured( "'babel' package is required to use 'WeekDaysType'" ) super(WeekDaysType, self).__init__(*args, **kwargs) @property def comparator_factory(self): return self.impl.comparator_factory def process_bind_param(self, value, dialect): if isinstance(value, WeekDays): value = value.as_bit_string() if dialect.name == 'mysql': func = bytes if six.PY3 else bytearray return func(value, 'utf8') return value def process_result_value(self, value, dialect): if value is not None: return WeekDays(value) def _coerce(self, value): if value is not None and not isinstance(value, WeekDays): return WeekDays(value) return value sqlalchemy-utils-0.36.1/sqlalchemy_utils/utils.py000066400000000000000000000013401360007755400222020ustar00rootroot00000000000000import sys import six try: from collections.abc import Iterable except ImportError: # For python 2.7 support from collections import Iterable def str_coercible(cls): if sys.version_info[0] >= 3: # Python 3 def __str__(self): return self.__unicode__() else: # Python 2 def __str__(self): return self.__unicode__().encode('utf8') cls.__str__ = __str__ return cls def is_sequence(value): return ( isinstance(value, Iterable) and not isinstance(value, six.string_types) ) def starts_with(iterable, prefix): """ Returns whether or not given iterable starts with given prefix. """ return list(iterable)[0:len(prefix)] == list(prefix) sqlalchemy-utils-0.36.1/sqlalchemy_utils/view.py000066400000000000000000000130571360007755400220240ustar00rootroot00000000000000import sqlalchemy as sa from sqlalchemy.ext import compiler from sqlalchemy.schema import DDLElement, PrimaryKeyConstraint class CreateView(DDLElement): def __init__(self, name, selectable, materialized=False): self.name = name self.selectable = selectable self.materialized = materialized @compiler.compiles(CreateView) def compile_create_materialized_view(element, compiler, **kw): return 'CREATE {}VIEW {} AS {}'.format( 'MATERIALIZED ' if element.materialized else '', element.name, compiler.sql_compiler.process(element.selectable, literal_binds=True), ) class DropView(DDLElement): def __init__(self, name, materialized=False, cascade=True): self.name = name self.materialized = materialized self.cascade = cascade @compiler.compiles(DropView) def compile_drop_materialized_view(element, compiler, **kw): return 'DROP {}VIEW IF EXISTS {} {}'.format( 'MATERIALIZED ' if element.materialized else '', element.name, 'CASCADE' if element.cascade else '' ) def create_table_from_selectable( name, selectable, indexes=None, metadata=None, aliases=None ): if indexes is None: indexes = [] if metadata is None: metadata = sa.MetaData() if aliases is None: aliases = {} args = [ sa.Column( c.name, c.type, key=aliases.get(c.name, c.name), primary_key=c.primary_key ) for c in selectable.c ] + indexes table = sa.Table(name, metadata, *args) if not any([c.primary_key for c in selectable.c]): table.append_constraint( PrimaryKeyConstraint(*[c.name for c in selectable.c]) ) return table def create_materialized_view( name, selectable, metadata, indexes=None, aliases=None ): """ Create a view on a given metadata :param name: The name of the view to create. :param selectable: An SQLAlchemy selectable e.g. a select() statement. :param metadata: An SQLAlchemy Metadata instance that stores the features of the database being described. :param indexes: An optional list of SQLAlchemy Index instances. :param aliases: An optional dictionary containing with keys as column names and values as column aliases. Same as for ``create_view`` except that a ``CREATE MATERIALIZED VIEW`` statement is emitted instead of a ``CREATE VIEW``. """ table = create_table_from_selectable( name=name, selectable=selectable, indexes=indexes, metadata=None, aliases=aliases ) sa.event.listen( metadata, 'after_create', CreateView(name, selectable, materialized=True) ) @sa.event.listens_for(metadata, 'after_create') def create_indexes(target, connection, **kw): for idx in table.indexes: idx.create(connection) sa.event.listen( metadata, 'before_drop', DropView(name, materialized=True) ) return table def create_view( name, selectable, metadata, cascade_on_drop=True ): """ Create a view on a given metadata :param name: The name of the view to create. :param selectable: An SQLAlchemy selectable e.g. a select() statement. :param metadata: An SQLAlchemy Metadata instance that stores the features of the database being described. The process for creating a view is similar to the standard way that a table is constructed, except that a selectable is provided instead of a set of columns. The view is created once a ``CREATE`` statement is executed against the supplied metadata (e.g. ``metadata.create_all(..)``), and dropped when a ``DROP`` is executed against the metadata. To create a view that performs basic filtering on a table. :: metadata = MetaData() users = Table('users', metadata, Column('id', Integer, primary_key=True), Column('name', String), Column('fullname', String), Column('premium_user', Boolean, default=False), ) premium_members = select([users]).where(users.c.premium_user == True) create_view('premium_users', premium_members, metadata) metadata.create_all(engine) # View is created at this point """ table = create_table_from_selectable( name=name, selectable=selectable, metadata=None ) sa.event.listen(metadata, 'after_create', CreateView(name, selectable)) @sa.event.listens_for(metadata, 'after_create') def create_indexes(target, connection, **kw): for idx in table.indexes: idx.create(connection) sa.event.listen( metadata, 'before_drop', DropView(name, cascade=cascade_on_drop) ) return table def refresh_materialized_view(session, name, concurrently=False): """ Refreshes an already existing materialized view :param session: An SQLAlchemy Session instance. :param name: The name of the materialized view to refresh. :param concurrently: Optional flag that causes the ``CONCURRENTLY`` parameter to be specified when the materialized view is refreshed. """ # Since session.execute() bypasses autoflush, we must manually flush in # order to include newly-created/modified objects in the refresh. session.flush() session.execute( 'REFRESH MATERIALIZED VIEW {}{}'.format( 'CONCURRENTLY ' if concurrently else '', name ) ) sqlalchemy-utils-0.36.1/tests/000077500000000000000000000000001360007755400162525ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/__init__.py000066400000000000000000000001631360007755400203630ustar00rootroot00000000000000def assert_contains(clause, query): # Test that query executes query.all() assert clause in str(query) sqlalchemy-utils-0.36.1/tests/aggregate/000077500000000000000000000000001360007755400202005ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/aggregate/__init__.py000066400000000000000000000000001360007755400222770ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/aggregate/test_backrefs.py000066400000000000000000000041621360007755400233740ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Thread(Base): class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('comments', sa.Column(sa.Integer, default=0)) def comment_count(self): return sa.func.count('1') return Thread @pytest.fixture def Comment(Base, Thread): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.Unicode(255)) thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) thread = sa.orm.relationship(Thread, backref='comments') return Comment @pytest.fixture def init_models(Thread, Comment): pass class TestAggregateValueGenerationWithBackrefs(object): def test_assigns_aggregates_on_insert(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_separate_insert( self, session, Thread, Comment ): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_delete(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.delete(comment) session.commit() session.refresh(thread) assert thread.comment_count == 0 sqlalchemy-utils-0.36.1/tests/aggregate/test_custom_select_expressions.py000066400000000000000000000040041360007755400271220ustar00rootroot00000000000000from decimal import Decimal import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) return Product @pytest.fixture def Catalog(Base, Product): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('products', sa.Column(sa.Numeric, default=0)) def net_worth(self): return sa.func.sum(Product.price) products = sa.orm.relationship('Product', backref='catalog') return Catalog @pytest.fixture def init_models(Product, Catalog): pass @pytest.mark.usefixtures('postgresql_dsn') class TestLazyEvaluatedSelectExpressionsForAggregates(object): def test_assigns_aggregates_on_insert(self, session, Product, Catalog): catalog = Catalog( name=u'Some catalog' ) session.add(catalog) session.commit() product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) session.add(product) session.commit() session.refresh(catalog) assert catalog.net_worth == Decimal('1000') def test_assigns_aggregates_on_update(self, session, Product, Catalog): catalog = Catalog( name=u'Some catalog' ) session.add(catalog) session.commit() product = Product( name=u'Some product', price=Decimal('1000'), catalog=catalog ) session.add(product) session.commit() product.price = Decimal('500') session.commit() session.refresh(catalog) assert catalog.net_worth == Decimal('500') sqlalchemy-utils-0.36.1/tests/aggregate/test_join_table_inheritance.py000066400000000000000000000073521360007755400262770ustar00rootroot00000000000000from decimal import Decimal import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) type = sa.Column(sa.String(255)) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'simple' } catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) return Product @pytest.fixture(params=['simple', 'child']) def AnyProduct(Product, request): if request.param == 'simple': return Product class ChildProduct(Product): __tablename__ = 'child_product' id = sa.Column( sa.Integer, sa.ForeignKey(Product.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 'child', } return ChildProduct @pytest.fixture def Catalog(Base, Product): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type } @aggregated('products', sa.Column(sa.Numeric, default=0)) def net_worth(self): return sa.func.sum(Product.price) products = sa.orm.relationship('Product', backref='catalog') return Catalog @pytest.fixture def CostumeCatalog(Catalog): class CostumeCatalog(Catalog): __tablename__ = 'costume_catalog' id = sa.Column( sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 'costumes', } return CostumeCatalog @pytest.fixture def CarCatalog(Catalog): class CarCatalog(Catalog): __tablename__ = 'car_catalog' id = sa.Column( sa.Integer, sa.ForeignKey(Catalog.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 'cars', } return CarCatalog @pytest.fixture def init_models(AnyProduct, Catalog, CostumeCatalog, CarCatalog): pass @pytest.mark.usefixtures('postgresql_dsn') class TestLazyEvaluatedSelectExpressionsForAggregates(object): def test_columns_inherited_from_parent( self, Catalog, CarCatalog, CostumeCatalog ): assert CarCatalog.net_worth assert CostumeCatalog.net_worth assert Catalog.net_worth assert not hasattr(CarCatalog.__table__.c, 'net_worth') assert not hasattr(CostumeCatalog.__table__.c, 'net_worth') def test_assigns_aggregates_on_insert(self, session, AnyProduct, Catalog): catalog = Catalog( name=u'Some catalog' ) session.add(catalog) session.commit() product = AnyProduct( name=u'Some product', price=Decimal('1000'), catalog=catalog ) session.add(product) session.commit() session.refresh(catalog) assert catalog.net_worth == Decimal('1000') def test_assigns_aggregates_on_update(self, session, Catalog, AnyProduct): catalog = Catalog( name=u'Some catalog' ) session.add(catalog) session.commit() product = AnyProduct( name=u'Some product', price=Decimal('1000'), catalog=catalog ) session.add(product) session.commit() product.price = Decimal('500') session.commit() session.refresh(catalog) assert catalog.net_worth == Decimal('500') sqlalchemy-utils-0.36.1/tests/aggregate/test_m2m.py000066400000000000000000000037211360007755400223070ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def User(Base): user_group = sa.Table( 'user_group', Base.metadata, sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) ) class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('groups', sa.Column(sa.Integer, default=0)) def group_count(self): return sa.func.count('1') groups = sa.orm.relationship( 'Group', backref='users', secondary=user_group ) return User @pytest.fixture def Group(Base): class Group(Base): __tablename__ = 'group' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Group @pytest.fixture def init_models(User, Group): pass @pytest.mark.usefixtures('postgresql_dsn') class TestAggregatesWithManyToManyRelationships(object): def test_assigns_aggregates_on_insert(self, session, User, Group): user = User( name=u'John Matrix' ) session.add(user) session.commit() group = Group( name=u'Some group', users=[user] ) session.add(group) session.commit() session.refresh(user) assert user.group_count == 1 def test_updates_aggregates_on_delete(self, session, User, Group): user = User( name=u'John Matrix' ) session.add(user) session.commit() group = Group( name=u'Some group', users=[user] ) session.add(group) session.commit() session.refresh(user) user.groups = [] session.commit() session.refresh(user) assert user.group_count == 0 sqlalchemy-utils-0.36.1/tests/aggregate/test_m2m_m2m.py000066400000000000000000000046751360007755400230730ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Category @pytest.fixture def Catalog(Base, Category): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated( 'products.categories', sa.Column(sa.Integer, default=0) ) def category_count(self): return sa.func.count(sa.distinct(Category.id)) return Catalog @pytest.fixture def Product(Base, Catalog, Category): catalog_products = sa.Table( 'catalog_product', Base.metadata, sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) ) product_categories = sa.Table( 'category_product', Base.metadata, sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) ) class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) catalog_id = sa.Column( sa.Integer, sa.ForeignKey('catalog.id') ) catalogs = sa.orm.relationship( Catalog, backref='products', secondary=catalog_products ) categories = sa.orm.relationship( Category, backref='products', secondary=product_categories ) return Product @pytest.fixture def init_models(Category, Catalog, Product): pass @pytest.mark.usefixtures('postgresql_dsn') class TestAggregateManyToManyAndManyToMany(object): def test_insert(self, session, Product, Category, Catalog): category = Category() products = [ Product(categories=[category]), Product(categories=[category]) ] catalog = Catalog(products=products) session.add(catalog) catalog2 = Catalog(products=products) session.add(catalog) session.commit() assert catalog.category_count == 1 assert catalog2.category_count == 1 sqlalchemy-utils-0.36.1/tests/aggregate/test_multiple_aggregates_per_class.py000066400000000000000000000052361360007755400276760ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Comment(Base): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.Unicode(255)) thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) return Comment @pytest.fixture def Thread(Base, Comment): class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated( 'comments', sa.Column(sa.Integer, default=0) ) def comment_count(self): return sa.func.count('1') @aggregated('comments', sa.Column(sa.Integer)) def last_comment_id(self): return sa.func.max(Comment.id) comments = sa.orm.relationship( 'Comment', backref='thread' ) Thread.last_comment = sa.orm.relationship( 'Comment', primaryjoin='Thread.last_comment_id == Comment.id', foreign_keys=[Thread.last_comment_id], viewonly=True ) return Thread @pytest.fixture def init_models(Comment, Thread): pass class TestAggregateValueGenerationForSimpleModelPaths(object): def test_assigns_aggregates_on_insert(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 assert thread.last_comment_id == comment.id def test_assigns_aggregates_on_separate_insert( self, session, Thread, Comment ): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 assert thread.last_comment_id == 1 def test_assigns_aggregates_on_delete(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.delete(comment) session.commit() session.refresh(thread) assert thread.comment_count == 0 assert thread.last_comment_id is None sqlalchemy-utils-0.36.1/tests/aggregate/test_o2m_m2m.py000066400000000000000000000044361360007755400230700ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Category @pytest.fixture def Catalog(Base, Category): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated( 'products.categories', sa.Column(sa.Integer, default=0) ) def category_count(self): return sa.func.count(sa.distinct(Category.id)) return Catalog @pytest.fixture def Product(Base, Catalog, Category): product_categories = sa.Table( 'category_product', Base.metadata, sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')), sa.Column('product_id', sa.Integer, sa.ForeignKey('product.id')) ) class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) catalog_id = sa.Column( sa.Integer, sa.ForeignKey('catalog.id') ) catalog = sa.orm.relationship( Catalog, backref='products' ) categories = sa.orm.relationship( Category, backref='products', secondary=product_categories ) return Product @pytest.fixture def init_models(Category, Catalog, Product): pass @pytest.mark.usefixtures('postgresql_dsn') class TestAggregateOneToManyAndManyToMany(object): def test_insert(self, session, Category, Catalog, Product): category = Category() products = [ Product(categories=[category]), Product(categories=[category]) ] catalog = Catalog(products=products) session.add(catalog) products2 = [ Product(categories=[category]), Product(categories=[category]) ] catalog2 = Catalog(products=products2) session.add(catalog) session.commit() assert catalog.category_count == 1 assert catalog2.category_count == 1 sqlalchemy-utils-0.36.1/tests/aggregate/test_o2m_o2m.py000066400000000000000000000037061360007755400230710ustar00rootroot00000000000000from decimal import Decimal import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Catalog(Base): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated( 'categories.products', sa.Column(sa.Integer, default=0) ) def product_count(self): return sa.func.count('1') categories = sa.orm.relationship('Category', backref='catalog') return Catalog @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) products = sa.orm.relationship('Product', backref='category') return Category @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) return Product @pytest.fixture def init_models(Catalog, Category, Product): pass @pytest.mark.usefixtures('postgresql_dsn') class TestAggregateOneToManyAndOneToMany(object): def test_assigns_aggregates(self, session, Category, Catalog, Product): category = Category(name=u'Some category') catalog = Catalog( categories=[category] ) catalog.name = u'Some catalog' session.add(catalog) session.commit() product = Product( name=u'Some product', price=Decimal('1000'), category=category ) session.add(product) session.commit() session.refresh(catalog) assert catalog.product_count == 1 sqlalchemy-utils-0.36.1/tests/aggregate/test_o2m_o2m_o2m.py000066400000000000000000000066121360007755400236450ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated @pytest.fixture def Catalog(Base): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) @aggregated( 'categories.sub_categories.products', sa.Column(sa.Integer, default=0) ) def product_count(self): return sa.func.count('1') categories = sa.orm.relationship('Category', backref='catalog') return Catalog @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) sub_categories = sa.orm.relationship( 'SubCategory', backref='category' ) return Category @pytest.fixture def SubCategory(Base): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) products = sa.orm.relationship('Product', backref='sub_category') return SubCategory @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( sa.Integer, sa.ForeignKey('sub_category.id') ) return Product @pytest.fixture def init_models(Catalog, Category, SubCategory, Product): pass @pytest.fixture def catalog_factory(Product, SubCategory, Category, Catalog, session): def catalog_factory(): product = Product() sub_category = SubCategory( products=[product] ) category = Category(sub_categories=[sub_category]) catalog = Catalog(categories=[category]) session.add(catalog) return catalog return catalog_factory @pytest.mark.usefixtures('postgresql_dsn') class Test3LevelDeepOneToMany(object): def test_assigns_aggregates(self, session, catalog_factory): catalog = catalog_factory() session.commit() session.refresh(catalog) assert catalog.product_count == 1 def catalog_factory( self, session, Product, SubCategory, Category, Catalog ): product = Product() sub_category = SubCategory( products=[product] ) category = Category(sub_categories=[sub_category]) catalog = Catalog(categories=[category]) session.add(catalog) return catalog def test_only_updates_affected_aggregates( self, session, catalog_factory, Product ): catalog = catalog_factory() catalog2 = catalog_factory() session.commit() # force set catalog2 product_count to zero in order to check if it gets # updated when the other catalog's product count gets updated session.execute( 'UPDATE catalog SET product_count = 0 WHERE id = %d' % catalog2.id ) catalog.categories[0].sub_categories[0].products.append( Product() ) session.commit() session.refresh(catalog) session.refresh(catalog2) assert catalog.product_count == 2 assert catalog2.product_count == 0 sqlalchemy-utils-0.36.1/tests/aggregate/test_search_vectors.py000066400000000000000000000032711360007755400246260ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import aggregated, TSVectorType def tsvector_reduce_concat(vectors): return sa.sql.expression.cast( sa.func.coalesce( sa.func.array_to_string(sa.func.array_agg(vectors), ' ') ), TSVectorType ) @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) price = sa.Column(sa.Numeric) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) return Product @pytest.fixture def Catalog(Base, Product): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('products', sa.Column(TSVectorType)) def product_search_vector(self): return tsvector_reduce_concat( sa.func.to_tsvector(Product.name) ) products = sa.orm.relationship('Product', backref='catalog') return Catalog @pytest.fixture def init_models(Product, Catalog): pass @pytest.mark.usefixtures('postgresql_dsn') class TestSearchVectorAggregates(object): def test_assigns_aggregates_on_insert(self, session, Product, Catalog): catalog = Catalog( name=u'Some catalog' ) session.add(catalog) session.commit() product = Product( name=u'Product XYZ', catalog=catalog ) session.add(product) session.commit() session.refresh(catalog) assert catalog.product_search_vector == "'product':1 'xyz':2" sqlalchemy-utils-0.36.1/tests/aggregate/test_simple_paths.py000066400000000000000000000041641360007755400243060ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Thread(Base): class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('comments', sa.Column(sa.Integer, default=0)) def comment_count(self): return sa.func.count('1') comments = sa.orm.relationship('Comment', backref='thread') return Thread @pytest.fixture def Comment(Base): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.Unicode(255)) thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) return Comment @pytest.fixture def init_models(Thread, Comment): pass class TestAggregateValueGenerationForSimpleModelPaths(object): def test_assigns_aggregates_on_insert(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_separate_insert( self, session, Thread, Comment ): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_delete(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) session.commit() comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.delete(comment) session.commit() session.refresh(thread) assert thread.comment_count == 0 sqlalchemy-utils-0.36.1/tests/aggregate/test_with_column_alias.py000066400000000000000000000035631360007755400253210ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Thread(Base): class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) @aggregated( 'comments', sa.Column('_comment_count', sa.Integer, default=0) ) def comment_count(self): return sa.func.count('1') comments = sa.orm.relationship('Comment', backref='thread') return Thread @pytest.fixture def Comment(Base): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) thread_id = sa.Column(sa.Integer, sa.ForeignKey('thread.id')) return Comment @pytest.fixture def init_models(Thread, Comment): pass class TestAggregatedWithColumnAlias(object): def test_assigns_aggregates_on_insert(self, session, Thread, Comment): thread = Thread() session.add(thread) comment = Comment(thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_separate_insert( self, session, Thread, Comment ): thread = Thread() session.add(thread) session.commit() comment = Comment(thread=thread) session.add(comment) session.commit() session.refresh(thread) assert thread.comment_count == 1 def test_assigns_aggregates_on_delete(self, session, Thread, Comment): thread = Thread() session.add(thread) session.commit() comment = Comment(thread=thread) session.add(comment) session.commit() session.delete(comment) session.commit() session.refresh(thread) assert thread.comment_count == 0 sqlalchemy-utils-0.36.1/tests/aggregate/test_with_ondelete_cascade.py000066400000000000000000000026321360007755400261110ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.aggregates import aggregated @pytest.fixture def Thread(Base): class Thread(Base): __tablename__ = 'thread' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @aggregated('comments', sa.Column(sa.Integer, default=0)) def comment_count(self): return sa.func.count('1') comments = sa.orm.relationship( 'Comment', passive_deletes=True, backref='thread' ) return Thread @pytest.fixture def Comment(Base): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.Unicode(255)) thread_id = sa.Column( sa.Integer, sa.ForeignKey('thread.id', ondelete='CASCADE') ) return Comment @pytest.fixture def init_models(Thread, Comment): pass @pytest.mark.usefixtures('postgresql_dsn') class TestAggregateValueGenerationWithCascadeDelete(object): def test_something(self, session, Thread, Comment): thread = Thread() thread.name = u'some article name' session.add(thread) comment = Comment(content=u'Some content', thread=thread) session.add(comment) session.commit() session.expire_all() session.delete(thread) session.commit() sqlalchemy-utils-0.36.1/tests/functions/000077500000000000000000000000001360007755400202625ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/functions/__init__.py000066400000000000000000000000001360007755400223610ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/functions/test_cast_if.py000066400000000000000000000024111360007755400233010ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import cast_if @pytest.fixture(scope='class') def base(): return declarative_base() @pytest.fixture(scope='class') def article_cls(base): class Article(base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) name_synonym = sa.orm.synonym('name') return Article class TestCastIf(object): def test_column(self, article_cls): expr = article_cls.__table__.c.name assert cast_if(expr, sa.String) is expr def test_column_property(self, article_cls): expr = article_cls.name.property assert cast_if(expr, sa.String) is expr def test_instrumented_attribute(self, article_cls): expr = article_cls.name assert cast_if(expr, sa.String) is expr def test_synonym(self, article_cls): expr = article_cls.name_synonym assert cast_if(expr, sa.String) is expr def test_scalar_selectable(self, article_cls): expr = sa.select([article_cls.id]).as_scalar() assert cast_if(expr, sa.Integer) is expr def test_scalar(self): assert cast_if('something', sa.String) == 'something' sqlalchemy-utils-0.36.1/tests/functions/test_database.py000066400000000000000000000077551360007755400234550ustar00rootroot00000000000000import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy_utils import create_database, database_exists, drop_database pymysql = None try: import pymysql # noqa except ImportError: pass class DatabaseTest(object): def test_create_and_drop(self, dsn): assert not database_exists(dsn) create_database(dsn) assert database_exists(dsn) drop_database(dsn) assert not database_exists(dsn) @pytest.mark.usefixtures('sqlite_memory_dsn') class TestDatabaseSQLiteMemory(object): def test_exists_memory(self, dsn): assert database_exists(dsn) @pytest.mark.usefixtures('sqlite_none_database_dsn') class TestDatabaseSQLiteMemoryNoDatabaseString(object): def test_exists_memory_none_database(self, sqlite_none_database_dsn): assert database_exists(sqlite_none_database_dsn) @pytest.mark.usefixtures('sqlite_file_dsn') class TestDatabaseSQLiteFile(DatabaseTest): def test_existing_non_sqlite_file(self, dsn): database = sa.engine.url.make_url(dsn).database open(database, 'w').close() self.test_create_and_drop(dsn) @pytest.mark.skipif('pymysql is None') @pytest.mark.usefixtures('mysql_dsn') class TestDatabaseMySQL(DatabaseTest): @pytest.fixture def db_name(self): return 'db_test_sqlalchemy_util' @pytest.mark.skipif('pymysql is None') @pytest.mark.usefixtures('mysql_dsn') class TestDatabaseMySQLWithQuotedName(DatabaseTest): @pytest.fixture def db_name(self): return 'db_test_sqlalchemy-util' @pytest.mark.usefixtures('postgresql_dsn') class TestDatabasePostgres(DatabaseTest): @pytest.fixture def db_name(self): return 'db_test_sqlalchemy_util' def test_template(self, postgresql_db_user): ( flexmock(sa.engine.Engine) .should_receive('execute') .with_args( "CREATE DATABASE db_test_sqlalchemy_util ENCODING 'utf8' " "TEMPLATE my_template" ) ) dsn = 'postgresql://{0}@localhost/db_test_sqlalchemy_util'.format( postgresql_db_user ) create_database(dsn, template='my_template') class TestDatabasePostgresPg8000(DatabaseTest): @pytest.fixture def dsn(self, postgresql_db_user): return 'postgresql+pg8000://{0}@localhost/{1}'.format( postgresql_db_user, 'db_to_test_create_and_drop_via_pg8000_driver' ) @pytest.mark.usefixtures('postgresql_dsn') class TestDatabasePostgresWithQuotedName(DatabaseTest): @pytest.fixture def db_name(self): return 'db_test_sqlalchemy-util' def test_template(self, postgresql_db_user): ( flexmock(sa.engine.Engine) .should_receive('execute') .with_args( '''CREATE DATABASE "db_test_sqlalchemy-util"''' " ENCODING 'utf8' " 'TEMPLATE "my-template"' ) ) dsn = 'postgresql://{0}@localhost/db_test_sqlalchemy-util'.format( postgresql_db_user ) create_database(dsn, template='my-template') class TestDatabasePostgresCreateDatabaseCloseConnection(object): def test_create_database_twice(self, postgresql_db_user): dsn_list = [ 'postgresql://{0}@localhost/db_test_sqlalchemy-util-a'.format( postgresql_db_user ), 'postgres://{0}@localhost/db_test_sqlalchemy-util-b'.format( postgresql_db_user ), ] for dsn_item in dsn_list: assert not database_exists(dsn_item) create_database(dsn_item, template="template1") assert database_exists(dsn_item) for dsn_item in dsn_list: drop_database(dsn_item) assert not database_exists(dsn_item) @pytest.mark.usefixtures('mssql_dsn') class TestDatabaseMssql(DatabaseTest): @pytest.fixture def db_name(self): pytest.importorskip('pyodbc') return 'db_test_sqlalchemy_util' sqlalchemy-utils-0.36.1/tests/functions/test_dependent_objects.py000066400000000000000000000244651360007755400253650ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import dependent_objects, get_referencing_foreign_keys class TestDependentObjects(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def Article(self, Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) owner_id = sa.Column( sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') ) author = sa.orm.relationship(User, foreign_keys=[author_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id]) return Article @pytest.fixture def BlogPost(self, Base, User): class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) owner_id = sa.Column( sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') ) owner = sa.orm.relationship(User) return BlogPost @pytest.fixture def init_models(self, User, Article, BlogPost): pass def test_returns_all_dependent_objects(self, session, User, Article): user = User(first_name=u'John') articles = [ Article(author=user), Article(), Article(owner=user), Article(author=user, owner=user) ] session.add_all(articles) session.commit() deps = list(dependent_objects(user)) assert len(deps) == 3 assert articles[0] in deps assert articles[2] in deps assert articles[3] in deps def test_with_foreign_keys_parameter( self, session, User, Article, BlogPost ): user = User(first_name=u'John') objects = [ Article(author=user), Article(), Article(owner=user), Article(author=user, owner=user), BlogPost(owner=user) ] session.add_all(objects) session.commit() deps = list( dependent_objects( user, ( fk for fk in get_referencing_foreign_keys(User) if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) ) assert len(deps) == 2 assert objects[0] in deps assert objects[3] in deps class TestDependentObjectsWithColumnAliases(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def Article(self, Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column( '_author_id', sa.Integer, sa.ForeignKey('user.id') ) owner_id = sa.Column( '_owner_id', sa.Integer, sa.ForeignKey('user.id', ondelete='SET NULL') ) author = sa.orm.relationship(User, foreign_keys=[author_id]) owner = sa.orm.relationship(User, foreign_keys=[owner_id]) return Article @pytest.fixture def BlogPost(self, Base, User): class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) owner_id = sa.Column( '_owner_id', sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE') ) owner = sa.orm.relationship(User) return BlogPost @pytest.fixture def init_models(self, User, Article, BlogPost): pass def test_returns_all_dependent_objects(self, session, User, Article): user = User(first_name=u'John') articles = [ Article(author=user), Article(), Article(owner=user), Article(author=user, owner=user) ] session.add_all(articles) session.commit() deps = list(dependent_objects(user)) assert len(deps) == 3 assert articles[0] in deps assert articles[2] in deps assert articles[3] in deps def test_with_foreign_keys_parameter( self, session, User, Article, BlogPost ): user = User(first_name=u'John') objects = [ Article(author=user), Article(), Article(owner=user), Article(author=user, owner=user), BlogPost(owner=user) ] session.add_all(objects) session.commit() deps = list( dependent_objects( user, ( fk for fk in get_referencing_foreign_keys(User) if fk.ondelete == 'RESTRICT' or fk.ondelete is None ) ).limit(5) ) assert len(deps) == 2 assert objects[0] in deps assert objects[3] in deps class TestDependentObjectsWithManyReferences(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def BlogPost(self, Base, User): class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) return BlogPost @pytest.fixture def Article(self, Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) return Article @pytest.fixture def init_models(self, User, BlogPost, Article): pass def test_with_many_dependencies(self, session, User, Article, BlogPost): user = User(first_name=u'John') objects = [ Article(author=user), BlogPost(author=user) ] session.add_all(objects) session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 class TestDependentObjectsWithCompositeKeys(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) return User @pytest.fixture def Article(self, Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), ) author = sa.orm.relationship(User) return Article @pytest.fixture def init_models(self, User, Article): pass def test_returns_all_dependent_objects(self, session, User, Article): user = User(first_name=u'John', last_name=u'Smith') articles = [ Article(author=user), Article(), Article(), Article(author=user) ] session.add_all(articles) session.commit() deps = list(dependent_objects(user)) assert len(deps) == 2 assert articles[0] in deps assert articles[3] in deps class TestDependentObjectsWithSingleTableInheritance(object): @pytest.fixture def Category(self, Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Category @pytest.fixture def TextItem(self, Base, Category): class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) category_id = sa.Column( sa.Integer, sa.ForeignKey(Category.id) ) category = sa.orm.relationship( Category, backref=sa.orm.backref( 'articles' ) ) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, } return TextItem @pytest.fixture def Article(self, TextItem): class Article(TextItem): __mapper_args__ = { 'polymorphic_identity': u'article' } return Article @pytest.fixture def BlogPost(self, TextItem): class BlogPost(TextItem): __mapper_args__ = { 'polymorphic_identity': u'blog_post' } return BlogPost @pytest.fixture def init_models(self, Category, TextItem, Article, BlogPost): pass def test_returns_all_dependent_objects(self, session, Category, Article): category1 = Category(name=u'Category #1') category2 = Category(name=u'Category #2') articles = [ Article(category=category1), Article(category=category1), Article(category=category2), Article(category=category2), ] session.add_all(articles) session.commit() deps = list(dependent_objects(category1)) assert len(deps) == 2 assert articles[0] in deps assert articles[1] in deps sqlalchemy-utils-0.36.1/tests/functions/test_escape_like.py000066400000000000000000000002351360007755400241370ustar00rootroot00000000000000from sqlalchemy_utils import escape_like class TestEscapeLike(object): def test_escapes_wildcards(self): assert escape_like('_*%') == '*_***%' sqlalchemy-utils-0.36.1/tests/functions/test_get_bind.py000066400000000000000000000010711360007755400234450ustar00rootroot00000000000000import pytest from sqlalchemy_utils import get_bind class TestGetBind(object): def test_with_session(self, session, connection): assert get_bind(session) == connection def test_with_connection(self, session, connection): assert get_bind(connection) == connection def test_with_model_object(self, session, connection, Article): article = Article() session.add(article) assert get_bind(article) == connection def test_with_unknown_type(self): with pytest.raises(TypeError): get_bind(None) sqlalchemy-utils-0.36.1/tests/functions/test_get_class_by_table.py000066400000000000000000000055261360007755400255100ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_class_by_table class TestGetClassByTableWithJoinedTableInheritance(object): @pytest.fixture def Entity(self, Base): class Entity(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) type = sa.Column(sa.String) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'entity' } return Entity @pytest.fixture def User(self, Entity): class User(Entity): __tablename__ = 'user' id = sa.Column( sa.Integer, sa.ForeignKey(Entity.id, ondelete='CASCADE'), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 'user' } return User def test_returns_class(self, Base, User, Entity): assert get_class_by_table(Base, User.__table__) == User assert get_class_by_table( Base, Entity.__table__ ) == Entity def test_table_with_no_associated_class(self, Base): table = sa.Table( 'some_table', Base.metadata, sa.Column('id', sa.Integer) ) assert get_class_by_table(Base, table) is None class TestGetClassByTableWithSingleTableInheritance(object): @pytest.fixture def Entity(self, Base): class Entity(Base): __tablename__ = 'entity' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) type = sa.Column(sa.String) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'entity' } return Entity @pytest.fixture def User(self, Entity): class User(Entity): __mapper_args__ = { 'polymorphic_identity': 'user' } return User def test_multiple_classes_without_data_parameter(self, Base, Entity, User): with pytest.raises(ValueError): assert get_class_by_table( Base, Entity.__table__ ) def test_multiple_classes_with_data_parameter(self, Base, Entity, User): assert get_class_by_table( Base, Entity.__table__, {'type': 'entity'} ) == Entity assert get_class_by_table( Base, Entity.__table__, {'type': 'user'} ) == User def test_multiple_classes_with_bogus_data(self, Base, Entity, User): with pytest.raises(ValueError): assert get_class_by_table( Base, Entity.__table__, {'type': 'unknown'} ) sqlalchemy-utils-0.36.1/tests/functions/test_get_column_key.py000066400000000000000000000022221360007755400246750ustar00rootroot00000000000000from copy import copy import pytest import sqlalchemy as sa from sqlalchemy_utils import get_column_key @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column('_name', sa.Unicode(255)) return Building @pytest.fixture def Movie(Base): class Movie(Base): __tablename__ = 'movie' id = sa.Column(sa.Integer, primary_key=True) return Movie class TestGetColumnKey(object): def test_supports_aliases(self, Building): assert ( get_column_key(Building, Building.__table__.c.id) == 'id' ) assert ( get_column_key(Building, Building.__table__.c._name) == 'name' ) def test_supports_vague_matching_of_column_objects(self, Building): column = copy(Building.__table__.c._name) assert get_column_key(Building, column) == 'name' def test_throws_value_error_for_unknown_column(self, Building, Movie): with pytest.raises(sa.orm.exc.UnmappedColumnError): get_column_key(Building, Movie.__table__.c.id) sqlalchemy-utils-0.36.1/tests/functions/test_get_columns.py000066400000000000000000000034601360007755400242150ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_columns @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column('_name', sa.Unicode(255)) return Building class TestGetColumns(object): def test_table(self, Building): assert isinstance( get_columns(Building.__table__), sa.sql.base.ImmutableColumnCollection ) def test_instrumented_attribute(self, Building): assert get_columns(Building.id) == [Building.__table__.c._id] def test_column_property(self, Building): assert get_columns(Building.id.property) == [ Building.__table__.c._id ] def test_column(self, Building): assert get_columns(Building.__table__.c._id) == [ Building.__table__.c._id ] def test_declarative_class(self, Building): assert isinstance( get_columns(Building), sa.util._collections.OrderedProperties ) def test_declarative_object(self, Building): assert isinstance( get_columns(Building()), sa.util._collections.OrderedProperties ) def test_mapper(self, Building): assert isinstance( get_columns(Building.__mapper__), sa.util._collections.OrderedProperties ) def test_class_alias(self, Building): assert isinstance( get_columns(sa.orm.aliased(Building)), sa.util._collections.OrderedProperties ) def test_table_alias(self, Building): alias = sa.orm.aliased(Building.__table__) assert isinstance( get_columns(alias), sa.sql.base.ImmutableColumnCollection ) sqlalchemy-utils-0.36.1/tests/functions/test_get_hybrid_properties.py000066400000000000000000000021201360007755400262620ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import get_hybrid_properties @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) @hybrid_property def lowercase_name(self): return self.name.lower() @lowercase_name.expression def lowercase_name(cls): return sa.func.lower(cls.name) return Category class TestGetHybridProperties(object): def test_declarative_model(self, Category): assert ( list(get_hybrid_properties(Category).keys()) == ['lowercase_name'] ) def test_mapper(self, Category): assert ( list(get_hybrid_properties(sa.inspect(Category)).keys()) == ['lowercase_name'] ) def test_aliased_class(self, Category): props = get_hybrid_properties(sa.orm.aliased(Category)) assert list(props.keys()) == ['lowercase_name'] sqlalchemy-utils-0.36.1/tests/functions/test_get_mapper.py000066400000000000000000000070521360007755400240220ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_mapper class TestGetMapper(object): @pytest.fixture def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) return Building def test_table(self, Building): assert get_mapper(Building.__table__) == sa.inspect(Building) def test_declarative_class(self, Building): assert ( get_mapper(Building) == sa.inspect(Building) ) def test_declarative_object(self, Building): assert ( get_mapper(Building()) == sa.inspect(Building) ) def test_mapper(self, Building): assert ( get_mapper(Building.__mapper__) == sa.inspect(Building) ) def test_class_alias(self, Building): assert ( get_mapper(sa.orm.aliased(Building)) == sa.inspect(Building) ) def test_instrumented_attribute(self, Building): assert ( get_mapper(Building.id) == sa.inspect(Building) ) def test_table_alias(self, Building): alias = sa.orm.aliased(Building.__table__) assert ( get_mapper(alias) == sa.inspect(Building) ) def test_column(self, Building): assert ( get_mapper(Building.__table__.c.id) == sa.inspect(Building) ) def test_column_of_an_alias(self, Building): assert ( get_mapper(sa.orm.aliased(Building.__table__).c.id) == sa.inspect(Building) ) class TestGetMapperWithQueryEntities(object): @pytest.fixture def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) return Building @pytest.fixture def init_models(self, Building): pass def test_mapper_entity_with_mapper(self, session, Building): entity = session.query(Building.__mapper__)._entities[0] assert ( get_mapper(entity) == sa.inspect(Building) ) def test_mapper_entity_with_class(self, session, Building): entity = session.query(Building)._entities[0] assert ( get_mapper(entity) == sa.inspect(Building) ) def test_column_entity(self, session, Building): query = session.query(Building.id) assert get_mapper(query._entities[0]) == sa.inspect(Building) class TestGetMapperWithMultipleMappersFound(object): @pytest.fixture def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) class BigBuilding(Building): pass return Building def test_table(self, Building): with pytest.raises(ValueError): get_mapper(Building.__table__) def test_table_alias(self, Building): alias = sa.orm.aliased(Building.__table__) with pytest.raises(ValueError): get_mapper(alias) class TestGetMapperForTableWithoutMapper(object): @pytest.fixture def building(self): metadata = sa.MetaData() return sa.Table('building', metadata) def test_table(self, building): with pytest.raises(ValueError): get_mapper(building) def test_table_alias(self, building): alias = sa.orm.aliased(building) with pytest.raises(ValueError): get_mapper(alias) sqlalchemy-utils-0.36.1/tests/functions/test_get_primary_keys.py000066400000000000000000000024631360007755400252550ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_primary_keys try: from collections import OrderedDict except ImportError: from ordereddict import OrderedDict @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column('_name', sa.Unicode(255)) return Building class TestGetPrimaryKeys(object): def test_table(self, Building): assert get_primary_keys(Building.__table__) == OrderedDict({ '_id': Building.__table__.c._id }) def test_declarative_class(self, Building): assert get_primary_keys(Building) == OrderedDict({ 'id': Building.__table__.c._id }) def test_declarative_object(self, Building): assert get_primary_keys(Building()) == OrderedDict({ 'id': Building.__table__.c._id }) def test_class_alias(self, Building): alias = sa.orm.aliased(Building) assert get_primary_keys(alias) == OrderedDict({ 'id': Building.__table__.c._id }) def test_table_alias(self, Building): alias = sa.orm.aliased(Building.__table__) assert get_primary_keys(alias) == OrderedDict({ '_id': alias.c._id }) sqlalchemy-utils-0.36.1/tests/functions/test_get_query_entities.py000066400000000000000000000064471360007755400256160ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_query_entities @pytest.fixture def TextItem(Base): class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, } return TextItem @pytest.fixture def Article(TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) category = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_identity': u'article' } return Article @pytest.fixture def BlogPost(TextItem): class BlogPost(TextItem): __tablename__ = 'blog_post' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': u'blog_post' } return BlogPost @pytest.fixture def init_models(TextItem, Article, BlogPost): pass class TestGetQueryEntities(object): def test_mapper(self, session, TextItem): query = session.query(sa.inspect(TextItem)) assert get_query_entities(query) == [TextItem] def test_entity(self, session, TextItem): query = session.query(TextItem) assert get_query_entities(query) == [TextItem] def test_instrumented_attribute(self, session, TextItem): query = session.query(TextItem.id) assert get_query_entities(query) == [TextItem] def test_column(self, session, TextItem): query = session.query(TextItem.__table__.c.id) assert get_query_entities(query) == [TextItem.__table__] def test_aliased_selectable(self, session, TextItem, BlogPost): selectable = sa.orm.with_polymorphic(TextItem, [BlogPost]) query = session.query(selectable) assert get_query_entities(query) == [selectable] def test_joined_entity(self, session, TextItem, BlogPost): query = session.query(TextItem).join( BlogPost, BlogPost.id == TextItem.id ) assert get_query_entities(query) == [ TextItem, sa.inspect(BlogPost) ] def test_joined_aliased_entity(self, session, TextItem, BlogPost): alias = sa.orm.aliased(BlogPost) query = session.query(TextItem).join( alias, alias.id == TextItem.id ) assert get_query_entities(query) == [TextItem, alias] def test_column_entity_with_label(self, session, Article): query = session.query(Article.id.label('id')) assert get_query_entities(query) == [Article] def test_with_subquery(self, session, Article): number_of_articles = ( sa.select( [sa.func.count(Article.id)], ) .select_from( Article.__table__ ) ).label('number_of_articles') query = session.query(Article, number_of_articles) assert get_query_entities(query) == [ Article, number_of_articles ] def test_aliased_entity(self, session, Article): alias = sa.orm.aliased(Article) query = session.query(alias) assert get_query_entities(query) == [alias] sqlalchemy-utils-0.36.1/tests/functions/test_get_referencing_foreign_keys.py000066400000000000000000000062231360007755400275700ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_referencing_foreign_keys class TestGetReferencingFksWithCompositeKeys(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) return User @pytest.fixture def Article(self, Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), ) return Article @pytest.fixture def init_models(self, User, Article): pass def test_with_declarative_class(self, User, Article): fks = get_referencing_foreign_keys(User) assert Article.__table__.foreign_keys == fks def test_with_table(self, User, Article): fks = get_referencing_foreign_keys(User.__table__) assert Article.__table__.foreign_keys == fks class TestGetReferencingFksWithInheritance(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode) first_name = sa.Column(sa.Unicode(255)) last_name = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': 'type' } return User @pytest.fixture def Admin(self, User): class Admin(User): __tablename__ = 'admin' id = sa.Column( sa.Integer, sa.ForeignKey(User.id), primary_key=True ) return Admin @pytest.fixture def TextItem(self, Base, User): class TextItem(Base): __tablename__ = 'textitem' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) __mapper_args__ = { 'polymorphic_on': 'type' } return TextItem @pytest.fixture def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 'article' } return Article @pytest.fixture def init_models(self, User, Admin, TextItem, Article): pass def test_with_declarative_class(self, Admin, TextItem): fks = get_referencing_foreign_keys(Admin) assert TextItem.__table__.foreign_keys == fks def test_with_table(self, Admin): fks = get_referencing_foreign_keys(Admin.__table__) assert fks == set([]) sqlalchemy-utils-0.36.1/tests/functions/test_get_tables.py000066400000000000000000000045011360007755400240040ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_tables @pytest.fixture def TextItem(Base): class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, 'with_polymorphic': '*' } return TextItem @pytest.fixture def Article(TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': u'article' } return Article @pytest.fixture def init_models(TextItem, Article): pass class TestGetTables(object): def test_child_class_using_join_table_inheritance(self, TextItem, Article): assert get_tables(Article) == [ TextItem.__table__, Article.__table__ ] def test_entity_using_with_polymorphic(self, TextItem, Article): assert get_tables(TextItem) == [ TextItem.__table__, Article.__table__ ] def test_instrumented_attribute(self, TextItem): assert get_tables(TextItem.name) == [ TextItem.__table__, ] def test_polymorphic_instrumented_attribute(self, TextItem, Article): assert get_tables(Article.id) == [ TextItem.__table__, Article.__table__ ] def test_column(self, Article): assert get_tables(Article.__table__.c.id) == [ Article.__table__ ] def test_mapper_entity_with_class(self, session, TextItem, Article): query = session.query(Article) assert get_tables(query._entities[0]) == [ TextItem.__table__, Article.__table__ ] def test_mapper_entity_with_mapper(self, session, TextItem, Article): query = session.query(sa.inspect(Article)) assert get_tables(query._entities[0]) == [ TextItem.__table__, Article.__table__ ] def test_column_entity(self, session, TextItem, Article): query = session.query(Article.id) assert get_tables(query._entities[0]) == [ TextItem.__table__, Article.__table__ ] sqlalchemy-utils-0.36.1/tests/functions/test_get_type.py000066400000000000000000000025231360007755400235150ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_type @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) return User @pytest.fixture def Article(Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship(User) some_property = sa.orm.column_property( sa.func.coalesce(id, 1) ) return Article class TestGetType(object): def test_instrumented_attribute(self, Article): assert isinstance(get_type(Article.id), sa.Integer) def test_column_property(self, Article): assert isinstance(get_type(Article.id.property), sa.Integer) def test_column(self, Article): assert isinstance(get_type(Article.__table__.c.id), sa.Integer) def test_calculated_column_property(self, Article): assert isinstance(get_type(Article.some_property), sa.Integer) def test_relationship_property(self, Article, User): assert get_type(Article.author) == User def test_scalar_select(self, Article): query = sa.select([Article.id]).as_scalar() assert isinstance(get_type(query), sa.Integer) sqlalchemy-utils-0.36.1/tests/functions/test_getdotattr.py000066400000000000000000000056151360007755400240630ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import getdotattr @pytest.fixture def Document(Base): class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Document @pytest.fixture def Section(Base, Document): class Section(Base): __tablename__ = 'section' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) document_id = sa.Column( sa.Integer, sa.ForeignKey(Document.id) ) document = sa.orm.relationship(Document, backref='sections') return Section @pytest.fixture def SubSection(Base, Section): class SubSection(Base): __tablename__ = 'subsection' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) section_id = sa.Column( sa.Integer, sa.ForeignKey(Section.id) ) section = sa.orm.relationship(Section, backref='subsections') return SubSection @pytest.fixture def SubSubSection(Base, SubSection): class SubSubSection(Base): __tablename__ = 'subsubsection' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) locale = sa.Column(sa.String(10)) subsection_id = sa.Column( sa.Integer, sa.ForeignKey(SubSection.id) ) subsection = sa.orm.relationship( SubSection, backref='subsubsections' ) return SubSubSection @pytest.fixture def init_models(Document, Section, SubSection, SubSubSection): pass class TestGetDotAttr(object): def test_simple_objects(self, Document, Section, SubSection): document = Document(name=u'some document') section = Section(document=document) subsection = SubSection(section=section) assert getdotattr( subsection, 'section.document.name' ) == u'some document' def test_with_instrumented_lists( self, Document, Section, SubSection, SubSubSection ): document = Document(name=u'some document') section = Section(document=document) subsection = SubSection(section=section) subsubsection = SubSubSection(subsection=subsection) assert getdotattr(document, 'sections') == [section] assert getdotattr(document, 'sections.subsections') == [ subsection ] assert getdotattr(document, 'sections.subsections.subsubsections') == [ subsubsection ] def test_class_paths(self, Document, Section, SubSection): assert getdotattr(Section, 'document') is Section.document assert ( getdotattr(SubSection, 'section.document') is Section.document ) assert getdotattr(Section, 'document.name') is Document.name sqlalchemy-utils-0.36.1/tests/functions/test_has_changes.py000066400000000000000000000024511360007755400241400ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import has_changes @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.String(100)) return Article class TestHasChangesWithStringAttr(object): def test_without_changed_attr(self, Article): article = Article() assert not has_changes(article, 'title') def test_with_changed_attr(self, Article): article = Article(title='Some title') assert has_changes(article, 'title') class TestHasChangesWithMultipleAttrs(object): def test_without_changed_attr(self, Article): article = Article() assert not has_changes(article, ['title']) def test_with_changed_attr(self, Article): article = Article(title='Some title') assert has_changes(article, ['title', 'id']) class TestHasChangesWithExclude(object): def test_without_changed_attr(self, Article): article = Article() assert not has_changes(article, exclude=['id']) def test_with_changed_attr(self, Article): article = Article(title='Some title') assert has_changes(article, exclude=['id']) assert not has_changes(article, exclude=['title']) sqlalchemy-utils-0.36.1/tests/functions/test_has_index.py000066400000000000000000000111341360007755400236350ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_fk_constraint_for_columns, has_index class TestHasIndex(object): @pytest.fixture def table(self, Base): class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) locale = sa.Column(sa.String(10), primary_key=True) title = sa.Column(sa.String(100)) is_published = sa.Column(sa.Boolean, index=True) is_deleted = sa.Column(sa.Boolean) is_archived = sa.Column(sa.Boolean) __table_args__ = ( sa.Index('my_index', is_deleted, is_archived), ) return ArticleTranslation.__table__ def test_column_that_belongs_to_an_alias(self, table): alias = sa.orm.aliased(table) with pytest.raises(TypeError): assert has_index(alias.c.id) def test_compound_primary_key(self, table): assert has_index(table.c.id) assert not has_index(table.c.locale) def test_single_column_index(self, table): assert has_index(table.c.is_published) def test_compound_column_index(self, table): assert has_index(table.c.is_deleted) assert not has_index(table.c.is_archived) def test_table_without_primary_key(self): article = sa.Table( 'article', sa.MetaData(), sa.Column('name', sa.String) ) assert not has_index(article.c.name) class TestHasIndexWithFKConstraint(object): def test_composite_fk_without_index(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert not has_index(constraint) def test_composite_fk_with_index(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name ) ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert has_index(constraint) def test_composite_fk_with_partial_index_match(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name, id ) ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert has_index(constraint) sqlalchemy-utils-0.36.1/tests/functions/test_has_unique_index.py000066400000000000000000000122331360007755400252240ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import get_fk_constraint_for_columns, has_unique_index class TestHasUniqueIndex(object): @pytest.fixture def articles(self, Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) return Article.__table__ @pytest.fixture def article_translations(self, Base): class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) locale = sa.Column(sa.String(10), primary_key=True) title = sa.Column(sa.String(100)) is_published = sa.Column(sa.Boolean, index=True) is_deleted = sa.Column(sa.Boolean, unique=True) is_archived = sa.Column(sa.Boolean) __table_args__ = ( sa.Index('my_index', is_archived, is_published, unique=True), ) return ArticleTranslation.__table__ def test_primary_key(self, articles): assert has_unique_index(articles.c.id) def test_column_of_aliased_table(self, articles): alias = sa.orm.aliased(articles) with pytest.raises(TypeError): assert has_unique_index(alias.c.id) def test_unique_index(self, article_translations): assert has_unique_index(article_translations.c.is_deleted) def test_compound_primary_key(self, article_translations): assert not has_unique_index(article_translations.c.id) assert not has_unique_index(article_translations.c.locale) def test_single_column_index(self, article_translations): assert not has_unique_index(article_translations.c.is_published) def test_compound_column_unique_index(self, article_translations): assert not has_unique_index(article_translations.c.is_published) assert not has_unique_index(article_translations.c.is_archived) class TestHasUniqueIndexWithFKConstraint(object): def test_composite_fk_without_index(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert not has_unique_index(constraint) def test_composite_fk_with_index(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name, unique=True ) ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert has_unique_index(constraint) def test_composite_fk_with_partial_index_match(self, Base): class User(Base): __tablename__ = 'user' first_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255), primary_key=True) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_first_name = sa.Column(sa.Unicode(255)) author_last_name = sa.Column(sa.Unicode(255)) __table_args__ = ( sa.ForeignKeyConstraint( [author_first_name, author_last_name], [User.first_name, User.last_name] ), sa.Index( 'my_index', author_first_name, author_last_name, id, unique=True ) ) table = Article.__table__ constraint = get_fk_constraint_for_columns( table, table.c.author_first_name, table.c.author_last_name ) assert not has_unique_index(constraint) sqlalchemy-utils-0.36.1/tests/functions/test_identity.py000066400000000000000000000023141360007755400235240ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import identity class IdentityTestCase(object): @pytest.fixture def init_models(self, Building): pass def test_for_transient_class_without_id(self, Building): assert identity(Building()) == (None, ) def test_for_transient_class_with_id(self, session, Building): building = Building(name=u'Some building') session.add(building) session.flush() assert identity(building) == (building.id, ) def test_identity_for_class(self, Building): assert identity(Building) == (Building.id, ) class TestIdentity(IdentityTestCase): @pytest.fixture def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Building class TestIdentityWithColumnAlias(IdentityTestCase): @pytest.fixture def Building(self, Base): class Building(Base): __tablename__ = 'building' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Building sqlalchemy-utils-0.36.1/tests/functions/test_is_loaded.py000066400000000000000000000011201360007755400236100ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import is_loaded @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article_translation' id = sa.Column(sa.Integer, primary_key=True) title = sa.orm.deferred(sa.Column(sa.String(100))) return Article class TestIsLoaded(object): def test_loaded_property(self, Article): article = Article(id=1) assert is_loaded(article, 'id') def test_unloaded_property(self, Article): article = Article(id=4) assert not is_loaded(article, 'title') sqlalchemy-utils-0.36.1/tests/functions/test_json_sql.py000066400000000000000000000014311360007755400235220ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import json_sql @pytest.mark.usefixtures('postgresql_dsn') class TestJSONSQL(object): @pytest.mark.parametrize( ('value', 'result'), ( (1, 1), (14.14, 14.14), ({'a': 2, 'b': 'c'}, {'a': 2, 'b': 'c'}), ( {'a': {'b': 'c'}}, {'a': {'b': 'c'}} ), ({}, {}), ([1, 2], [1, 2]), ([], []), ( [sa.select([sa.text('1')]).label('alias')], [1] ) ) ) def test_compiled_scalars(self, connection, value, result): assert result == ( connection.execute(sa.select([json_sql(value)])).fetchone()[0] ) sqlalchemy-utils-0.36.1/tests/functions/test_make_order_by_deterministic.py000066400000000000000000000067161360007755400274320ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.functions.sort_query import make_order_by_deterministic from .. import assert_contains @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship('User') return Article @pytest.fixture def User(Base, Article): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode) email = sa.Column(sa.Unicode, unique=True) email_lower = sa.orm.column_property( sa.func.lower(name) ) User.article_count = sa.orm.column_property( sa.select([sa.func.count()], from_obj=Article) .where(Article.author_id == User.id) .label('article_count') ) return User @pytest.fixture def init_models(Article, User): pass class TestMakeOrderByDeterministic(object): def test_column_property(self, session, User): query = session.query(User).order_by(User.email_lower) query = make_order_by_deterministic(query) assert_contains('lower(user.name) AS lower_1', query) assert_contains('lower_1, user.id ASC', query) def test_unique_column(self, session, User): query = session.query(User).order_by(User.email) query = make_order_by_deterministic(query) assert str(query).endswith('ORDER BY user.email') def test_non_unique_column(self, session, User): query = session.query(User).order_by(User.name) query = make_order_by_deterministic(query) assert_contains('ORDER BY user.name, user.id ASC', query) def test_descending_order_by(self, session, User): query = session.query(User).order_by( sa.desc(User.name) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY user.name DESC, user.id DESC', query) def test_ascending_order_by(self, session, User): query = session.query(User).order_by( sa.asc(User.name) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY user.name ASC, user.id ASC', query) def test_string_order_by(self, session, User): query = session.query(User).order_by('name') query = make_order_by_deterministic(query) assert_contains('ORDER BY user.name, user.id ASC', query) def test_annotated_label(self, session, User): query = session.query(User).order_by(User.article_count) query = make_order_by_deterministic(query) assert_contains('article_count, user.id ASC', query) def test_annotated_label_with_descending_order(self, session, User): query = session.query(User).order_by( sa.desc(User.article_count) ) query = make_order_by_deterministic(query) assert_contains('ORDER BY article_count DESC, user.id DESC', query) def test_query_without_order_by(self, session, User): query = session.query(User) query = make_order_by_deterministic(query) assert 'ORDER BY user.id' in str(query) def test_alias(self, session, User): alias = sa.orm.aliased(User.__table__) query = session.query(alias).order_by(alias.c.name) query = make_order_by_deterministic(query) assert str(query).endswith('ORDER BY user_1.name, user.id ASC') sqlalchemy-utils-0.36.1/tests/functions/test_merge_references.py000066400000000000000000000134551360007755400252030ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import merge_references class TestMergeReferences(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name return User @pytest.fixture def BlogPost(self, Base, User): class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) title = sa.Column(sa.Unicode(255)) content = sa.Column(sa.UnicodeText) author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) author = sa.orm.relationship(User) return BlogPost @pytest.fixture def init_models(self, User, BlogPost): pass def test_updates_foreign_keys(self, session, User, BlogPost): john = User(name=u'John') jack = User(name=u'Jack') post = BlogPost(title=u'Some title', author=john) post2 = BlogPost(title=u'Other title', author=jack) session.add(john) session.add(jack) session.add(post) session.add(post2) session.commit() merge_references(john, jack) session.commit() assert post.author == jack assert post2.author == jack def test_object_merging_whenever_possible(self, session, User, BlogPost): john = User(name=u'John') jack = User(name=u'Jack') post = BlogPost(title=u'Some title', author=john) post2 = BlogPost(title=u'Other title', author=jack) session.add(john) session.add(jack) session.add(post) session.add(post2) session.commit() # Load the author for post assert post.author_id == john.id merge_references(john, jack) assert post.author_id == jack.id assert post2.author_id == jack.id class TestMergeReferencesWithManyToManyAssociations(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) def __repr__(self): return 'User(%r)' % self.name return User @pytest.fixture def Team(self, Base): team_member = sa.Table( 'team_member', Base.metadata, sa.Column( 'user_id', sa.Integer, sa.ForeignKey('user.id', ondelete='CASCADE'), primary_key=True ), sa.Column( 'team_id', sa.Integer, sa.ForeignKey('team.id', ondelete='CASCADE'), primary_key=True ) ) class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) members = sa.orm.relationship( 'User', secondary=team_member, backref='teams' ) return Team @pytest.fixture def init_models(self, User, Team): pass def test_supports_associations(self, session, User, Team): john = User(name=u'John') jack = User(name=u'Jack') team = Team(name=u'Team') team.members.append(john) session.add(john) session.add(jack) session.commit() merge_references(john, jack) assert john not in team.members assert jack in team.members class TestMergeReferencesWithManyToManyAssociationObjects(object): @pytest.fixture def Team(self, Base): class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) return Team @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def TeamMember(self, Base, User, Team): class TeamMember(Base): __tablename__ = 'team_member' user_id = sa.Column( sa.Integer, sa.ForeignKey(User.id, ondelete='CASCADE'), primary_key=True ) team_id = sa.Column( sa.Integer, sa.ForeignKey(Team.id, ondelete='CASCADE'), primary_key=True ) role = sa.Column(sa.Unicode(255)) team = sa.orm.relationship( Team, backref=sa.orm.backref( 'members', cascade='all, delete-orphan' ), primaryjoin=team_id == Team.id, ) user = sa.orm.relationship( User, backref=sa.orm.backref( 'memberships', cascade='all, delete-orphan' ), primaryjoin=user_id == User.id, ) return TeamMember @pytest.fixture def init_models(self, User, Team, TeamMember): pass def test_supports_associations(self, session, User, Team, TeamMember): john = User(name=u'John') jack = User(name=u'Jack') team = Team(name=u'Team') team.members.append(TeamMember(user=john)) session.add(john) session.add(jack) session.add(team) session.commit() merge_references(john, jack) session.commit() users = [member.user for member in team.members] assert john not in users assert jack in users sqlalchemy-utils-0.36.1/tests/functions/test_naturally_equivalent.py000066400000000000000000000006541360007755400261500ustar00rootroot00000000000000from sqlalchemy_utils.functions import naturally_equivalent class TestNaturallyEquivalent(object): def test_returns_true_when_properties_match(self, User): assert naturally_equivalent( User(name=u'someone'), User(name=u'someone') ) def test_skips_primary_keys(self, User): assert naturally_equivalent( User(id=1, name=u'someone'), User(id=2, name=u'someone') ) sqlalchemy-utils-0.36.1/tests/functions/test_non_indexed_foreign_keys.py000066400000000000000000000036321360007755400267350ustar00rootroot00000000000000from itertools import chain import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import non_indexed_foreign_keys class TestFindNonIndexedForeignKeys(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def Category(self, Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Category @pytest.fixture def Article(self, Base, User, Category): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) author_id = sa.Column( sa.Integer, sa.ForeignKey(User.id), index=True ) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship( Category, primaryjoin=category_id == Category.id, backref=sa.orm.backref( 'articles', ) ) return Article @pytest.fixture def init_models(self, User, Category, Article): pass def test_finds_all_non_indexed_fks(self, session, Base, engine): fks = non_indexed_foreign_keys(Base.metadata, engine) assert ( 'article' in fks ) column_names = list(chain( *( names for names in ( fk.columns.keys() for fk in fks['article'] ) ) )) assert 'category_id' in column_names assert 'author_id' not in column_names sqlalchemy-utils-0.36.1/tests/functions/test_quote.py000066400000000000000000000014311360007755400230270ustar00rootroot00000000000000from sqlalchemy.dialects import postgresql from sqlalchemy_utils.functions import quote class TestQuote(object): def test_quote_with_preserved_keyword(self, engine, connection, session): assert quote(connection, 'order') == '"order"' assert quote(session, 'order') == '"order"' assert quote(engine, 'order') == '"order"' assert quote(postgresql.dialect(), 'order') == '"order"' def test_quote_with_non_preserved_keyword( self, engine, connection, session ): assert quote(connection, 'some_order') == 'some_order' assert quote(session, 'some_order') == 'some_order' assert quote(engine, 'some_order') == 'some_order' assert quote(postgresql.dialect(), 'some_order') == 'some_order' sqlalchemy-utils-0.36.1/tests/functions/test_render.py000066400000000000000000000035771360007755400231660ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.functions import ( mock_engine, render_expression, render_statement ) class TestRender(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) return User @pytest.fixture def init_models(self, User): pass def test_render_orm_query(self, session, User): query = session.query(User).filter_by(id=3) text = render_statement(query) assert 'SELECT user.id, user.name' in text assert 'FROM user' in text assert 'WHERE user.id = 3' in text def test_render_statement(self, session, User): statement = User.__table__.select().where(User.id == 3) text = render_statement(statement, bind=session.bind) assert 'SELECT user.id, user.name' in text assert 'FROM user' in text assert 'WHERE user.id = 3' in text def test_render_statement_without_mapper(self, session): statement = sa.select([sa.text('1')]) text = render_statement(statement, bind=session.bind) assert 'SELECT 1' in text def test_render_ddl(self, engine, User): expression = 'User.__table__.create(engine)' stream = render_expression(expression, engine) text = stream.getvalue() assert 'CREATE TABLE user' in text assert 'PRIMARY KEY' in text def test_render_mock_ddl(self, engine, User): # TODO: mock_engine doesn't seem to work with locally scoped variables. self.engine = engine with mock_engine('self.engine') as stream: User.__table__.create(self.engine) text = stream.getvalue() assert 'CREATE TABLE user' in text assert 'PRIMARY KEY' in text sqlalchemy-utils-0.36.1/tests/functions/test_table_name.py000066400000000000000000000014211360007755400237600ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import table_name @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) return Building @pytest.fixture def init_models(Base): pass class TestTableName(object): def test_class(self, Building): assert table_name(Building) == 'building' del Building.__tablename__ assert table_name(Building) == 'building' def test_attribute(self, Building): assert table_name(Building.id) == 'building' assert table_name(Building.name) == 'building' def test_target(self, Building): assert table_name(Building()) == 'building' sqlalchemy-utils-0.36.1/tests/generic_relationship/000077500000000000000000000000001360007755400224475ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/generic_relationship/__init__.py000066400000000000000000000047601360007755400245670ustar00rootroot00000000000000import six class GenericRelationshipTestCase(object): def test_set_as_none(self, Event): event = Event() event.object = None assert event.object is None def test_set_manual_and_get(self, session, User, Event): user = User() session.add(user) session.commit() event = Event() event.object_id = user.id event.object_type = six.text_type(type(user).__name__) assert event.object is None session.add(event) session.commit() assert event.object == user def test_set_and_get(self, session, User, Event): user = User() session.add(user) session.commit() event = Event(object=user) assert event.object_id == user.id assert event.object_type == type(user).__name__ session.add(event) session.commit() assert event.object == user def test_compare_instance(self, session, User, Event): user1 = User() user2 = User() session.add_all([user1, user2]) session.commit() event = Event(object=user1) session.add(event) session.commit() assert event.object == user1 assert event.object != user2 def test_compare_query(self, session, User, Event): user1 = User() user2 = User() session.add_all([user1, user2]) session.commit() event = Event(object=user1) session.add(event) session.commit() q = session.query(Event) assert q.filter_by(object=user1).first() is not None assert q.filter_by(object=user2).first() is None assert q.filter(Event.object == user2).first() is None def test_compare_not_query(self, session, User, Event): user1 = User() user2 = User() session.add_all([user1, user2]) session.commit() event = Event(object=user1) session.add(event) session.commit() q = session.query(Event) assert q.filter(Event.object != user2).first() is not None def test_compare_type(self, session, User, Event): user1 = User() user2 = User() session.add_all([user1, user2]) session.commit() event1 = Event(object=user1) event2 = Event(object=user2) session.add_all([event1, event2]) session.commit() statement = Event.object.is_type(User) q = session.query(Event).filter(statement) assert q.first() is not None sqlalchemy-utils-0.36.1/tests/generic_relationship/test_abstract_base_class.py000066400000000000000000000022311360007755400300400ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy_utils import generic_relationship from . import GenericRelationshipTestCase @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) return Building @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) return User @pytest.fixture def EventBase(Base): class EventBase(Base): __abstract__ = True object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) @declared_attr def object(cls): return generic_relationship('object_type', 'object_id') return EventBase @pytest.fixture def Event(EventBase): class Event(EventBase): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) return Event @pytest.fixture def init_models(Building, User, Event): pass class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): pass sqlalchemy-utils-0.36.1/tests/generic_relationship/test_column_aliases.py000066400000000000000000000016641360007755400270650ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import generic_relationship from . import GenericRelationshipTestCase @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) return Building @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) return User @pytest.fixture def Event(Base): class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) object_type = sa.Column(sa.Unicode(255), name="objectType") object_id = sa.Column(sa.Integer, nullable=False) object = generic_relationship(object_type, object_id) return Event @pytest.fixture def init_models(Building, User, Event): pass class TestGenericRelationship(GenericRelationshipTestCase): pass sqlalchemy-utils-0.36.1/tests/generic_relationship/test_composite_keys.py000066400000000000000000000037331360007755400271230ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy_utils import generic_relationship from ..generic_relationship import GenericRelationshipTestCase @pytest.fixture def incrementor(): class Incrementor(object): value = 1 return Incrementor() @pytest.fixture def Building(Base, incrementor): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) code = sa.Column(sa.Integer, primary_key=True) def __init__(self): incrementor.value += 1 self.id = incrementor.value self.code = incrementor.value return Building @pytest.fixture def User(Base, incrementor): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) code = sa.Column(sa.Integer, primary_key=True) def __init__(self): incrementor.value += 1 self.id = incrementor.value self.code = incrementor.value return User @pytest.fixture def Event(Base): class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) object_code = sa.Column(sa.Integer, nullable=False) object = generic_relationship( object_type, (object_id, object_code) ) return Event @pytest.fixture def init_models(Building, User, Event): pass class TestGenericRelationship(GenericRelationshipTestCase): def test_set_manual_and_get(self, session, Event, User): user = User() session.add(user) session.commit() event = Event() event.object_id = user.id event.object_type = six.text_type(type(user).__name__) event.object_code = user.code assert event.object is None session.add(event) session.commit() assert event.object == user sqlalchemy-utils-0.36.1/tests/generic_relationship/test_hybrid_properties.py000066400000000000000000000037041360007755400276210ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils import generic_relationship @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) return User @pytest.fixture def UserHistory(Base): class UserHistory(Base): __tablename__ = 'user_history' id = sa.Column(sa.Integer, primary_key=True) transaction_id = sa.Column(sa.Integer, primary_key=True) return UserHistory @pytest.fixture def Event(Base): class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) transaction_id = sa.Column(sa.Integer) object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) object = generic_relationship( object_type, object_id ) @hybrid_property def object_version_type(self): return self.object_type + 'History' @object_version_type.expression def object_version_type(cls): return sa.func.concat(cls.object_type, 'History') object_version = generic_relationship( object_version_type, (object_id, transaction_id) ) return Event @pytest.fixture def init_models(User, UserHistory, Event): pass class TestGenericRelationship(object): def test_set_manual_and_get(self, session, User, UserHistory, Event): user = User(id=1) history = UserHistory(id=1, transaction_id=1) session.add(user) session.add(history) session.commit() event = Event(transaction_id=1) event.object_id = user.id event.object_type = six.text_type(type(user).__name__) assert event.object is None session.add(event) session.commit() assert event.object == user assert event.object_version == history sqlalchemy-utils-0.36.1/tests/generic_relationship/test_single_table_inheritance.py000066400000000000000000000105101360007755400310560ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy_utils import generic_relationship @pytest.fixture def Employee(Base): class Employee(Base): __tablename__ = 'employee' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(50)) type = sa.Column(sa.String(20)) __mapper_args__ = { 'polymorphic_on': type, 'polymorphic_identity': 'employee' } return Employee @pytest.fixture def Manager(Employee): class Manager(Employee): __mapper_args__ = { 'polymorphic_identity': 'manager' } return Manager @pytest.fixture def Engineer(Employee): class Engineer(Employee): __mapper_args__ = { 'polymorphic_identity': 'engineer' } return Engineer @pytest.fixture def Event(Base): class Event(Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) object_type = sa.Column(sa.Unicode(255)) object_id = sa.Column(sa.Integer, nullable=False) object = generic_relationship(object_type, object_id) return Event @pytest.fixture def init_models(Employee, Manager, Engineer, Event): pass class TestGenericRelationship(object): def test_set_as_none(self, Event): event = Event() event.object = None assert event.object is None def test_set_manual_and_get(self, session, Manager, Event): manager = Manager() session.add(manager) session.commit() event = Event() event.object_id = manager.id event.object_type = six.text_type(type(manager).__name__) assert event.object is None session.add(event) session.commit() assert event.object == manager def test_set_and_get(self, session, Manager, Event): manager = Manager() session.add(manager) session.commit() event = Event(object=manager) assert event.object_id == manager.id assert event.object_type == type(manager).__name__ session.add(event) session.commit() assert event.object == manager def test_compare_instance(self, session, Manager, Event): manager1 = Manager() manager2 = Manager() session.add_all([manager1, manager2]) session.commit() event = Event(object=manager1) session.add(event) session.commit() assert event.object == manager1 assert event.object != manager2 def test_compare_query(self, session, Manager, Event): manager1 = Manager() manager2 = Manager() session.add_all([manager1, manager2]) session.commit() event = Event(object=manager1) session.add(event) session.commit() q = session.query(Event) assert q.filter_by(object=manager1).first() is not None assert q.filter_by(object=manager2).first() is None assert q.filter(Event.object == manager2).first() is None def test_compare_not_query(self, session, Manager, Event): manager1 = Manager() manager2 = Manager() session.add_all([manager1, manager2]) session.commit() event = Event(object=manager1) session.add(event) session.commit() q = session.query(Event) assert q.filter(Event.object != manager2).first() is not None def test_compare_type(self, session, Manager, Event): manager1 = Manager() manager2 = Manager() session.add_all([manager1, manager2]) session.commit() event1 = Event(object=manager1) event2 = Event(object=manager2) session.add_all([event1, event2]) session.commit() statement = Event.object.is_type(Manager) q = session.query(Event).filter(statement) assert q.first() is not None def test_compare_super_type(self, session, Manager, Event, Employee): manager1 = Manager() manager2 = Manager() session.add_all([manager1, manager2]) session.commit() event1 = Event(object=manager1) event2 = Event(object=manager2) session.add_all([event1, event2]) session.commit() statement = Event.object.is_type(Employee) q = session.query(Event).filter(statement) assert q.first() is not None sqlalchemy-utils-0.36.1/tests/mixins.py000066400000000000000000000147221360007755400201410ustar00rootroot00000000000000import pytest import sqlalchemy as sa class ThreeLevelDeepOneToOne(object): @pytest.fixture def Catalog(self, Base, Category): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column('_id', sa.Integer, primary_key=True) category = sa.orm.relationship( Category, uselist=False, backref='catalog' ) return Catalog @pytest.fixture def Category(self, Base, SubCategory): class Category(Base): __tablename__ = 'category' id = sa.Column('_id', sa.Integer, primary_key=True) catalog_id = sa.Column( '_catalog_id', sa.Integer, sa.ForeignKey('catalog._id') ) sub_category = sa.orm.relationship( SubCategory, uselist=False, backref='category' ) return Category @pytest.fixture def SubCategory(self, Base, Product): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) category_id = sa.Column( '_category_id', sa.Integer, sa.ForeignKey('category._id') ) product = sa.orm.relationship( Product, uselist=False, backref='sub_category' ) return SubCategory @pytest.fixture def Product(self, Base): class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Integer) sub_category_id = sa.Column( '_sub_category_id', sa.Integer, sa.ForeignKey('sub_category._id') ) return Product @pytest.fixture def init_models(self, Catalog, Category, SubCategory, Product): pass class ThreeLevelDeepOneToMany(object): @pytest.fixture def Catalog(self, Base, Category): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column('_id', sa.Integer, primary_key=True) categories = sa.orm.relationship(Category, backref='catalog') return Catalog @pytest.fixture def Category(self, Base, SubCategory): class Category(Base): __tablename__ = 'category' id = sa.Column('_id', sa.Integer, primary_key=True) catalog_id = sa.Column( '_catalog_id', sa.Integer, sa.ForeignKey('catalog._id') ) sub_categories = sa.orm.relationship( SubCategory, backref='category' ) return Category @pytest.fixture def SubCategory(self, Base, Product): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) category_id = sa.Column( '_category_id', sa.Integer, sa.ForeignKey('category._id') ) products = sa.orm.relationship( Product, backref='sub_category' ) return SubCategory @pytest.fixture def Product(self, Base): class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( '_sub_category_id', sa.Integer, sa.ForeignKey('sub_category._id') ) def __repr__(self): return '' % self.id return Product @pytest.fixture def init_models(self, Catalog, Category, SubCategory, Product): pass class ThreeLevelDeepManyToMany(object): @pytest.fixture def Catalog(self, Base, Category): catalog_category = sa.Table( 'catalog_category', Base.metadata, sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog._id')), sa.Column('category_id', sa.Integer, sa.ForeignKey('category._id')) ) class Catalog(Base): __tablename__ = 'catalog' id = sa.Column('_id', sa.Integer, primary_key=True) categories = sa.orm.relationship( Category, backref='catalogs', secondary=catalog_category ) return Catalog @pytest.fixture def Category(self, Base, SubCategory): category_subcategory = sa.Table( 'category_subcategory', Base.metadata, sa.Column( 'category_id', sa.Integer, sa.ForeignKey('category._id') ), sa.Column( 'subcategory_id', sa.Integer, sa.ForeignKey('sub_category._id') ) ) class Category(Base): __tablename__ = 'category' id = sa.Column('_id', sa.Integer, primary_key=True) sub_categories = sa.orm.relationship( SubCategory, backref='categories', secondary=category_subcategory ) return Category @pytest.fixture def SubCategory(self, Base, Product): subcategory_product = sa.Table( 'subcategory_product', Base.metadata, sa.Column( 'subcategory_id', sa.Integer, sa.ForeignKey('sub_category._id') ), sa.Column( 'product_id', sa.Integer, sa.ForeignKey('product._id') ) ) class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column('_id', sa.Integer, primary_key=True) products = sa.orm.relationship( Product, backref='sub_categories', secondary=subcategory_product ) return SubCategory @pytest.fixture def Product(self, Base): class Product(Base): __tablename__ = 'product' id = sa.Column('_id', sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) return Product @pytest.fixture def init_models(self, Catalog, Category, SubCategory, Product): pass sqlalchemy-utils-0.36.1/tests/observes/000077500000000000000000000000001360007755400201025ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/observes/__init__.py000066400000000000000000000000001360007755400222010ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/observes/test_column_property.py000066400000000000000000000071761360007755400247670ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForColumn(object): @pytest.fixture def Product(self, Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Integer) @observes('price') def product_price_observer(self, price): self.price = price * 2 return Product @pytest.fixture def init_models(self, Product): pass def test_simple_insert(self, session, Product): product = Product(price=100) session.add(product) session.flush() assert product.price == 200 @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForColumnWithoutActualChanges(object): @pytest.fixture def Product(self, Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Integer) @observes('price') def product_price_observer(self, price): raise Exception('Trying to change price') return Product @pytest.fixture def init_models(self, Product): pass def test_only_notifies_observer_on_actual_changes(self, session, Product): product = Product() session.add(product) session.flush() with pytest.raises(Exception) as e: product.price = 500 session.commit() assert str(e.value) == 'Trying to change price' @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForMultipleColumns(object): @pytest.fixture def Order(self, Base): class Order(Base): __tablename__ = 'order' id = sa.Column(sa.Integer, primary_key=True) unit_price = sa.Column(sa.Integer) amount = sa.Column(sa.Integer) total_price = sa.Column(sa.Integer) @observes('amount', 'unit_price') def total_price_observer(self, amount, unit_price): self.total_price = amount * unit_price return Order @pytest.fixture def init_models(self, Order): pass def test_only_notifies_observer_on_actual_changes(self, session, Order): order = Order() order.amount = 2 order.unit_price = 10 session.add(order) session.flush() order.amount = 1 session.flush() assert order.total_price == 10 order.unit_price = 100 session.flush() assert order.total_price == 100 @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForMultipleColumnsFiresOnlyOnce(object): @pytest.fixture def Order(self, Base): class Order(Base): __tablename__ = 'order' id = sa.Column(sa.Integer, primary_key=True) unit_price = sa.Column(sa.Integer) amount = sa.Column(sa.Integer) @observes('amount', 'unit_price') def total_price_observer(self, amount, unit_price): self.call_count = self.call_count + 1 return Order @pytest.fixture def init_models(self, Order): pass def test_only_notifies_observer_on_actual_changes(self, session, Order): order = Order() order.amount = 2 order.unit_price = 10 order.call_count = 0 session.add(order) session.flush() assert order.call_count == 1 order.amount = 1 order.unit_price = 100 session.flush() assert order.call_count == 2 sqlalchemy-utils-0.36.1/tests/observes/test_m2m_m2m_m2m.py000066400000000000000000000102001360007755400235250ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.fixture def Catalog(Base): catalog_category = sa.Table( 'catalog_category', Base.metadata, sa.Column('catalog_id', sa.Integer, sa.ForeignKey('catalog.id')), sa.Column('category_id', sa.Integer, sa.ForeignKey('category.id')) ) class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) product_count = sa.Column(sa.Integer, default=0) @observes('categories.sub_categories.products') def product_observer(self, products): self.product_count = len(products) categories = sa.orm.relationship( 'Category', backref='catalogs', secondary=catalog_category ) return Catalog @pytest.fixture def Category(Base): category_subcategory = sa.Table( 'category_subcategory', Base.metadata, sa.Column( 'category_id', sa.Integer, sa.ForeignKey('category.id') ), sa.Column( 'subcategory_id', sa.Integer, sa.ForeignKey('sub_category.id') ) ) class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) sub_categories = sa.orm.relationship( 'SubCategory', backref='categories', secondary=category_subcategory ) return Category @pytest.fixture def SubCategory(Base): subcategory_product = sa.Table( 'subcategory_product', Base.metadata, sa.Column( 'subcategory_id', sa.Integer, sa.ForeignKey('sub_category.id') ), sa.Column( 'product_id', sa.Integer, sa.ForeignKey('product.id') ) ) class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) products = sa.orm.relationship( 'Product', backref='sub_categories', secondary=subcategory_product ) return SubCategory @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) return Product @pytest.fixture def init_models(Catalog, Category, SubCategory, Product): pass @pytest.fixture def catalog(session, Catalog, Category, SubCategory, Product): sub_category = SubCategory(products=[Product()]) category = Category(sub_categories=[sub_category]) catalog = Catalog(categories=[category]) session.add(catalog) session.flush() return catalog @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForManyToManyToManyToMany(object): def test_simple_insert(self, catalog): assert catalog.product_count == 1 def test_add_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_categories[0].products.append(product) session.flush() assert catalog.product_count == 2 def test_remove_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_categories[0].products.append(product) session.flush() session.delete(product) session.flush() assert catalog.product_count == 1 def test_delete_intermediate_object(self, catalog, session): session.delete(catalog.categories[0].sub_categories[0]) session.commit() assert catalog.product_count == 0 def test_gathered_objects_are_distinct( self, session, Catalog, Category, SubCategory, Product ): catalog = Catalog() category = Category(catalogs=[catalog]) product = Product() category.sub_categories.append( SubCategory(products=[product]) ) session.add( SubCategory(categories=[category], products=[product]) ) session.commit() assert catalog.product_count == 1 sqlalchemy-utils-0.36.1/tests/observes/test_o2m_o2m_o2m.py000066400000000000000000000070131360007755400235430ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.fixture def Catalog(Base): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) product_count = sa.Column(sa.Integer, default=0) @observes('categories.sub_categories.products') def product_observer(self, products): self.product_count = len(products) categories = sa.orm.relationship('Category', backref='catalog') return Catalog @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) sub_categories = sa.orm.relationship( 'SubCategory', backref='category' ) return Category @pytest.fixture def SubCategory(Base): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) products = sa.orm.relationship( 'Product', backref='sub_category' ) return SubCategory @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( sa.Integer, sa.ForeignKey('sub_category.id') ) def __repr__(self): return '' % self.id return Product @pytest.fixture def init_models(Catalog, Category, SubCategory, Product): pass @pytest.fixture def catalog(session, Catalog, Category, SubCategory, Product): sub_category = SubCategory(products=[Product()]) category = Category(sub_categories=[sub_category]) catalog = Catalog(categories=[category]) session.add(catalog) session.commit() return catalog @pytest.mark.usefixtures('postgresql_dsn') class TestObservesFor3LevelDeepOneToMany(object): def test_simple_insert(self, catalog): assert catalog.product_count == 1 def test_add_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_categories[0].products.append(product) session.flush() assert catalog.product_count == 2 def test_remove_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_categories[0].products.append(product) session.flush() session.delete(product) session.commit() assert catalog.product_count == 1 session.delete( catalog.categories[0].sub_categories[0].products[0] ) session.commit() assert catalog.product_count == 0 def test_delete_intermediate_object(self, catalog, session): session.delete(catalog.categories[0].sub_categories[0]) session.commit() assert catalog.product_count == 0 def test_gathered_objects_are_distinct( self, session, Catalog, Category, SubCategory, Product ): catalog = Catalog() category = Category(catalog=catalog) product = Product() category.sub_categories.append( SubCategory(products=[product]) ) session.add( SubCategory(category=category, products=[product]) ) session.commit() assert catalog.product_count == 1 sqlalchemy-utils-0.36.1/tests/observes/test_o2m_o2o_o2m.py000066400000000000000000000063641360007755400235550ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.fixture def Catalog(Base): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) product_count = sa.Column(sa.Integer, default=0) @observes('categories.sub_category.products') def product_observer(self, products): self.product_count = len(products) categories = sa.orm.relationship('Category', backref='catalog') return Catalog @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) sub_category = sa.orm.relationship( 'SubCategory', uselist=False, backref='category' ) return Category @pytest.fixture def SubCategory(Base): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) products = sa.orm.relationship('Product', backref='sub_category') return SubCategory @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Numeric) sub_category_id = sa.Column( sa.Integer, sa.ForeignKey('sub_category.id') ) return Product @pytest.fixture def init_models(Catalog, Category, SubCategory, Product): pass @pytest.fixture def catalog(session, Catalog, Category, SubCategory, Product): sub_category = SubCategory(products=[Product()]) category = Category(sub_category=sub_category) catalog = Catalog(categories=[category]) session.add(catalog) session.flush() return catalog @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForOneToManyToOneToMany(object): def test_simple_insert(self, catalog): assert catalog.product_count == 1 def test_add_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_category.products.append(product) session.flush() assert catalog.product_count == 2 def test_remove_leaf_object(self, catalog, session, Product): product = Product() catalog.categories[0].sub_category.products.append(product) session.flush() session.delete(product) session.flush() assert catalog.product_count == 1 def test_delete_intermediate_object(self, catalog, session): session.delete(catalog.categories[0].sub_category) session.commit() assert catalog.product_count == 0 def test_gathered_objects_are_distinct( self, session, Catalog, Category, SubCategory, Product ): catalog = Catalog() category = Category(catalog=catalog) product = Product() category.sub_category = SubCategory(products=[product]) session.add( Category(catalog=catalog, sub_category=category.sub_category) ) session.commit() assert catalog.product_count == 1 sqlalchemy-utils-0.36.1/tests/observes/test_o2o_o2o.py000066400000000000000000000031251360007755400227720ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.fixture def Device(Base): class Device(Base): __tablename__ = 'device' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) return Device @pytest.fixture def Order(Base): class Order(Base): __tablename__ = 'order' id = sa.Column(sa.Integer, primary_key=True) device_id = sa.Column( 'device', sa.ForeignKey('device.id'), nullable=False ) device = sa.orm.relationship('Device', backref='orders') return Order @pytest.fixture def SalesInvoice(Base): class SalesInvoice(Base): __tablename__ = 'sales_invoice' id = sa.Column(sa.Integer, primary_key=True) order_id = sa.Column( 'order', sa.ForeignKey('order.id'), nullable=False ) order = sa.orm.relationship( 'Order', backref=sa.orm.backref( 'invoice', uselist=False ) ) device_name = sa.Column(sa.String) @observes('order.device') def process_device(self, device): self.device_name = device.name return SalesInvoice @pytest.fixture def init_models(Device, Order, SalesInvoice): pass @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForOneToManyToOneToMany(object): def test_observable_root_obj_is_none(self, session, Device, Order): order = Order(device=Device(name='Something')) session.add(order) session.flush() sqlalchemy-utils-0.36.1/tests/observes/test_o2o_o2o_o2o.py000066400000000000000000000051501360007755400235510ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.observer import observes @pytest.fixture def Catalog(Base): class Catalog(Base): __tablename__ = 'catalog' id = sa.Column(sa.Integer, primary_key=True) product_price = sa.Column(sa.Integer) @observes('category.sub_category.product') def product_observer(self, product): self.product_price = product.price if product else None category = sa.orm.relationship( 'Category', uselist=False, backref='catalog' ) return Catalog @pytest.fixture def Category(Base): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id')) sub_category = sa.orm.relationship( 'SubCategory', uselist=False, backref='category' ) return Category @pytest.fixture def SubCategory(Base): class SubCategory(Base): __tablename__ = 'sub_category' id = sa.Column(sa.Integer, primary_key=True) category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) product = sa.orm.relationship( 'Product', uselist=False, backref='sub_category' ) return SubCategory @pytest.fixture def Product(Base): class Product(Base): __tablename__ = 'product' id = sa.Column(sa.Integer, primary_key=True) price = sa.Column(sa.Integer) sub_category_id = sa.Column( sa.Integer, sa.ForeignKey('sub_category.id') ) return Product @pytest.fixture def init_models(Catalog, Category, SubCategory, Product): pass @pytest.fixture def catalog(session, Catalog, Category, SubCategory, Product): sub_category = SubCategory(product=Product(price=123)) category = Category(sub_category=sub_category) catalog = Catalog(category=category) session.add(catalog) session.flush() return catalog @pytest.mark.usefixtures('postgresql_dsn') class TestObservesForOneToOneToOneToOne(object): def test_simple_insert(self, catalog): assert catalog.product_price == 123 def test_replace_leaf_object(self, catalog, session, Product): product = Product(price=44) catalog.category.sub_category.product = product session.flush() assert catalog.product_price == 44 def test_delete_leaf_object(self, catalog, session): session.delete(catalog.category.sub_category.product) session.flush() assert catalog.product_price is None sqlalchemy-utils-0.36.1/tests/primitives/000077500000000000000000000000001360007755400204455ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/primitives/__init__.py000066400000000000000000000000001360007755400225440ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/primitives/test_country.py000066400000000000000000000056051360007755400235670ustar00rootroot00000000000000import operator import pytest import six from sqlalchemy_utils import Country, i18n @pytest.fixture def set_get_locale(): i18n.get_locale = lambda: i18n.babel.Locale('en') @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('set_get_locale') class TestCountry(object): def test_init(self): assert Country(u'FI') == Country(Country(u'FI')) def test_constructor_with_wrong_type(self): with pytest.raises(TypeError) as e: Country(None) assert str(e.value) == ( "Country() argument must be a string or a country, not 'NoneType'" ) def test_constructor_with_invalid_code(self): with pytest.raises(ValueError) as e: Country('SomeUnknownCode') assert str(e.value) == ( 'Could not convert string to country code: SomeUnknownCode' ) @pytest.mark.parametrize( 'code', ( 'FI', 'US', ) ) def test_validate_with_valid_codes(self, code): Country.validate(code) def test_validate_with_invalid_code(self): with pytest.raises(ValueError) as e: Country.validate('SomeUnknownCode') assert str(e.value) == ( 'Could not convert string to country code: SomeUnknownCode' ) def test_equality_operator(self): assert Country(u'FI') == u'FI' assert u'FI' == Country(u'FI') assert Country(u'FI') == Country(u'FI') def test_non_equality_operator(self): assert Country(u'FI') != u'sv' assert not (Country(u'FI') != u'FI') @pytest.mark.parametrize( 'op, code_left, code_right, is_', [ (operator.lt, u'ES', u'FI', True), (operator.lt, u'FI', u'ES', False), (operator.lt, u'ES', u'ES', False), (operator.le, u'ES', u'FI', True), (operator.le, u'FI', u'ES', False), (operator.le, u'ES', u'ES', True), (operator.ge, u'ES', u'FI', False), (operator.ge, u'FI', u'ES', True), (operator.ge, u'ES', u'ES', True), (operator.gt, u'ES', u'FI', False), (operator.gt, u'FI', u'ES', True), (operator.gt, u'ES', u'ES', False), ] ) def test_ordering(self, op, code_left, code_right, is_): country_left = Country(code_left) country_right = Country(code_right) assert op(country_left, country_right) is is_ assert op(country_left, code_right) is is_ assert op(code_left, country_right) is is_ def test_hash(self): return hash(Country('FI')) == hash('FI') def test_repr(self): return repr(Country('FI')) == "Country('FI')" def test_unicode(self): country = Country('FI') assert six.text_type(country) == u'Finland' def test_str(self): country = Country('FI') assert str(country) == 'Finland' sqlalchemy-utils-0.36.1/tests/primitives/test_currency.py000066400000000000000000000034371360007755400237170ustar00rootroot00000000000000# -*- coding: utf-8 -*- import pytest import six from sqlalchemy_utils import Currency, i18n @pytest.fixture def set_get_locale(): i18n.get_locale = lambda: i18n.babel.Locale('en') @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('set_get_locale') class TestCurrency(object): def test_init(self): assert Currency('USD') == Currency(Currency('USD')) def test_hashability(self): assert len(set([Currency('USD'), Currency('USD')])) == 1 def test_invalid_currency_code(self): with pytest.raises(ValueError): Currency('Unknown code') def test_invalid_currency_code_type(self): with pytest.raises(TypeError): Currency(None) @pytest.mark.parametrize( ('code', 'name'), ( ('USD', 'US Dollar'), ('EUR', 'Euro') ) ) def test_name_property(self, code, name): assert Currency(code).name == name @pytest.mark.parametrize( ('code', 'symbol'), ( ('USD', u'$'), ('EUR', u'€') ) ) def test_symbol_property(self, code, symbol): assert Currency(code).symbol == symbol def test_equality_operator(self): assert Currency('USD') == 'USD' assert 'USD' == Currency('USD') assert Currency('USD') == Currency('USD') def test_non_equality_operator(self): assert Currency('USD') != 'EUR' assert not (Currency('USD') != 'USD') def test_unicode(self): currency = Currency('USD') assert six.text_type(currency) == u'USD' def test_str(self): currency = Currency('USD') assert str(currency) == 'USD' def test_representation(self): currency = Currency('USD') assert repr(currency) == "Currency('USD')" sqlalchemy-utils-0.36.1/tests/primitives/test_ltree.py000066400000000000000000000127721360007755400232020ustar00rootroot00000000000000# -*- coding: utf-8 -*- import pytest import six from sqlalchemy_utils import Ltree class TestLtree(object): def test_init(self): assert Ltree('path.path') == Ltree(Ltree('path.path')) def test_constructor_with_wrong_type(self): with pytest.raises(TypeError) as e: Ltree(None) assert str(e.value) == ( "Ltree() argument must be a string or an Ltree, not 'NoneType'" ) def test_constructor_with_invalid_code(self): with pytest.raises(ValueError) as e: Ltree('..') assert str(e.value) == "'..' is not a valid ltree path." @pytest.mark.parametrize( 'code', ( 'path', 'path.path', '1_.2', '_._', ) ) def test_validate_with_valid_codes(self, code): Ltree.validate(code) @pytest.mark.parametrize( 'path', ( '', '.', 'path.', 'path..path', 'path.path..path', 'path.path..', 'path.äö', ) ) def test_validate_with_invalid_path(self, path): with pytest.raises(ValueError) as e: Ltree.validate(path) assert str(e.value) == ( "'{0}' is not a valid ltree path.".format(path) ) @pytest.mark.parametrize( ('path', 'length'), ( ('path', 1), ('1.1', 2), ('1.2.3', 3) ) ) def test_length(self, path, length): assert len(Ltree(path)) == length @pytest.mark.parametrize( ('path', 'subpath', 'index'), ( ('path.path', 'path', 0), ('1.2.3', '2.3', 1), ('1.2.3.4', '2.3', 1), ('1.2.3.4', '3.4', 2) ) ) def test_index(self, path, subpath, index): assert Ltree(path).index(subpath) == index @pytest.mark.parametrize( ('path', 'item_slice', 'result'), ( ('path.path', 0, 'path'), ('1.1.2.3', slice(1, 3), '1.2'), ('1.1.2.3', slice(1, None), '1.2.3'), ) ) def test_getitem(self, path, item_slice, result): assert Ltree(path)[item_slice] == result @pytest.mark.parametrize( ('path', 'others', 'result'), ( ('1.2.3', ['1.2.3', '1.2'], '1'), ('1.2.3.4.5', ['1.2.3', '1.2.3.4'], '1.2'), ('1.2.3.4.5', ['3.4', '1.2.3.4'], None), ) ) def test_lca(self, path, others, result): assert Ltree(path).lca(*others) == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1.2.3', '4.5', '1.2.3.4.5'), ('1', '1', '1.1'), ) ) def test_add(self, path, other, result): assert Ltree(path) + other == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1.2.3', '4.5', '4.5.1.2.3'), ('1', '1', '1.1'), ) ) def test_radd(self, path, other, result): assert other + Ltree(path) == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1.2.3', '4.5', '1.2.3.4.5'), ('1', '1', '1.1'), ) ) def test_iadd(self, path, other, result): ltree = Ltree(path) ltree += other assert ltree == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1.2.3', '2', True), ('1.2.3', '3', True), ('1', '1', True), ('1', '2', False), ) ) def test_contains(self, path, other, result): assert (other in Ltree(path)) == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1', '1.2.3', True), ('1.2', '1.2.3', True), ('1.2.3', '1.2.3', True), ('1.2.3', '1', False), ('1.2.3', '1.2', False), ('1', '1', True), ('1', '2', False), ) ) def test_ancestor_of(self, path, other, result): assert Ltree(path).ancestor_of(other) == result @pytest.mark.parametrize( ('path', 'other', 'result'), ( ('1', '1.2.3', False), ('1.2', '1.2.3', False), ('1.2', '1.2.3', False), ('1.2.3', '1', True), ('1.2.3', '1.2', True), ('1.2.3', '1.2.3', True), ('1', '1', True), ('1', '2', False), ) ) def test_descendant_of(self, path, other, result): assert Ltree(path).descendant_of(other) == result def test_getitem_with_other_than_slice_or_in(self): with pytest.raises(TypeError): Ltree('1.2')['something'] def test_index_raises_value_error_if_subpath_not_found(self): with pytest.raises(ValueError): Ltree('1.2').index('3') def test_equality_operator(self): assert Ltree('path.path') == 'path.path' assert 'path.path' == Ltree('path.path') assert Ltree('path.path') == Ltree('path.path') def test_non_equality_operator(self): assert Ltree('path.path') != u'path.' assert not (Ltree('path.path') != 'path.path') def test_hash(self): return hash(Ltree('path')) == hash('path') def test_repr(self): return repr(Ltree('path')) == "Ltree('path')" def test_unicode(self): ltree = Ltree('path.path') assert six.text_type(ltree) == 'path.path' def test_str(self): ltree = Ltree('path') assert str(ltree) == 'path' sqlalchemy-utils-0.36.1/tests/primitives/test_weekdays.py000066400000000000000000000117141360007755400236760ustar00rootroot00000000000000import pytest import six from flexmock import flexmock from sqlalchemy_utils import i18n from sqlalchemy_utils.primitives import WeekDay, WeekDays @pytest.fixture def set_get_locale(): i18n.get_locale = lambda: i18n.babel.Locale('fi') @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('set_get_locale') class TestWeekDay(object): def test_constructor_with_valid_index(self): day = WeekDay(1) assert day.index == 1 @pytest.mark.parametrize('index', [-1, 7]) def test_constructor_with_invalid_index(self, index): with pytest.raises(ValueError): WeekDay(index) def test_equality_with_equal_week_day(self): day = WeekDay(1) day2 = WeekDay(1) assert day == day2 def test_equality_with_unequal_week_day(self): day = WeekDay(1) day2 = WeekDay(2) assert day != day2 def test_equality_with_unsupported_comparison(self): day = WeekDay(1) assert day != 'foobar' def test_hash_is_equal_to_index_hash(self): day = WeekDay(1) assert hash(day) == hash(day.index) def test_representation(self): day = WeekDay(1) assert repr(day) == "WeekDay(1)" @pytest.mark.parametrize( ('index', 'first_week_day', 'position'), [ (0, 0, 0), (1, 0, 1), (6, 0, 6), (0, 6, 1), (1, 6, 2), (6, 6, 0), ] ) def test_position(self, index, first_week_day, position): i18n.get_locale = flexmock(first_week_day=first_week_day) day = WeekDay(index) assert day.position == position def test_get_name_returns_localized_week_day_name(self): day = WeekDay(0) assert day.get_name() == u'maanantaina' def test_override_get_locale_as_class_method(self): day = WeekDay(0) assert day.get_name() == u'maanantaina' def test_name_delegates_to_get_name(self): day = WeekDay(0) flexmock(day).should_receive('get_name').and_return(u'maanantaina') assert day.name == u'maanantaina' def test_unicode(self): day = WeekDay(0) flexmock(day).should_receive('name').and_return(u'maanantaina') assert six.text_type(day) == u'maanantaina' def test_str(self): day = WeekDay(0) flexmock(day).should_receive('name').and_return(u'maanantaina') assert str(day) == 'maanantaina' @pytest.mark.skipif('i18n.babel is None') class TestWeekDays(object): def test_constructor_with_valid_bit_string(self): days = WeekDays('1000100') assert days._days == set([WeekDay(0), WeekDay(4)]) @pytest.mark.parametrize( 'bit_string', [ '000000', # too short '00000000', # too long ] ) def test_constructor_with_bit_string_of_invalid_length(self, bit_string): with pytest.raises(ValueError): WeekDays(bit_string) def test_constructor_with_bit_string_containing_invalid_characters(self): with pytest.raises(ValueError): WeekDays('foobarz') def test_constructor_with_another_week_days_object(self): days = WeekDays('0000000') another_days = WeekDays(days) assert days._days == another_days._days def test_representation(self): days = WeekDays('0000000') assert repr(days) == "WeekDays('0000000')" @pytest.mark.parametrize( 'bit_string', [ '0000000', '1000000', '0000001', '0101000', '1111111', ] ) def test_as_bit_string(self, bit_string): days = WeekDays(bit_string) assert days.as_bit_string() == bit_string def test_equality_with_equal_week_days_object(self): days = WeekDays('0001000') days2 = WeekDays('0001000') assert days == days2 def test_equality_with_unequal_week_days_object(self): days = WeekDays('0001000') days2 = WeekDays('1000000') assert days != days2 def test_equality_with_equal_bit_string(self): days = WeekDays('0001000') assert days == '0001000' def test_equality_with_unequal_bit_string(self): days = WeekDays('0001000') assert days != '0101000' def test_equality_with_unsupported_comparison(self): days = WeekDays('0001000') assert days != 0 def test_iterator_starts_from_locales_first_week_day(self): i18n.get_locale = lambda: flexmock(first_week_day=1) days = WeekDays('1111111') indices = list(day.index for day in days) assert indices == [1, 2, 3, 4, 5, 6, 0] def test_unicode(self): i18n.get_locale = lambda: i18n.babel.Locale('fi') days = WeekDays('1000100') assert six.text_type(days) == u'maanantaina, perjantaina' def test_str(self): i18n.get_locale = lambda: i18n.babel.Locale('fi') days = WeekDays('1000100') assert str(days) == 'maanantaina, perjantaina' sqlalchemy-utils-0.36.1/tests/relationships/000077500000000000000000000000001360007755400211365ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/relationships/__init__.py000066400000000000000000000000001360007755400232350ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/relationships/test_chained_join.py000066400000000000000000000070311360007755400251620ustar00rootroot00000000000000import pytest from sqlalchemy_utils.relationships import chained_join from ..mixins import ( ThreeLevelDeepManyToMany, ThreeLevelDeepOneToMany, ThreeLevelDeepOneToOne ) @pytest.mark.usefixtures('postgresql_dsn') class TestChainedJoinFoDeepToManyToMany(ThreeLevelDeepManyToMany): def test_simple_join(self, Catalog): assert str(chained_join(Catalog.categories)) == ( 'catalog_category JOIN category ON ' 'category._id = catalog_category.category_id' ) def test_two_relations(self, Catalog, Category): sql = chained_join( Catalog.categories, Category.sub_categories ) assert str(sql) == ( 'catalog_category JOIN category ON category._id = ' 'catalog_category.category_id JOIN category_subcategory ON ' 'category._id = category_subcategory.category_id JOIN ' 'sub_category ON sub_category._id = ' 'category_subcategory.subcategory_id' ) def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( Catalog.categories, Category.sub_categories, SubCategory.products ) assert str(sql) == ( 'catalog_category JOIN category ON category._id = ' 'catalog_category.category_id JOIN category_subcategory ON ' 'category._id = category_subcategory.category_id JOIN sub_category' ' ON sub_category._id = category_subcategory.subcategory_id JOIN ' 'subcategory_product ON sub_category._id = ' 'subcategory_product.subcategory_id JOIN product ON product._id =' ' subcategory_product.product_id' ) @pytest.mark.usefixtures('postgresql_dsn') class TestChainedJoinForDeepOneToMany(ThreeLevelDeepOneToMany): def test_simple_join(self, Catalog): assert str(chained_join(Catalog.categories)) == 'category' def test_two_relations(self, Catalog, Category): sql = chained_join( Catalog.categories, Category.sub_categories ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id' ) def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( Catalog.categories, Category.sub_categories, SubCategory.products ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id JOIN product ON sub_category._id = ' 'product._sub_category_id' ) @pytest.mark.usefixtures('postgresql_dsn') class TestChainedJoinForDeepOneToOne(ThreeLevelDeepOneToOne): def test_simple_join(self, Catalog): assert str(chained_join(Catalog.category)) == 'category' def test_two_relations(self, Catalog, Category): sql = chained_join( Catalog.category, Category.sub_category ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id' ) def test_three_relations(self, Catalog, Category, SubCategory): sql = chained_join( Catalog.category, Category.sub_category, SubCategory.product ) assert str(sql) == ( 'category JOIN sub_category ON category._id = ' 'sub_category._category_id JOIN product ON sub_category._id = ' 'product._sub_category_id' ) sqlalchemy-utils-0.36.1/tests/relationships/test_select_correlated_expression.py000066400000000000000000000257071360007755400305240ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy_utils.relationships import select_correlated_expression @pytest.fixture def group_user_tbl(Base): return sa.Table( 'group_user', Base.metadata, sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')), sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id')) ) @pytest.fixture def group_tbl(Base): class Group(Base): __tablename__ = 'group' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) return Group @pytest.fixture def friendship_tbl(Base): return sa.Table( 'friendships', Base.metadata, sa.Column( 'friend_a_id', sa.Integer, sa.ForeignKey('user.id'), primary_key=True ), sa.Column( 'friend_b_id', sa.Integer, sa.ForeignKey('user.id'), primary_key=True ) ) @pytest.fixture def User(Base, group_user_tbl, friendship_tbl): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) groups = sa.orm.relationship( 'Group', secondary=group_user_tbl, backref='users' ) # this relationship is used for persistence friends = sa.orm.relationship( 'User', secondary=friendship_tbl, primaryjoin=id == friendship_tbl.c.friend_a_id, secondaryjoin=id == friendship_tbl.c.friend_b_id, ) friendship_union = ( sa.select([ friendship_tbl.c.friend_a_id, friendship_tbl.c.friend_b_id ]).union( sa.select([ friendship_tbl.c.friend_b_id, friendship_tbl.c.friend_a_id] ) ).alias() ) User.all_friends = sa.orm.relationship( 'User', secondary=friendship_union, primaryjoin=User.id == friendship_union.c.friend_a_id, secondaryjoin=User.id == friendship_union.c.friend_b_id, viewonly=True, order_by=User.id ) return User @pytest.fixture def Category(Base, group_user_tbl, friendship_tbl): class Category(Base): __tablename__ = 'category' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) created_at = sa.Column(sa.DateTime) parent_id = sa.Column(sa.Integer, sa.ForeignKey('category.id')) parent = sa.orm.relationship( 'Category', backref='subcategories', remote_side=[id], order_by=id ) return Category @pytest.fixture def Article(Base, Category, User): class Article(Base): __tablename__ = 'article' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column(sa.String) name_synonym = sa.orm.synonym('name') @hybrid_property def name_upper(self): return self.name.upper() if self.name else None @name_upper.expression def name_upper(cls): return sa.func.upper(cls.name) content = sa.Column(sa.String) category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) category = sa.orm.relationship(Category, backref='articles') author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship( User, primaryjoin=author_id == User.id, backref='authored_articles' ) owner_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) owner = sa.orm.relationship( User, primaryjoin=owner_id == User.id, backref='owned_articles' ) return Article @pytest.fixture def Comment(Base, Article, User): class Comment(Base): __tablename__ = 'comment' id = sa.Column(sa.Integer, primary_key=True) content = sa.Column(sa.String) article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) article = sa.orm.relationship(Article, backref='comments') author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship(User, backref='comments') Article.comment_count = sa.orm.column_property( sa.select([sa.func.count(Comment.id)]) .where(Comment.article_id == Article.id) .correlate_except(Article) ) return Comment @pytest.fixture def model_mapping(Article, Category, Comment, group_tbl, User): return { 'articles': Article, 'categories': Category, 'comments': Comment, 'groups': group_tbl, 'users': User } @pytest.fixture def init_models(Article, Category, Comment, group_tbl, User): pass @pytest.fixture def dataset( session, User, group_tbl, Article, Category, Comment ): group = group_tbl(name='Group 1') group2 = group_tbl(name='Group 2') user = User(id=1, name='User 1', groups=[group, group2]) user2 = User(id=2, name='User 2') user3 = User(id=3, name='User 3', groups=[group]) user4 = User(id=4, name='User 4', groups=[group2]) user5 = User(id=5, name='User 5') user.friends = [user2] user2.friends = [user3, user4] user3.friends = [user5] article = Article( name='Some article', author=user, owner=user2, category=Category( id=1, name='Some category', subcategories=[ Category( id=2, name='Subcategory 1', subcategories=[ Category( id=3, name='Subsubcategory 1', subcategories=[ Category( id=5, name='Subsubsubcategory 1', ), Category( id=6, name='Subsubsubcategory 2', ) ] ) ] ), Category(id=4, name='Subcategory 2'), ] ), comments=[ Comment( content='Some comment', author=user ) ] ) session.add(user3) session.add(user4) session.add(article) session.commit() @pytest.mark.usefixtures('dataset', 'postgresql_dsn') class TestSelectCorrelatedExpression(object): @pytest.mark.parametrize( ('model_key', 'related_model_key', 'path', 'result'), ( ( 'categories', 'categories', 'subcategories', [ (1, 2), (2, 1), (3, 2), (4, 0), (5, 0), (6, 0) ] ), ( 'articles', 'comments', 'comments', [ (1, 1), ] ), ( 'users', 'groups', 'groups', [ (1, 2), (2, 0), (3, 1), (4, 1), (5, 0) ] ), ( 'users', 'users', 'all_friends', [ (1, 1), (2, 3), (3, 2), (4, 1), (5, 1) ] ), ( 'users', 'users', 'all_friends.all_friends', [ (1, 3), (2, 2), (3, 3), (4, 3), (5, 2) ] ), ( 'users', 'users', 'groups.users', [ (1, 3), (2, 0), (3, 2), (4, 2), (5, 0) ] ), ( 'groups', 'articles', 'users.authored_articles', [ (1, 1), (2, 1), ] ), ( 'categories', 'categories', 'subcategories.subcategories', [ (1, 1), (2, 2), (3, 0), (4, 0), (5, 0), (6, 0) ] ), ( 'categories', 'categories', 'subcategories.subcategories.subcategories', [ (1, 2), (2, 0), (3, 0), (4, 0), (5, 0), (6, 0) ] ), ) ) def test_returns_correct_results( self, session, model_mapping, model_key, related_model_key, path, result ): model = model_mapping[model_key] alias = sa.orm.aliased(model_mapping[related_model_key]) aggregate = select_correlated_expression( model, sa.func.count(sa.distinct(alias.id)), path, alias ) query = session.query( model.id, aggregate.label('count') ).order_by(model.id) assert query.all() == result def test_order_by_intermediate_table_column( self, session, model_mapping, group_user_tbl ): model = model_mapping['users'] alias = sa.orm.aliased(model_mapping['groups']) aggregate = select_correlated_expression( model, sa.func.json_build_object('id', alias.id), 'groups', alias, order_by=[group_user_tbl.c.user_id] ).alias('test') # Just check that the query execution doesn't fail because of wrongly # constructed aliases assert session.execute(aggregate) def test_with_non_aggregate_function( self, session, User, Article ): aggregate = select_correlated_expression( Article, sa.func.json_build_object('name', User.name), 'comments.author', User ) query = session.query( Article.id, aggregate.label('author_json') ).order_by(Article.id) result = query.all() assert result == [ (1, {'name': 'User 1'}) ] sqlalchemy-utils-0.36.1/tests/test_asserts.py000066400000000000000000000115451360007755400213550ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy_utils import ( assert_max_length, assert_max_value, assert_min_value, assert_non_nullable, assert_nullable ) @pytest.fixture() def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column('_id', sa.Integer, primary_key=True) name = sa.Column('_name', sa.String(20)) age = sa.Column('_age', sa.Integer, nullable=False) email = sa.Column( '_email', sa.String(200), nullable=False, unique=True ) fav_numbers = sa.Column('_fav_numbers', ARRAY(sa.Integer)) __table_args__ = ( sa.CheckConstraint(sa.and_(age >= 0, age <= 150)), sa.CheckConstraint( sa.and_( sa.func.array_length(fav_numbers, 1) <= 8 ) ) ) return User @pytest.fixture() def user(User, session): user = User( name='Someone', email='someone@example.com', age=15, fav_numbers=[1, 2, 3] ) session.add(user) session.commit() return user @pytest.mark.usefixtures('postgresql_dsn') class TestAssertMaxLengthWithArray(object): def test_with_max_length(self, user): assert_max_length(user, 'fav_numbers', 8) assert_max_length(user, 'fav_numbers', 8) def test_smaller_than_max_length(self, user): with pytest.raises(AssertionError): assert_max_length(user, 'fav_numbers', 7) with pytest.raises(AssertionError): assert_max_length(user, 'fav_numbers', 7) def test_bigger_than_max_length(self, user): with pytest.raises(AssertionError): assert_max_length(user, 'fav_numbers', 9) with pytest.raises(AssertionError): assert_max_length(user, 'fav_numbers', 9) @pytest.mark.usefixtures('postgresql_dsn') class TestAssertNonNullable(object): def test_non_nullable_column(self, user): # Test everything twice so that session gets rolled back properly assert_non_nullable(user, 'age') assert_non_nullable(user, 'age') def test_nullable_column(self, user): with pytest.raises(AssertionError): assert_non_nullable(user, 'name') with pytest.raises(AssertionError): assert_non_nullable(user, 'name') @pytest.mark.usefixtures('postgresql_dsn') class TestAssertNullable(object): def test_nullable_column(self, user): assert_nullable(user, 'name') assert_nullable(user, 'name') def test_non_nullable_column(self, user): with pytest.raises(AssertionError): assert_nullable(user, 'age') with pytest.raises(AssertionError): assert_nullable(user, 'age') @pytest.mark.usefixtures('postgresql_dsn') class TestAssertMaxLength(object): def test_with_max_length(self, user): assert_max_length(user, 'name', 20) assert_max_length(user, 'name', 20) def test_with_non_nullable_column(self, user): assert_max_length(user, 'email', 200) assert_max_length(user, 'email', 200) def test_smaller_than_max_length(self, user): with pytest.raises(AssertionError): assert_max_length(user, 'name', 19) with pytest.raises(AssertionError): assert_max_length(user, 'name', 19) def test_bigger_than_max_length(self, user): with pytest.raises(AssertionError): assert_max_length(user, 'name', 21) with pytest.raises(AssertionError): assert_max_length(user, 'name', 21) @pytest.mark.usefixtures('postgresql_dsn') class TestAssertMinValue(object): def test_with_min_value(self, user): assert_min_value(user, 'age', 0) assert_min_value(user, 'age', 0) def test_smaller_than_min_value(self, user): with pytest.raises(AssertionError): assert_min_value(user, 'age', -1) with pytest.raises(AssertionError): assert_min_value(user, 'age', -1) def test_bigger_than_min_value(self, user): with pytest.raises(AssertionError): assert_min_value(user, 'age', 1) with pytest.raises(AssertionError): assert_min_value(user, 'age', 1) @pytest.mark.usefixtures('postgresql_dsn') class TestAssertMaxValue(object): def test_with_min_value(self, user): assert_max_value(user, 'age', 150) assert_max_value(user, 'age', 150) def test_smaller_than_max_value(self, user): with pytest.raises(AssertionError): assert_max_value(user, 'age', 149) with pytest.raises(AssertionError): assert_max_value(user, 'age', 149) def test_bigger_than_max_value(self, user): with pytest.raises(AssertionError): assert_max_value(user, 'age', 151) with pytest.raises(AssertionError): assert_max_value(user, 'age', 151) sqlalchemy-utils-0.36.1/tests/test_auto_delete_orphans.py000066400000000000000000000053061360007755400237130ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.orm import backref from sqlalchemy_utils import auto_delete_orphans, ImproperlyConfigured @pytest.fixture def tagging_tbl(Base): return sa.Table( 'tagging', Base.metadata, sa.Column( 'tag_id', sa.Integer, sa.ForeignKey('tag.id', ondelete='cascade'), primary_key=True ), sa.Column( 'entry_id', sa.Integer, sa.ForeignKey('entry.id', ondelete='cascade'), primary_key=True ) ) @pytest.fixture def Tag(Base): class Tag(Base): __tablename__ = 'tag' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String(100), unique=True, nullable=False) def __init__(self, name=None): self.name = name return Tag @pytest.fixture( params=['entries', backref('entries', lazy='select')], ids=['backref_string', 'backref_with_keywords'] ) def Entry(Base, Tag, tagging_tbl, request): class Entry(Base): __tablename__ = 'entry' id = sa.Column(sa.Integer, primary_key=True) tags = sa.orm.relationship( Tag, secondary=tagging_tbl, backref=request.param ) auto_delete_orphans(Entry.tags) return Entry @pytest.fixture def EntryWithoutTagsBackref(Base, Tag, tagging_tbl): class EntryWithoutTagsBackref(Base): __tablename__ = 'entry' id = sa.Column(sa.Integer, primary_key=True) tags = sa.orm.relationship( Tag, secondary=tagging_tbl ) return EntryWithoutTagsBackref class TestAutoDeleteOrphans(object): @pytest.fixture def init_models(self, Entry, Tag): pass def test_orphan_deletion(self, session, Entry, Tag): r1 = Entry() r2 = Entry() r3 = Entry() t1, t2, t3, t4 = ( Tag('t1'), Tag('t2'), Tag('t3'), Tag('t4') ) r1.tags.extend([t1, t2]) r2.tags.extend([t2, t3]) r3.tags.extend([t4]) session.add_all([r1, r2, r3]) assert session.query(Tag).count() == 4 r2.tags.remove(t2) assert session.query(Tag).count() == 4 r1.tags.remove(t2) assert session.query(Tag).count() == 3 r1.tags.remove(t1) assert session.query(Tag).count() == 2 class TestAutoDeleteOrphansWithoutBackref(object): @pytest.fixture def init_models(self, EntryWithoutTagsBackref, Tag): pass def test_orphan_deletion(self, EntryWithoutTagsBackref): with pytest.raises(ImproperlyConfigured): auto_delete_orphans(EntryWithoutTagsBackref.tags) sqlalchemy-utils-0.36.1/tests/test_case_insensitive_comparator.py000066400000000000000000000026341360007755400254520ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import EmailType @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) email = sa.Column(EmailType) def __repr__(self): return 'Building(%r)' % self.id return User @pytest.fixture def init_models(User): pass class TestCaseInsensitiveComparator(object): def test_supports_equals(self, session, User): query = ( session.query(User) .filter(User.email == u'email@example.com') ) assert 'user.email = lower(?)' in str(query) def test_supports_in_(self, session, User): query = ( session.query(User) .filter(User.email.in_([u'email@example.com', u'a'])) ) assert ( 'user.email IN (lower(?), lower(?))' in str(query) ) def test_supports_notin_(self, session, User): query = ( session.query(User) .filter(User.email.notin_([u'email@example.com', u'a'])) ) assert ( 'user.email NOT IN (lower(?), lower(?))' in str(query) ) def test_does_not_apply_lower_to_types_that_are_already_lowercased( self, User ): assert str(User.email == User.email) == ( '"user".email = "user".email' ) sqlalchemy-utils-0.36.1/tests/test_expressions.py000066400000000000000000000035151360007755400222510ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy_utils import Asterisk, row_to_json @pytest.fixture def assert_startswith(session): def assert_startswith(query, query_part): assert str( query.compile(dialect=postgresql.dialect()) ).startswith(query_part) # Check that query executes properly session.execute(query) return assert_startswith @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) content = sa.Column(sa.UnicodeText) return Article class TestAsterisk(object): def test_with_table_object(self): Base = sa.ext.declarative.declarative_base() class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) assert str(Asterisk(Article.__table__)) == 'article.*' def test_with_quoted_identifier(self): Base = sa.ext.declarative.declarative_base() class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) assert str(Asterisk(User.__table__).compile( dialect=postgresql.dialect() )) == '"user".*' class TestRowToJson(object): def test_compiler_with_default_dialect(self): assert str(row_to_json(sa.text('article.*'))) == ( 'row_to_json(article.*)' ) def test_compiler_with_postgresql(self): assert str(row_to_json(sa.text('article.*')).compile( dialect=postgresql.dialect() )) == 'row_to_json(article.*)' def test_type(self): assert isinstance( sa.func.row_to_json(sa.text('article.*')).type, postgresql.JSON ) sqlalchemy-utils-0.36.1/tests/test_instant_defaults_listener.py000066400000000000000000000014201360007755400251340ustar00rootroot00000000000000from datetime import datetime import pytest import sqlalchemy as sa from sqlalchemy_utils.listeners import force_instant_defaults force_instant_defaults() @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255), default=u'Some article') created_at = sa.Column(sa.DateTime, default=datetime.now) return Article class TestInstantDefaultListener(object): def test_assigns_defaults_on_object_construction(self, Article): article = Article() assert article.name == u'Some article' def test_callables_as_defaults(self, Article): article = Article() assert isinstance(article.created_at, datetime) sqlalchemy-utils-0.36.1/tests/test_instrumented_list.py000066400000000000000000000010721360007755400234370ustar00rootroot00000000000000class TestInstrumentedList(object): def test_any_returns_true_if_member_has_attr_defined( self, Category, Article ): category = Category() category.articles.append(Article()) category.articles.append(Article(name=u'some name')) assert category.articles.any('name') def test_any_returns_false_if_no_member_has_attr_defined( self, Category, Article ): category = Category() category.articles.append(Article()) assert not category.articles.any('name') sqlalchemy-utils-0.36.1/tests/test_models.py000066400000000000000000000052551360007755400211550ustar00rootroot00000000000000import sys from datetime import datetime import pytest import sqlalchemy as sa from sqlalchemy_utils import generic_repr, Timestamp class TestTimestamp(object): @pytest.fixture def Article(self, Base): class Article(Base, Timestamp): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255), default=u'Some article') return Article def test_created(self, session, Article): then = datetime.utcnow() article = Article() session.add(article) session.commit() assert article.created >= then and article.created <= datetime.utcnow() def test_updated(self, session, Article): article = Article() session.add(article) session.commit() then = datetime.utcnow() article.name = u"Something" session.commit() assert article.updated >= then and article.updated <= datetime.utcnow() class TestGenericRepr: @pytest.fixture def Article(self, Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255), default=u'Some article') return Article def test_repr(self, Article): """Representation of a basic model.""" Article = generic_repr(Article) article = Article(id=1, name=u'Foo') if sys.version_info[0] == 2: expected_repr = u'Article(id=1, name=u\'Foo\')' elif sys.version_info[0] == 3: expected_repr = u'Article(id=1, name=\'Foo\')' else: raise AssertionError actual_repr = repr(article) assert actual_repr == expected_repr def test_repr_partial(self, Article): """Representation of a basic model with selected fields.""" Article = generic_repr('id')(Article) article = Article(id=1, name=u'Foo') expected_repr = u'Article(id=1)' actual_repr = repr(article) assert actual_repr == expected_repr def test_not_loaded(self, session, Article): """:py:func:`~sqlalchemy_utils.models.generic_repr` doesn't force execution of additional queries if some fields are not loaded and instead represents them as "". """ Article = generic_repr(Article) article = Article(name=u'Foo') session.add(article) session.commit() article = session.query(Article).options(sa.orm.defer('name')).one() actual_repr = repr(article) expected_repr = u'Article(id={}, name=)'.format(article.id) assert actual_repr == expected_repr sqlalchemy-utils-0.36.1/tests/test_path.py000066400000000000000000000125351360007755400206250ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy.util.langhelpers import symbol from sqlalchemy_utils.path import AttrPath, Path @pytest.fixture def Document(Base): class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) locale = sa.Column(sa.String(10)) return Document @pytest.fixture def Section(Base, Document): class Section(Base): __tablename__ = 'section' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) locale = sa.Column(sa.String(10)) document_id = sa.Column( sa.Integer, sa.ForeignKey(Document.id) ) document = sa.orm.relationship(Document, backref='sections') return Section @pytest.fixture def SubSection(Base, Section): class SubSection(Base): __tablename__ = 'subsection' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) locale = sa.Column(sa.String(10)) section_id = sa.Column( sa.Integer, sa.ForeignKey(Section.id) ) section = sa.orm.relationship(Section, backref='subsections') return SubSection class TestAttrPath(object): @pytest.fixture def init_models(self, Document, Section, SubSection): pass def test_direction(self, SubSection): assert ( AttrPath(SubSection, 'section').direction == symbol('MANYTOONE') ) def test_invert(self, Document, Section, SubSection): path = ~ AttrPath(SubSection, 'section.document') assert path.parts == [ Document.sections, Section.subsections ] assert str(path.path) == 'sections.subsections' def test_len(self, SubSection): len(AttrPath(SubSection, 'section.document')) == 2 def test_init(self, SubSection): path = AttrPath(SubSection, 'section.document') assert path.class_ == SubSection assert path.path == Path('section.document') def test_iter(self, Section, SubSection): path = AttrPath(SubSection, 'section.document') assert list(path) == [ SubSection.section, Section.document ] def test_repr(self, SubSection): path = AttrPath(SubSection, 'section.document') assert repr(path) == ( "AttrPath(SubSection, 'section.document')" ) def test_index(self, Section, SubSection): path = AttrPath(SubSection, 'section.document') assert path.index(Section.document) == 1 assert path.index(SubSection.section) == 0 def test_getitem(self, Section, SubSection): path = AttrPath(SubSection, 'section.document') assert path[0] is SubSection.section assert path[1] is Section.document def test_getitem_with_slice(self, Section, SubSection): path = AttrPath(SubSection, 'section.document') assert path[:] == AttrPath(SubSection, 'section.document') assert path[:-1] == AttrPath(SubSection, 'section') assert path[1:] == AttrPath(Section, 'document') def test_eq(self, SubSection): assert ( AttrPath(SubSection, 'section.document') == AttrPath(SubSection, 'section.document') ) assert not ( AttrPath(SubSection, 'section') == AttrPath(SubSection, 'section.document') ) def test_ne(self, SubSection): assert not ( AttrPath(SubSection, 'section.document') != AttrPath(SubSection, 'section.document') ) assert ( AttrPath(SubSection, 'section') != AttrPath(SubSection, 'section.document') ) class TestPath(object): def test_init(self): path = Path('attr.attr2') assert path.path == 'attr.attr2' def test_init_with_path_object(self): path = Path(Path('attr.attr2')) assert path.path == 'attr.attr2' def test_iter(self): path = Path('s.s2.s3') assert list(path) == ['s', 's2', 's3'] @pytest.mark.parametrize(('path', 'length'), ( (Path('s.s2.s3'), 3), (Path('s.s2'), 2), (Path(''), 0) )) def test_len(self, path, length): return len(path) == length def test_reversed(self): path = Path('s.s2.s3') assert list(reversed(path)) == ['s3', 's2', 's'] def test_repr(self): path = Path('s.s2') assert repr(path) == "Path('s.s2')" def test_getitem(self): path = Path('s.s2') assert path[0] == 's' assert path[1] == 's2' def test_str(self): assert str(Path('s.s2')) == 's.s2' def test_index(self): assert Path('s.s2.s3').index('s2') == 1 def test_unicode(self): assert six.text_type(Path('s.s2')) == u's.s2' def test_getitem_with_slice(self): path = Path('s.s2.s3') assert path[1:] == Path('s2.s3') @pytest.mark.parametrize(('test', 'result'), ( (Path('s.s2') == Path('s.s2'), True), (Path('s.s2') == Path('s.s3'), False) )) def test_eq(self, test, result): assert test is result @pytest.mark.parametrize(('test', 'result'), ( (Path('s.s2') != Path('s.s2'), False), (Path('s.s2') != Path('s.s3'), True) )) def test_ne(self, test, result): assert test is result sqlalchemy-utils-0.36.1/tests/test_proxy_dict.py000066400000000000000000000070371360007755400220560ustar00rootroot00000000000000import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy_utils import proxy_dict, ProxyDict @pytest.fixture def ArticleTranslation(Base): class ArticleTranslation(Base): __tablename__ = 'article_translation' id = sa.Column( sa.Integer, sa.ForeignKey('article.id'), primary_key=True ) locale = sa.Column(sa.String(10), primary_key=True) name = sa.Column(sa.UnicodeText) return ArticleTranslation @pytest.fixture def Article(Base, ArticleTranslation): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) description = sa.Column(sa.UnicodeText) _translations = sa.orm.relationship( ArticleTranslation, lazy='dynamic', cascade='all, delete-orphan', passive_deletes=True, backref=sa.orm.backref('parent'), ) @property def translations(self): return proxy_dict( self, '_translations', ArticleTranslation.locale ) return Article @pytest.fixture def init_models(ArticleTranslation, Article): pass class TestProxyDict(object): def test_access_key_for_pending_parent(self, session, Article): article = Article() session.add(article) assert article.translations['en'] def test_access_key_for_transient_parent(self, Article): article = Article() assert article.translations['en'] def test_cache(self, session, Article): article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .once() ) session.add(article) session.commit() article.translations['en'] article.translations['en'] def test_set_updates_cache(self, session, Article, ArticleTranslation): article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .once() ) session.add(article) session.commit() article.translations['en'] article.translations['en'] = ArticleTranslation( locale='en', name=u'something' ) article.translations['en'] def test_contains_efficiency(self, connection, session, Article): article = Article() session.add(article) session.commit() article.id query_count = connection.query_count 'en' in article.translations 'en' in article.translations 'en' in article.translations assert connection.query_count == query_count + 1 def test_getitem_with_none_value_in_cache(self, session, Article): article = Article() session.add(article) session.commit() article.id 'en' in article.translations assert article.translations['en'] def test_contains(self, Article): article = Article() assert 'en' not in article.translations # does not auto-append new translation assert 'en' not in article.translations def test_committing_session_empties_proxy_dict_cache( self, session, Article ): article = Article() ( flexmock(ProxyDict) .should_receive('fetch') .twice() ) session.add(article) session.commit() article.translations['en'] session.commit() article.translations['en'] sqlalchemy-utils-0.36.1/tests/test_query_chain.py000066400000000000000000000051331360007755400221740ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import QueryChain @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) return User @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) return Article @pytest.fixture def BlogPost(Base): class BlogPost(Base): __tablename__ = 'blog_post' id = sa.Column(sa.Integer, primary_key=True) return BlogPost @pytest.fixture def init_models(User, Article, BlogPost): pass @pytest.fixture def users(session, User): users = [User(), User()] session.add_all(users) session.commit() return users @pytest.fixture def articles(session, Article): articles = [Article(), Article(), Article(), Article()] session.add_all(articles) session.commit() return articles @pytest.fixture def posts(session, BlogPost): posts = [BlogPost(), BlogPost(), BlogPost()] session.add_all(posts) session.commit() return posts @pytest.fixture def chain(session, users, articles, posts, User, Article, BlogPost): return QueryChain( [ session.query(User).order_by('id'), session.query(Article).order_by('id'), session.query(BlogPost).order_by('id') ] ) class TestQueryChain(object): def test_iter(self, chain): assert len(list(chain)) == 9 def test_iter_with_limit(self, chain, users, articles): c = chain.limit(4) objects = list(c) assert users == objects[0:2] assert articles[0:2] == objects[2:] def test_iter_with_offset(self, chain, articles, posts): c = chain.offset(3) objects = list(c) assert articles[1:] + posts == objects def test_iter_with_limit_and_offset(self, chain, articles, posts): c = chain.offset(3).limit(4) objects = list(c) assert articles[1:] + posts[0:1] == objects def test_iter_with_offset_spanning_multiple_queries(self, chain, posts): c = chain.offset(7) objects = list(c) assert posts[1:] == objects def test_repr(self, chain): assert repr(chain) == '' % id(chain) def test_getitem_with_slice(self, chain): c = chain[1:] assert c._offset == 1 assert c._limit is None def test_getitem_with_single_key(self, chain, articles): article = chain[2] assert article == articles[0] def test_count(self, chain): assert chain.count() == 9 sqlalchemy-utils-0.36.1/tests/test_sort_query.py000066400000000000000000000275551360007755400221150ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import sort_query from sqlalchemy_utils.functions import QuerySorterException from . import assert_contains class TestSortQuery(object): def test_without_sort_param_returns_the_query_object_untouched( self, session, Article ): query = session.query(Article) query = sort_query(query, '') assert query == query def test_column_ascending(self, session, Article): query = sort_query(session.query(Article), 'name') assert_contains('ORDER BY article.name ASC', query) def test_column_descending(self, session, Article): query = sort_query(session.query(Article), '-name') assert_contains('ORDER BY article.name DESC', query) def test_skips_unknown_columns(self, session, Article): query = session.query(Article) query = sort_query(query, '-unknown') assert query == query def test_non_silent_mode(self, session, Article): query = session.query(Article) with pytest.raises(QuerySorterException): sort_query(query, '-unknown', silent=False) def test_join(self, session, Article): query = ( session.query(Article) .join(Article.category) ) query = sort_query(query, 'name', silent=False) assert_contains('ORDER BY article.name ASC', query) def test_calculated_value_ascending(self, session, Article, Category): query = session.query( Category, sa.func.count(Article.id).label('articles') ) query = sort_query(query, 'articles') assert_contains('ORDER BY articles ASC', query) def test_calculated_value_descending(self, session, Article, Category): query = session.query( Category, sa.func.count(Article.id).label('articles') ) query = sort_query(query, '-articles') assert_contains('ORDER BY articles DESC', query) def test_subqueried_scalar(self, session, Article, Category): article_count = ( sa.sql.select( [sa.func.count(Article.id)], from_obj=[Article.__table__] ) .where(Article.category_id == Category.id) .correlate(Category.__table__) ) query = session.query( Category, article_count.label('articles') ) query = sort_query(query, '-articles') assert_contains('ORDER BY articles DESC', query) def test_aliased_joined_entity(self, session, Article, Category): alias = sa.orm.aliased(Category, name='categories') query = session.query( Article ).join( alias, Article.category ) query = sort_query(query, '-categories-name') assert_contains('ORDER BY categories.name DESC', query) def test_joined_table_column(self, session, Article): query = session.query(Article).join(Article.category) query = sort_query(query, 'category-name') assert_contains('category.name ASC', query) def test_multiple_columns(self, session, Article): query = session.query(Article) query = sort_query(query, 'name', 'id') assert_contains('article.name ASC, article.id ASC', query) def test_column_property(self, session, Article, Category): Category.article_count = sa.orm.column_property( sa.select([sa.func.count(Article.id)]) .where(Article.category_id == Category.id) .label('article_count') ) query = session.query(Category) query = sort_query(query, 'article_count') assert_contains('article_count ASC', query) def test_column_property_descending(self, session, Article, Category): Category.article_count = sa.orm.column_property( sa.select([sa.func.count(Article.id)]) .where(Article.category_id == Category.id) .label('article_count') ) query = session.query(Category) query = sort_query(query, '-article_count') assert_contains('article_count DESC', query) def test_relationship_property(self, session, Category): query = session.query(Category) query = sort_query(query, 'articles') assert 'ORDER BY' not in str(query) def test_regular_property(self, session, Category): query = session.query(Category) query = sort_query(query, 'name_alias') assert 'ORDER BY' not in str(query) def test_synonym_property(self, session, Category): query = session.query(Category) query = sort_query(query, 'name_synonym') assert_contains('ORDER BY category.name ASC', query) def test_hybrid_property(self, session, Category): query = session.query(Category) query = sort_query(query, 'articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) def test_hybrid_property_descending(self, session, Category): query = session.query(Category) query = sort_query(query, '-articles_count') assert_contains( 'ORDER BY (SELECT count(article.id) AS count_1', query ) assert ' DESC' in str(query) def test_assigned_hybrid_property(self, session, Article): def getter(self): return self.name Article.some_hybrid = sa.ext.hybrid.hybrid_property( fget=getter ) query = session.query(Article) query = sort_query(query, 'some_hybrid') assert_contains('ORDER BY article.name ASC', query) def test_with_mapper_and_column_property(self, session, Base, Article): class Apple(Base): __tablename__ = 'apple' id = sa.Column(sa.Integer, primary_key=True) article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id)) Article.apples = sa.orm.relationship(Apple) Article.apple_count = sa.orm.column_property( sa.select([sa.func.count(Apple.id)]) .where(Apple.article_id == Article.id) .correlate(Article.__table__) .label('apple_count'), deferred=True ) query = ( session.query(sa.inspect(Article)) .outerjoin(Article.apples) .options( sa.orm.undefer(Article.apple_count) ) .options(sa.orm.contains_eager(Article.apples)) ) query = sort_query(query, 'apple_count') assert 'ORDER BY apple_count' in str(query) def test_table(self, session, Article): query = session.query(Article.__table__) query = sort_query(query, 'name') assert_contains('ORDER BY article.name', query) @pytest.mark.usefixtures('postgresql_dsn') class TestSortQueryRelationshipCounts(object): """ Currently this doesn't work with SQLite """ def test_relation_hybrid_property(self, session, Article): query = ( session.query(Article) .join(Article.category) ).group_by(Article.id) query = sort_query(query, '-category-articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) def test_aliased_hybrid_property(self, session, Article, Category): alias = sa.orm.aliased( Category, name='categories' ) query = ( session.query(Article) .outerjoin(alias, Article.category) .options( sa.orm.contains_eager(Article.category, alias=alias) ) ).group_by(alias.id, Article.id) query = sort_query(query, '-categories-articles_count') assert_contains('ORDER BY (SELECT count(article.id) AS count_1', query) def test_aliased_concat_hybrid_property(self, session, Article, Category): alias = sa.orm.aliased( Category, name='aliased' ) query = ( session.query(Article) .outerjoin(alias, Article.category) .options( sa.orm.contains_eager(Article.category, alias=alias) ) ) query = sort_query(query, 'aliased-full_name') assert_contains( 'concat(aliased.title, %(concat_1)s, aliased.name)', query ) @pytest.mark.usefixtures('postgresql_dsn') class TestSortQueryWithPolymorphicInheritance(object): """ Currently this doesn't work with SQLite """ @pytest.fixture def TextItem(self, Base): class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, 'with_polymorphic': '*' } return TextItem @pytest.fixture def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) category = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_identity': u'article' } return Article @pytest.fixture def init_models(self, TextItem, Article): pass def test_column_property(self, session, TextItem): TextItem.item_count = sa.orm.column_property( sa.select( [ sa.func.count('1') ], ) .select_from(TextItem.__table__) .label('item_count') ) query = sort_query( session.query(TextItem), 'item_count' ) assert_contains('ORDER BY item_count', query) def test_child_class_attribute(self, session, TextItem): query = sort_query( session.query(TextItem), 'category' ) assert_contains('ORDER BY article.category ASC', query) def test_with_ambiguous_column(self, session, TextItem): query = sort_query( session.query(TextItem), 'id' ) assert_contains('ORDER BY text_item.id ASC', query) @pytest.mark.usefixtures('postgresql_dsn') class TestSortQueryWithCustomPolymorphic(object): """ Currently this doesn't work with SQLite """ @pytest.fixture def TextItem(self, Base): class TextItem(Base): __tablename__ = 'text_item' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': type, } return TextItem @pytest.fixture def Article(self, TextItem): class Article(TextItem): __tablename__ = 'article' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) category = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_identity': u'article' } return Article @pytest.fixture def BlogPost(self, TextItem): class BlogPost(TextItem): __tablename__ = 'blog_post' id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': u'blog_post' } return BlogPost def test_with_unknown_column(self, session, TextItem, BlogPost): query = sort_query( session.query( sa.orm.with_polymorphic(TextItem, [BlogPost]) ), 'category' ) assert 'ORDER BY' not in str(query) def test_with_existing_column(self, session, TextItem, Article): query = sort_query( session.query( sa.orm.with_polymorphic(TextItem, [Article]) ), 'category' ) assert 'ORDER BY' in str(query) sqlalchemy-utils-0.36.1/tests/test_translation_hybrid.py000066400000000000000000000105551360007755400235700ustar00rootroot00000000000000import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.orm import aliased from sqlalchemy_utils import i18n, TranslationHybrid # noqa @pytest.fixture def translation_hybrid(): return TranslationHybrid('fi', 'en') @pytest.fixture def City(Base, translation_hybrid): class City(Base): __tablename__ = 'city' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) name = translation_hybrid(name_translations) locale = 'en' return City @pytest.fixture def init_models(City): pass @pytest.mark.usefixtures('postgresql_dsn') @pytest.mark.skipif('i18n.babel is None') class TestTranslationHybrid(object): def test_using_hybrid_as_constructor(self, City): city = City(name='Helsinki') assert city.name_translations['fi'] == 'Helsinki' def test_if_no_translation_exists_returns_none(self, City): city = City() assert city.name is None def test_custom_default_value(self, City, translation_hybrid): translation_hybrid.default_value = 'Some value' city = City() assert city.name == 'Some value' def test_fall_back_to_default_translation(self, City, translation_hybrid): city = City(name_translations={'en': 'Helsinki'}) translation_hybrid.current_locale = 'sv' assert city.name == 'Helsinki' def test_fallback_to_dynamic_locale(self, City, translation_hybrid): translation_hybrid.current_locale = 'en' translation_hybrid.default_locale = lambda self: self.locale city = City(name_translations={}) city.locale = 'fi' city.name_translations['fi'] = 'Helsinki' assert city.name == 'Helsinki' def test_fallback_to_attr_dependent_locale(self, City, translation_hybrid): translation_hybrid.current_locale = 'en' translation_hybrid.default_locale = ( lambda obj, attr: sorted(getattr(obj, attr).keys())[0] ) city = City(name_translations={}) city.name_translations['fi'] = 'Helsinki' assert city.name == 'Helsinki' city.name_translations['de'] = 'Stadt Helsinki' assert city.name == 'Stadt Helsinki' @pytest.mark.parametrize( ('name_translations', 'name'), ( ({'fi': 'Helsinki', 'en': 'Helsing'}, 'Helsinki'), ({'en': 'Helsinki'}, 'Helsinki'), ({'fi': 'Helsinki'}, 'Helsinki'), ({}, None), ) ) def test_hybrid_as_an_expression( self, session, City, name_translations, name ): city = City(name_translations=name_translations) session.add(city) session.commit() assert session.query(City.name).scalar() == name def test_dynamic_locale(self, Base): translation_hybrid = TranslationHybrid( lambda obj: obj.locale, 'fi' ) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) name = translation_hybrid(name_translations) locale = sa.Column(sa.String) assert ( 'coalesce(article.name_translations -> article.locale' in str(Article.name.expression) ) def test_locales_casted_only_in_compilation_phase(self, Base): class LocaleGetter(object): def current_locale(self): return lambda obj: obj.locale flexmock(LocaleGetter).should_receive('current_locale').never() translation_hybrid = TranslationHybrid( LocaleGetter().current_locale, 'fi' ) class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name_translations = sa.Column(HSTORE) name = translation_hybrid(name_translations) locale = sa.Column(sa.String) Article.name def test_no_implicit_join_when_using_aliased_entities(self, session, City): # Ensure that queried entities are taken from the alias so that # there isn't an extra join to the original entity. CityAlias = aliased(City) query_str = str(session.query(CityAlias.name)) assert query_str.endswith('FROM city AS city_1') sqlalchemy-utils-0.36.1/tests/test_views.py000066400000000000000000000107541360007755400210270ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import ( create_materialized_view, create_view, refresh_materialized_view ) @pytest.fixture def Article(Base, User): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) author = sa.orm.relationship(User) return Article @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.String) return User @pytest.fixture def ArticleMV(Base, Article, User): class ArticleMV(Base): __table__ = create_materialized_view( name='article_mv', selectable=sa.select( [ Article.id, Article.name, User.id.label('author_id'), User.name.label('author_name') ], from_obj=( Article.__table__ .join(User, Article.author_id == User.id) ) ), aliases={'name': 'article_name'}, metadata=Base.metadata, indexes=[sa.Index('article_mv_id_idx', 'id')] ) return ArticleMV @pytest.fixture def ArticleView(Base, Article, User): class ArticleView(Base): __table__ = create_view( name='article_view', selectable=sa.select( [ Article.id, Article.name, User.id.label('author_id'), User.name.label('author_name') ], from_obj=( Article.__table__ .join(User, Article.author_id == User.id) ) ), metadata=Base.metadata ) return ArticleView @pytest.fixture def init_models(ArticleMV, ArticleView): pass @pytest.mark.usefixtures('postgresql_dsn') class TestMaterializedViews: def test_refresh_materialized_view( self, session, Article, User, ArticleMV ): article = Article( name='Some article', author=User(name='Some user') ) session.add(article) session.commit() refresh_materialized_view(session, 'article_mv') materialized = session.query(ArticleMV).first() assert materialized.article_name == 'Some article' assert materialized.author_name == 'Some user' def test_querying_view( self, session, Article, User, ArticleView ): article = Article( name='Some article', author=User(name='Some user') ) session.add(article) session.commit() row = session.query(ArticleView).first() assert row.name == 'Some article' assert row.author_name == 'Some user' class TrivialViewTestCases: def life_cycle( self, engine, metadata, column, cascade_on_drop ): __table__ = create_view( name='trivial_view', selectable=sa.select([column]), metadata=metadata, cascade_on_drop=cascade_on_drop ) __table__.create(engine) __table__.drop(engine) class SupportsCascade(TrivialViewTestCases): def test_life_cycle_cascade( self, connection, engine, Base, User ): self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=True) class DoesntSupportCascade(SupportsCascade): @pytest.mark.xfail def test_life_cycle_cascade(self, *args, **kwargs): super(DoesntSupportCascade, self).test_life_cycle_cascade( *args, **kwargs ) class SupportsNoCascade(TrivialViewTestCases): def test_life_cycle_no_cascade( self, connection, engine, Base, User ): self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=False) @pytest.mark.usefixtures('postgresql_dsn') class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade): pass @pytest.mark.usefixtures('mysql_dsn') class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade): pass @pytest.mark.usefixtures('sqlite_none_database_dsn') class TestSqliteTrivialView(DoesntSupportCascade, SupportsNoCascade): pass sqlalchemy-utils-0.36.1/tests/types/000077500000000000000000000000001360007755400174165ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/types/__init__.py000066400000000000000000000000001360007755400215150ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/types/encrypted/000077500000000000000000000000001360007755400214135ustar00rootroot00000000000000sqlalchemy-utils-0.36.1/tests/types/encrypted/test_padding.py000066400000000000000000000033371360007755400244400ustar00rootroot00000000000000import pytest from sqlalchemy_utils.types.encrypted.padding import ( InvalidPaddingError, PKCS5Padding ) class TestPkcs5Padding(object): def setup_method(self): self.BLOCK_SIZE = 8 self.padder = PKCS5Padding(self.BLOCK_SIZE) def test_various_lengths_roundtrip(self): for l in range(0, 3 * self.BLOCK_SIZE): val = b'*' * l padded = self.padder.pad(val) unpadded = self.padder.unpad(padded) assert val == unpadded, 'Round trip error for length %d' % l def test_invalid_unpad(self): with pytest.raises(InvalidPaddingError): self.padder.unpad(None) with pytest.raises(InvalidPaddingError): self.padder.unpad(b'') with pytest.raises(InvalidPaddingError): self.padder.unpad(b'\01') with pytest.raises(InvalidPaddingError): self.padder.unpad((b'*' * (self.BLOCK_SIZE - 1)) + b'\00') with pytest.raises(InvalidPaddingError): self.padder.unpad((b'*' * self.BLOCK_SIZE) + b'\01') def test_pad_longer_than_block(self): with pytest.raises(InvalidPaddingError): self.padder.unpad( 'x' * (self.BLOCK_SIZE - 1) + chr(self.BLOCK_SIZE + 1) * (self.BLOCK_SIZE + 1) ) def test_incorrect_padding(self): # Hard-coded for blocksize of 8 assert self.padder.unpad(b'1234\04\04\04\04') == b'1234' with pytest.raises(InvalidPaddingError): self.padder.unpad(b'1234\02\04\04\04') with pytest.raises(InvalidPaddingError): self.padder.unpad(b'1234\04\02\04\04') with pytest.raises(InvalidPaddingError): self.padder.unpad(b'1234\04\04\02\04') sqlalchemy-utils-0.36.1/tests/types/test_arrow.py000066400000000000000000000046111360007755400221630ustar00rootroot00000000000000from datetime import datetime import pytest import sqlalchemy as sa from dateutil import tz from sqlalchemy_utils.types import arrow @pytest.fixture def Article(Base): class Article(Base): __tablename__ = 'article' id = sa.Column(sa.Integer, primary_key=True) created_at = sa.Column(arrow.ArrowType) published_at = sa.Column(arrow.ArrowType(timezone=True)) published_at_dt = sa.Column(sa.DateTime(timezone=True)) return Article @pytest.fixture def init_models(Article): pass @pytest.mark.skipif('arrow.arrow is None') class TestArrowDateTimeType(object): def test_parameter_processing(self, session, Article): article = Article( created_at=arrow.arrow.get(datetime(2000, 11, 1)) ) session.add(article) session.commit() article = session.query(Article).first() assert article.created_at.datetime def test_string_coercion(self, Article): article = Article( created_at='2013-01-01' ) assert article.created_at.year == 2013 def test_utc(self, session, Article): time = arrow.arrow.utcnow() article = Article(created_at=time) session.add(article) assert article.created_at == time session.commit() assert article.created_at == time def test_other_tz(self, session, Article): time = arrow.arrow.utcnow() local = time.to('US/Pacific') article = Article(created_at=local) session.add(article) assert article.created_at == time == local session.commit() assert article.created_at == time def test_literal_param(self, session, Article): clause = Article.created_at > '2015-01-01' compiled = str(clause.compile(compile_kwargs={"literal_binds": True})) assert compiled == 'article.created_at > 2015-01-01' @pytest.mark.usefixtures('postgresql_dsn') def test_timezone(self, session, Article): timezone = tz.gettz('Europe/Stockholm') dt = arrow.arrow.get(datetime(2015, 1, 1, 15, 30, 45), timezone) article = Article(published_at=dt, published_at_dt=dt.datetime) session.add(article) session.commit() session.expunge_all() item = session.query(Article).one() assert item.published_at.datetime == item.published_at_dt assert item.published_at.to(timezone) == dt sqlalchemy-utils-0.36.1/tests/types/test_choice.py000066400000000000000000000126341360007755400222670ustar00rootroot00000000000000import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy_utils import Choice, ChoiceType, ImproperlyConfigured from sqlalchemy_utils.types.choice import Enum class TestChoice(object): def test_equality_operator(self): assert Choice(1, 1) == 1 assert 1 == Choice(1, 1) assert Choice(1, 1) == Choice(1, 1) def test_non_equality_operator(self): assert Choice(1, 1) != 2 assert not (Choice(1, 1) != 1) def test_hash(self): assert hash(Choice(1, 1)) == hash(1) class TestChoiceType(object): @pytest.fixture def User(self, Base): class User(Base): TYPES = [ ('admin', 'Admin'), ('regular-user', 'Regular user') ] __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(ChoiceType(TYPES)) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(self, User): pass def test_python_type(self, User): type_ = User.__table__.c.type.type assert type_.python_type def test_string_processing(self, session, User): flexmock(ChoiceType).should_receive('_coerce').and_return( u'admin' ) user = User( type=u'admin' ) session.add(user) session.commit() user = session.query(User).first() assert user.type.value == u'Admin' def test_parameter_processing(self, session, User): user = User( type=u'admin' ) session.add(user) session.commit() user = session.query(User).first() assert user.type.value == u'Admin' def test_scalar_attributes_get_coerced_to_objects(self, User): user = User(type=u'admin') assert isinstance(user.type, Choice) def test_throws_exception_if_no_choices_given(self): with pytest.raises(ImproperlyConfigured): ChoiceType([]) class TestChoiceTypeWithCustomUnderlyingType(object): def test_init_type(self): type_ = ChoiceType([(1, u'something')], impl=sa.Integer) assert type_.impl == sa.Integer @pytest.mark.skipif('Enum is None') class TestEnumType(object): @pytest.fixture def OrderStatus(self): class OrderStatus(Enum): unpaid = 0 paid = 1 return OrderStatus @pytest.fixture def Order(self, Base, OrderStatus): class Order(Base): __tablename__ = 'order' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( ChoiceType(OrderStatus, impl=sa.Integer()), default=OrderStatus.unpaid, ) def __repr__(self): return 'Order(%r, %r)' % (self.id_, self.status) return Order @pytest.fixture def OrderNullable(self, Base, OrderStatus): class OrderNullable(Base): __tablename__ = 'order_nullable' id_ = sa.Column(sa.Integer, primary_key=True) status = sa.Column( ChoiceType(OrderStatus, impl=sa.Integer()), nullable=True, ) return OrderNullable @pytest.fixture def init_models(self, Order, OrderNullable): pass def test_parameter_initialization(self, session, Order, OrderStatus): order = Order() session.add(order) session.commit() order = session.query(Order).first() assert order.status is OrderStatus.unpaid assert order.status.value == 0 def test_setting_by_value(self, session, Order, OrderStatus): order = Order() order.status = 1 session.add(order) session.commit() order = session.query(Order).first() assert order.status is OrderStatus.paid def test_setting_by_enum(self, session, Order, OrderStatus): order = Order() order.status = OrderStatus.paid session.add(order) session.commit() order = session.query(Order).first() assert order.status is OrderStatus.paid def test_setting_value_that_resolves_to_none( self, session, Order, OrderStatus ): order = Order() order.status = 0 session.add(order) session.commit() order = session.query(Order).first() assert order.status is OrderStatus.unpaid def test_setting_to_wrong_enum_raises_valueerror(self, Order): class WrongEnum(Enum): foo = 0 bar = 1 order = Order() with pytest.raises(ValueError): order.status = WrongEnum.foo def test_setting_to_uncoerceable_type_raises_valueerror(self, Order): order = Order() with pytest.raises(ValueError): order.status = 'Bad value' def test_order_nullable_stores_none(self, session, OrderNullable): # With nullable=False as in `Order`, a `None` value is always # converted to the default value, unless we explicitly set it to # sqlalchemy.sql.null(), so we use this class to test our ability # to set and retrive `None`. order_nullable = OrderNullable() assert order_nullable.status is None order_nullable.status = None session.add(order_nullable) session.commit() assert order_nullable.status is None sqlalchemy-utils-0.36.1/tests/types/test_color.py000066400000000000000000000033271360007755400221520ustar00rootroot00000000000000import pytest import sqlalchemy as sa from flexmock import flexmock from sqlalchemy_utils import ColorType, types # noqa @pytest.fixture def Document(Base): class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) bg_color = sa.Column(ColorType) def __repr__(self): return 'Document(%r)' % self.id return Document @pytest.fixture def init_models(Document): pass @pytest.mark.skipif('types.color.python_colour_type is None') class TestColorType(object): def test_string_parameter_processing(self, session, Document): from colour import Color flexmock(ColorType).should_receive('_coerce').and_return( u'white' ) document = Document( bg_color='white' ) session.add(document) session.commit() document = session.query(Document).first() assert document.bg_color.hex == Color(u'white').hex def test_color_parameter_processing(self, session, Document): from colour import Color document = Document(bg_color=Color(u'white')) session.add(document) session.commit() document = session.query(Document).first() assert document.bg_color.hex == Color(u'white').hex def test_scalar_attributes_get_coerced_to_objects(self, Document): from colour import Color document = Document(bg_color='white') assert isinstance(document.bg_color, Color) def test_literal_param(self, session, Document): clause = Document.bg_color == 'white' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == "document.bg_color = 'white'" sqlalchemy-utils-0.36.1/tests/types/test_composite.py000066400000000000000000000227011360007755400230330ustar00rootroot00000000000000# -*- coding: utf-8 -*- import pytest import sqlalchemy as sa from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import close_all_sessions from sqlalchemy_utils import ( CompositeArray, CompositeType, Currency, CurrencyType, i18n, NumericRangeType, register_composites, remove_composite_listeners ) from sqlalchemy_utils.types import pg_composite from sqlalchemy_utils.types.range import intervals @pytest.mark.usefixtures('postgresql_dsn') class TestCompositeTypeWithRegularTypes(object): @pytest.fixture def Account(self, Base): class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column( CompositeType( 'money_type', [ sa.Column('currency', sa.String), sa.Column('amount', sa.Integer) ] ) ) return Account @pytest.fixture def init_models(self, Account): pass def test_parameter_processing(self, session, Account): account = Account( balance=('USD', 15) ) session.add(account) session.commit() account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 def test_non_ascii_chars(self, session, Account): account = Account( balance=(u'ääöö', 15) ) session.add(account) session.commit() account = session.query(Account).first() assert account.balance.currency == u'ääöö' assert account.balance.amount == 15 @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('postgresql_dsn') class TestCompositeTypeWithTypeDecorators(object): @pytest.fixture def Account(self, Base): class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column( CompositeType( 'money_type', [ sa.Column('currency', CurrencyType), sa.Column('amount', sa.Integer) ] ) ) return Account @pytest.fixture def init_models(self, Account): i18n.get_locale = lambda: i18n.babel.Locale('en') def test_result_set_processing(self, session, Account): account = Account( balance=('USD', 15) ) session.add(account) session.commit() account = session.query(Account).first() assert account.balance.currency == Currency('USD') assert account.balance.amount == 15 def test_parameter_processing(self, session, Account): account = Account( balance=(Currency('USD'), 15) ) session.add(account) session.commit() account = session.query(Account).first() assert account.balance.currency == Currency('USD') assert account.balance.amount == 15 @pytest.mark.skipif('i18n.babel is None') @pytest.mark.usefixtures('postgresql_dsn') class TestCompositeTypeInsideArray(object): @pytest.fixture def type_(self): return CompositeType( 'money_type', [ sa.Column('currency', CurrencyType), sa.Column('amount', sa.Integer) ] ) @pytest.fixture def Account(self, Base, type_): class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balances = sa.Column( CompositeArray(type_) ) return Account @pytest.fixture def init_models(self, Account): i18n.get_locale = lambda: i18n.babel.Locale('en') def test_parameter_processing(self, session, type_, Account): account = Account( balances=[ type_.type_cls(Currency('USD'), 15), type_.type_cls(Currency('AUD'), 20) ] ) session.add(account) session.commit() account = session.query(Account).first() assert account.balances[0].currency == Currency('USD') assert account.balances[0].amount == 15 assert account.balances[1].currency == Currency('AUD') assert account.balances[1].amount == 20 @pytest.mark.skipif('intervals is None') @pytest.mark.usefixtures('postgresql_dsn') class TestCompositeTypeWithRangeTypeInsideArray(object): @pytest.fixture def type_(self): return CompositeType( 'category', [ sa.Column('scale', NumericRangeType), sa.Column('name', sa.String) ] ) @pytest.fixture def Account(self, Base, type_): class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) categories = sa.Column( CompositeArray(type_) ) return Account @pytest.fixture def init_models(self, Account): pass def test_parameter_processing_with_named_tuple( self, session, type_, Account ): account = Account( categories=[ type_.type_cls( intervals.DecimalInterval([15, 18]), 'bad' ), type_.type_cls( intervals.DecimalInterval([18, 20]), 'good' ) ] ) session.add(account) session.commit() account = session.query(Account).first() assert ( account.categories[0].scale == intervals.DecimalInterval([15, 18]) ) assert account.categories[0].name == 'bad' assert ( account.categories[1].scale == intervals.DecimalInterval([18, 20]) ) assert account.categories[1].name == 'good' def test_parameter_processing_with_tuple(self, session, Account): account = Account( categories=[ (intervals.DecimalInterval([15, 18]), 'bad'), (intervals.DecimalInterval([18, 20]), 'good') ] ) session.add(account) session.commit() account = session.query(Account).first() assert ( account.categories[0].scale == intervals.DecimalInterval([15, 18]) ) assert account.categories[0].name == 'bad' assert ( account.categories[1].scale == intervals.DecimalInterval([18, 20]) ) assert account.categories[1].name == 'good' def test_parameter_processing_with_nulls_as_composite_fields( self, session, Account ): account = Account( categories=[ (None, 'bad'), (intervals.DecimalInterval([18, 20]), None) ] ) session.add(account) session.commit() assert account.categories[0].scale is None assert account.categories[0].name == 'bad' assert ( account.categories[1].scale == intervals.DecimalInterval([18, 20]) ) assert account.categories[1].name is None def test_parameter_processing_with_nulls_as_composites( self, session, Account ): account = Account( categories=[ (None, None), None ] ) session.add(account) session.commit() assert account.categories[0].scale is None assert account.categories[0].name is None assert account.categories[1] is None @pytest.mark.usefixtures('postgresql_dsn') class TestCompositeTypeWhenTypeAlreadyExistsInDatabase(object): @pytest.fixture def Account(self, Base): pg_composite.registered_composites = {} type_ = CompositeType( 'money_type', [ sa.Column('currency', sa.String), sa.Column('amount', sa.Integer) ] ) class Account(Base): __tablename__ = 'account' id = sa.Column(sa.Integer, primary_key=True) balance = sa.Column(type_) return Account @pytest.fixture def session(self, request, engine, connection, Base, Account): sa.orm.configure_mappers() Session = sessionmaker(bind=connection) session = Session() session.execute( "CREATE TYPE money_type AS (currency VARCHAR, amount INTEGER)" ) session.execute( """CREATE TABLE account ( id SERIAL, balance MONEY_TYPE, PRIMARY KEY(id) )""" ) def teardown(): session.execute('DROP TABLE account') session.execute('DROP TYPE money_type') session.commit() close_all_sessions() connection.close() remove_composite_listeners() engine.dispose() register_composites(connection) request.addfinalizer(teardown) return session def test_parameter_processing(self, session, Account): account = Account( balance=('USD', 15), ) session.add(account) session.commit() account = session.query(Account).first() assert account.balance.currency == 'USD' assert account.balance.amount == 15 sqlalchemy-utils-0.36.1/tests/types/test_country.py000066400000000000000000000021531360007755400225330ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import Country, CountryType, i18n # noqa @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) country = sa.Column(CountryType) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass @pytest.mark.skipif('i18n.babel is None') class TestCountryType(object): def test_parameter_processing(self, session, User): user = User( country=Country(u'FI') ) session.add(user) session.commit() user = session.query(User).first() assert user.country.name == u'Finland' def test_scalar_attributes_get_coerced_to_objects(self, User): user = User(country='FI') assert isinstance(user.country, Country) def test_literal_param(self, session, User): clause = User.country == 'FI' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == '"user".country = \'FI\'' sqlalchemy-utils-0.36.1/tests/types/test_currency.py000066400000000000000000000024221360007755400226610ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import Currency, CurrencyType, i18n @pytest.fixture def set_get_locale(): i18n.get_locale = lambda: i18n.babel.Locale('en') @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) currency = sa.Column(CurrencyType) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass @pytest.mark.skipif('i18n.babel is None') class TestCurrencyType(object): def test_parameter_processing(self, session, User, set_get_locale): user = User( currency=Currency('USD') ) session.add(user) session.commit() user = session.query(User).first() assert user.currency.name == u'US Dollar' def test_scalar_attributes_get_coerced_to_objects( self, User, set_get_locale ): user = User(currency='USD') assert isinstance(user.currency, Currency) def test_literal_param(self, session, User): clause = User.currency == 'USD' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == '"user".currency = \'USD\'' sqlalchemy-utils-0.36.1/tests/types/test_date_range.py000066400000000000000000000064331360007755400231260ustar00rootroot00000000000000from datetime import datetime, timedelta import pytest import sqlalchemy as sa from sqlalchemy_utils import DateRangeType intervals = None inf = 0 try: import intervals from infinity import inf except ImportError: pass @pytest.fixture def Booking(Base): class Booking(Base): __tablename__ = 'booking' id = sa.Column(sa.Integer, primary_key=True) during = sa.Column(DateRangeType) return Booking @pytest.fixture def create_booking(session, Booking): def create_booking(date_range): booking = Booking( during=date_range ) session.add(booking) session.commit() return session.query(Booking).first() return create_booking @pytest.fixture def init_models(Booking): pass @pytest.mark.skipif('intervals is None') class DateRangeTestCase(object): def test_nullify_range(self, create_booking): booking = create_booking(None) assert booking.during is None @pytest.mark.parametrize( ('date_range'), ( [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()], [datetime(2015, 1, 1).date(), inf], [-inf, datetime(2015, 1, 1).date()] ) ) def test_save_date_range(self, create_booking, date_range): booking = create_booking(date_range) assert booking.during.lower == date_range[0] assert booking.during.upper == date_range[1] def test_nullify_date_range(self, session, Booking): booking = Booking( during=intervals.DateInterval( [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()] ) ) session.add(booking) session.commit() booking = session.query(Booking).first() booking.during = None session.commit() booking = session.query(Booking).first() assert booking.during is None def test_integer_coercion(self, Booking): booking = Booking(during=datetime(2015, 1, 1).date()) assert booking.during.lower == datetime(2015, 1, 1).date() assert booking.during.upper == datetime(2015, 1, 1).date() @pytest.mark.usefixtures('postgresql_dsn') class TestDateRangeOnPostgres(object): @pytest.mark.parametrize( ('date_range', 'length'), ( ( [datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date()], timedelta(days=2) ), ( [datetime(2015, 1, 1).date(), datetime(2015, 1, 1).date()], timedelta(days=0) ), ([-inf, datetime(2015, 1, 1).date()], None), ([datetime(2015, 1, 1).date(), inf], None), ) ) def test_length( self, session, Booking, create_booking, date_range, length ): create_booking(date_range) query = ( session.query(Booking.during.length) ) assert query.scalar() == length def test_literal_param(self, session, Booking): clause = Booking.during == [ datetime(2015, 1, 1).date(), datetime(2015, 1, 3).date() ] compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == "booking.during = '[2015-01-01, 2015-01-03]'" sqlalchemy-utils-0.36.1/tests/types/test_datetime_range.py000066400000000000000000000055551360007755400240110ustar00rootroot00000000000000from datetime import datetime, timedelta import pytest import sqlalchemy as sa from sqlalchemy_utils import DateTimeRangeType intervals = None inf = 0 try: import intervals from infinity import inf except ImportError: pass @pytest.fixture def Booking(Base): class Booking(Base): __tablename__ = 'booking' id = sa.Column(sa.Integer, primary_key=True) during = sa.Column(DateTimeRangeType) return Booking @pytest.fixture def create_booking(session, Booking): def create_booking(date_range): booking = Booking( during=date_range ) session.add(booking) session.commit() return session.query(Booking).first() return create_booking @pytest.fixture def init_models(Booking): pass @pytest.mark.skipif('intervals is None') class DateRangeTestCase(object): def test_nullify_range(self, create_booking): booking = create_booking(None) assert booking.during is None @pytest.mark.parametrize( ('date_range'), ( [datetime(2015, 1, 1), datetime(2015, 1, 3)], [datetime(2015, 1, 1), inf], [-inf, datetime(2015, 1, 1)] ) ) def test_save_date_range(self, create_booking, date_range): booking = create_booking(date_range) assert booking.during.lower == date_range[0] assert booking.during.upper == date_range[1] def test_nullify_date_range(self, session, Booking): booking = Booking( during=intervals.DateInterval( [datetime(2015, 1, 1), datetime(2015, 1, 3)] ) ) session.add(booking) session.commit() booking = session.query(Booking).first() booking.during = None session.commit() booking = session.query(Booking).first() assert booking.during is None def test_integer_coercion(self, Booking): booking = Booking(during=datetime(2015, 1, 1)) assert booking.during.lower == datetime(2015, 1, 1) assert booking.during.upper == datetime(2015, 1, 1) @pytest.mark.usefixtures('postgresql_dsn') class TestDateRangeOnPostgres(object): @pytest.mark.parametrize( ('date_range', 'length'), ( ( [datetime(2015, 1, 1), datetime(2015, 1, 3)], timedelta(days=2) ), ( [datetime(2015, 1, 1), datetime(2015, 1, 1)], timedelta(days=0) ), ([-inf, datetime(2015, 1, 1)], None), ([datetime(2015, 1, 1), inf], None), ) ) def test_length( self, session, Booking, create_booking, date_range, length ): create_booking(date_range) query = ( session.query(Booking.during.length) ) assert query.scalar() == length sqlalchemy-utils-0.36.1/tests/types/test_email.py000066400000000000000000000020261360007755400221160ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import EmailType @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) email = sa.Column(EmailType) short_email = sa.Column(EmailType(length=70)) def __repr__(self): return 'User(%r)' % self.id return User class TestEmailType(object): def test_saves_email_as_lowercased(self, session, User): user = User(email=u'Someone@example.com') session.add(user) session.commit() user = session.query(User).first() assert user.email == u'someone@example.com' def test_literal_param(self, session, User): clause = User.email == 'Someone@example.com' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == '"user".email = lower(\'Someone@example.com\')' def test_custom_length(self, session, User): assert User.short_email.type.impl.length == 70 sqlalchemy-utils-0.36.1/tests/types/test_encrypted.py000066400000000000000000000326631360007755400230360ustar00rootroot00000000000000from datetime import date, datetime, time import pytest import sqlalchemy as sa from sqlalchemy_utils import ColorType, EncryptedType, PhoneNumberType from sqlalchemy_utils.types.encrypted.encrypted_type import ( AesEngine, AesGcmEngine, DatetimeHandler, FernetEngine, InvalidCiphertextError ) cryptography = None try: import cryptography # noqa except ImportError: pass @pytest.fixture def User(Base, encryption_engine, test_key, padding_mechanism): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) username = sa.Column(EncryptedType( sa.Unicode, test_key, encryption_engine, padding_mechanism) ) access_token = sa.Column(EncryptedType( sa.String, test_key, encryption_engine, padding_mechanism) ) is_active = sa.Column(EncryptedType( sa.Boolean, test_key, encryption_engine, padding_mechanism) ) accounts_num = sa.Column(EncryptedType( sa.Integer, test_key, encryption_engine, padding_mechanism) ) phone = sa.Column(EncryptedType( PhoneNumberType, test_key, encryption_engine, padding_mechanism) ) color = sa.Column(EncryptedType( ColorType, test_key, encryption_engine, padding_mechanism) ) date = sa.Column(EncryptedType( sa.Date, test_key, encryption_engine, padding_mechanism) ) time = sa.Column(EncryptedType( sa.Time, test_key, encryption_engine, padding_mechanism) ) datetime = sa.Column(EncryptedType( sa.DateTime, test_key, encryption_engine, padding_mechanism) ) enum = sa.Column(EncryptedType( sa.Enum('One', name='user_enum_t'), test_key, encryption_engine, padding_mechanism) ) return User @pytest.fixture def test_key(): return 'secretkey1234' @pytest.fixture def user_name(): return u'someone' @pytest.fixture def user_phone(): return u'(555) 555-5555' @pytest.fixture def user_color(): return u'#fff' @pytest.fixture def user_enum(): return 'One' @pytest.fixture def user_date(): return date(2010, 10, 2) @pytest.fixture def user_time(): return time(10, 12) @pytest.fixture def user_datetime(): return datetime(2010, 10, 2, 10, 12, 45, 2334) @pytest.fixture def test_token(): import string import random token = '' characters = string.ascii_letters + string.digits for i in range(60): token += ''.join(random.choice(characters)) return token @pytest.fixture def active(): return True @pytest.fixture def accounts_num(): return 2 @pytest.fixture def user( request, session, User, user_name, user_phone, user_color, user_date, user_time, user_enum, user_datetime, test_token, active, accounts_num ): # set the values to the user object user = User() user.username = user_name user.phone = user_phone user.color = user_color user.date = user_date user.time = user_time user.enum = user_enum user.datetime = user_datetime user.access_token = test_token user.is_active = active user.accounts_num = accounts_num session.add(user) session.commit() return session.query(User).get(user.id) @pytest.fixture def datetime_with_micro_and_timezone(): import pytz tz = pytz.timezone('Pacific/Tahiti') return datetime(2017, 8, 21, 4, 26, 36, 523010, tzinfo=tz) @pytest.fixture def datetime_with_micro(): return datetime(2017, 8, 21, 10, 12, 45, 22) @pytest.fixture def datetime_simple(): return datetime(2017, 8, 21, 10, 12, 45) @pytest.fixture def time_with_micro(): return time(10, 12, 45, 22) @pytest.fixture def time_simple(): return time(10, 12, 45) @pytest.fixture def date_simple(): return date(2017, 8, 21) @pytest.mark.skipif('cryptography is None') class EncryptedTypeTestCase(object): @pytest.fixture def Team(self, Base, encryption_engine, padding_mechanism): self._team_key = None class Team(Base): __tablename__ = 'team' id = sa.Column(sa.Integer, primary_key=True) key = sa.Column(sa.String(50)) name = sa.Column(EncryptedType( sa.Unicode, lambda: self._team_key, encryption_engine, padding_mechanism) ) return Team @pytest.fixture def init_models(self, User, Team): pass def test_unicode(self, user, user_name): assert user.username == user_name def test_string(self, user, test_token): assert user.access_token == test_token def test_boolean(self, user, active): assert user.is_active == active def test_integer(self, user, accounts_num): assert user.accounts_num == accounts_num def test_phone_number(self, user, user_phone): assert str(user.phone) == user_phone def test_color(self, user, user_color): assert user.color.hex == user_color def test_date(self, user, user_date): assert user.date == user_date def test_datetime(self, user, user_datetime): assert user.datetime == user_datetime def test_time(self, user, user_time): assert user.time == user_time def test_enum(self, user, user_enum): assert user.enum == user_enum def test_lookup_key(self, session, Team): # Add teams self._team_key = 'one' team = Team(key=self._team_key, name=u'One') session.add(team) session.commit() team_1_id = team.id self._team_key = 'two' team = Team(key=self._team_key) team.name = u'Two' session.add(team) session.commit() team_2_id = team.id # Lookup teams self._team_key = session.query(Team.key).filter_by( id=team_1_id ).one()[0] team = session.query(Team).get(team_1_id) assert team.name == u'One' session.expunge_all() self._team_key = session.query(Team.key).filter_by( id=team_2_id ).one()[0] team = session.query(Team).get(team_2_id) assert team.name == u'Two' session.expunge_all() # Remove teams session.query(Team).delete() session.commit() class AesEncryptedTypeTestCase(EncryptedTypeTestCase): @pytest.fixture def encryption_engine(self): return AesEngine def test_lookup_by_encrypted_string(self, session, User, user, user_name): test = session.query(User).filter( User.username == user_name ).first() assert test.username == user.username class TestAesEncryptedTypeWithPKCS5Padding(AesEncryptedTypeTestCase): @pytest.fixture def padding_mechanism(self): return 'pkcs5' class TestAesEncryptedTypeWithOneAndZeroesPadding(AesEncryptedTypeTestCase): @pytest.fixture def padding_mechanism(self): return 'oneandzeroes' class TestAesEncryptedTypeWithZeroesPadding(AesEncryptedTypeTestCase): @pytest.fixture def padding_mechanism(self): return 'zeroes' class TestAesEncryptedTypeTestcaseWithNaivePadding(AesEncryptedTypeTestCase): @pytest.fixture def padding_mechanism(self): return 'naive' def test_decrypt_raises_value_error_with_invalid_key(self, session, Team): self._team_key = 'one' team = Team(key=self._team_key, name=u'One') session.add(team) session.commit() self._team_key = 'notone' with pytest.raises(ValueError): assert team.name == u'One' class TestFernetEncryptedTypeTestCase(EncryptedTypeTestCase): @pytest.fixture def encryption_engine(self): return FernetEngine @pytest.fixture def padding_mechanism(self): return None class TestDatetimeHandler(object): def test_datetime_with_micro_and_timezone( self, datetime_with_micro_and_timezone ): original_datetime = datetime_with_micro_and_timezone original_datetime_isoformat = original_datetime.isoformat() python_type = datetime assert DatetimeHandler.process_value( original_datetime_isoformat, python_type ) == original_datetime def test_datetime_with_micro(self, datetime_with_micro): original_datetime = datetime_with_micro original_datetime_isoformat = original_datetime.isoformat() python_type = datetime assert DatetimeHandler.process_value( original_datetime_isoformat, python_type ) == original_datetime def test_datetime_simple(self, datetime_simple): original_datetime = datetime_simple original_datetime_isoformat = original_datetime.isoformat() python_type = datetime assert DatetimeHandler.process_value( original_datetime_isoformat, python_type ) == original_datetime def test_time_with_micro(self, time_with_micro): original_time = time_with_micro original_time_isoformat = original_time.isoformat() python_type = time assert DatetimeHandler.process_value( original_time_isoformat, python_type ) == original_time def test_time_simple(self, time_simple): original_time = time_simple original_time_isoformat = original_time.isoformat() python_type = time assert DatetimeHandler.process_value( original_time_isoformat, python_type ) == original_time def test_date_simple(self, date_simple): original_date = date_simple original_date_isoformat = original_date.isoformat() python_type = date assert DatetimeHandler.process_value( original_date_isoformat, python_type ) == original_date @pytest.mark.skipif('cryptography is None') class TestAesGcmEngine(object): KEY = b'0123456789ABCDEF' def setup_method(self): self.engine = AesGcmEngine() self.engine._initialize_engine(TestAesGcmEngine.KEY) def test_roundtrip(self): for l in range(0, 36): plaintext = '0123456789abcdefghijklmnopqrstuvwxyz'[:l] encrypted = self.engine.encrypt(plaintext) decrypted = self.engine.decrypt(encrypted) assert plaintext == decrypted, "Round-trip failed for len: %d" % l def test_modified_iv_fails_to_decrypt(self): plaintext = 'abcdefgh' encrypted = self.engine.encrypt(plaintext) # 3rd char will be IV. Modify it POS = 3 encrypted = encrypted[:POS] + \ (b'A' if encrypted[POS] != b'A' else b'B') + \ encrypted[POS + 1:] with pytest.raises(InvalidCiphertextError): self.engine.decrypt(encrypted) def test_modified_tag_fails_to_decrypt(self): plaintext = 'abcdefgh' encrypted = self.engine.encrypt(plaintext) # 19th char will be tag. Modify it POS = 19 encrypted = encrypted[:POS] + \ (b'A' if encrypted[POS] != b'A' else b'B') + \ encrypted[POS + 1:] with pytest.raises(InvalidCiphertextError): self.engine.decrypt(encrypted) def test_modified_ciphertext_fails_to_decrypt(self): plaintext = 'abcdefgh' encrypted = self.engine.encrypt(plaintext) # 43rd char will be ciphertext. Modify it POS = 43 encrypted = encrypted[:POS] + \ (b'A' if encrypted[POS] != b'A' else b'B') + \ encrypted[POS + 1:] with pytest.raises(InvalidCiphertextError): self.engine.decrypt(encrypted) def test_too_short_ciphertext_fails_to_decrypt(self): plaintext = 'abcdefgh' encrypted = self.engine.encrypt(plaintext)[:20] with pytest.raises(InvalidCiphertextError): self.engine.decrypt(encrypted) def test_different_ciphertexts_each_time(self): plaintext = 'abcdefgh' encrypted1 = self.engine.encrypt(plaintext) encrypted2 = self.engine.encrypt(plaintext) assert self.engine.decrypt(encrypted1) == \ self.engine.decrypt(encrypted2) # The following has a very low probability of failing # accidentally (2^-96) assert encrypted1 != encrypted2 class TestAesGcmEncryptedType(EncryptedTypeTestCase): @pytest.fixture def encryption_engine(self): return AesGcmEngine # GCM doesn't need padding. This is here just because we're reusing test # code that requires this @pytest.fixture def padding_mechanism(self): return 'pkcs5' def test_lookup_by_encrypted_string(self, session, User, user, user_name): test = session.query(User).filter( User.username == "someonex" ).first() # With probability 1-2^-96, the 2 different encryptions will choose a # different IV, and will therefore result in different ciphertexts. # Thus, the 2 values will almost certainly be different, even though # we're really searching for the same username. Hence, the above search # will fail assert test is None sqlalchemy-utils-0.36.1/tests/types/test_int_range.py000066400000000000000000000217341360007755400230040ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import IntRangeType intervals = None inf = -1 try: import intervals from infinity import inf except ImportError: pass @pytest.fixture def Building(Base): class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) persons_at_night = sa.Column(IntRangeType) def __repr__(self): return 'Building(%r)' % self.id return Building @pytest.fixture def init_models(Building): pass @pytest.fixture def create_building(session, Building): def create_building(number_range): building = Building( persons_at_night=number_range ) session.add(building) session.commit() return session.query(Building).first() return create_building @pytest.mark.skipif('intervals is None') class NumberRangeTestCase(object): def test_nullify_range(self, create_building): building = create_building(None) assert building.persons_at_night is None def test_update_with_none(self, session, create_building): interval = intervals.IntInterval([None, None]) building = create_building(interval) building.persons_at_night = None assert building.persons_at_night is None session.commit() assert building.persons_at_night is None @pytest.mark.parametrize( 'number_range', ( [1, 3], (0, 4), ) ) def test_save_number_range(self, create_building, number_range): building = create_building(number_range) assert building.persons_at_night.lower == 1 assert building.persons_at_night.upper == 3 def test_infinite_upper_bound(self, create_building): building = create_building([1, inf]) assert building.persons_at_night.lower == 1 assert building.persons_at_night.upper == inf def test_infinite_lower_bound(self, create_building): building = create_building([-inf, 1]) assert building.persons_at_night.lower == -inf assert building.persons_at_night.upper == 1 def test_nullify_number_range(self, session, Building): building = Building( persons_at_night=intervals.IntInterval([1, 3]) ) session.add(building) session.commit() building = session.query(Building).first() building.persons_at_night = None session.commit() building = session.query(Building).first() assert building.persons_at_night is None def test_integer_coercion(self, Building): building = Building(persons_at_night=15) assert building.persons_at_night.lower == 15 assert building.persons_at_night.upper == 15 @pytest.mark.usefixtures('postgresql_dsn') class TestIntRangeTypeOnPostgres(NumberRangeTestCase): @pytest.mark.parametrize( 'number_range', ( [1, 3], (0, 4) ) ) def test_eq_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night == number_range) ) assert query.count() @pytest.mark.parametrize( ('number_range', 'length'), ( ([1, 3], 2), ([1, 1], 0), ([-1, 1], 2), ([-inf, 1], None), ([0, inf], None), ([0, 0], 0), ([-3, -1], 2) ) ) def test_length( self, session, Building, create_building, number_range, length ): create_building(number_range) query = ( session.query(Building.persons_at_night.length) ) assert query.scalar() == length @pytest.mark.parametrize( 'number_range', ( [[1, 3]], [(0, 4)], ) ) def test_in_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night.in_(number_range)) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 3], (0, 4), ) ) def test_rshift_operator( self, session, Building, create_building, number_range ): create_building([5, 6]) query = ( session.query(Building) .filter(Building.persons_at_night >> number_range) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 3], (0, 4), ) ) def test_lshift_operator( self, session, Building, create_building, number_range ): create_building([-1, 0]) query = ( session.query(Building) .filter(Building.persons_at_night << number_range) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 3], (1, 3), 2 ) ) def test_contains_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night.contains(number_range)) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 3], (0, 8), (-inf, inf) ) ) def test_contained_by_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night.contained_by(number_range)) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [2, 5], 0 ) ) def test_not_in_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(~ Building.persons_at_night.in_([number_range])) ) assert query.count() def test_eq_with_query_arg(self, session, Building, create_building): create_building([1, 3]) query = ( session.query(Building) .filter( Building.persons_at_night == session.query(Building.persons_at_night) ).order_by(Building.persons_at_night).limit(1) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 2], (0, 4), [0, 3], 0, 1, ) ) def test_ge_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night >= number_range) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [0, 2], 0, [-inf, 2] ) ) def test_gt_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night > number_range) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [1, 4], 4, [2, inf] ) ) def test_le_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night <= number_range) ) assert query.count() @pytest.mark.parametrize( 'number_range', ( [2, 4], 4, [1, inf] ) ) def test_lt_operator( self, session, Building, create_building, number_range ): create_building([1, 3]) query = ( session.query(Building) .filter(Building.persons_at_night < number_range) ) assert query.count() def test_literal_param(self, session, Building): clause = Building.persons_at_night == [1, 3] compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == "building.persons_at_night = '[1, 3]'" class TestNumberRangeTypeOnSqlite(NumberRangeTestCase): pass sqlalchemy-utils-0.36.1/tests/types/test_ip_address.py000066400000000000000000000015221360007755400231440ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy_utils.types import ip_address @pytest.fixture def Visitor(Base): class Visitor(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) ip_address = sa.Column(ip_address.IPAddressType) def __repr__(self): return 'Visitor(%r)' % self.id return Visitor @pytest.fixture def init_models(Visitor): pass @pytest.mark.skipif('ip_address.ip_address is None') class TestIPAddressType(object): def test_parameter_processing(self, session, Visitor): visitor = Visitor( ip_address=u'111.111.111.111' ) session.add(visitor) session.commit() visitor = session.query(Visitor).first() assert six.text_type(visitor.ip_address) == u'111.111.111.111' sqlalchemy-utils-0.36.1/tests/types/test_json.py000066400000000000000000000030061360007755400217770ustar00rootroot00000000000000# -*- coding: utf-8 -*- import pytest import sqlalchemy as sa from sqlalchemy_utils.types import json @pytest.fixture def Document(Base): class Document(Base): __tablename__ = 'document' id = sa.Column(sa.Integer, primary_key=True) json = sa.Column(json.JSONType) return Document @pytest.fixture def init_models(Document): pass class JSONTestCase(object): def test_list(self, session, Document): document = Document( json=[1, 2, 3] ) session.add(document) session.commit() document = session.query(Document).first() assert document.json == [1, 2, 3] def test_parameter_processing(self, session, Document): document = Document( json={'something': 12} ) session.add(document) session.commit() document = session.query(Document).first() assert document.json == {'something': 12} def test_non_ascii_chars(self, session, Document): document = Document( json={'something': u'äääööö'} ) session.add(document) session.commit() document = session.query(Document).first() assert document.json == {'something': u'äääööö'} @pytest.mark.skipif('json.json is None') @pytest.mark.usefixtures('sqlite_memory_dsn') class TestSqliteJSONType(JSONTestCase): pass @pytest.mark.skipif('json.json is None') @pytest.mark.usefixtures('postgresql_dsn') class TestPostgresJSONType(JSONTestCase): pass sqlalchemy-utils-0.36.1/tests/types/test_locale.py000066400000000000000000000032351360007755400222710ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.types import locale @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) locale = sa.Column(locale.LocaleType) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass @pytest.mark.skipif('locale.babel is None') class TestLocaleType(object): def test_parameter_processing(self, session, User): user = User( locale=locale.babel.Locale(u'fi') ) session.add(user) session.commit() user = session.query(User).first() def test_territory_parsing(self, session, User): ko_kr = locale.babel.Locale(u'ko', territory=u'KR') user = User(locale=ko_kr) session.add(user) session.commit() assert session.query(User.locale).first()[0] == ko_kr def test_coerce_territory_parsing(self, User): user = User() user.locale = 'ko_KR' assert user.locale == locale.babel.Locale(u'ko', territory=u'KR') def test_scalar_attributes_get_coerced_to_objects(self, User): user = User(locale='en_US') assert isinstance(user.locale, locale.babel.Locale) def test_unknown_locale_throws_exception(self, User): with pytest.raises(locale.babel.UnknownLocaleError): User(locale=u'unknown') def test_literal_param(self, session, User): clause = User.locale == 'en_US' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == '"user".locale = \'en_US\'' sqlalchemy-utils-0.36.1/tests/types/test_ltree.py000066400000000000000000000021161360007755400221420ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils import Ltree, LtreeType @pytest.fixture def Section(Base): class Section(Base): __tablename__ = 'section' id = sa.Column(sa.Integer, primary_key=True) path = sa.Column(LtreeType) return Section @pytest.fixture def init_models(Section, connection): connection.execute('CREATE EXTENSION IF NOT EXISTS ltree') pass @pytest.mark.usefixtures('postgresql_dsn') class TestLTREE(object): def test_saves_path(self, session, Section): section = Section(path='1.1.2') session.add(section) session.commit() user = session.query(Section).first() assert user.path == '1.1.2' def test_scalar_attributes_get_coerced_to_objects(self, Section): section = Section(path='path.path') assert isinstance(section.path, Ltree) def test_literal_param(self, session, Section): clause = Section.path == 'path' compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == 'section.path = \'path\'' sqlalchemy-utils-0.36.1/tests/types/test_numeric_range.py000066400000000000000000000077461360007755400236630ustar00rootroot00000000000000from decimal import Decimal import pytest import sqlalchemy as sa from sqlalchemy_utils import NumericRangeType intervals = None inf = 0 try: import intervals from infinity import inf except ImportError: pass @pytest.fixture def create_car(session, Car): def create_car(number_range): car = Car( price_range=number_range ) session.add(car) session.commit() return session.query(Car).first() return create_car @pytest.mark.skipif('intervals is None') class NumericRangeTestCase(object): @pytest.fixture def Car(self, Base): class Car(Base): __tablename__ = 'car' id = sa.Column(sa.Integer, primary_key=True) price_range = sa.Column(NumericRangeType) return Car @pytest.fixture def init_models(self, Car): pass def test_nullify_range(self, create_car): car = create_car(None) assert car.price_range is None @pytest.mark.parametrize( 'number_range', ( [1, 3], (1, 3) ) ) def test_save_number_range(self, create_car, number_range): car = create_car(number_range) assert car.price_range.lower == 1 assert car.price_range.upper == 3 def test_infinite_upper_bound(self, create_car): car = create_car([1, inf]) assert car.price_range.lower == 1 assert car.price_range.upper == inf def test_infinite_lower_bound(self, create_car): car = create_car([-inf, 1]) assert car.price_range.lower == -inf assert car.price_range.upper == 1 def test_nullify_number_range(self, session, Car): car = Car( price_range=intervals.DecimalInterval([1, 3]) ) session.add(car) session.commit() car = session.query(Car).first() car.price_range = None session.commit() car = session.query(Car).first() assert car.price_range is None def test_integer_coercion(self, Car): car = Car(price_range=15) assert car.price_range.lower == 15 assert car.price_range.upper == 15 @pytest.mark.usefixtures('postgresql_dsn') class TestNumericRangeOnPostgres(NumericRangeTestCase): @pytest.mark.parametrize( ('number_range', 'length'), ( ([1, 3], 2), ([1, 1], 0), ([-1, 1], 2), ([-inf, 1], None), ([0, inf], None), ([0, 0], 0), ([-3, -1], 2) ) ) def test_length(self, session, Car, create_car, number_range, length): create_car(number_range) query = ( session.query(Car.price_range.length) ) assert query.scalar() == length def test_literal_param(self, session, Car): clause = Car.price_range == [1, 3] compiled = str(clause.compile(compile_kwargs={'literal_binds': True})) assert compiled == "car.price_range = '[1, 3]'" @pytest.mark.skipif('intervals is None') class TestNumericRangeWithStep(object): @pytest.fixture def Car(self, Base): class Car(Base): __tablename__ = 'car' id = sa.Column(sa.Integer, primary_key=True) price_range = sa.Column(NumericRangeType(step=Decimal('0.5'))) return Car @pytest.fixture def init_models(self, Car): pass def test_passes_step_argument_to_interval_object(self, create_car): car = create_car([Decimal('0.2'), Decimal('0.8')]) assert car.price_range.lower == Decimal('0') assert car.price_range.upper == Decimal('1') assert car.price_range.step == Decimal('0.5') def test_passes_step_fetched_objects(self, session, Car, create_car): create_car([Decimal('0.2'), Decimal('0.8')]) session.expunge_all() car = session.query(Car).first() assert car.price_range.lower == Decimal('0') assert car.price_range.upper == Decimal('1') assert car.price_range.step == Decimal('0.5') sqlalchemy-utils-0.36.1/tests/types/test_password.py000066400000000000000000000156551360007755400227050ustar00rootroot00000000000000import mock import pytest import sqlalchemy as sa import sqlalchemy.dialects.mysql import sqlalchemy.dialects.oracle import sqlalchemy.dialects.postgresql import sqlalchemy.dialects.sqlite from sqlalchemy import inspect from sqlalchemy_utils import Password, PasswordType, types # noqa @pytest.fixture def extra_kwargs(): """PasswordType extra keyword arguments.""" return {} @pytest.fixture def User(Base, extra_kwargs): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) password = sa.Column(PasswordType( schemes=[ 'pbkdf2_sha512', 'pbkdf2_sha256', 'md5_crypt', 'hex_md5' ], deprecated=['md5_crypt', 'hex_md5'], **extra_kwargs )) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass def onload_callback(schemes, deprecated): """ Get onload callback that takes the PasswordType arguments from the config. """ def onload(**kwargs): kwargs['schemes'] = schemes kwargs['deprecated'] = deprecated return kwargs return onload @pytest.mark.skipif('types.password.passlib is None') class TestPasswordType(object): @pytest.mark.parametrize('dialect_module,impl', [ (sqlalchemy.dialects.sqlite, sa.dialects.sqlite.BLOB), (sqlalchemy.dialects.postgresql, sa.dialects.postgresql.BYTEA), (sqlalchemy.dialects.oracle, sa.dialects.oracle.RAW), (sqlalchemy.dialects.mysql, sa.VARBINARY), ]) def test_load_dialect_impl(self, dialect_module, impl): """ Should produce the same impl type as Alembic would expect after inspecing a database """ password_type = PasswordType() assert isinstance( password_type.load_dialect_impl(dialect_module.dialect()), impl ) def test_encrypt(self, User): """Should encrypt the password on setting the attribute.""" obj = User() obj.password = b'b' assert obj.password.hash != 'b' assert obj.password.hash.startswith(b'$pbkdf2-sha512$') def test_check(self, session, User): """ Should be able to compare the plaintext against the encrypted form. """ obj = User() obj.password = 'b' assert obj.password == 'b' assert obj.password != 'a' session.add(obj) session.commit() obj = session.query(User).get(obj.id) assert obj.password == b'b' assert obj.password != 'a' def test_check_and_update(self, User): """ Should be able to compare the plaintext against a deprecated encrypted form and have it auto-update to the preferred version. """ from passlib.hash import md5_crypt obj = User() obj.password = Password(md5_crypt.hash('b')) assert obj.password.hash.decode('utf8').startswith('$1$') assert obj.password == 'b' assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') def test_auto_column_length(self, User): """Should derive the correct column length from the specified schemes. """ from passlib.hash import pbkdf2_sha512 kind = inspect(User).c.password.type # name + rounds + salt + hash + ($ * 4) of largest hash expected_length = len(pbkdf2_sha512.name) expected_length += len(str(pbkdf2_sha512.max_rounds)) expected_length += pbkdf2_sha512.max_salt_size expected_length += pbkdf2_sha512.encoded_checksum_size expected_length += 4 assert kind.length == expected_length def test_without_schemes(self): assert PasswordType(schemes=[]).length == 1024 def test_compare(self, User): from passlib.hash import md5_crypt obj = User() obj.password = Password(md5_crypt.hash('b')) other = User() other.password = Password(md5_crypt.hash('b')) # Not sure what to assert here; the test raised an error before. assert obj.password != other.password def test_set_none(self, session, User): obj = User() obj.password = None assert obj.password is None session.add(obj) session.commit() obj = session.query(User).get(obj.id) assert obj.password is None def test_update_none(self, session, User): """ Should be able to change a password from ``None`` to a valid password. """ obj = User() obj.password = None session.add(obj) session.commit() obj = session.query(User).get(obj.id) obj.password = 'b' session.commit() def test_compare_none(self, User): """ Should be able to compare a password of ``None``. """ obj = User() obj.password = None assert obj.password is None assert obj.password == None # noqa obj.password = 'b' assert obj.password is not None assert obj.password != None # noqa def test_check_and_update_persist(self, session, User): """ When a password is compared, the hash should update if needed to change the algorithm; and, commit to the database. """ from passlib.hash import md5_crypt obj = User() obj.password = Password(md5_crypt.hash('b')) session.add(obj) session.commit() assert obj.password.hash.decode('utf8').startswith('$1$') assert obj.password == 'b' session.commit() obj = session.query(User).get(obj.id) assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha512$') assert obj.password == 'b' @pytest.mark.parametrize( 'extra_kwargs', [ dict( onload=onload_callback( schemes=['pbkdf2_sha256'], deprecated=[], ) ) ] ) def test_lazy_configuration(self, User): """ Field should be able to read the passlib attributes lazily from the config (e.g. Flask config). """ schemes = User.password.type.context.schemes() assert tuple(schemes) == ('pbkdf2_sha256',) obj = User() obj.password = b'b' assert obj.password.hash.decode('utf8').startswith('$pbkdf2-sha256$') @pytest.mark.parametrize('max_length', [1, 103]) def test_constant_length(self, max_length): """ Test that constant max_length is applied. """ typ = PasswordType(max_length=max_length) assert typ.length == max_length def test_context_is_lazy(self): """ Make sure the init doesn't evaluate the lazy context. """ onload = mock.Mock(return_value={}) PasswordType(onload=onload) assert not onload.called sqlalchemy-utils-0.36.1/tests/types/test_phonenumber.py000066400000000000000000000140061360007755400233520ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy_utils import ( # noqa PhoneNumber, PhoneNumberParseException, PhoneNumberType, types ) VALID_PHONE_NUMBERS = ( '040 1234567', '+358 401234567', '09 2501234', '+358 92501234', '0800 939393', '09 4243 0456', '0600 900 500' ) @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.Unicode(255)) phone_number = sa.Column(PhoneNumberType()) return User @pytest.fixture def init_models(User): pass @pytest.fixture def phone_number(): return PhoneNumber( '040 1234567', 'FI' ) @pytest.fixture def user(session, User, phone_number): user = User() user.name = u'Someone' user.phone_number = phone_number session.add(user) session.commit() return user @pytest.mark.skipif('types.phone_number.phonenumbers is None') class TestPhoneNumber(object): @pytest.mark.parametrize('raw_number', VALID_PHONE_NUMBERS) def test_valid_phone_numbers(self, raw_number): number = PhoneNumber(raw_number, 'FI') assert number.is_valid_number() @pytest.mark.parametrize('raw_number', ('abc', '+040 1234567')) def test_invalid_phone_numbers__constructor_fails(self, raw_number): with pytest.raises(PhoneNumberParseException): PhoneNumber(raw_number, 'FI') @pytest.mark.parametrize('raw_number', ('0111234567', '358')) def test_invalid_phone_numbers__is_valid_number(self, raw_number): number = PhoneNumber(raw_number, 'FI') assert not number.is_valid_number() def test_invalid_phone_numbers_throw_dont_wrap_exception( self, session, User ): with pytest.raises(PhoneNumberParseException): session.execute( User.__table__.insert().values( name=u'Someone', phone_number=u'abc' ) ) def test_phone_number_attributes(self): number = PhoneNumber('+358401234567') assert number.e164 == u'+358401234567' assert number.international == u'+358 40 1234567' assert number.national == u'040 1234567' def test_phone_number_attributes_for_short_code(self): """ For international and national shortcode remains the same, if we pass short code to PhoneNumber library without giving check_region it will raise exception :return: """ number = PhoneNumber('72404', check_region=False) assert number.e164 == u'+072404' assert number.international == u'72404' assert number.national == u'72404' def test_phone_number_str_repr(self): number = PhoneNumber('+358401234567') if six.PY2: assert unicode(number) == number.national # noqa assert str(number) == number.national.encode('utf-8') else: assert str(number) == number.national @pytest.mark.skipif('types.phone_number.phonenumbers is None') class TestPhoneNumberType(object): def test_query_returns_phone_number_object( self, session, User, user, phone_number ): queried_user = session.query(User).first() assert queried_user.phone_number == phone_number def test_phone_number_is_stored_as_string(self, session, user): result = session.execute( 'SELECT phone_number FROM "user" WHERE id=:param', {'param': user.id} ) assert result.first()[0] == u'+358401234567' def test_phone_number_with_extension(self, session, User): user = User(phone_number='555-555-5555 Ext. 555') session.add(user) session.commit() session.refresh(user) assert user.phone_number.extension == '555' def test_empty_phone_number_is_equiv_to_none(self, session, User): user = User(phone_number='') session.add(user) session.commit() session.refresh(user) assert user.phone_number is None def test_uses_phonenumber_class_as_python_type(self): assert PhoneNumberType().python_type is PhoneNumber @pytest.mark.usefixtures('user') def test_phone_number_is_none(self, session, User): phone_number = None user = User() user.name = u'Someone' user.phone_number = phone_number session.add(user) session.commit() queried_user = session.query(User)[1] assert queried_user.phone_number is None result = session.execute( 'SELECT phone_number FROM "user" WHERE id=:param', {'param': user.id} ) assert result.first()[0] is None def test_scalar_attributes_get_coerced_to_objects(self, User): user = User(phone_number='050111222') assert isinstance(user.phone_number, PhoneNumber) @pytest.mark.skipif('types.phone_number.phonenumbers is None') class TestPhoneNumberComposite(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) name = sa.Column(sa.String(255)) _phone_number = sa.Column(sa.String(255)) country = sa.Column(sa.String(255)) phone_number = sa.orm.composite( PhoneNumber, _phone_number, country ) return User @pytest.fixture def user(self, session, User): user = User() user.name = u'Someone' user.phone_number = PhoneNumber('+35840111222', 'FI') session.add(user) session.commit() return user def test_query_returns_phone_number_object( self, session, User, user ): queried_user = session.query(User).first() assert queried_user.phone_number.national == '040 111222' assert queried_user.phone_number.region == 'FI' sqlalchemy-utils-0.36.1/tests/types/test_scalar_list.py000066400000000000000000000043761360007755400233410ustar00rootroot00000000000000import pytest import six import sqlalchemy as sa from sqlalchemy_utils import ScalarListType class TestScalarIntegerList(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) some_list = sa.Column(ScalarListType(int)) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(self, User): pass def test_save_integer_list(self, session, User): user = User( some_list=[1, 2, 3, 4] ) session.add(user) session.commit() user = session.query(User).first() assert user.some_list == [1, 2, 3, 4] class TestScalarUnicodeList(object): @pytest.fixture def User(self, Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) some_list = sa.Column(ScalarListType(six.text_type)) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(self, User): pass def test_throws_exception_if_using_separator_in_list_values( self, session, User ): user = User( some_list=[u','] ) session.add(user) with pytest.raises(sa.exc.StatementError) as db_err: session.commit() assert ( "List values can't contain string ',' (its being used as " "separator. If you wish for scalar list values to contain " "these strings, use a different separator string.)" ) in str(db_err.value) def test_save_unicode_list(self, session, User): user = User( some_list=[u'1', u'2', u'3', u'4'] ) session.add(user) session.commit() user = session.query(User).first() assert user.some_list == [u'1', u'2', u'3', u'4'] def test_save_and_retrieve_empty_list(self, session, User): user = User( some_list=[] ) session.add(user) session.commit() user = session.query(User).first() assert user.some_list == [] sqlalchemy-utils-0.36.1/tests/types/test_timezone.py000066400000000000000000000052271360007755400226670ustar00rootroot00000000000000import pytest import pytz import sqlalchemy as sa from dateutil.zoneinfo import getzoneinfofile_stream, tzfile, ZoneInfoFile from sqlalchemy_utils.types import timezone, TimezoneType @pytest.fixture def Visitor(Base): class Visitor(Base): __tablename__ = 'visitor' id = sa.Column(sa.Integer, primary_key=True) timezone_dateutil = sa.Column( timezone.TimezoneType(backend='dateutil') ) timezone_pytz = sa.Column( timezone.TimezoneType(backend='pytz') ) def __repr__(self): return 'Visitor(%r)' % self.id return Visitor @pytest.fixture def init_models(Visitor): pass class TestTimezoneType(object): def test_parameter_processing(self, session, Visitor): visitor = Visitor( timezone_dateutil=u'America/Los_Angeles', timezone_pytz=u'America/Los_Angeles' ) session.add(visitor) session.commit() visitor_dateutil = session.query(Visitor).filter_by( timezone_dateutil=u'America/Los_Angeles' ).first() visitor_pytz = session.query(Visitor).filter_by( timezone_pytz=u'America/Los_Angeles' ).first() assert visitor_dateutil is not None assert visitor_pytz is not None TIMEZONE_BACKENDS = ['dateutil', 'pytz'] def test_can_coerce_pytz_DstTzInfo(): tzcol = TimezoneType(backend='pytz') tz = pytz.timezone('America/New_York') assert isinstance(tz, pytz.tzfile.DstTzInfo) assert tzcol._coerce(tz) is tz def test_can_coerce_pytz_StaticTzInfo(): tzcol = TimezoneType(backend='pytz') tz = pytz.timezone('Pacific/Truk') assert tzcol._coerce(tz) is tz @pytest.mark.parametrize('zone', pytz.all_timezones) def test_can_coerce_string_for_pytz_zone(zone): tzcol = TimezoneType(backend='pytz') assert tzcol._coerce(zone).zone == zone @pytest.mark.parametrize( 'zone', ZoneInfoFile(getzoneinfofile_stream()).zones.keys()) def test_can_coerce_string_for_dateutil_zone(zone): tzcol = TimezoneType(backend='dateutil') assert isinstance(tzcol._coerce(zone), tzfile) @pytest.mark.parametrize('backend', TIMEZONE_BACKENDS) def test_can_coerce_and_raise_UnknownTimeZoneError_or_ValueError(backend): tzcol = TimezoneType(backend=backend) with pytest.raises((ValueError, pytz.exceptions.UnknownTimeZoneError)): tzcol._coerce('SolarSystem/Mars') with pytest.raises((ValueError, pytz.exceptions.UnknownTimeZoneError)): tzcol._coerce('') @pytest.mark.parametrize('backend', TIMEZONE_BACKENDS) def test_can_coerce_None(backend): tzcol = TimezoneType(backend=backend) assert tzcol._coerce(None) is None sqlalchemy-utils-0.36.1/tests/types/test_tsvector.py000066400000000000000000000046171360007755400227100ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy_utils import TSVectorType @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) name = sa.Column(sa.Unicode(255)) search_index = sa.Column( TSVectorType(name, regconfig='pg_catalog.finnish') ) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass @pytest.mark.usefixtures('postgresql_dsn') class TestTSVector(object): def test_generates_table(self, User): assert 'search_index' in User.__table__.c @pytest.mark.usefixtures('session') def test_type_reflection(self, engine): reflected_metadata = sa.schema.MetaData() table = sa.schema.Table( 'user', reflected_metadata, autoload=True, autoload_with=engine ) assert isinstance(table.c['search_index'].type, TSVECTOR) def test_catalog_and_columns_as_args(self): type_ = TSVectorType('name', 'age', regconfig='pg_catalog.simple') assert type_.columns == ('name', 'age') assert type_.options['regconfig'] == 'pg_catalog.simple' def test_match(self, connection, User): expr = User.search_index.match(u'something') assert str(expr.compile(connection)) == ( '''"user".search_index @@ to_tsquery('pg_catalog.finnish', ''' '''%(search_index_1)s)''' ) def test_concat(self, User): assert str(User.search_index | User.search_index) == ( '"user".search_index || "user".search_index' ) def test_match_concatenation(self, session, User): concat = User.search_index | User.search_index bind = session.bind assert str(concat.match('something').compile(bind)) == ( '("user".search_index || "user".search_index) @@ ' "to_tsquery('pg_catalog.finnish', %(param_1)s)" ) def test_match_with_catalog(self, connection, User): expr = User.search_index.match( u'something', postgresql_regconfig='pg_catalog.simple' ) assert str(expr.compile(connection)) == ( '''"user".search_index @@ to_tsquery('pg_catalog.simple', ''' '''%(search_index_1)s)''' ) sqlalchemy-utils-0.36.1/tests/types/test_url.py000066400000000000000000000016271360007755400216370ustar00rootroot00000000000000import pytest import sqlalchemy as sa from sqlalchemy_utils.types import url @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) website = sa.Column(url.URLType) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass @pytest.mark.skipif('url.furl is None') class TestURLType(object): def test_color_parameter_processing(self, session, User): user = User( website=url.furl(u'www.example.com') ) session.add(user) session.commit() user = session.query(User).first() assert isinstance(user.website, url.furl) def test_scalar_attributes_get_coerced_to_objects(self, User): user = User(website=u'www.example.com') assert isinstance(user.website, url.furl) sqlalchemy-utils-0.36.1/tests/types/test_uuid.py000066400000000000000000000016761360007755400220070ustar00rootroot00000000000000import uuid import pytest import sqlalchemy as sa from sqlalchemy_utils import UUIDType @pytest.fixture def User(Base): class User(Base): __tablename__ = 'user' id = sa.Column(UUIDType, default=uuid.uuid4, primary_key=True) def __repr__(self): return 'User(%r)' % self.id return User @pytest.fixture def init_models(User): pass class TestUUIDType(object): def test_commit(self, session, User): obj = User() obj.id = uuid.uuid4().hex session.add(obj) session.commit() u = session.query(User).one() assert u.id == obj.id def test_coerce(self, User): obj = User() obj.id = identifier = uuid.uuid4().hex assert isinstance(obj.id, uuid.UUID) assert obj.id.hex == identifier obj.id = identifier = uuid.uuid4().bytes assert isinstance(obj.id, uuid.UUID) assert obj.id.bytes == identifier sqlalchemy-utils-0.36.1/tests/types/test_weekdays.py000066400000000000000000000030451360007755400226450ustar00rootroot00000000000000# -*- coding: utf-8 -*- import pytest import sqlalchemy as sa from sqlalchemy_utils import i18n from sqlalchemy_utils.primitives import WeekDays from sqlalchemy_utils.types import WeekDaysType @pytest.fixture def Schedule(Base): class Schedule(Base): __tablename__ = 'schedule' id = sa.Column(sa.Integer, primary_key=True) working_days = sa.Column(WeekDaysType) def __repr__(self): return 'Schedule(%r)' % self.id return Schedule @pytest.fixture def init_models(Schedule): pass @pytest.fixture def set_get_locale(): i18n.get_locale = lambda: i18n.babel.Locale('en') @pytest.mark.usefixtures('set_get_locale') @pytest.mark.skipif('i18n.babel is None') class WeekDaysTypeTestCase(object): def test_color_parameter_processing(self, session, Schedule): schedule = Schedule( working_days=b'0001111' ) session.add(schedule) session.commit() schedule = session.query(Schedule).first() assert isinstance(schedule.working_days, WeekDays) def test_scalar_attributes_get_coerced_to_objects(self, Schedule): schedule = Schedule(working_days=b'1010101') assert isinstance(schedule.working_days, WeekDays) @pytest.mark.usefixtures('sqlite_memory_dsn') class TestWeekDaysTypeOnSQLite(WeekDaysTypeTestCase): pass @pytest.mark.usefixtures('postgresql_dsn') class TestWeekDaysTypeOnPostgres(WeekDaysTypeTestCase): pass @pytest.mark.usefixtures('mysql_dsn') class TestWeekDaysTypeOnMySQL(WeekDaysTypeTestCase): pass sqlalchemy-utils-0.36.1/tox.ini000066400000000000000000000012041360007755400164200ustar00rootroot00000000000000[tox] envlist = py27, py35, py36, py37, lint [testenv] commands = py.test sqlalchemy_utils tests deps = .[test_all] passenv = SQLALCHEMY_UTILS_TEST_DB SQLALCHEMY_UTILS_TEST_POSTGRESQL_USER SQLALCHEMY_UTILS_TEST_MYSQL_USER [testenv:py27] recreate = True [testenv:py35] recreate = True [testenv:py36] recreate = True [testenv:py37] recreate = True [testenv:lint] recreate = True commands = flake8 sqlalchemy_utils tests isort --verbose --recursive --diff sqlalchemy_utils tests isort --verbose --recursive --check-only sqlalchemy_utils tests skip_install = True deps = .[test_all] flake8>=2.5.0 isort==4.2.2