././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864711.0871167 pony-0.7.11/0000777000000000000000000000000000000000000010720 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318475.0 pony-0.7.11/LICENSE0000666000000000000000000002425500000000000011735 0ustar0000000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS Copyright 2016 Alexander Kozlovsky, Alexey Malashkevich ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/MANIFEST.in0000666000000000000000000000014200000000000012453 0ustar0000000000000000include pony/orm/tests/queries.txt include pony/flask/example/templates *.html include LICENSE ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864711.0856197 pony-0.7.11/PKG-INFO0000666000000000000000000000617700000000000012030 0ustar0000000000000000Metadata-Version: 1.1 Name: pony Version: 0.7.11 Summary: Pony Object-Relational Mapper Home-page: https://ponyorm.com Author: Alexander Kozlovsky, Alexey Malashkevich Author-email: team@ponyorm.com License: Apache License Version 2.0 Download-URL: http://pypi.python.org/pypi/pony/ Description: About ========= Pony ORM is easy to use and powerful object-relational mapper for Python. Using Pony, developers can create and maintain database-oriented software applications faster and with less effort. One of the most interesting features of Pony is its ability to write queries to the database using generator expressions. Pony then analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. Following is an example of a query in Pony:: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Such approach simplify the code and allows a programmer to concentrate on the business logic of the application. Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. The package `pony.orm.examples `_ contains several examples. Installation ================= :: pip install pony Entity-Relationship Diagram Editor ============================================= `Pony online ER Diagram Editor `_ is a great tool for prototyping. You can draw your ER diagram online, generate Pony entity declarations or SQL script for creating database schema based on the diagram and start working with the database in seconds. Pony ORM Links: ================= - Main site: https://ponyorm.com - Documentation: https://docs.ponyorm.com - GitHub: https://github.com/ponyorm/pony - Mailing list: http://ponyorm-list.ponyorm.com - ER Diagram Editor: https://editor.ponyorm.com - Blog: https://blog.ponyorm.com Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Database ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/README.md0000666000000000000000000000570400000000000012205 0ustar0000000000000000Pony Object-Relational Mapper ============================= Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions and lambdas. Pony analyzes the abstract syntax tree of the expression and translates it into a SQL query. Here is an example query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps to save resources. Pony achieves this ease of use through the following: * Compact entity definitions * The concise query language * Ability to work with Pony interactively in a Python interpreter * Comprehensive error messages, showing the exact part where an error occurred in the query * Displaying of the generated SQL in a readable format with indentation All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. See the example [here](https://github.com/ponyorm/pony/blob/orm/pony/orm/examples/estore.py) Support Pony ORM Development ---------------------------- Pony ORM is Apache 2.0 licensed open source project. If you would like to support Pony ORM development, please consider: [Become a backer or sponsor](https://ponyorm.org/donation.html) Online tool for database design ------------------------------- Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. Documentation ------------- Documenation is available at [https://docs.ponyorm.org](https://docs.ponyorm.org) The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc). Please create new documentation related issues [here](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. License ------- Pony ORM is released under the Apache 2.0 license. PonyORM community ----------------- Please post your questions on [Stack Overflow](http://stackoverflow.com/questions/tagged/ponyorm). Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://t.me/ponyorm). Join our newsletter at [ponyorm.org](https://ponyorm.org). Reach us on [Twitter](https://twitter.com/ponyorm). Copyright (c) 2013-2019 Pony ORM. All rights reserved. info (at) ponyorm.org ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.1779199 pony-0.7.11/pony/0000777000000000000000000000000000000000000011705 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864301.0 pony-0.7.11/pony/__init__.py0000666000000000000000000000255300000000000014023 0ustar0000000000000000from __future__ import absolute_import, print_function import os, sys from os.path import dirname __version__ = '0.7.11' def detect_mode(): try: import google.appengine except ImportError: pass else: if os.environ.get('SERVER_SOFTWARE', '').startswith('Development'): return 'GAE-LOCAL' return 'GAE-SERVER' try: from mod_wsgi import version except: pass else: return 'MOD_WSGI' main = sys.modules['__main__'] if not hasattr(main, '__file__'): # console return 'INTERACTIVE' if getattr(main, 'INTERACTIVE_MODE_AVAILABLE', False): # pycharm console return 'INTERACTIVE' if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' if 'uwsgi' in sys.modules: return 'UWSGI' if 'flask' in sys.modules: return 'FLASK' if 'cherrypy' in sys.modules: return 'CHERRYPY' if 'bottle' in sys.modules: return 'BOTTLE' return 'UNKNOWN' MODE = detect_mode() MAIN_FILE = None if MODE == 'MOD_WSGI': for module_name, module in sys.modules.items(): if module_name.startswith('_mod_wsgi_'): MAIN_FILE = module.__file__ break elif MODE != 'INTERACTIVE': MAIN_FILE = sys.modules['__main__'].__file__ if MAIN_FILE is not None: MAIN_DIR = dirname(MAIN_FILE) else: MAIN_DIR = None PONY_DIR = dirname(__file__) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1524918839.0 pony-0.7.11/pony/converting.py0000666000000000000000000002157000000000000014442 0ustar0000000000000000# coding: cp1251 from __future__ import absolute_import, print_function from pony.py23compat import PY2, iteritems, imap, izip, xrange, unicode, basestring import re from datetime import datetime, date, time, timedelta from pony.utils import is_ident class ValidationError(ValueError): pass def check_ip(s): s = s.strip() items = s.split('.') if len(items) != 4: raise ValueError() for item in items: if not 0 <= int(item) <= 255: raise ValueError() return s def check_positive(s): i = int(s) if i > 0: return i raise ValueError() def check_identifier(s): if is_ident(s): return s raise ValueError() isbn_re = re.compile(r'(?:\d[ -]?)+x?') def isbn10_checksum(digits): if len(digits) != 9: raise ValueError() reminder = sum(digit*coef for digit, coef in izip(imap(int, digits), xrange(10, 1, -1))) % 11 if reminder == 1: return 'X' return reminder and str(11 - reminder) or '0' def isbn13_checksum(digits): if len(digits) != 12: raise ValueError() reminder = sum(digit*coef for digit, coef in izip(imap(int, digits), (1, 3)*6)) % 10 return reminder and str(10 - reminder) or '0' def check_isbn(s, convert_to=None): s = s.strip().upper() if s[:4] == 'ISBN': s = s[4:].lstrip() digits = s.replace('-', '').replace(' ', '') size = len(digits) if size == 10: checksum_func = isbn10_checksum elif size == 13: checksum_func = isbn13_checksum else: raise ValueError() digits, last = digits[:-1], digits[-1] if checksum_func(digits) != last: if last.isdigit() or size == 10 and last == 'X': raise ValidationError('Invalid ISBN checksum') raise ValueError() if convert_to is not None: if size == 10 and convert_to == 13: digits = '978' + digits s = digits + isbn13_checksum(digits) elif size == 13 and convert_to == 10 and digits[:3] == '978': digits = digits[3:] s = digits + isbn10_checksum(digits) return s def isbn10_to_isbn13(s): return check_isbn(s, convert_to=13) def isbn13_to_isbn10(s): return check_isbn(s, convert_to=10) # The next two regular expressions taken from # http://www.regular-expressions.info/email.html email_re = re.compile( r'^[a-z0-9._%+-]+@[a-z0-9][a-z0-9-]*(?:\.[a-z0-9][a-z0-9-]*)+$', re.IGNORECASE) rfc2822_email_re = re.compile(r''' ^(?: [a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)* | "(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21\x23-\x5b\x5d-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])*" ) @ (?: (?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])? | \[ (?: (?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3} (?: 25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?|[a-z0-9-]*[a-z0-9] :(?:[\x01-\x08\x0b\x0c\x0e-\x1f\x21-\x5a\x53-\x7f]|\\[\x01-\x09\x0b\x0c\x0e-\x7f])+ ) \] )$''', re.IGNORECASE | re.VERBOSE) def check_email(s): s = s.strip() if email_re.match(s) is None: raise ValueError() return s def check_rfc2822_email(s): s = s.strip() if rfc2822_email_re.match(s) is None: raise ValueError() return s date_str_list = [ r'(?P\d{1,2})/(?P\d{1,2})/(?P\d{4})', r'(?P\d{1,2})\.(?P\d{1,2})\.(?P\d{4})', r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,4})', r'(?P\d{4})/(?P\d{1,2})/(?P\d{1,4})', r'(?P\d{4})\.(?P\d{1,2})\.(?P\d{1,4})', r'\D*(?P\d{4})\D+(?P\d{1,2})\D*', r'\D*(?P\d{1,2})\D+(?P\d{4})\D*' ] date_re_list = [ re.compile('^%s$'%s, re.UNICODE) for s in date_str_list ] time_str = r''' (?P\d{1,2}) # hours (?: \s* [hu] \s* )? # optional hours suffix (?: (?: (?<=\d)[:. ] | (?\d{1,2}) # minutes (?: (?: \s* m(?:in)? | ' ) \s* )? # optional minutes suffix (?: (?: (?<=\d)[:. ] | (?\d{1,2}(?:\.\d{1,6})?) # seconds with optional microseconds \s* (?: (?: s(?:ec)? | " ) \s* )? # optional seconds suffix )? )? (?: # optional A.M./P.M. part \s* (?: (?P a\.?m\.? ) | (?P p\.?m\.? ) ) )? ''' time_re = re.compile('^%s$'%time_str, re.VERBOSE) datetime_re_list = [ re.compile('^%s(?:[t ]%s)?$' % (date_str, time_str), re.UNICODE | re.VERBOSE) for date_str in date_str_list ] month_lists = [ "jan feb mar apr may jun jul aug sep oct nov dec".split(), u"янв фев мар апр май июн июл авг сен окт ноя дек".split(), # Russian ] month_dict = {} for month_list in month_lists: for i, month in enumerate(month_list): month_dict[month] = i + 1 month_dict[u'мая'] = 5 # Russian def str2date(s): s = s.strip().lower() for date_re in date_re_list: match = date_re.match(s) if match is not None: break else: raise ValueError('Unrecognized date format') dict = match.groupdict() year = dict['year'] day = dict['day'] month = dict.get('month') if month is None: for key, value in iteritems(month_dict): if key in s: month = value; break else: raise ValueError('Unrecognized date format') return date(int(year), int(month), int(day)) def str2time(s): s = s.strip().lower() match = time_re.match(s) if match is None: raise ValueError('Unrecognized time format') hh, mm, ss, mcs = _extract_time_parts(match.groupdict()) return time(hh, mm, ss, mcs) def str2datetime(s): s = s.strip().lower() for datetime_re in datetime_re_list: match = datetime_re.match(s) if match is not None: break else: raise ValueError('Unrecognized datetime format') d = match.groupdict() year, day, month = d['year'], d['day'], d.get('month') if month is None: for key, value in iteritems(month_dict): if key in s: month = value; break else: raise ValueError('Unrecognized datetime format') hh, mm, ss, mcs = _extract_time_parts(d) return datetime(int(year), int(month), int(day), hh, mm, ss, mcs) def _extract_time_parts(groupdict): hh, mm, ss, am, pm = imap(groupdict.get, ('hh', 'mm', 'ss', 'am', 'pm')) if hh is None: hh, mm, ss = 12, 00, 00 elif am and hh == '12': hh = 0 elif pm and hh != '12': hh = int(hh) + 12 if isinstance(ss, basestring) and '.' in ss: ss, mcs = ss.split('.', 1) if len('mcs') < 6: mcs = (mcs + '000000')[:6] else: mcs = 0 return int(hh), int(mm or 0), int(ss or 0), int(mcs) def str2timedelta(s): if '.' in s: s, fractional = s.split('.') microseconds = int((fractional + '000000')[:6]) else: microseconds = 0 h, m, s = imap(int, s.split(':')) td = timedelta(hours=abs(h), minutes=m, seconds=s, microseconds=microseconds) return -td if h < 0 else td def timedelta2str(td): total_seconds = td.days * (24 * 60 * 60) + td.seconds microseconds = td.microseconds if td.days < 0: total_seconds = abs(total_seconds) if microseconds: total_seconds -= 1 microseconds = 1000000 - microseconds minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) if microseconds: result = '%d:%d:%d.%06d' % (hours, minutes, seconds, microseconds) else: result = '%d:%d:%d' % (hours, minutes, seconds) if td.days >= 0: return result return '-' + result converters = { int: (int, unicode, 'Incorrect number'), float: (float, unicode, 'Must be a real number'), 'IP': (check_ip, unicode, 'Incorrect IP address'), 'positive': (check_positive, unicode, 'Must be a positive number'), 'identifier': (check_identifier, unicode, 'Incorrect identifier'), 'ISBN': (check_isbn, unicode, 'Incorrect ISBN'), 'email': (check_email, unicode, 'Incorrect e-mail address'), 'rfc2822_email': (check_rfc2822_email, unicode, 'Must be correct e-mail address'), date: (str2date, unicode, 'Must be correct date (mm/dd/yyyy or dd.mm.yyyy)'), time: (str2time, unicode, 'Must be correct time (hh:mm or hh:mm:ss)'), datetime: (str2datetime, unicode, 'Must be correct date & time'), } if PY2: converters[long] = (long, unicode, 'Incorrect number') def str2py(value, type): if type is None or not isinstance(value, unicode): return value if isinstance(type, tuple): str2py, py2str, err_msg = type else: str2py, py2str, err_msg = converters.get(type, (type, unicode, None)) try: return str2py(value) except ValidationError: raise except: if value == '': return None raise ValidationError(err_msg or 'Incorrect data') ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.2233288 pony-0.7.11/pony/flask/0000777000000000000000000000000000000000000013005 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/flask/__init__.py0000666000000000000000000000116400000000000015120 0ustar0000000000000000from pony.orm import db_session from flask import request def _enter_session(): session = db_session() request.pony_session = session session.__enter__() def _exit_session(exception): session = getattr(request, 'pony_session', None) if session is not None: session.__exit__(exc=exception) class Pony(object): def __init__(self, app=None): self.app = None if app is not None: self.init_app(app) def init_app(self, app): self.app = app self.app.before_request(_enter_session) self.app.teardown_request(_exit_session)././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.2857056 pony-0.7.11/pony/flask/example/0000777000000000000000000000000000000000000014440 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/__init__.py0000666000000000000000000000000000000000000016537 0ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/__main__.py0000666000000000000000000000024700000000000016535 0ustar0000000000000000from .views import * from .app import app if __name__ == '__main__': db.bind(**app.config['PONY']) db.generate_mapping(create_tables=True) app.run()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/app.py0000666000000000000000000000056300000000000015576 0ustar0000000000000000from flask import Flask from flask_login import LoginManager from pony.flask import Pony from .config import config from .models import db app = Flask(__name__) app.config.update(config) Pony(app) login_manager = LoginManager(app) login_manager.login_view = 'login' @login_manager.user_loader def load_user(user_id): return db.User.get(id=user_id) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/config.py0000666000000000000000000000026300000000000016260 0ustar0000000000000000config = dict( DEBUG = False, SECRET_KEY = 'secret_xxx', PONY = { 'provider': 'sqlite', 'filename': 'db.db3', 'create_db': True } )././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/models.py0000666000000000000000000000042600000000000016277 0ustar0000000000000000from pony.orm import Database, Required, Optional from flask_login import UserMixin from datetime import datetime db = Database() class User(db.Entity, UserMixin): login = Required(str, unique=True) password = Required(str) last_login = Optional(datetime)././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.3206365 pony-0.7.11/pony/flask/example/templates/0000777000000000000000000000000000000000000016436 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/templates/index.html0000666000000000000000000000252000000000000020432 0ustar0000000000000000 Hello!
{% with messages = get_flashed_messages() %} {% if messages %} {% for message in messages %} {% endfor %} {% endif %} {% endwith %} {% if not current_user.is_authenticated %}

Hi, please log in or register


{% else %}

Hi, {{ current_user.login }}. Your last login: {{ current_user.last_login.strftime('%Y-%m-%d') }}

Logout

List of users

    {% for user in users %}
  • {% if user == current_user %} {{ user.login }} {% else %} {{ user.login }} {% endif %}
  • {% endfor %}
{% endif %}
././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/templates/login.html0000666000000000000000000000207700000000000020442 0ustar0000000000000000 Login page
{% with messages = get_flashed_messages() %} {% if messages %} {% for message in messages %} {% endfor %} {% endif %} {% endwith %}

Please login


{% if error %}

Error: {{ error }} {% endif %}

././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/templates/reg.html0000666000000000000000000000207600000000000020106 0ustar0000000000000000 Login page
{% with messages = get_flashed_messages() %} {% if messages %} {% for message in messages %} {% endfor %} {% endif %} {% endwith %}

Register


{% if error %}

Error: {{ error }} {% endif %}

././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/flask/example/views.py0000666000000000000000000000341700000000000016154 0ustar0000000000000000from .app import app from .models import db from flask import render_template, request, flash, redirect, abort from flask_login import current_user, logout_user, login_user, login_required from datetime import datetime from pony.orm import flush @app.route('/') def index(): users = db.User.select() return render_template('index.html', user=current_user, users=users) @app.route('/login', methods=['GET', 'POST']) def login(): if request.method == 'POST': username = request.form['username'] password = request.form['password'] possible_user = db.User.get(login=username) if not possible_user: flash('Wrong username') return redirect('/login') if possible_user.password == password: possible_user.last_login = datetime.now() login_user(possible_user) return redirect('/') flash('Wrong password') return redirect('/login') else: return render_template('login.html') @app.route('/reg', methods=['GET', 'POST']) def reg(): if request.method == 'POST': username = request.form['username'] password = request.form['password'] exist = db.User.get(login=username) if exist: flash('Username %s is already taken, choose another one' % username) return redirect('/reg') user = db.User(login=username, password=password) user.last_login = datetime.now() flush() login_user(user) flash('Successfully registered') return redirect('/') else: return render_template('reg.html') @app.route('/logout') @login_required def logout(): logout_user() flash('Logged out') return redirect('/')././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862663.0 pony-0.7.11/pony/options.py0000666000000000000000000000507200000000000013756 0ustar0000000000000000DEBUG = True STATIC_DIR = None CUT_TRACEBACK = True #postprocessing options: STD_DOCTYPE = '' STD_STYLESHEETS = [ ("/pony/static/blueprint/screen.css", "screen, projection"), ("/pony/static/blueprint/print.css", "print"), ("/pony/static/blueprint/ie.css.css", "screen, projection", "if IE"), ("/pony/static/css/default.css", "screen, projection"), ] BASE_STYLESHEETS_PLACEHOLDER = '' COMPONENT_STYLESHEETS_PLACEHOLDER = '' SCRIPTS_PLACEHOLDER = '' # reloading options: RELOADING_CHECK_INTERVAL = 1.0 # in seconds # logging options: LOG_TO_SQLITE = None LOGGING_LEVEL = None LOGGING_PONY_LEVEL = None #auth options: MAX_SESSION_CTIME = 60*24 # one day MAX_SESSION_MTIME = 60*2 # 2 hours MAX_LONGLIFE_SESSION = 14 # 14 days COOKIE_SERIALIZATION_TYPE = 'json' # may be 'json' or 'pickle' COOKIE_NAME = 'pony' COOKIE_PATH = '/' COOKIE_DOMAIN = None HASH_ALGORITHM = None # sha-1 by default # HASH_ALGORITHM = hashlib.sha512 SESSION_STORAGE = None # pony.sessionstorage.memcachedstorage by default # SESSION_STORAGE = mystoragemodule # SESSION_STORAGE = False # means use cookies for save session data, # can lead to race conditions # memcached options (ignored under GAE): MEMCACHE = None # Use in-process python version by default # MEMCACHE = [ "127.0.0.1:11211" ] # MEMCACHE = MyMemcacheConnectionImplementation(...) ALTERNATIVE_SESSION_MEMCACHE = None # Use general memcache connection by default ALTERNATIVE_ORM_MEMCACHE = None # Use general memcache connection by default ALTERNATIVE_TEMPLATING_MEMCACHE = None # Use general memcache connection by default ALTERNATIVE_RESPONCE_MEMCACHE = None # Use general memcache connection by default # pickle options: PICKLE_START_OFFSET = 230 PICKLE_HTML_AS_PLAIN_STR = True # encoding options for pony.pathces.repr RESTORE_ESCAPES = True SOURCE_ENCODING = None CONSOLE_ENCODING = None # db options MAX_FETCH_COUNT = None # used for select(...).show() CONSOLE_WIDTH = 80 # sql translator options SIMPLE_ALIASES = True # if True just use entity name like "Course-1" # if False use attribute names chain as an alias like "student-grades-course" INNER_JOIN_SYNTAX = False # put conditions to INNER JOIN ... ON ... or to WHERE ... # debugging options DEBUGGING_REMOVE_ADDR = True DEBUGGING_RESTORE_ESCAPES = True ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.3924932 pony-0.7.11/pony/orm/0000777000000000000000000000000000000000000012502 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537633389.0 pony-0.7.11/pony/orm/__init__.py0000666000000000000000000000010700000000000014611 0ustar0000000000000000from __future__ import absolute_import from pony.orm.core import * ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862663.0 pony-0.7.11/pony/orm/asttranslation.py0000666000000000000000000003603300000000000016127 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import basestring, iteritems from functools import update_wrapper from pony.thirdparty.compiler import ast from pony.utils import HashableDict, throw, copy_ast class TranslationError(Exception): pass pre_method_caches = {} post_method_caches = {} class ASTTranslator(object): def __init__(translator, tree): translator.tree = tree translator_cls = translator.__class__ pre_method_caches.setdefault(translator_cls, {}) post_method_caches.setdefault(translator_cls, {}) def dispatch(translator, node): translator_cls = translator.__class__ pre_methods = pre_method_caches[translator_cls] post_methods = post_method_caches[translator_cls] node_cls = node.__class__ try: pre_method = pre_methods[node_cls] except KeyError: pre_method = getattr(translator_cls, 'pre' + node_cls.__name__, translator_cls.default_pre) pre_methods[node_cls] = pre_method stop = translator.call(pre_method, node) if stop: return for child in node.getChildNodes(): translator.dispatch(child) try: post_method = post_methods[node_cls] except KeyError: post_method = getattr(translator_cls, 'post' + node_cls.__name__, translator_cls.default_post) post_methods[node_cls] = post_method translator.call(post_method, node) def call(translator, method, node): return method(translator, node) def default_pre(translator, node): pass def default_post(translator, node): pass def priority(p): def decorator(func): def new_func(translator, node): node.priority = p for child in node.getChildNodes(): if getattr(child, 'priority', 0) >= p: child.src = '(%s)' % child.src return func(translator, node) return update_wrapper(new_func, func) return decorator def binop_src(op, node): return op.join((node.left.src, node.right.src)) def ast2src(tree): src = getattr(tree, 'src', None) if src is not None: return src PythonTranslator(tree) return tree.src class PythonTranslator(ASTTranslator): def __init__(translator, tree): ASTTranslator.__init__(translator, tree) translator.top_level_f_str = None translator.dispatch(tree) def call(translator, method, node): node.src = method(translator, node) def default_pre(translator, node): if getattr(node, 'src', None) is not None: return True # node.src is already calculated, stop dispatching def default_post(translator, node): throw(NotImplementedError, node) def postGenExpr(translator, node): return '(%s)' % node.code.src def postGenExprInner(translator, node): return node.expr.src + ' ' + ' '.join(qual.src for qual in node.quals) def postGenExprFor(translator, node): src = 'for %s in %s' % (node.assign.src, node.iter.src) if node.ifs: ifs = ' '.join(if_.src for if_ in node.ifs) src += ' ' + ifs return src def postGenExprIf(translator, node): return 'if %s' % node.test.src def postIfExp(translator, node): return '%s if %s else %s' % (node.then.src, node.test.src, node.else_.src) def postLambda(translator, node): argnames = list(node.argnames) kwargs_name = argnames.pop() if node.kwargs else None varargs_name = argnames.pop() if node.varargs else None def_argnames = argnames[-len(node.defaults):] if node.defaults else [] nodef_argnames = argnames[:-len(node.defaults)] if node.defaults else argnames args = ', '.join(nodef_argnames) d_args = ', '.join('%s=%s' % (argname, default.src) for argname, default in zip(def_argnames, node.defaults)) v_arg = '*%s' % varargs_name if varargs_name else None kw_arg = '**%s' % kwargs_name if kwargs_name else None args = ', '.join(x for x in [args, d_args, v_arg, kw_arg] if x) return 'lambda %s: %s' % (args, node.code.src) @priority(14) def postOr(translator, node): return ' or '.join(expr.src for expr in node.nodes) @priority(13) def postAnd(translator, node): return ' and '.join(expr.src for expr in node.nodes) @priority(12) def postNot(translator, node): return 'not ' + node.expr.src @priority(11) def postCompare(translator, node): result = [ node.expr.src ] for op, expr in node.ops: result.extend((op, expr.src)) return ' '.join(result) @priority(10) def postBitor(translator, node): return ' | '.join(expr.src for expr in node.nodes) @priority(9) def postBitxor(translator, node): return ' ^ '.join(expr.src for expr in node.nodes) @priority(8) def postBitand(translator, node): return ' & '.join(expr.src for expr in node.nodes) @priority(7) def postLeftShift(translator, node): return binop_src(' << ', node) @priority(7) def postRightShift(translator, node): return binop_src(' >> ', node) @priority(6) def postAdd(translator, node): return binop_src(' + ', node) @priority(6) def postSub(translator, node): return binop_src(' - ', node) @priority(5) def postMul(translator, node): return binop_src(' * ', node) @priority(5) def postDiv(translator, node): return binop_src(' / ', node) @priority(5) def postFloorDiv(translator, node): return binop_src(' // ', node) @priority(5) def postMod(translator, node): return binop_src(' % ', node) @priority(4) def postUnarySub(translator, node): return '-' + node.expr.src @priority(4) def postUnaryAdd(translator, node): return '+' + node.expr.src @priority(4) def postInvert(translator, node): return '~' + node.expr.src @priority(3) def postPower(translator, node): return binop_src(' ** ', node) def postGetattr(translator, node): node.priority = 2 return '.'.join((node.expr.src, node.attrname)) def postCallFunc(translator, node): node.priority = 2 args = [ arg.src for arg in node.args ] if node.star_args: args.append('*'+node.star_args.src) if node.dstar_args: args.append('**'+node.dstar_args.src) if len(args) == 1 and isinstance(node.args[0], ast.GenExpr): return node.node.src + args[0] return '%s(%s)' % (node.node.src, ', '.join(args)) def postSubscript(translator, node): node.priority = 2 if len(node.subs) == 1: sub = node.subs[0] if isinstance(sub, ast.Const) and type(sub.value) is tuple and len(sub.value) > 1: key = sub.src assert key.startswith('(') and key.endswith(')') key = key[1:-1] else: key = sub.src else: key = ', '.join([ sub.src for sub in node.subs ]) return '%s[%s]' % (node.expr.src, key) def postSlice(translator, node): node.priority = 2 lower = node.lower.src if node.lower is not None else '' upper = node.upper.src if node.upper is not None else '' return '%s[%s:%s]' % (node.expr.src, lower, upper) def postSliceobj(translator, node): return ':'.join(item.src for item in node.nodes) def postConst(translator, node): node.priority = 1 value = node.value if type(value) is float: # for Python < 2.7 s = str(value) if float(s) == value: return s return repr(value) def postEllipsis(translator, node): return '...' def postList(translator, node): node.priority = 1 return '[%s]' % ', '.join(item.src for item in node.nodes) def postTuple(translator, node): node.priority = 1 if len(node.nodes) == 1: return '(%s,)' % node.nodes[0].src else: return '(%s)' % ', '.join(item.src for item in node.nodes) def postAssTuple(translator, node): node.priority = 1 if len(node.nodes) == 1: return '(%s,)' % node.nodes[0].src else: return '(%s)' % ', '.join(item.src for item in node.nodes) def postDict(translator, node): node.priority = 1 return '{%s}' % ', '.join('%s:%s' % (key.src, value.src) for key, value in node.items) def postSet(translator, node): node.priority = 1 return '{%s}' % ', '.join(item.src for item in node.nodes) def postBackquote(translator, node): node.priority = 1 return '`%s`' % node.expr.src def postName(translator, node): node.priority = 1 return node.name def postAssName(translator, node): node.priority = 1 return node.name def postKeyword(translator, node): return '='.join((node.name, node.expr.src)) def preStr(self, node): if self.top_level_f_str is None: self.top_level_f_str = node def postStr(self, node): if self.top_level_f_str is node: self.top_level_f_str = None return "f%r" % ('{%s}' % node.value.src) return '{%s}' % node.value.src def preJoinedStr(self, node): if self.top_level_f_str is None: self.top_level_f_str = node def postJoinedStr(self, node): result = ''.join( value.value if isinstance(value, ast.Const) else value.src for value in node.values) if self.top_level_f_str is node: self.top_level_f_str = None return "f%r" % result return result def preFormattedValue(self, node): if self.top_level_f_str is None: self.top_level_f_str = node def postFormattedValue(self, node): res = '{%s:%s}' % (node.value.src, node.fmt_spec.src) if self.top_level_f_str is node: self.top_level_f_str = None return "f%r" % res return res nonexternalizable_types = (ast.Keyword, ast.Sliceobj, ast.List, ast.Tuple) class PreTranslator(ASTTranslator): def __init__(translator, tree, globals, locals, special_functions, const_functions, outer_names=()): ASTTranslator.__init__(translator, tree) translator.globals = globals translator.locals = locals translator.special_functions = special_functions translator.const_functions = const_functions translator.contexts = [] if outer_names: translator.contexts.append(outer_names) translator.externals = externals = set() translator.dispatch(tree) for node in externals.copy(): if isinstance(node, nonexternalizable_types) \ or node.constant and not isinstance(node, ast.Const): node.external = False externals.remove(node) externals.update(node for node in node.getChildNodes() if node.external and not node.constant) def dispatch(translator, node): node.external = node.constant = None ASTTranslator.dispatch(translator, node) children = node.getChildNodes() if node.external is None and children and all( getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in children): node.external = True if node.external and not node.constant: externals = translator.externals externals.difference_update(children) externals.add(node) def preGenExprInner(translator, node): translator.contexts.append(set()) dispatch = translator.dispatch for i, qual in enumerate(node.quals): dispatch(qual.iter) dispatch(qual.assign) for if_ in qual.ifs: dispatch(if_.test) dispatch(node.expr) translator.contexts.pop() return True def preLambda(translator, node): if node.varargs or node.kwargs or node.defaults: throw(NotImplementedError) translator.contexts.append(set(node.argnames)) translator.dispatch(node.code) translator.contexts.pop() return True def postAssName(translator, node): if node.flags != 'OP_ASSIGN': throw(TypeError) name = node.name if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) translator.contexts[-1].add(name) def postName(translator, node): name = node.name for context in translator.contexts: if name in context: return node.external = True def postConst(translator, node): node.external = node.constant = True def postDict(translator, node): node.external = True def postList(translator, node): node.external = True def postKeyword(translator, node): node.constant = node.expr.constant def postCallFunc(translator, node): func_node = node.node if not func_node.external: return attrs = [] while isinstance(func_node, ast.Getattr): attrs.append(func_node.attrname) func_node = func_node.expr if not isinstance(func_node, ast.Name): return attrs.append(func_node.name) expr = '.'.join(reversed(attrs)) x = eval(expr, translator.globals, translator.locals) try: hash(x) except TypeError: pass else: if x in translator.special_functions: if x.__name__ == 'raw_sql': node.raw_sql = True elif x is getattr: attr_node = node.args[1] attr_node.parent_node = node else: node.external = False elif x in translator.const_functions: for arg in node.args: if not arg.constant: return if node.star_args is not None and not node.star_args.constant: return if node.dstar_args is not None and not node.dstar_args.constant: return node.constant = True extractors_cache = {} def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, outer_names=()): result = extractors_cache.get(code_key) if not result: pretranslator = PreTranslator(tree, globals, locals, special_functions, const_functions, outer_names) extractors = {} for node in pretranslator.externals: src = node.src = ast2src(node) if src == '.0': def extractor(globals, locals): return locals['.0'] else: code = compile(src, src, 'eval') def extractor(globals, locals, code=code): return eval(code, globals, locals) extractors[src] = extractor result = extractors_cache[code_key] = tree, extractors return result ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/core.py0000666000000000000000000114220200000000000014006 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ basestring, unicode, buffer, int_types, builtins, with_metaclass import json, re, sys, types, datetime, logging, itertools, warnings, inspect from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time from decimal import Decimal from random import shuffle, randint, random from threading import Lock, RLock, currentThread as current_thread, _MainThread from contextlib import contextmanager from collections import defaultdict from hashlib import md5 from inspect import isgeneratorfunction from functools import wraps from pony.thirdparty.compiler import ast, parse import pony from pony import options from pony.orm.decompiling import decompile from pony.orm.ormtypes import ( LongStr, LongUnicode, numeric_types, raw_sql, RawSQL, normalize, Json, TrackedValue, QueryType, Array, IntArray, StrArray, FloatArray ) from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, OperationalError, IntegrityError, InternalError, ProgrammingError, NotSupportedError ) from pony import utils from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ between, concat, coalesce, HashableDict, deref_proxy, deduplicate __all__ = [ 'pony', 'DBException', 'RowNotFound', 'MultipleRowsFound', 'TooManyRowsFound', 'Warning', 'Error', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', 'BindingError', 'TableDoesNotExist', 'TableIsNotEmpty', 'ConstraintError', 'CacheIndexError', 'ObjectNotFound', 'MultipleObjectsFoundError', 'TooManyObjectsFoundError', 'OperationWithDeletedObjectError', 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError', 'CommitException', 'RollbackException', 'UnrepeatableReadError', 'OptimisticCheckError', 'UnresolvableCyclicDependency', 'UnexpectedError', 'DatabaseSessionIsOver', 'PonyRuntimeWarning', 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', 'TranslationError', 'ExprEvalError', 'PermissionError', 'Database', 'sql_debug', 'set_sql_debug', 'sql_debugging', 'show', 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', 'composite_key', 'composite_index', 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', 'make_proxy', 'LongStr', 'LongUnicode', 'Json', 'IntArray', 'StrArray', 'FloatArray', 'select', 'left_join', 'get', 'exists', 'delete', 'count', 'sum', 'min', 'max', 'avg', 'group_concat', 'distinct', 'JOIN', 'desc', 'between', 'concat', 'coalesce', 'raw_sql', 'buffer', 'unicode', 'get_current_user', 'set_current_user', 'perm', 'has_perm', 'get_user_groups', 'get_user_roles', 'get_object_labels', 'user_groups_getter', 'user_roles_getter', 'obj_labels_getter' ] suppress_debug_change = False def sql_debug(value): # todo: make sql_debug deprecated if not suppress_debug_change: local.debug = value def set_sql_debug(debug=True, show_values=None): if not suppress_debug_change: local.debug = debug local.show_values = show_values orm_logger = logging.getLogger('pony.orm') sql_logger = logging.getLogger('pony.orm.sql') orm_log_level = logging.INFO def has_handlers(logger): if not PY2: return logger.hasHandlers() while logger: if logger.handlers: return True elif not logger.propagate: return False logger = logger.parent return False def log_orm(msg): if has_handlers(orm_logger): orm_logger.log(orm_log_level, msg) else: print(msg) def log_sql(sql, arguments=None): if type(arguments) is list: sql = 'EXECUTEMANY (%d)\n%s' % (len(arguments), sql) if has_handlers(sql_logger): if local.show_values and arguments: sql = '%s\n%s' % (sql, format_arguments(arguments)) sql_logger.log(orm_log_level, sql) else: if (local.show_values is None or local.show_values) and arguments: sql = '%s\n%s' % (sql, format_arguments(arguments)) print(sql, end='\n\n') def format_arguments(arguments): if type(arguments) is not list: return args2str(arguments) return '\n'.join(args2str(args) for args in arguments) def args2str(args): if isinstance(args, (tuple, list)): return '[%s]' % ', '.join(imap(repr, args)) elif isinstance(args, dict): return '{%s}' % ', '.join('%s:%s' % (repr(key), repr(val)) for key, val in sorted(iteritems(args))) adapted_sql_cache = {} string2ast_cache = {} class OrmError(Exception): pass class ERDiagramError(OrmError): pass class DBSchemaError(OrmError): pass class MappingError(OrmError): pass class BindingError(OrmError): pass class TableDoesNotExist(OrmError): pass class TableIsNotEmpty(OrmError): pass class ConstraintError(OrmError): pass class CacheIndexError(OrmError): pass class RowNotFound(OrmError): pass class MultipleRowsFound(OrmError): pass class TooManyRowsFound(OrmError): pass class PermissionError(OrmError): pass class ObjectNotFound(OrmError): def __init__(exc, entity, pkval=None): if pkval is not None: if type(pkval) is tuple: pkval = ','.join(imap(repr, pkval)) else: pkval = repr(pkval) msg = '%s[%s]' % (entity.__name__, pkval) else: msg = entity.__name__ OrmError.__init__(exc, msg) exc.entity = entity exc.pkval = pkval class MultipleObjectsFoundError(OrmError): pass class TooManyObjectsFoundError(OrmError): pass class OperationWithDeletedObjectError(OrmError): pass class TransactionError(OrmError): pass class ConnectionClosedError(TransactionError): pass class TransactionIntegrityError(TransactionError): def __init__(exc, msg, original_exc=None): Exception.__init__(exc, msg) exc.original_exc = original_exc class CommitException(TransactionError): def __init__(exc, msg, exceptions): Exception.__init__(exc, msg) exc.exceptions = exceptions class PartialCommitException(TransactionError): def __init__(exc, msg, exceptions): Exception.__init__(exc, msg) exc.exceptions = exceptions class RollbackException(TransactionError): def __init__(exc, msg, exceptions): Exception.__init__(exc, msg) exc.exceptions = exceptions class DatabaseSessionIsOver(TransactionError): pass TransactionRolledBack = DatabaseSessionIsOver class IsolationError(TransactionError): pass class UnrepeatableReadError(IsolationError): pass class OptimisticCheckError(IsolationError): pass class UnresolvableCyclicDependency(TransactionError): pass class UnexpectedError(TransactionError): def __init__(exc, msg, original_exc): Exception.__init__(exc, msg) exc.original_exc = original_exc class ExprEvalError(TranslationError): def __init__(exc, src, cause): assert isinstance(cause, Exception) msg = '`%s` raises %s: %s' % (src, type(cause).__name__, str(cause)) TranslationError.__init__(exc, msg) exc.cause = cause class PonyInternalException(Exception): pass class OptimizationFailed(PonyInternalException): pass # Internal exception, cannot be encountered in user code class UseAnotherTranslator(PonyInternalException): def __init__(self, translator): Exception.__init__(self, 'This exception should be catched internally by PonyORM') self.translator = translator class PonyRuntimeWarning(RuntimeWarning): pass class DatabaseContainsIncorrectValue(PonyRuntimeWarning): pass class DatabaseContainsIncorrectEmptyValue(DatabaseContainsIncorrectValue): pass def adapt_sql(sql, paramstyle): result = adapted_sql_cache.get((sql, paramstyle)) if result is not None: return result pos = 0 result = [] args = [] kwargs = {} original_sql = sql if paramstyle in ('format', 'pyformat'): sql = sql.replace('%', '%%') while True: try: i = sql.index('$', pos) except ValueError: result.append(sql[pos:]) break result.append(sql[pos:i]) if sql[i+1] == '$': result.append('$') pos = i+2 else: try: expr, _ = parse_expr(sql, i+1) except ValueError: raise # TODO pos = i+1 + len(expr) if expr.endswith(';'): expr = expr[:-1] compile(expr, '', 'eval') # expr correction check if paramstyle == 'qmark': args.append(expr) result.append('?') elif paramstyle == 'format': args.append(expr) result.append('%s') elif paramstyle == 'numeric': args.append(expr) result.append(':%d' % len(args)) elif paramstyle == 'named': key = 'p%d' % (len(kwargs) + 1) kwargs[key] = expr result.append(':' + key) elif paramstyle == 'pyformat': key = 'p%d' % (len(kwargs) + 1) kwargs[key] = expr result.append('%%(%s)s' % key) else: throw(NotImplementedError) if args or kwargs: adapted_sql = ''.join(result) if args: source = '(%s,)' % ', '.join(args) else: source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) code = compile(source, '', 'eval') else: adapted_sql = original_sql.replace('$$', '$') code = compile('None', '', 'eval') result = adapted_sql, code adapted_sql_cache[(sql, paramstyle)] = result return result class PrefetchContext(object): def __init__(self, database=None): self.database = database self.attrs_to_prefetch_dict = defaultdict(set) self.entities_to_prefetch = set() self.relations_to_prefetch_cache = {} def copy(self): result = PrefetchContext(self.database) result.attrs_to_prefetch_dict = self.attrs_to_prefetch_dict.copy() result.entities_to_prefetch = self.entities_to_prefetch.copy() return result def __enter__(self): assert local.prefetch_context is None local.prefetch_context = self def __exit__(self, exc_type, exc_val, exc_tb): assert local.prefetch_context is self local.prefetch_context = None def get_frozen_attrs_to_prefetch(self, entity): attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ()) if type(attrs_to_prefetch) is set: attrs_to_prefetch = frozenset(attrs_to_prefetch) self.attrs_to_prefetch_dict[entity] = attrs_to_prefetch return attrs_to_prefetch def get_relations_to_prefetch(self, entity): result = self.relations_to_prefetch_cache.get(entity) if result is None: attrs_to_prefetch = self.attrs_to_prefetch_dict[entity] result = tuple(attr for attr in entity._attrs_ if attr.is_relation and ( attr in attrs_to_prefetch or attr.py_type in self.entities_to_prefetch and not attr.is_collection)) self.relations_to_prefetch_cache[entity] = result return result class Local(localbase): def __init__(local): local.debug = False local.show_values = None local.debug_stack = [] local.db2cache = {} local.db_context_counter = 0 local.db_session = None local.prefetch_context = None local.current_user = None local.perms_context = None local.user_groups_cache = {} local.user_roles_cache = defaultdict(dict) def push_debug_state(local, debug, show_values): local.debug_stack.append((local.debug, local.show_values)) if not suppress_debug_change: local.debug = debug local.show_values = show_values def pop_debug_state(local): local.debug, local.show_values = local.debug_stack.pop() local = Local() def _get_caches(): return list(sorted((cache for cache in local.db2cache.values()), reverse=True, key=lambda cache : (cache.database.priority, cache.num))) @cut_traceback def flush(): for cache in _get_caches(): cache.flush() def transact_reraise(exc_class, exceptions): cls, exc, tb = exceptions[0] new_exc = None try: msg = " ".join(tostring(arg) for arg in exc.args) if not issubclass(cls, TransactionError): msg = '%s: %s' % (cls.__name__, msg) new_exc = exc_class(msg, exceptions) new_exc.__cause__ = None reraise(exc_class, new_exc, tb) finally: del exceptions, exc, tb, new_exc def rollback_and_reraise(exc_info): try: rollback() finally: reraise(*exc_info) @cut_traceback def commit(): caches = _get_caches() if not caches: return try: for cache in caches: cache.flush() except: rollback_and_reraise(sys.exc_info()) primary_cache = caches[0] other_caches = caches[1:] exceptions = [] try: primary_cache.commit() except: exceptions.append(sys.exc_info()) for cache in other_caches: try: cache.rollback() except: exceptions.append(sys.exc_info()) transact_reraise(CommitException, exceptions) else: for cache in other_caches: try: cache.commit() except: exceptions.append(sys.exc_info()) if exceptions: transact_reraise(PartialCommitException, exceptions) finally: del exceptions @cut_traceback def rollback(): exceptions = [] try: for cache in _get_caches(): try: cache.rollback() except: exceptions.append(sys.exc_info()) if exceptions: transact_reraise(RollbackException, exceptions) assert not local.db2cache finally: del exceptions select_re = re.compile(r'\s*select\b', re.IGNORECASE) class DBSessionContextManager(object): __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', \ 'immediate', 'ddl', 'serializable', 'strict', 'optimistic', \ 'sql_debug', 'show_values' def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True, retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None): if retry != 0: if type(retry) is not int: throw(TypeError, "'retry' parameter of db_session must be of integer type. Got: %s" % type(retry)) if retry < 0: throw(TypeError, "'retry' parameter of db_session must not be negative. Got: %d" % retry) if ddl: throw(TypeError, "'ddl' and 'retry' parameters of db_session cannot be used together") if not callable(allowed_exceptions) and not callable(retry_exceptions): for e in allowed_exceptions: if e in retry_exceptions: throw(TypeError, 'The same exception %s cannot be specified in both ' 'allowed and retry exception lists simultaneously' % e.__name__) db_session.retry = retry db_session.ddl = ddl db_session.serializable = serializable db_session.immediate = immediate or ddl or serializable or not optimistic db_session.strict = strict db_session.optimistic = optimistic and not serializable db_session.retry_exceptions = retry_exceptions db_session.allowed_exceptions = allowed_exceptions db_session.sql_debug = sql_debug db_session.show_values = show_values def __call__(db_session, *args, **kwargs): if not args and not kwargs: return db_session if len(args) > 1: throw(TypeError, 'Pass only keyword arguments to db_session or use db_session as decorator') if not args: return db_session.__class__(**kwargs) if kwargs: throw(TypeError, 'Pass only keyword arguments to db_session or use db_session as decorator') func = args[0] if isgeneratorfunction(func) or hasattr(inspect, 'iscoroutinefunction') and inspect.iscoroutinefunction(func): return db_session._wrap_coroutine_or_generator_function(func) return db_session._wrap_function(func) def __enter__(db_session): if db_session.retry != 0: throw(TypeError, "@db_session can accept 'retry' parameter only when used as decorator and not as context manager") db_session._enter() def _enter(db_session): if local.db_session is None: assert not local.db_context_counter local.db_session = db_session elif db_session.ddl and not local.db_session.ddl: throw(TransactionError, 'Cannot start ddl transaction inside non-ddl transaction') elif db_session.serializable and not local.db_session.serializable: throw(TransactionError, 'Cannot start serializable transaction inside non-serializable transaction') local.db_context_counter += 1 if db_session.sql_debug is not None: local.push_debug_state(db_session.sql_debug, db_session.show_values) def __exit__(db_session, exc_type=None, exc=None, tb=None): local.db_context_counter -= 1 try: if not local.db_context_counter: assert local.db_session is db_session db_session._commit_or_rollback(exc_type, exc, tb) finally: if db_session.sql_debug is not None: local.pop_debug_state() def _commit_or_rollback(db_session, exc_type, exc, tb): try: if exc_type is None: can_commit = True elif not callable(db_session.allowed_exceptions): can_commit = issubclass(exc_type, tuple(db_session.allowed_exceptions)) else: assert exc is not None # exc can be None in Python 2.6 even if exc_type is not None try: can_commit = db_session.allowed_exceptions(exc) except: rollback_and_reraise(sys.exc_info()) if can_commit: commit() for cache in _get_caches(): cache.release() assert not local.db2cache else: try: rollback() except: if exc_type is None: raise # if exc_type is not None it will be reraised outside of __exit__ finally: del exc, tb local.db_session = None local.user_groups_cache.clear() local.user_roles_cache.clear() def _wrap_function(db_session, func): def new_func(func, *args, **kwargs): if local.db_context_counter: if db_session.ddl: fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func throw(TransactionError, '@db_session-decorated %s function with `ddl` option ' 'cannot be called inside of another db_session' % fname) if db_session.retry: fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func message = '@db_session decorator with `retry=%d` option is ignored for %s function ' \ 'because it is called inside another db_session' % (db_session.retry, fname) warnings.warn(message, PonyRuntimeWarning, stacklevel=3) if db_session.sql_debug is None: return func(*args, **kwargs) local.push_debug_state(db_session.sql_debug, db_session.show_values) try: return func(*args, **kwargs) finally: local.pop_debug_state() exc = tb = None try: for i in xrange(db_session.retry+1): db_session._enter() exc_type = exc = tb = None try: result = func(*args, **kwargs) commit() return result except: exc_type, exc, tb = sys.exc_info() retry_exceptions = db_session.retry_exceptions if not callable(retry_exceptions): do_retry = issubclass(exc_type, tuple(retry_exceptions)) else: assert exc is not None # exc can be None in Python 2.6 do_retry = retry_exceptions(exc) if not do_retry: raise finally: db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) finally: del exc, tb return decorator(new_func, func) def _wrap_coroutine_or_generator_function(db_session, gen_func): for option in ('ddl', 'retry', 'serializable'): if getattr(db_session, option, None): throw(TypeError, "db_session with `%s` option cannot be applied to generator function" % option) def interact(iterator, input=None, exc_info=None): if exc_info is None: return next(iterator) if input is None else iterator.send(input) if exc_info[0] is GeneratorExit: close = getattr(iterator, 'close', None) if close is not None: close() reraise(*exc_info) throw_ = getattr(iterator, 'throw', None) if throw_ is None: reraise(*exc_info) return throw_(*exc_info) @wraps(gen_func) def new_gen_func(*args, **kwargs): db2cache_copy = {} def wrapped_interact(iterator, input=None, exc_info=None): if local.db_session is not None: throw(TransactionError, '@db_session-wrapped generator cannot be used inside another db_session') assert not local.db_context_counter and not local.db2cache local.db_context_counter = 1 local.db_session = db_session local.db2cache.update(db2cache_copy) db2cache_copy.clear() if db_session.sql_debug is not None: local.push_debug_state(db_session.sql_debug, db_session.show_values) try: try: output = interact(iterator, input, exc_info) except StopIteration as e: commit() for cache in _get_caches(): cache.release() assert not local.db2cache raise e for cache in _get_caches(): if cache.modified or cache.in_transaction: throw(TransactionError, 'You need to manually commit() changes before suspending the generator') except: rollback_and_reraise(sys.exc_info()) else: return output finally: if db_session.sql_debug is not None: local.pop_debug_state() db2cache_copy.update(local.db2cache) local.db2cache.clear() local.db_context_counter = 0 local.db_session = None gen = gen_func(*args, **kwargs) iterator = gen.__await__() if hasattr(gen, '__await__') else iter(gen) try: output = wrapped_interact(iterator) while True: try: input = yield output except: output = wrapped_interact(iterator, exc_info=sys.exc_info()) else: output = wrapped_interact(iterator, input) except StopIteration: assert not db2cache_copy and not local.db2cache return if hasattr(types, 'coroutine'): new_gen_func = types.coroutine(new_gen_func) return new_gen_func db_session = DBSessionContextManager() class SQLDebuggingContextManager(object): def __init__(self, debug=True, show_values=None): self.debug = debug self.show_values = show_values def __call__(self, *args, **kwargs): if not kwargs and len(args) == 1 and callable(args[0]): arg = args[0] if not isgeneratorfunction(arg): return self._wrap_function(arg) return self._wrap_generator_function(arg) return self.__class__(*args, **kwargs) def __enter__(self): local.push_debug_state(self.debug, self.show_values) def __exit__(self, exc_type=None, exc=None, tb=None): local.pop_debug_state() def _wrap_function(self, func): def new_func(func, *args, **kwargs): self.__enter__() try: return func(*args, **kwargs) finally: self.__exit__() return decorator(new_func, func) def _wrap_generator_function(self, gen_func): def interact(iterator, input=None, exc_info=None): if exc_info is None: return next(iterator) if input is None else iterator.send(input) if exc_info[0] is GeneratorExit: close = getattr(iterator, 'close', None) if close is not None: close() reraise(*exc_info) throw_ = getattr(iterator, 'throw', None) if throw_ is None: reraise(*exc_info) return throw_(*exc_info) def new_gen_func(gen_func, *args, **kwargs): def wrapped_interact(iterator, input=None, exc_info=None): self.__enter__() try: return interact(iterator, input, exc_info) finally: self.__exit__() gen = gen_func(*args, **kwargs) iterator = iter(gen) output = wrapped_interact(iterator) try: while True: try: input = yield output except: output = wrapped_interact(iterator, exc_info=sys.exc_info()) else: output = wrapped_interact(iterator, input) except StopIteration: return return decorator(new_gen_func, gen_func) sql_debugging = SQLDebuggingContextManager() def throw_db_session_is_over(action, obj, attr=None): msg = 'Cannot %s %s%s: the database session is over' throw(DatabaseSessionIsOver, msg % (action, safe_repr(obj), '.%s' % attr.name if attr else '')) def with_transaction(*args, **kwargs): deprecated(3, "@with_transaction decorator is deprecated, use @db_session decorator instead") return db_session(*args, **kwargs) @decorator def db_decorator(func, *args, **kwargs): web = sys.modules.get('pony.web') allowed_exceptions = [ web.HttpRedirect ] if web else [] try: with db_session(allowed_exceptions=allowed_exceptions): return func(*args, **kwargs) except (ObjectNotFound, RowNotFound): if web: throw(web.Http404NotFound) raise known_providers = ('sqlite', 'postgres', 'mysql', 'oracle') class OnConnectDecorator(object): @staticmethod def check_provider(provider): if provider: if not isinstance(provider, basestring): throw(TypeError, "'provider' option should be type of 'string', got %r" % type(provider).__name__) if provider not in known_providers: throw(BindingError, 'Unknown provider %s' % provider) def __init__(self, database, provider): OnConnectDecorator.check_provider(provider) self.provider = provider self.database = database def __call__(self, func=None, provider=None): if isinstance(func, types.FunctionType): self.database._on_connect_funcs.append((func, provider or self.provider)) if not provider and func is basestring: provider = func OnConnectDecorator.check_provider(provider) return OnConnectDecorator(self.database, provider) class Database(object): def __deepcopy__(self, memo): return self # Database cannot be cloned by deepcopy() @cut_traceback def __init__(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs self.priority = 0 self._insert_cache = {} # ER-diagram related stuff: self._translator_cache = {} self._constructed_sql_cache = {} self.entities = {} self.schema = None self.Entity = type.__new__(EntityMeta, 'Entity', (Entity,), {}) self.Entity._database_ = self # Statistics-related stuff: self._global_stats = {} self._global_stats_lock = RLock() self._dblocal = DbLocal() self.on_connect = OnConnectDecorator(self, None) self._on_connect_funcs = [] self.provider = self.provider_name = None if args or kwargs: self._bind(*args, **kwargs) def call_on_connect(database, con): for func, provider in database._on_connect_funcs: if not provider or provider == database.provider_name: func(database, con) con.commit() @cut_traceback def bind(self, *args, **kwargs): self._bind(*args, **kwargs) def _bind(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs if self.provider is not None: throw(BindingError, 'Database object was already bound to %s provider' % self.provider.dialect) if len(args) == 1 and not kwargs and hasattr(args[0], 'keys'): args, kwargs = (), args[0] provider = None if args: provider, args = args[0], args[1:] elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified') else: provider = kwargs.pop('provider') if isinstance(provider, type) and issubclass(provider, DBAPIProvider): provider_cls = provider else: if not isinstance(provider, basestring): throw(TypeError, 'Provider name should be string. Got: %r' % type(provider).__name__) if provider == 'pygresql': throw(TypeError, 'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.') self.provider_name = provider provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls kwargs['pony_call_on_connect'] = self.call_on_connect self.provider = provider_cls(*args, **kwargs) @property def last_sql(database): return database._dblocal.last_sql @property def local_stats(database): return database._dblocal.stats def _update_local_stat(database, sql, query_start_time): dblocal = database._dblocal dblocal.last_sql = sql stats = dblocal.stats query_end_time = time() duration = query_end_time - query_start_time stat = stats.get(sql) if stat is not None: stat.query_executed(duration) else: stats[sql] = QueryStat(sql, duration) total_stat = stats.get(None) if total_stat is not None: total_stat.query_executed(duration) else: stats[None] = QueryStat(None, duration) def merge_local_stats(database): setdefault = database._global_stats.setdefault with database._global_stats_lock: for sql, stat in iteritems(database._dblocal.stats): global_stat = setdefault(sql, stat) if global_stat is not stat: global_stat.merge(stat) database._dblocal.stats = {None: QueryStat(None)} @property def global_stats(database): with database._global_stats_lock: return {sql: stat.copy() for sql, stat in iteritems(database._global_stats)} @property def global_stats_lock(database): deprecated(3, "global_stats_lock is deprecated, just use global_stats property without any locking") return database._global_stats_lock @cut_traceback def get_connection(database): cache = database._get_cache() if not cache.in_transaction: cache.immediate = True cache.prepare_connection_for_query_execution() cache.in_transaction = True connection = cache.connection assert connection is not None return connection @cut_traceback def disconnect(database): provider = database.provider if provider is None: return if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_session') cache = local.db2cache.get(database) if cache is not None: cache.rollback() provider.disconnect() def _get_cache(database): if database.provider is None: throw(MappingError, 'Database object is not bound with a provider yet') cache = local.db2cache.get(database) if cache is not None: return cache if not local.db_context_counter and not ( pony.MODE == 'INTERACTIVE' and current_thread().__class__ is _MainThread ): throw(TransactionError, 'db_session is required when working with the database') cache = local.db2cache[database] = SessionCache(database) return cache @cut_traceback def flush(database): database._get_cache().flush() @cut_traceback def commit(database): cache = local.db2cache.get(database) if cache is not None: cache.flush_and_commit() @cut_traceback def rollback(database): cache = local.db2cache.get(database) if cache is not None: try: cache.rollback() except: transact_reraise(RollbackException, [sys.exc_info()]) @cut_traceback def execute(database, sql, globals=None, locals=None): return database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1, start_transaction=True) def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') sql = sql[:] # sql = templating.plainstr(sql) if globals is None: assert locals is None frame_depth += 1 globals = sys._getframe(frame_depth).f_globals locals = sys._getframe(frame_depth).f_locals adapted_sql, code = adapt_sql(sql, provider.paramstyle) arguments = eval(code, globals, locals) return database._exec_sql(adapted_sql, arguments, False, start_transaction) @cut_traceback def select(database, sql, globals=None, locals=None, frame_depth=0): if not select_re.match(sql): sql = 'select ' + sql cursor = database._exec_raw_sql(sql, globals, locals, frame_depth+cut_traceback_depth+1) max_fetch_count = options.MAX_FETCH_COUNT if max_fetch_count is not None: result = cursor.fetchmany(max_fetch_count) if cursor.fetchone() is not None: throw(TooManyRowsFound) else: result = cursor.fetchall() if len(cursor.description) == 1: return [ row[0] for row in result ] row_class = type("row", (tuple,), {}) for i, column_info in enumerate(cursor.description): column_name = column_info[0] if not is_ident(column_name): continue if hasattr(tuple, column_name) and column_name.startswith('__'): continue setattr(row_class, column_name, property(itemgetter(i))) return [ row_class(row) for row in result ] @cut_traceback def get(database, sql, globals=None, locals=None): rows = database.select(sql, globals, locals, frame_depth=cut_traceback_depth+1) if not rows: throw(RowNotFound) if len(rows) > 1: throw(MultipleRowsFound) row = rows[0] return row @cut_traceback def exists(database, sql, globals=None, locals=None): if not select_re.match(sql): sql = 'select ' + sql cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1) result = cursor.fetchone() return bool(result) @cut_traceback def insert(database, table_name, returning=None, **kwargs): table_name = database._get_table_name(table_name) if database.provider is None: throw(MappingError, 'Database object is not bound with a provider yet') query_key = (table_name,) + tuple(kwargs) # keys are not sorted deliberately!! if returning is not None: query_key = query_key + (returning,) cached_sql = database._insert_cache.get(query_key) if cached_sql is None: ast = [ 'INSERT', table_name, kwargs.keys(), [ [ 'PARAM', (i, None, None) ] for i in xrange(len(kwargs)) ], returning ] sql, adapter = database._ast2sql(ast) cached_sql = sql, adapter database._insert_cache[query_key] = cached_sql else: sql, adapter = cached_sql arguments = adapter(values_list(kwargs)) # order of values same as order of keys if returning is not None: return database._exec_sql(sql, arguments, returning_id=True, start_transaction=True) cursor = database._exec_sql(sql, arguments, start_transaction=True) return getattr(cursor, 'lastrowid', None) def _ast2sql(database, sql_ast): sql, adapter = database.provider.ast2sql(sql_ast) return sql, adapter def _exec_sql(database, sql, arguments=None, returning_id=False, start_transaction=False): cache = database._get_cache() if start_transaction: cache.immediate = True connection = cache.prepare_connection_for_query_execution() cursor = connection.cursor() if local.debug: log_sql(sql, arguments) provider = database.provider t = time() try: new_id = provider.execute(cursor, sql, arguments, returning_id) except Exception as e: connection = cache.reconnect(e) cursor = connection.cursor() if local.debug: log_sql(sql, arguments) t = time() new_id = provider.execute(cursor, sql, arguments, returning_id) if cache.immediate: cache.in_transaction = True database._update_local_stat(sql, t) if not returning_id: return cursor if PY2 and type(new_id) is long: new_id = int(new_id) return new_id @cut_traceback def generate_mapping(database, filename=None, check_tables=True, create_tables=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') if database.schema: throw(BindingError, 'Mapping was already generated') if filename is not None: throw(NotImplementedError) schema = database.schema = provider.dbschema_cls(provider) entities = list(sorted(database.entities.values(), key=attrgetter('_id_'))) for entity in entities: entity._resolve_attr_types_() for entity in entities: entity._link_reverse_attrs_() for entity in entities: entity._check_table_options_() def get_columns(table, column_names): column_dict = table.column_dict return tuple(column_dict[name] for name in column_names) for entity in entities: entity._get_pk_columns_() table_name = entity._table_ is_subclass = entity._root_ is not entity if is_subclass: if table_name is not None: throw(NotImplementedError, 'Cannot specify table name for entity %r which is subclass of %r' % (entity.__name__, entity._root_.__name__)) table_name = entity._root_._table_ entity._table_ = table_name elif table_name is None: table_name = provider.get_default_entity_table_name(entity) entity._table_ = table_name else: assert isinstance(table_name, (basestring, tuple)) table = schema.tables.get(table_name) if table is None: table = schema.add_table(table_name, entity) else: table.add_entity(entity) for attr in entity._new_attrs_: if attr.is_collection: if not isinstance(attr, Set): throw(NotImplementedError) reverse = attr.reverse if not reverse.is_collection: # many-to-one: if attr.table is not None: throw(MappingError, "Parameter 'table' is not allowed for many-to-one attribute %s" % attr) elif attr.columns: throw(NotImplementedError, "Parameter 'column' is not allowed for many-to-one attribute %s" % attr) continue # many-to-many: if not isinstance(reverse, Set): throw(NotImplementedError) if attr.entity.__name__ > reverse.entity.__name__: continue if attr.entity is reverse.entity and attr.name > reverse.name: continue if attr.table: if not reverse.table: reverse.table = attr.table elif reverse.table != attr.table: throw(MappingError, "Parameter 'table' for %s and %s do not match" % (attr, reverse)) table_name = attr.table elif reverse.table: table_name = attr.table = reverse.table else: table_name = provider.get_default_m2m_table_name(attr, reverse) m2m_table = schema.tables.get(table_name) if m2m_table is not None: if not attr.table: seq_counter = itertools.count(2) while m2m_table is not None: if isinstance(table_name, basestring): new_table_name = table_name + '_%d' % next(seq_counter) else: schema_name, base_name = provider.split_table_name(table_name) new_table_name = schema_name, base_name + '_%d' % next(seq_counter) m2m_table = schema.tables.get(new_table_name) table_name = new_table_name elif m2m_table.entities or m2m_table.m2m: throw(MappingError, "Table name %s is already in use" % provider.format_table_name(table_name)) else: throw(NotImplementedError) attr.table = reverse.table = table_name m2m_table = schema.add_table(table_name) m2m_columns_1 = attr.get_m2m_columns(is_reverse=False) m2m_columns_2 = reverse.get_m2m_columns(is_reverse=True) if m2m_columns_1 == m2m_columns_2: throw(MappingError, 'Different column names should be specified for attributes %s and %s' % (attr, reverse)) assert len(m2m_columns_1) == len(reverse.converters) assert len(m2m_columns_2) == len(attr.converters) for column_name, converter in izip(m2m_columns_1 + m2m_columns_2, reverse.converters + attr.converters): m2m_table.add_column(column_name, converter.get_sql_type(), converter, True) m2m_table.add_index(None, tuple(m2m_table.column_list), is_pk=True) m2m_table.m2m.add(attr) m2m_table.m2m.add(reverse) else: if attr.is_required: pass elif not attr.type_has_empty_value: if attr.nullable is False: throw(TypeError, 'Optional attribute with non-string type %s must be nullable' % attr) attr.nullable = True elif entity._database_.provider.dialect == 'Oracle': if attr.nullable is False: throw(ERDiagramError, 'In Oracle, optional string attribute %s must be nullable' % attr) attr.nullable = True columns = attr.get_columns() # initializes attr.converters if not attr.reverse and attr.default is not None: assert len(attr.converters) == 1 if not callable(attr.default): attr.default = attr.validate(attr.default) assert len(columns) == len(attr.converters) if len(columns) == 1: converter = attr.converters[0] table.add_column(columns[0], converter.get_sql_type(attr), converter, not attr.nullable, attr.sql_default) elif columns: if attr.sql_type is not None: throw(NotImplementedError, 'sql_type cannot be specified for composite attribute %s' % attr) for (column_name, converter) in izip(columns, attr.converters): table.add_column(column_name, converter.get_sql_type(), converter, not attr.nullable) else: pass # virtual attribute of one-to-one pair entity._attrs_with_columns_ = [ attr for attr in entity._attrs_ if not attr.is_collection and attr.columns ] if not table.pk_index: if len(entity._pk_columns_) == 1 and entity._pk_attrs_[0].auto: is_pk = "auto" else: is_pk = True table.add_index(None, get_columns(table, entity._pk_columns_), is_pk) for index in entity._indexes_: if index.is_pk: continue column_names = [] attrs = index.attrs for attr in attrs: column_names.extend(attr.columns) index_name = attrs[0].index if len(attrs) == 1 else None table.add_index(index_name, get_columns(table, column_names), is_unique=index.is_unique) columns = [] columns_without_pk = [] converters = [] converters_without_pk = [] for attr in entity._attrs_with_columns_: columns.extend(attr.columns) # todo: inheritance converters.extend(attr.converters) if not attr.is_pk: columns_without_pk.extend(attr.columns) converters_without_pk.extend(attr.converters) entity._columns_ = columns entity._columns_without_pk_ = columns_without_pk entity._converters_ = converters entity._converters_without_pk_ = converters_without_pk for entity in entities: table = schema.tables[entity._table_] for attr in entity._new_attrs_: if attr.is_collection: reverse = attr.reverse if not reverse.is_collection: continue if not isinstance(attr, Set): throw(NotImplementedError) if not isinstance(reverse, Set): throw(NotImplementedError) m2m_table = schema.tables[attr.table] parent_columns = get_columns(table, entity._pk_columns_) child_columns = get_columns(m2m_table, reverse.columns) on_delete = 'CASCADE' m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns, attr.index, on_delete) if attr.symmetric: reverse_child_columns = get_columns(m2m_table, attr.reverse_columns) m2m_table.add_foreign_key(attr.reverse_fk_name, reverse_child_columns, table, parent_columns, attr.reverse_index, on_delete) elif attr.reverse and attr.columns: rentity = attr.reverse.entity parent_table = schema.tables[rentity._table_] parent_columns = get_columns(parent_table, rentity._pk_columns_) child_columns = get_columns(table, attr.columns) if attr.reverse.cascade_delete: on_delete = 'CASCADE' elif isinstance(attr, Optional) and attr.nullable: on_delete = 'SET NULL' else: on_delete = None table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index, on_delete) elif attr.index and attr.columns: if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL': pass # GIN indexes are supported only in PostgreSQL else: columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) table.add_index(attr.index, columns, is_unique=attr.is_unique) entity._initialize_bits_() if create_tables: database.create_tables(check_tables) elif check_tables: database.check_tables() @cut_traceback @db_session(ddl=True) def drop_table(database, table_name, if_exists=False, with_all_data=False): database._drop_tables([ table_name ], if_exists, with_all_data, try_normalized=True) def _get_table_name(database, table_name): if isinstance(table_name, EntityMeta): entity = table_name table_name = entity._table_ elif isinstance(table_name, Set): attr = table_name table_name = attr.table if attr.reverse.is_collection else attr.entity._table_ elif isinstance(table_name, Attribute): throw(TypeError, "Attribute %s is not Set and doesn't have corresponding table" % table_name) elif table_name is None: if database.schema is None: throw(MappingError, 'No mapping was generated for the database') else: throw(TypeError, 'Table name cannot be None') elif isinstance(table_name, tuple): for component in table_name: if not isinstance(component, basestring): throw(TypeError, 'Invalid table name component: {}'.format(component)) elif isinstance(table_name, basestring): table_name = table_name[:] # table_name = templating.plainstr(table_name) else: throw(TypeError, 'Invalid table name: {}'.format(table_name)) return table_name @cut_traceback @db_session(ddl=True) def drop_all_tables(database, with_all_data=False): if database.schema is None: throw(ERDiagramError, 'No mapping was generated for the database') database._drop_tables(database.schema.tables, True, with_all_data) def _drop_tables(database, table_names, if_exists, with_all_data, try_normalized=False): cache = database._get_cache() connection = cache.prepare_connection_for_query_execution() provider = database.provider existed_tables = [] for table_name in table_names: table_name = database._get_table_name(table_name) if provider.table_exists(connection, table_name): existed_tables.append(table_name) elif not if_exists: if try_normalized: if isinstance(table_name, basestring): normalized_table_name = provider.normalize_name(table_name) else: schema_name, base_name = provider.split_table_name(table_name) normalized_table_name = schema_name, provider.normalize_name(base_name) if normalized_table_name != table_name and provider.table_exists(connection, normalized_table_name): throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' % ( provider.format_table_name(table_name), provider.format_table_name(normalized_table_name))) throw(TableDoesNotExist, 'Table %s does not exist' % provider.format_table_name(table_name)) if not with_all_data: for table_name in existed_tables: if provider.table_has_data(connection, table_name): throw(TableIsNotEmpty, 'Cannot drop table %s because it is not empty. Specify option ' 'with_all_data=True if you want to drop table with all data' % provider.format_table_name(table_name)) for table_name in existed_tables: if local.debug: log_orm('DROPPING TABLE %s' % provider.format_table_name(table_name)) provider.drop_table(connection, table_name) @cut_traceback @db_session(ddl=True) def create_tables(database, check_tables=False): cache = database._get_cache() if database.schema is None: throw(MappingError, 'No mapping was generated for the database') connection = cache.prepare_connection_for_query_execution() database.schema.create_tables(database.provider, connection) if check_tables: database.schema.check_tables(database.provider, connection) @cut_traceback @db_session() def check_tables(database): cache = database._get_cache() if database.schema is None: throw(MappingError, 'No mapping was generated for the database') connection = cache.prepare_connection_for_query_execution() database.schema.check_tables(database.provider, connection) @contextmanager def set_perms_for(database, *entities): if not entities: throw(TypeError, 'You should specify at least one positional argument') entity_set = set(entities) for entity in entities: if not isinstance(entity, EntityMeta): throw(TypeError, 'Entity class expected. Got: %s' % entity) entity_set.update(entity._subclasses_) if local.perms_context is not None: throw(OrmError, "'set_perms_for' context manager calls cannot be nested") local.perms_context = database, entity_set try: yield finally: assert local.perms_context and local.perms_context[0] is database local.perms_context = None def _get_schema_dict(database): result = [] user = get_current_user() for entity in sorted(database.entities.values(), key=attrgetter('_id_')): if not can_view(user, entity): continue attrs = [] for attr in entity._new_attrs_: if not can_view(user, attr): continue d = dict(name=attr.name, type=attr.py_type.__name__, kind=attr.__class__.__name__) if attr.auto: d['auto'] = True if attr.reverse: if not can_view(user, attr.reverse.entity): continue if not can_view(user, attr.reverse): continue d['reverse'] = attr.reverse.name if attr.lazy: d['lazy'] = True if attr.nullable: d['nullable'] = True if attr.default and issubclass(type(attr.default), (int_types, basestring)): d['defaultValue'] = attr.default attrs.append(d) d = dict(name=entity.__name__, newAttrs=attrs, pkAttrs=[ attr.name for attr in entity._pk_attrs_ ]) if entity._all_bases_: d['bases'] = [ base.__name__ for base in entity._all_bases_ ] if entity._simple_keys_: d['simpleKeys'] = [ attr.name for attr in entity._simple_keys_ ] if entity._composite_keys_: d['compositeKeys'] = [ [ attr.name for attr in attrs ] for attrs in entity._composite_keys_ ] result.append(d) return result def _get_schema_json(database): schema_json = json.dumps(database._get_schema_dict(), default=basic_converter, sort_keys=True) schema_hash = md5(schema_json.encode('utf-8')).hexdigest() return schema_json, schema_hash @cut_traceback def to_json(database, data, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): for attrs, param_name in ((include, 'include'), (exclude, 'exclude')): for attr in attrs: if not isinstance(attr, Attribute): throw(TypeError, "Each item of '%s' list should be attribute. Got: %s" % (param_name, attr)) include, exclude = set(include), set(exclude) if converter is None: converter = basic_converter user = get_current_user() def user_has_no_rights_to_see(obj, attr=None): user_groups = get_user_groups(user) throw(PermissionError, 'The current user %s which belongs to groups %s ' 'has no rights to see the object %s on the frontend' % (user, sorted(user_groups), obj)) object_set = set() caches = set() def obj_converter(obj): if not isinstance(obj, Entity): return converter(obj) cache = obj._session_cache_ if cache is not None: caches.add(cache) if len(caches) > 1: throw(TransactionError, 'An attempt to serialize objects belonging to different transactions') if not can_view(user, obj): user_has_no_rights_to_see(obj) object_set.add(obj) pkval = obj._get_raw_pkval_() if len(pkval) == 1: pkval = pkval[0] return { 'class': obj.__class__.__name__, 'pk': pkval } data_json = json.dumps(data, default=obj_converter) objects = {} if caches: cache = caches.pop() if cache.database is not database: throw(TransactionError, 'An object does not belong to specified database') object_list = list(object_set) objects = {} for obj in object_list: if obj in cache.seeds[obj._pk_attrs_]: obj._load_() entity = obj.__class__ if not can_view(user, obj): user_has_no_rights_to_see(obj) d = objects.setdefault(entity.__name__, {}) for val in obj._get_raw_pkval_(): d = d.setdefault(val, {}) assert not d, d for attr in obj._attrs_: if attr in exclude: continue if attr in include: pass # if attr not in entity_perms.can_read: user_has_no_rights_to_see(obj, attr) elif attr.is_collection: continue elif attr.lazy: continue # elif attr not in entity_perms.can_read: continue if attr.is_collection: if not isinstance(attr, Set): throw(NotImplementedError) value = [] for item in attr.__get__(obj): if item not in object_set: object_set.add(item) object_list.append(item) pkval = item._get_raw_pkval_() value.append(pkval[0] if len(pkval) == 1 else pkval) value.sort() else: value = attr.__get__(obj) if value is not None and attr.is_relation: if attr in include and value not in object_set: object_set.add(value) object_list.append(value) pkval = value._get_raw_pkval_() value = pkval[0] if len(pkval) == 1 else pkval d[attr.name] = value objects_json = json.dumps(objects, default=converter) if not with_schema: return '{"data": %s, "objects": %s}' % (data_json, objects_json) schema_json, new_schema_hash = database._get_schema_json() if schema_hash is not None and schema_hash == new_schema_hash: return '{"data": %s, "objects": %s, "schema_hash": "%s"}' \ % (data_json, objects_json, new_schema_hash) return '{"data": %s, "objects": %s, "schema": %s, "schema_hash": "%s"}' \ % (data_json, objects_json, schema_json, new_schema_hash) @cut_traceback @db_session def from_json(database, changes, observer=None): changes = json.loads(changes) import pprint; pprint.pprint(changes) objmap = {} for diff in changes['objects']: if diff['_status_'] == 'c': continue pk = diff['_pk_'] pk = (pk,) if type(pk) is not list else tuple(pk) entity_name = diff['class'] entity = database.entities[entity_name] obj = entity._get_by_raw_pkval_(pk, from_db=False) oid = diff['_id_'] objmap[oid] = obj def id2obj(attr, val): return objmap[val] if attr.reverse and val is not None else val user = get_current_user() def user_has_no_rights_to(operation, x): user_groups = get_user_groups(user) s = 'attribute %s' % x if isinstance(x, Attribute) else 'object %s' % x throw(PermissionError, 'The current user %s which belongs to groups %s ' 'has no rights to %s the %s on the frontend' % (user, sorted(user_groups), operation, s)) for diff in changes['objects']: entity_name = diff['class'] entity = database.entities[entity_name] oldvals = {} newvals = {} oldadict = {} newadict = {} for name, val in diff.items(): if name not in ('class', '_pk_', '_id_', '_status_'): attr = entity._adict_[name] if not attr.is_collection: if type(val) is dict: if 'old' in val: oldvals[attr.name] = oldadict[attr] = attr.validate(id2obj(attr, val['old'])) if 'new' in val: newvals[attr.name] = newadict[attr] = attr.validate(id2obj(attr, val['new'])) else: newvals[attr.name] = newadict[attr] = attr.validate(id2obj(attr, val)) oid = diff['_id_'] status = diff['_status_'] if status == 'c': assert not oldvals for attr in newadict: if not can_create(user, attr): user_has_no_rights_to('initialize', attr) obj = entity(**newvals) if observer: flush() # in order to get obj.id observer('create', obj, newvals) objmap[oid] = obj if not can_edit(user, obj): user_has_no_rights_to('create', obj) else: obj = objmap[oid] if status == 'd': if not can_delete(user, obj): user_has_no_rights_to('delete', obj) if observer: observer('delete', obj) obj.delete() elif status == 'u': if not can_edit(user, obj): user_has_no_rights_to('update', obj) if newvals: for attr in newadict: if not can_edit(user, attr): user_has_no_rights_to('edit', attr) assert oldvals if observer: observer('update', obj, newvals, oldvals) obj._db_set_(oldadict) # oldadict can be modified here for attr in oldadict: attr.__get__(obj) obj.set(**newvals) else: assert not oldvals objmap[oid] = obj flush() for diff in changes['objects']: if diff['_status_'] == 'd': continue obj = objmap[diff['_id_']] entity = obj.__class__ for name, val in diff.items(): if name not in ('class', '_pk_', '_id_', '_status_'): attr = entity._adict_[name] if attr.is_collection and attr.reverse.is_collection and attr < attr.reverse: removed = [ objmap[oid] for oid in val.get('removed', ()) ] added = [ objmap[oid] for oid in val.get('added', ()) ] if (added or removed) and not can_edit(user, attr): user_has_no_rights_to('edit', attr) collection = attr.__get__(obj) if removed: observer('remove', obj, {name: removed}) collection.remove(removed) if added: observer('add', obj, {name: added}) collection.add(added) flush() def deserialize(x): t = type(x) if t is list: return list(imap(deserialize, x)) if t is dict: if '_id_' not in x: return {key: deserialize(val) for key, val in iteritems(x)} obj = objmap.get(x['_id_']) if obj is None: entity_name = x['class'] entity = database.entities[entity_name] pk = x['_pk_'] obj = entity[pk] return obj return x return deserialize(changes['data']) def basic_converter(x): if isinstance(x, (datetime.datetime, datetime.date, Decimal)): return str(x) if isinstance(x, dict): return dict(x) if isinstance(x, Entity): pkval = x._get_raw_pkval_() return pkval[0] if len(pkval) == 1 else pkval if hasattr(x, '__iter__'): return list(x) throw(TypeError, 'The following object cannot be converted to JSON: %r' % x) @cut_traceback def perm(*args, **kwargs): if local.perms_context is None: throw(OrmError, "'perm' function can be called within 'set_perm_for' context manager only") database, entities = local.perms_context permissions = _split_names('Permission', args) groups = pop_names_from_kwargs('Group', kwargs, 'group', 'groups') roles = pop_names_from_kwargs('Role', kwargs, 'role', 'roles') labels = pop_names_from_kwargs('Label', kwargs, 'label', 'labels') for kwname in kwargs: throw(TypeError, 'Unknown keyword argument name: %s' % kwname) return AccessRule(database, entities, permissions, groups, roles, labels) def _split_names(typename, names): if names is None: return set() if isinstance(names, basestring): names = names.replace(',', ' ').split() else: try: namelist = list(names) except: throw(TypeError, '%s name should be string. Got: %s' % (typename, names)) names = [] for name in namelist: names.extend(_split_names(typename, name)) for name in names: if not is_ident(name): throw(TypeError, '%s name should be identifier. Got: %s' % (typename, name)) return set(names) def pop_names_from_kwargs(typename, kwargs, *kwnames): result = set() for kwname in kwnames: kwarg = kwargs.pop(kwname, None) if kwarg is not None: result.update(_split_names(typename, kwarg)) return result class AccessRule(object): def __init__(rule, database, entities, permissions, groups, roles, labels): rule.database = database rule.entities = entities if not permissions: throw(TypeError, 'At least one permission should be specified') rule.permissions = permissions rule.groups = groups rule.groups.add('anybody') rule.roles = roles rule.labels = labels rule.entities_to_exclude = set() rule.attrs_to_exclude = set() for entity in entities: for perm in rule.permissions: entity._access_rules_[perm].add(rule) def exclude(rule, *args): for arg in args: if isinstance(arg, EntityMeta): entity = arg rule.entities_to_exclude.add(entity) rule.entities_to_exclude.update(entity._subclasses_) elif isinstance(arg, Attribute): attr = arg if attr.pk_offset is not None: throw(TypeError, 'Primary key attribute %s cannot be excluded' % attr) rule.attrs_to_exclude.add(attr) else: throw(TypeError, 'Entity or attribute expected. Got: %r' % arg) @cut_traceback def has_perm(user, perm, x): if isinstance(x, EntityMeta): entity = x elif isinstance(x, Entity): entity = x.__class__ elif isinstance(x, Attribute): if x.hidden: return False entity = x.entity else: throw(TypeError, "The third parameter of 'has_perm' function should be entity class, entity instance " "or attribute. Got: %r" % x) access_rules = entity._access_rules_.get(perm) if not access_rules: return False cache = entity._database_._get_cache() perm_cache = cache.perm_cache[user][perm] result = perm_cache.get(x) if result is not None: return result user_groups = get_user_groups(user) result = False if isinstance(x, EntityMeta): for rule in access_rules: if user_groups.issuperset(rule.groups) and entity not in rule.entities_to_exclude: result = True break elif isinstance(x, Attribute): attr = x for rule in access_rules: if user_groups.issuperset(rule.groups) and entity not in rule.entities_to_exclude \ and attr not in rule.attrs_to_exclude: result = True break reverse = attr.reverse if reverse: reverse_rules = reverse.entity._access_rules_.get(perm) if not reverse_rules: return False for reverse_rule in access_rules: if user_groups.issuperset(reverse_rule.groups) \ and reverse.entity not in reverse_rule.entities_to_exclude \ and reverse not in reverse_rule.attrs_to_exclude: result = True break if result: break else: obj = x user_roles = get_user_roles(user, obj) obj_labels = get_object_labels(obj) for rule in access_rules: if x in rule.entities_to_exclude: continue elif not user_groups.issuperset(rule.groups): pass elif not user_roles.issuperset(rule.roles): pass elif not obj_labels.issuperset(rule.labels): pass else: result = True break perm_cache[perm] = result return result def can_view(user, x): return has_perm(user, 'view', x) or has_perm(user, 'edit', x) def can_edit(user, x): return has_perm(user, 'edit', x) def can_create(user, x): return has_perm(user, 'create', x) def can_delete(user, x): return has_perm(user, 'delete', x) def get_current_user(): return local.current_user def set_current_user(user): local.current_user = user anybody_frozenset = frozenset(['anybody']) def get_user_groups(user): result = local.user_groups_cache.get(user) if result is not None: return result if user is None: return anybody_frozenset result = {'anybody'} for cls, func in usergroup_functions: if cls is None or isinstance(user, cls): groups = func(user) if isinstance(groups, basestring): # single group name result.add(groups) elif groups is not None: result.update(groups) result = frozenset(result) local.user_groups_cache[user] = result return result def get_user_roles(user, obj): if user is None: return frozenset() roles_cache = local.user_roles_cache[user] result = roles_cache.get(obj) if result is not None: return result result = set() if user is obj: result.add('self') for user_cls, obj_cls, func in userrole_functions: if user_cls is None or isinstance(user, user_cls): if obj_cls is None or isinstance(obj, obj_cls): roles = func(user, obj) if isinstance(roles, basestring): # single role name result.add(roles) elif roles is not None: result.update(roles) result = frozenset(result) roles_cache[obj] = result return result def get_object_labels(obj): cache = obj._database_._get_cache() obj_labels_cache = cache.obj_labels_cache result = obj_labels_cache.get(obj) if result is None: result = set() for obj_cls, func in objlabel_functions: if obj_cls is None or isinstance(obj, obj_cls): labels = func(obj) if isinstance(labels, basestring): # single label name result.add(labels) elif labels is not None: result.update(labels) obj_labels_cache[obj] = result return result usergroup_functions = [] def user_groups_getter(cls=None): def decorator(func): if func not in usergroup_functions: usergroup_functions.append((cls, func)) return func return decorator userrole_functions = [] def user_roles_getter(user_cls=None, obj_cls=None): def decorator(func): if func not in userrole_functions: userrole_functions.append((user_cls, obj_cls, func)) return func return decorator objlabel_functions = [] def obj_labels_getter(cls=None): def decorator(func): if func not in objlabel_functions: objlabel_functions.append((cls, func)) return func return decorator class DbLocal(localbase): def __init__(dblocal): dblocal.stats = {None: QueryStat(None)} dblocal.last_sql = None class QueryStat(object): def __init__(stat, sql, duration=None): if duration is not None: stat.min_time = stat.max_time = stat.sum_time = duration stat.db_count = 1 stat.cache_count = 0 else: stat.min_time = stat.max_time = stat.sum_time = None stat.db_count = 0 stat.cache_count = 1 stat.sql = sql def copy(stat): result = object.__new__(QueryStat) result.__dict__.update(stat.__dict__) return result def query_executed(stat, duration): if stat.db_count: stat.min_time = builtins.min(stat.min_time, duration) stat.max_time = builtins.max(stat.max_time, duration) stat.sum_time += duration else: stat.min_time = stat.max_time = stat.sum_time = duration stat.db_count += 1 def merge(stat, stat2): assert stat.sql == stat2.sql if not stat2.db_count: pass elif stat.db_count: stat.min_time = builtins.min(stat.min_time, stat2.min_time) stat.max_time = builtins.max(stat.max_time, stat2.max_time) stat.sum_time += stat2.sum_time else: stat.min_time = stat2.min_time stat.max_time = stat2.max_time stat.sum_time = stat2.sum_time stat.db_count += stat2.db_count stat.cache_count += stat2.cache_count @property def avg_time(stat): if not stat.db_count: return None return stat.sum_time / stat.db_count num_counter = itertools.count() class SessionCache(object): def __init__(cache, database): cache.is_alive = True cache.num = next(num_counter) cache.database = database cache.objects = set() cache.indexes = defaultdict(dict) cache.seeds = defaultdict(set) cache.max_id_cache = {} cache.collection_statistics = {} cache.for_update = set() cache.noflush_counter = 0 cache.modified_collections = defaultdict(set) cache.objects_to_save = [] cache.saved_objects = [] cache.query_results = {} cache.dbvals_deduplication_cache = {} cache.modified = False cache.db_session = db_session = local.db_session cache.immediate = db_session is not None and db_session.immediate cache.connection = None cache.in_transaction = False cache.saved_fk_state = None cache.perm_cache = defaultdict(lambda : defaultdict(dict)) # user -> perm -> cls_or_attr_or_obj -> bool cache.user_roles_cache = defaultdict(dict) # user -> obj -> roles cache.obj_labels_cache = {} # obj -> labels def connect(cache): assert cache.connection is None if cache.in_transaction: throw(ConnectionClosedError, 'Transaction cannot be continued because database connection failed') database = cache.database provider = database.provider connection, is_new_connection = provider.connect() if is_new_connection: database.call_on_connect(connection) try: provider.set_transaction_mode(connection, cache) # can set cache.in_transaction except: provider.drop(connection, cache) raise cache.connection = connection return connection def reconnect(cache, exc): provider = cache.database.provider if exc is not None: exc = getattr(exc, 'original_exc', exc) if not provider.should_reconnect(exc): reraise(*sys.exc_info()) if local.debug: log_orm('CONNECTION FAILED: %s' % exc) connection = cache.connection assert connection is not None cache.connection = None provider.drop(connection, cache) else: assert cache.connection is None return cache.connect() def prepare_connection_for_query_execution(cache): db_session = local.db_session if db_session is not None and cache.db_session is None: # This situation can arise when a transaction was started # in the interactive mode, outside of the db_session if cache.in_transaction or cache.modified: local.db_session = None try: cache.flush_and_commit() finally: local.db_session = db_session cache.db_session = db_session cache.immediate = cache.immediate or db_session.immediate else: assert cache.db_session is db_session, (cache.db_session, db_session) connection = cache.connection if connection is None: connection = cache.connect() elif cache.immediate and not cache.in_transaction: provider = cache.database.provider try: provider.set_transaction_mode(connection, cache) # can set cache.in_transaction except Exception as e: connection = cache.reconnect(e) if not cache.noflush_counter and cache.modified: cache.flush() return connection def flush_and_commit(cache): try: cache.flush() except: cache.rollback() raise try: cache.commit() except: transact_reraise(CommitException, [sys.exc_info()]) def commit(cache): assert cache.is_alive try: if cache.modified: cache.flush() if cache.in_transaction: assert cache.connection is not None cache.database.provider.commit(cache.connection, cache) cache.for_update.clear() cache.query_results.clear() cache.max_id_cache.clear() cache.immediate = True except: cache.rollback() raise def rollback(cache): cache.close(rollback=True) def release(cache): cache.close(rollback=False) def close(cache, rollback=True): assert cache.is_alive if not rollback: assert not cache.in_transaction database = cache.database x = local.db2cache.pop(database); assert x is cache cache.is_alive = False provider = database.provider connection = cache.connection if connection is None: return cache.connection = None try: if rollback: try: provider.rollback(connection, cache) except: provider.drop(connection, cache) raise provider.release(connection, cache) finally: db_session = cache.db_session or local.db_session if db_session and db_session.strict: for obj in cache.objects: obj._vals_ = obj._dbvals_ = obj._session_cache_ = None cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None else: for obj in cache.objects: obj._dbvals_ = obj._session_cache_ = None for attr, setdata in iteritems(obj._vals_): if attr.is_collection: if not setdata.is_fully_loaded: obj._vals_[attr] = None cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ = cache.modified_collections = cache.collection_statistics = cache.dbvals_deduplication_cache = None @contextmanager def flush_disabled(cache): cache.noflush_counter += 1 try: yield finally: cache.noflush_counter -= 1 def flush(cache): if cache.noflush_counter: return assert cache.is_alive assert not cache.saved_objects prev_immediate = cache.immediate cache.immediate = True try: for i in xrange(50): if not cache.modified: return with cache.flush_disabled(): for obj in cache.objects_to_save: # can grow during iteration if obj is not None: obj._before_save_() cache.query_results.clear() modified_m2m = cache._calc_modified_m2m() for attr, (added, removed) in iteritems(modified_m2m): if not removed: continue attr.remove_m2m(removed) for obj in cache.objects_to_save: if obj is not None: obj._save_() for attr, (added, removed) in iteritems(modified_m2m): if not added: continue attr.add_m2m(added) cache.max_id_cache.clear() cache.modified_collections.clear() cache.objects_to_save[:] = () cache.modified = False cache.call_after_save_hooks() else: if cache.modified: throw(TransactionError, 'Recursion depth limit reached in obj._after_save_() call') finally: if not cache.in_transaction: cache.immediate = prev_immediate def call_after_save_hooks(cache): saved_objects = cache.saved_objects cache.saved_objects = [] for obj, status in saved_objects: obj._after_save_(status) def _calc_modified_m2m(cache): modified_m2m = {} for attr, objects in sorted(iteritems(cache.modified_collections), key=lambda pair: (pair[0].entity.__name__, pair[0].name)): if not isinstance(attr, Set): throw(NotImplementedError) reverse = attr.reverse if not reverse.is_collection: for obj in objects: setdata = obj._vals_[attr] setdata.added = setdata.removed = setdata.absent = None continue if not isinstance(reverse, Set): throw(NotImplementedError) if reverse in modified_m2m: continue added, removed = modified_m2m.setdefault(attr, (set(), set())) for obj in objects: setdata = obj._vals_[attr] if setdata.added: for obj2 in setdata.added: added.add((obj, obj2)) if setdata.removed: for obj2 in setdata.removed: removed.add((obj, obj2)) if obj._status_ == 'marked_to_delete': del obj._vals_[attr] else: setdata.added = setdata.removed = setdata.absent = None cache.modified_collections.clear() return modified_m2m def update_simple_index(cache, obj, attr, old_val, new_val, undo): assert old_val != new_val cache_index = cache.indexes[attr] if new_val is not None: obj2 = cache_index.setdefault(new_val, obj) if obj2 is not obj: throw(CacheIndexError, 'Cannot update %s.%s: %s with key %s already exists' % (obj.__class__.__name__, attr.name, obj2, new_val)) if old_val is not None: del cache_index[old_val] undo.append((cache_index, old_val, new_val)) def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): assert old_dbval != new_dbval cache_index = cache.indexes[attr] if new_dbval is not None: obj2 = cache_index.setdefault(new_dbval, obj) if obj2 is not obj: throw(TransactionIntegrityError, '%s with unique index %s.%s already exists: %s' % (obj2.__class__.__name__, obj.__class__.__name__, attr.name, new_dbval)) # attribute which was created or updated lately clashes with one stored in database cache_index.pop(old_dbval, None) def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): assert prev_vals != new_vals if None in prev_vals: prev_vals = None if None in new_vals: new_vals = None if prev_vals is None and new_vals is None: return cache_index = cache.indexes[attrs] if new_vals is not None: obj2 = cache_index.setdefault(new_vals, obj) if obj2 is not obj: attr_names = ', '.join(attr.name for attr in attrs) throw(CacheIndexError, 'Cannot update %r: composite key (%s) with value %s already exists for %r' % (obj, attr_names, new_vals, obj2)) if prev_vals is not None: del cache_index[prev_vals] undo.append((cache_index, prev_vals, new_vals)) def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals): assert prev_vals != new_vals cache_index = cache.indexes[attrs] if None not in new_vals: obj2 = cache_index.setdefault(new_vals, obj) if obj2 is not obj: key_str = ', '.join(repr(item) for item in new_vals) throw(TransactionIntegrityError, '%s with unique index (%s) already exists: %s' % (obj2.__class__.__name__, ', '.join(attr.name for attr in attrs), key_str)) cache_index.pop(prev_vals, None) class NotLoadedValueType(object): def __repr__(self): return 'NOT_LOADED' NOT_LOADED = NotLoadedValueType() class DefaultValueType(object): def __repr__(self): return 'DEFAULT' DEFAULT = DefaultValueType() class DescWrapper(object): def __init__(self, attr): self.attr = attr def __repr__(self): return '' % self.attr def __call__(self): return self def __eq__(self, other): return type(other) is DescWrapper and self.attr == other.attr def __ne__(self, other): return type(other) is not DescWrapper or self.attr != other.attr def __hash__(self): return hash(self.attr) + 1 attr_id_counter = itertools.count(1) class Attribute(object): __slots__ = 'nullable', 'is_required', 'is_discriminator', 'is_unique', 'is_part_of_unique_index', \ 'is_pk', 'is_collection', 'is_relation', 'is_basic', 'is_string', 'is_volatile', 'is_implicit', \ 'id', 'pk_offset', 'pk_columns_offset', 'py_type', 'sql_type', 'entity', 'name', \ 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ 'cascade_delete', 'index', 'reverse_index', 'original_default', 'sql_default', 'py_check', 'hidden', \ 'optimistic', 'fk_name', 'type_has_empty_value' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback def __init__(attr, py_type, *args, **kwargs): if attr.__class__ is Attribute: throw(TypeError, "'Attribute' is abstract type") attr.is_implicit = False attr.is_required = isinstance(attr, Required) attr.is_discriminator = isinstance(attr, Discriminator) attr.is_unique = kwargs.pop('unique', None) if isinstance(attr, PrimaryKey): if attr.is_unique is not None: throw(TypeError, "'unique' option cannot be set for PrimaryKey attribute ") attr.is_unique = True attr.nullable = kwargs.pop('nullable', None) attr.is_part_of_unique_index = attr.is_unique # Also can be set to True later attr.is_pk = isinstance(attr, PrimaryKey) if attr.is_pk: attr.pk_offset = 0 else: attr.pk_offset = None attr.id = next(attr_id_counter) if not isinstance(py_type, (type, basestring, types.FunctionType, Array)): if py_type is datetime: throw(TypeError, 'datetime is the module and cannot be used as attribute type. Use datetime.datetime instead') throw(TypeError, 'Incorrect type of attribute: %r' % py_type) attr.py_type = py_type attr.is_string = type(py_type) is type and issubclass(py_type, basestring) attr.type_has_empty_value = attr.is_string or hasattr(attr.py_type, 'default_empty_value') attr.is_collection = isinstance(attr, Collection) attr.is_relation = isinstance(attr.py_type, (EntityMeta, basestring, types.FunctionType)) attr.is_basic = not attr.is_collection and not attr.is_relation attr.sql_type = kwargs.pop('sql_type', None) attr.entity = attr.name = None attr.args = args attr.auto = kwargs.pop('auto', False) attr.cascade_delete = kwargs.pop('cascade_delete', None) attr.reverse = kwargs.pop('reverse', None) if not attr.reverse: pass elif not isinstance(attr.reverse, (basestring, Attribute)): throw(TypeError, "Value of 'reverse' option must be name of reverse attribute). Got: %r" % attr.reverse) elif not attr.is_relation: throw(TypeError, 'Reverse option cannot be set for this type: %r' % attr.py_type) attr.column = kwargs.pop('column', None) attr.columns = kwargs.pop('columns', None) if attr.column is not None: if attr.columns is not None: throw(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") if not isinstance(attr.column, basestring): throw(TypeError, "Parameter 'column' must be a string. Got: %r" % attr.column) attr.columns = [ attr.column ] elif attr.columns is not None: if not isinstance(attr.columns, (tuple, list)): throw(TypeError, "Parameter 'columns' must be a list. Got: %r'" % attr.columns) for column in attr.columns: if not isinstance(column, basestring): throw(TypeError, "Items of parameter 'columns' must be strings. Got: %r" % attr.columns) if len(attr.columns) == 1: attr.column = attr.columns[0] else: attr.columns = [] attr.index = kwargs.pop('index', None) attr.reverse_index = kwargs.pop('reverse_index', None) attr.fk_name = kwargs.pop('fk_name', None) attr.col_paths = [] attr._columns_checked = False attr.composite_keys = [] attr.lazy = kwargs.pop('lazy', getattr(py_type, 'lazy', False)) attr.lazy_sql_cache = None attr.is_volatile = kwargs.pop('volatile', False) attr.optimistic = kwargs.pop('optimistic', None) attr.sql_default = kwargs.pop('sql_default', None) attr.py_check = kwargs.pop('py_check', None) attr.hidden = kwargs.pop('hidden', False) attr.kwargs = kwargs attr.converters = [] def _init_(attr, entity, name): attr.entity = entity attr.name = name if attr.pk_offset is not None and attr.lazy: throw(TypeError, 'Primary key attribute %s cannot be lazy' % attr) if attr.cascade_delete is not None and attr.is_basic: throw(TypeError, "'cascade_delete' option cannot be set for attribute %s, " "because it is not relationship attribute" % attr) if not attr.is_required: if attr.is_unique and attr.nullable is False: throw(TypeError, 'Optional unique attribute %s must be nullable' % attr) if entity._root_ is not entity: if attr.nullable is False: throw(ERDiagramError, 'Attribute %s must be nullable due to single-table inheritance' % attr) attr.nullable = True if 'default' in attr.kwargs: attr.default = attr.original_default = attr.kwargs.pop('default') if attr.is_required: if attr.default is None: throw(TypeError, 'Default value for required attribute %s cannot be None' % attr) if attr.default == '': throw(TypeError, 'Default value for required attribute %s cannot be empty string' % attr) elif attr.default is None and not attr.nullable: throw(TypeError, 'Default value for non-nullable attribute %s cannot be set to None' % attr) elif attr.type_has_empty_value and not attr.is_required and not attr.nullable: attr.default = '' if attr.is_string else attr.py_type.default_empty_value() else: attr.default = None sql_default = attr.sql_default if isinstance(sql_default, basestring): if sql_default == '': throw(TypeError, "'sql_default' option value cannot be empty string, " "because it should be valid SQL literal or expression. " "Try to use \"''\", or just specify default='' instead.") elif attr.sql_default not in (None, True, False): throw(TypeError, "'sql_default' option of %s attribute must be of string or bool type. Got: %s" % (attr, attr.sql_default)) if attr.py_check is not None and not callable(attr.py_check): throw(TypeError, "'py_check' parameter of %s attribute should be callable" % attr) # composite keys will be checked later inside EntityMeta.__init__ if attr.py_type == float: if attr.is_pk: throw(TypeError, 'PrimaryKey attribute %s cannot be of type float' % attr) elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr) if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError, '%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr)) def linked(attr): reverse = attr.reverse if attr.cascade_delete is None: attr.cascade_delete = attr.is_collection and reverse.is_required elif attr.cascade_delete: if reverse.cascade_delete: throw(TypeError, "'cascade_delete' option cannot be set for both sides of relationship " "(%s and %s) simultaneously" % (attr, reverse)) if reverse.is_collection: throw(TypeError, "'cascade_delete' option cannot be set for attribute %s, " "because reverse attribute %s is collection" % (attr, reverse)) if attr.is_collection and not reverse.is_collection: if attr.fk_name is not None: throw(TypeError, 'You should specify fk_name in %s instead of %s' % (reverse, attr)) for option in attr.kwargs: throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) @cut_traceback def __repr__(attr): owner_name = attr.entity.__name__ if attr.entity else '?' return '%s.%s' % (owner_name, attr.name or '?') def __lt__(attr, other): return attr.id < other.id def _get_entity(attr, obj, entity): if entity is not None: return entity if obj is not None: return obj.__class__ return attr.entity def validate(attr, val, obj=None, entity=None, from_db=False): val = deref_proxy(val) if val is None: if not attr.nullable and not from_db and not attr.is_required: # for required attribute the exception will be thrown later with another message throw(ValueError, 'Attribute %s cannot be set to None' % attr) return val assert val is not NOT_LOADED if val is DEFAULT: default = attr.default if default is None: return None if callable(default): val = default() else: val = default entity = attr._get_entity(obj, entity) reverse = attr.reverse if not reverse: if isinstance(val, Entity): throw(TypeError, 'Attribute %s must be of %s type. Got: %s' % (attr, attr.py_type.__name__, val)) if not attr.converters: return val if type(val) is attr.py_type else attr.py_type(val) if len(attr.converters) != 1: throw(NotImplementedError) converter = attr.converters[0] if converter is not None: try: if from_db: return converter.sql2py(val) val = converter.validate(val, obj) except UnicodeDecodeError as e: throw(ValueError, 'Value for attribute %s cannot be converted to %s: %s' % (attr, unicode.__name__, truncate_repr(val))) else: rentity = reverse.entity if not isinstance(val, rentity): vals = val if type(val) is tuple else (val,) if len(vals) != len(rentity._pk_columns_): throw(TypeError, 'Invalid number of columns were specified for attribute %s. Expected: %d, got: %d' % (attr, len(rentity._pk_columns_), len(vals))) try: val = rentity._get_by_raw_pkval_(vals, from_db=from_db) except TypeError: throw(TypeError, 'Attribute %s must be of %s type. Got: %r' % (attr, rentity.__name__, val)) else: if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() if cache is not val._session_cache_: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') if attr.py_check is not None and not attr.py_check(val): throw(ValueError, 'Check for attribute %s failed. Value: %s' % (attr, truncate_repr(val))) return val def parse_value(attr, row, offsets, dbvals_deduplication_cache): assert len(attr.columns) == len(offsets) if not attr.reverse: if len(offsets) > 1: throw(NotImplementedError) offset = offsets[0] dbval = attr.validate(row[offset], None, attr.entity, from_db=True) dbval = deduplicate(dbval, dbvals_deduplication_cache) else: dbvals = [ row[offset] for offset in offsets ] if None in dbvals: assert len(set(dbvals)) == 1 dbval = None else: dbval = attr.py_type._get_by_raw_pkval_(dbvals) return dbval def load(attr, obj): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load attribute', obj, attr) if not attr.columns: reverse = attr.reverse assert reverse is not None and reverse.columns dbval = reverse.entity._find_in_db_({reverse : obj}) if dbval is None: obj._vals_[attr] = None else: assert obj._vals_[attr] == dbval return dbval if attr.lazy: entity = attr.entity database = entity._database_ if not attr.lazy_sql_cache: select_list = [ 'ALL' ] + [ [ 'COLUMN', None, column ] for column in attr.columns ] from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] pk_columns = entity._pk_columns_ pk_converters = entity._pk_converters_ criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] sql, adapter = database._ast2sql(sql_ast) offsets = tuple(xrange(len(attr.columns))) attr.lazy_sql_cache = sql, adapter, offsets else: sql, adapter, offsets = attr.lazy_sql_cache arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) row = cursor.fetchone() dbval = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) attr.db_set(obj, dbval) else: obj._load_() return obj._vals_[attr] @cut_traceback def __get__(attr, obj, cls=None): if obj is None: return attr if attr.pk_offset is not None: return attr.get(obj) value = attr.get(obj) bit = obj._bits_except_volatile_[attr] wbits = obj._wbits_ if wbits is not None and not wbits & bit: obj._rbits_ |= bit return value def get(attr, obj): if attr.pk_offset is None and obj._status_ in ('deleted', 'cancelled'): throw_object_was_deleted(obj) vals = obj._vals_ if vals is None: throw_db_session_is_over('read value of', obj, attr) val = vals[attr] if attr in vals else attr.load(obj) if val is not None and attr.reverse and val._subclasses_ and val._status_ not in ('deleted', 'cancelled'): cache = obj._session_cache_ if cache is not None and val in cache.seeds[val._pk_attrs_]: val._load_() return val @cut_traceback def __set__(attr, obj, new_val, undo_funcs=None): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) reverse = attr.reverse new_val = attr.validate(new_val, obj, from_db=False) if attr.pk_offset is not None: pkval = obj._pkval_ if pkval is None: pass elif obj._pk_is_composite_: if new_val == pkval[attr.pk_offset]: return elif new_val == pkval: return throw(TypeError, 'Cannot change value of primary key') with cache.flush_disabled(): old_val = obj._vals_.get(attr, NOT_LOADED) if old_val is NOT_LOADED and reverse and not reverse.is_collection: old_val = attr.load(obj) status = obj._status_ wbits = obj._wbits_ bit = obj._bits_[attr] objects_to_save = cache.objects_to_save objects_to_save_needs_undo = False if wbits is not None and bit: obj._wbits_ = wbits | bit if status != 'modified': assert status in ('loaded', 'inserted', 'updated') assert obj._save_pos_ is None obj._status_ = 'modified' obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) objects_to_save_needs_undo = True cache.modified = True if not attr.reverse and not attr.is_part_of_unique_index: obj._vals_[attr] = new_val return is_reverse_call = undo_funcs is not None if not is_reverse_call: undo_funcs = [] undo = [] def undo_func(): obj._status_ = status obj._wbits_ = wbits if objects_to_save_needs_undo: assert objects_to_save obj2 = objects_to_save.pop() assert obj2 is obj and obj._save_pos_ == len(objects_to_save) obj._save_pos_ = None if old_val is NOT_LOADED: obj._vals_.pop(attr) else: obj._vals_[attr] = old_val for cache_index, old_key, new_key in undo: if new_key is not None: del cache_index[new_key] if old_key is not None: cache_index[old_key] = obj undo_funcs.append(undo_func) if old_val == new_val: return try: if attr.is_unique: cache.update_simple_index(obj, attr, old_val, new_val, undo) get_val = obj._vals_.get for attrs, i in attr.composite_keys: vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! prev_vals = tuple(vals) vals[i] = new_val new_vals = tuple(vals) cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) obj._vals_[attr] = new_val if not reverse: pass elif not is_reverse_call: attr.update_reverse(obj, old_val, new_val, undo_funcs) elif old_val not in (None, NOT_LOADED): if not reverse.is_collection: if new_val is not None: if reverse.is_required: throw(ConstraintError, 'Cannot unlink %r from previous %s object, because %r attribute is required' % (old_val, obj, reverse)) reverse.__set__(old_val, None, undo_funcs) elif isinstance(reverse, Set): reverse.reverse_remove((old_val,), obj, undo_funcs) else: throw(NotImplementedError) except: if not is_reverse_call: for undo_func in reversed(undo_funcs): undo_func() raise def db_set(attr, obj, new_dbval, is_reverse_call=False): cache = obj._session_cache_ assert cache is not None and cache.is_alive assert obj._status_ not in created_or_deleted_statuses assert attr.pk_offset is None if new_dbval is NOT_LOADED: assert is_reverse_call old_dbval = obj._dbvals_.get(attr, NOT_LOADED) if old_dbval is not NOT_LOADED: if old_dbval == new_dbval or ( not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): return bit = obj._bits_except_volatile_[attr] if obj._rbits_ & bit: assert old_dbval is not NOT_LOADED msg = 'Value of %s for %s was updated outside of current transaction' % (attr, obj) if new_dbval is not NOT_LOADED: msg = '%s (was: %s, now: %s)' % (msg, old_dbval, new_dbval) elif isinstance(attr.reverse, Optional): assert old_dbval is not None msg = "Multiple %s objects linked with the same %s object. " \ "Maybe %s attribute should be Set instead of Optional" \ % (attr.entity.__name__, old_dbval, attr.reverse) throw(UnrepeatableReadError, msg) if new_dbval is NOT_LOADED: obj._dbvals_.pop(attr, None) else: obj._dbvals_[attr] = new_dbval wbit = bool(obj._wbits_ & bit) if not wbit: old_val = obj._vals_.get(attr, NOT_LOADED) assert old_val == old_dbval, (old_val, old_dbval) if attr.is_part_of_unique_index: if attr.is_unique: cache.db_update_simple_index(obj, attr, old_val, new_dbval) get_val = obj._vals_.get for attrs, i in attr.composite_keys: vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! old_vals = tuple(vals) vals[i] = new_dbval new_vals = tuple(vals) cache.db_update_composite_index(obj, attrs, old_vals, new_vals) if new_dbval is NOT_LOADED: obj._vals_.pop(attr, None) elif attr.reverse: obj._vals_[attr] = new_dbval else: assert len(attr.converters) == 1 obj._vals_[attr] = attr.converters[0].dbval2val(new_dbval, obj) reverse = attr.reverse if not reverse: pass elif not is_reverse_call: attr.db_update_reverse(obj, old_dbval, new_dbval) elif old_dbval not in (None, NOT_LOADED): if not reverse.is_collection: if new_dbval is not NOT_LOADED: reverse.db_set(old_dbval, NOT_LOADED, is_reverse_call=True) elif isinstance(reverse, Set): reverse.db_reverse_remove((old_dbval,), obj) else: throw(NotImplementedError) def update_reverse(attr, obj, old_val, new_val, undo_funcs): reverse = attr.reverse if not reverse.is_collection: if old_val not in (None, NOT_LOADED): if attr.cascade_delete: old_val._delete_(undo_funcs) elif reverse.is_required: throw(ConstraintError, 'Cannot unlink %r from previous %s object, because %r attribute is required' % (old_val, obj, reverse)) else: reverse.__set__(old_val, None, undo_funcs) if new_val is not None: reverse.__set__(new_val, obj, undo_funcs) elif isinstance(reverse, Set): if old_val not in (None, NOT_LOADED): reverse.reverse_remove((old_val,), obj, undo_funcs) if new_val is not None: reverse.reverse_add((new_val,), obj, undo_funcs) else: throw(NotImplementedError) def db_update_reverse(attr, obj, old_dbval, new_dbval): reverse = attr.reverse if not reverse.is_collection: if old_dbval not in (None, NOT_LOADED): reverse.db_set(old_dbval, NOT_LOADED, True) if new_dbval is not None: reverse.db_set(new_dbval, obj, True) elif isinstance(reverse, Set): if old_dbval not in (None, NOT_LOADED): reverse.db_reverse_remove((old_dbval,), obj) if new_dbval is not None: reverse.db_reverse_add((new_dbval,), obj) else: throw(NotImplementedError) def __delete__(attr, obj): throw(NotImplementedError) def get_raw_values(attr, val): reverse = attr.reverse if not reverse: return (val,) rentity = reverse.entity if val is None: return rentity._pk_nones_ return val._get_raw_pkval_() def get_columns(attr): assert not attr.is_collection assert not isinstance(attr.py_type, basestring) if attr._columns_checked: return attr.columns provider = attr.entity._database_.provider reverse = attr.reverse if not reverse: # attr is not part of relationship if not attr.columns: attr.columns = provider.get_default_column_names(attr) elif len(attr.columns) > 1: throw(MappingError, "Too many columns were specified for %s" % attr) attr.col_paths = [ attr.name ] attr.converters = [ provider.get_converter_by_attr(attr) ] else: def generate_columns(): reverse_pk_columns = reverse.entity._get_pk_columns_() reverse_pk_col_paths = reverse.entity._pk_paths_ if not attr.columns: attr.columns = provider.get_default_column_names(attr, reverse_pk_columns) elif len(attr.columns) != len(reverse_pk_columns): throw(MappingError, 'Invalid number of columns specified for %s' % attr) attr.col_paths = [ '-'.join((attr.name, paths)) for paths in reverse_pk_col_paths ] attr.converters = [] for a in reverse.entity._pk_attrs_: attr.converters.extend(a.converters) if reverse.is_collection: # one-to-many: generate_columns() # one-to-one: elif attr.is_required: assert not reverse.is_required generate_columns() elif attr.columns: generate_columns() elif reverse.columns: pass elif reverse.is_required: pass elif attr.entity.__name__ > reverse.entity.__name__: pass else: generate_columns() attr._columns_checked = True if len(attr.columns) == 1: attr.column = attr.columns[0] else: attr.column = None return attr.columns @property def asc(attr): return attr @property def desc(attr): return DescWrapper(attr) def describe(attr): t = attr.py_type if isinstance(t, type): t = t.__name__ options = [] if attr.args: options.append(', '.join(imap(str, attr.args))) if attr.auto: options.append('auto=True') for k, v in sorted(attr.kwargs.items()): options.append('%s=%r' % (k, v)) if not isinstance(attr, PrimaryKey) and attr.is_unique: options.append('unique=True') if attr.default is not None: options.append('default=%r' % attr.default) if not options: options = '' else: options = ', ' + ', '.join(options) result = "%s(%s%s)" % (attr.__class__.__name__, t, options) return "%s = %s" % (attr.name, result) class Optional(Attribute): __slots__ = [] class Required(Attribute): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): val = Attribute.validate(attr, val, obj, entity, from_db) if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): if not from_db: throw(ValueError, 'Attribute %s is required' % ( attr if obj is None or obj._status_ is None else '%r.%s' % (obj, attr.name))) else: warnings.warn('Database contains %s for required attribute %s' % ('NULL' if val is None else 'empty string', attr), DatabaseContainsIncorrectEmptyValue) return val class Discriminator(Required): __slots__ = [ 'code2cls' ] def __init__(attr, py_type, *args, **kwargs): Attribute.__init__(attr, py_type, *args, **kwargs) attr.code2cls = {} def _init_(attr, entity, name): if entity._root_ is not entity: throw(ERDiagramError, 'Discriminator attribute %s cannot be declared in subclass' % attr) Required._init_(attr, entity, name) entity._discriminator_attr_ = attr @staticmethod def create_default_attr(entity): if hasattr(entity, 'classtype'): throw(ERDiagramError, "Cannot create discriminator column for %s automatically " "because name 'classtype' is already in use" % entity.__name__) attr = Discriminator(str, column='classtype') attr.is_implicit = True attr._init_(entity, 'classtype') entity._attrs_.append(attr) entity._new_attrs_.append(attr) entity._adict_['classtype'] = attr entity.classtype = attr attr.process_entity_inheritance(entity) def process_entity_inheritance(attr, entity): if '_discriminator_' not in entity.__dict__: entity._discriminator_ = entity.__name__ discr_value = entity._discriminator_ if discr_value is not None: try: entity._discriminator_ = discr_value = attr.validate(discr_value, None, entity) except ValueError: throw(TypeError, "Incorrect discriminator value is set for %s attribute '%s' of '%s' type: %r" % (entity.__name__, attr.name, attr.py_type.__name__, discr_value)) elif issubclass(attr.py_type, basestring): discr_value = entity._discriminator_ = entity.__name__ else: throw(TypeError, "Discriminator value for entity %s " "with custom discriminator column '%s' of '%s' type is not set" % (entity.__name__, attr.name, attr.py_type.__name__)) attr.code2cls[discr_value] = entity def validate(attr, val, obj=None, entity=None, from_db=False): if from_db: return val entity = attr._get_entity(obj, entity) if val is DEFAULT: assert entity is not None return entity._discriminator_ if val != entity._discriminator_: for cls in entity._subclasses_: if val == cls._discriminator_: break else: throw(TypeError, 'Invalid discriminator attribute value for %s. Expected: %r, got: %r' % (entity.__name__, entity._discriminator_, val)) return Attribute.validate(attr, val, obj, entity) def load(attr, obj): assert False # pragma: no cover def __get__(attr, obj, cls=None): if obj is None: return attr return obj._discriminator_ def __set__(attr, obj, new_val): throw(TypeError, 'Cannot assign value to discriminator attribute') def db_set(attr, obj, new_dbval): assert False # pragma: no cover def update_reverse(attr, obj, old_val, new_val, undo_funcs): assert False # pragma: no cover class Index(object): __slots__ = 'entity', 'attrs', 'is_pk', 'is_unique' def __init__(index, *attrs, **options): index.entity = None index.attrs = list(attrs) index.is_pk = options.pop('is_pk', False) index.is_unique = options.pop('is_unique', True) assert not options def _init_(index, entity): index.entity = entity attrs = index.attrs for i, attr in enumerate(index.attrs): if isinstance(attr, basestring): try: attr = getattr(entity, attr) except AttributeError: throw(AttributeError, 'Entity %s does not have attribute %s' % (entity.__name__, attr)) attrs[i] = attr index.attrs = attrs = tuple(attrs) for i, attr in enumerate(attrs): if not isinstance(attr, Attribute): func_name = 'PrimaryKey' if index.is_pk else 'composite_key' if index.is_unique else 'composite_index' throw(TypeError, '%s() arguments must be attributes. Got: %r' % (func_name, attr)) if index.is_unique: attr.is_part_of_unique_index = True if len(attrs) > 1: attr.composite_keys.append((attrs, i)) if not issubclass(entity, attr.entity): throw(ERDiagramError, 'Invalid use of attribute %s in entity %s' % (attr, entity.__name__)) key_type = 'primary key' if index.is_pk else 'unique index' if index.is_unique else 'index' if attr.is_collection or (index.is_pk and not attr.is_required and not attr.auto): throw(TypeError, '%s attribute %s cannot be part of %s' % (attr.__class__.__name__, attr, key_type)) if isinstance(attr.py_type, type) and issubclass(attr.py_type, float): throw(TypeError, 'Attribute %s of type float cannot be part of %s' % (attr, key_type)) if index.is_pk and attr.is_volatile: throw(TypeError, 'Volatile attribute %s cannot be part of primary key' % attr) if not attr.is_required: if attr.nullable is False: throw(TypeError, 'Optional attribute %s must be nullable, because it is part of composite key' % attr) attr.nullable = True if attr.is_string and attr.default == '' and not hasattr(attr, 'original_default'): attr.default = None def _define_index(func_name, attrs, is_unique=False): if len(attrs) < 2: throw(TypeError, '%s() must receive at least two attributes as arguments' % func_name) cls_dict = sys._getframe(2).f_locals indexes = cls_dict.setdefault('_indexes_', []) indexes.append(Index(*attrs, is_pk=False, is_unique=is_unique)) def composite_index(*attrs): _define_index('composite_index', attrs) def composite_key(*attrs): _define_index('composite_key', attrs, is_unique=True) class PrimaryKey(Required): __slots__ = [] def __new__(cls, *args, **kwargs): if not args: throw(TypeError, 'PrimaryKey must receive at least one positional argument') cls_dict = sys._getframe(1).f_locals attrs = tuple(a for a in args if isinstance(a, Attribute)) non_attrs = [ a for a in args if not isinstance(a, Attribute) ] cls_dict = sys._getframe(1).f_locals if not attrs: return Required.__new__(cls) elif non_attrs or kwargs: throw(TypeError, 'PrimaryKey got invalid arguments: %r %r' % (args, kwargs)) elif len(attrs) == 1: attr = attrs[0] attr_name = 'something' for key, val in iteritems(cls_dict): if val is attr: attr_name = key; break py_type = attr.py_type type_str = py_type.__name__ if type(py_type) is type else repr(py_type) throw(TypeError, 'Just use %s = PrimaryKey(%s, ...) directly instead of PrimaryKey(%s)' % (attr_name, type_str, attr_name)) for i, attr in enumerate(attrs): attr.is_part_of_unique_index = True attr.composite_keys.append((attrs, i)) indexes = cls_dict.setdefault('_indexes_', []) indexes.append(Index(*attrs, is_pk=True)) return None class Collection(Attribute): __slots__ = 'table', 'wrapper_class', 'symmetric', 'reverse_column', 'reverse_columns', \ 'nplus1_threshold', 'cached_load_sql', 'cached_add_m2m_sql', 'cached_remove_m2m_sql', \ 'cached_count_sql', 'cached_empty_sql', 'reverse_fk_name' def __init__(attr, py_type, *args, **kwargs): if attr.__class__ is Collection: throw(TypeError, "'Collection' is abstract type") table = kwargs.pop('table', None) # TODO: rename table to link_table or m2m_table if table is not None and not isinstance(table, basestring): if not isinstance(table, (list, tuple)): throw(TypeError, "Parameter 'table' must be a string. Got: %r" % table) for name_part in table: if not isinstance(name_part, basestring): throw(TypeError, 'Each part of table name must be a string. Got: %r' % name_part) table = tuple(table) attr.table = table Attribute.__init__(attr, py_type, *args, **kwargs) if attr.auto: throw(TypeError, "'auto' option could not be set for collection attribute") kwargs = attr.kwargs attr.reverse_column = kwargs.pop('reverse_column', None) attr.reverse_columns = kwargs.pop('reverse_columns', None) if attr.reverse_column is not None: if attr.reverse_columns is not None and attr.reverse_columns != [ attr.reverse_column ]: throw(TypeError, "Parameters 'reverse_column' and 'reverse_columns' cannot be specified simultaneously") if not isinstance(attr.reverse_column, basestring): throw(TypeError, "Parameter 'reverse_column' must be a string. Got: %r" % attr.reverse_column) attr.reverse_columns = [ attr.reverse_column ] elif attr.reverse_columns is not None: if not isinstance(attr.reverse_columns, (tuple, list)): throw(TypeError, "Parameter 'reverse_columns' must be a list. Got: %r" % attr.reverse_columns) for reverse_column in attr.reverse_columns: if not isinstance(reverse_column, basestring): throw(TypeError, "Parameter 'reverse_columns' must be a list of strings. Got: %r" % attr.reverse_columns) if len(attr.reverse_columns) == 1: attr.reverse_column = attr.reverse_columns[0] else: attr.reverse_columns = [] attr.reverse_fk_name = kwargs.pop('reverse_fk_name', None) attr.nplus1_threshold = kwargs.pop('nplus1_threshold', 1) attr.cached_load_sql = {} attr.cached_add_m2m_sql = None attr.cached_remove_m2m_sql = None attr.cached_count_sql = None attr.cached_empty_sql = None def _init_(attr, entity, name): Attribute._init_(attr, entity, name) if attr.is_unique: throw(TypeError, "'unique' option cannot be set for attribute %s because it is collection" % attr) if attr.default is not None: throw(TypeError, 'Default value could not be set for collection attribute') attr.symmetric = (attr.py_type == entity.__name__ and attr.reverse == name) if not attr.symmetric: if attr.reverse_columns: throw(TypeError, "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only") if attr.reverse_index: throw(TypeError, "'reverse_index' option can be set for symmetric relations only") if attr.py_check is not None: throw(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def load(attr, obj): assert False, 'Abstract method' # pragma: no cover def __get__(attr, obj, cls=None): assert False, 'Abstract method' # pragma: no cover def __set__(attr, obj, val): assert False, 'Abstract method' # pragma: no cover def __delete__(attr, obj): assert False, 'Abstract method' # pragma: no cover def prepare(attr, obj, val, fromdb=False): assert False, 'Abstract method' # pragma: no cover def set(attr, obj, val, fromdb=False): assert False, 'Abstract method' # pragma: no cover class SetData(set): __slots__ = 'is_fully_loaded', 'added', 'removed', 'absent', 'count' def __init__(setdata): setdata.is_fully_loaded = False setdata.added = setdata.removed = setdata.absent = None setdata.count = None def construct_batchload_criteria_list(alias, columns, converters, batch_size, row_value_syntax, start=0, from_seeds=True): assert batch_size > 0 def param(i, j, converter): if from_seeds: return [ 'PARAM', (i, None, j), converter ] else: return [ 'PARAM', (i, j, None), converter ] if batch_size == 1: return [ [ converter.EQ, [ 'COLUMN', alias, column ], param(start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] if len(columns) == 1: column = columns[0] converter = converters[0] param_list = [ param(i+start, 0, converter) for i in xrange(batch_size) ] condition = [ 'IN', [ 'COLUMN', alias, column ], param_list ] return [ condition ] elif row_value_syntax: row = [ 'ROW' ] + [ [ 'COLUMN', alias, column ] for column in columns ] param_list = [ [ 'ROW' ] + [ param(i+start, j, converter) for j, converter in enumerate(converters) ] for i in xrange(batch_size) ] condition = [ 'IN', row, param_list ] return [ condition ] else: conditions = [ [ 'AND' ] + [ [ converter.EQ, [ 'COLUMN', alias, column ], param(i+start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] for i in xrange(batch_size) ] return [ [ 'OR' ] + conditions ] class Set(Collection): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): val = deref_proxy(val) assert val is not NOT_LOADED if val is DEFAULT: return set() reverse = attr.reverse if val is None: throw(ValueError, 'A single %(cls)s instance or %(cls)s iterable is expected. ' 'Got: None' % dict(cls=reverse.entity.__name__)) if entity is not None: pass elif obj is not None: entity = obj.__class__ else: entity = attr.entity if not reverse: throw(NotImplementedError) if isinstance(val, reverse.entity): items = set((val,)) else: rentity = reverse.entity try: items = set(val) except TypeError: throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, val)) for item in items: item = deref_proxy(item) if not isinstance(item, rentity): throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, item)) if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() for item in items: if item._session_cache_ is not cache: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') return items def prefetch_load_all(attr, objects): entity = attr.entity database = entity._database_ cache = database._get_cache() if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') reverse = attr.reverse rentity = reverse.entity objects = sorted(objects, key=entity._get_raw_pkval_) max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) result = set() if not reverse.is_collection: for i in xrange(0, len(objects), max_batch_size): batch = objects[i:i+max_batch_size] sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(batch), reverse) arguments = adapter(batch) cursor = database._exec_sql(sql, arguments) result.update(rentity._fetch_objects(cursor, attr_offsets)) else: pk_len = len(entity._pk_columns_) m2m_dict = defaultdict(set) for i in xrange(0, len(objects), max_batch_size): batch = objects[i:i+max_batch_size] sql, adapter = attr.construct_sql_m2m(len(batch)) arguments = adapter(batch) cursor = database._exec_sql(sql, arguments) if len(batch) > 1: for row in cursor.fetchall(): obj = entity._get_by_raw_pkval_(row[:pk_len]) item = rentity._get_by_raw_pkval_(row[pk_len:]) m2m_dict[obj].add(item) else: obj = batch[0] m2m_dict[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} for obj2, items in iteritems(m2m_dict): setdata2 = obj2._vals_.get(attr) if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added if phantoms: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 if setdata2.removed: items -= setdata2.removed setdata2 |= items reverse.db_reverse_add(items, obj2) result.update(items) for obj in objects: setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() setdata.is_fully_loaded = True setdata.absent = None setdata.count = len(setdata) return result def load(attr, obj, items=None): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load collection', obj, attr) assert obj._status_ not in del_statuses setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.is_fully_loaded: return setdata entity = attr.entity reverse = attr.reverse rentity = reverse.entity database = obj._database_ if cache is not database._get_cache(): throw(TransactionError, "Transaction of object %s belongs to different thread") if items: if not reverse.is_collection: items = {item for item in items if reverse not in item._vals_} else: items = set(items) items -= setdata if setdata.removed: items -= setdata.removed if not items: return setdata if items and (attr.lazy or not setdata): items = list(items) if not reverse.is_collection: sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(items)) arguments = adapter(items) cursor = database._exec_sql(sql, arguments) items = rentity._fetch_objects(cursor, attr_offsets) return setdata sql, adapter = attr.construct_sql_m2m(1, len(items)) items.append(obj) arguments = adapter(items) cursor = database._exec_sql(sql, arguments) loaded_items = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} setdata |= loaded_items reverse.db_reverse_add(loaded_items, obj) return setdata counter = cache.collection_statistics.setdefault(attr, 0) nplus1_threshold = attr.nplus1_threshold prefetching = not attr.lazy and nplus1_threshold is not None and counter >= nplus1_threshold objects = [ obj ] setdata_list = [ setdata ] if prefetching: pk_index = cache.indexes[entity._pk_attrs_] max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) for obj2 in itervalues(pk_index): if obj2 is obj: continue if obj2._status_ in created_or_deleted_statuses: continue setdata2 = obj2._vals_.get(attr) if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() elif setdata2.is_fully_loaded: continue objects.append(obj2) setdata_list.append(setdata2) if len(objects) >= max_batch_size: break if not reverse.is_collection: sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(objects), reverse) arguments = adapter(objects) cursor = database._exec_sql(sql, arguments) items = rentity._fetch_objects(cursor, attr_offsets) else: sql, adapter = attr.construct_sql_m2m(len(objects)) arguments = adapter(objects) cursor = database._exec_sql(sql, arguments) pk_len = len(entity._pk_columns_) d = {} if len(objects) > 1: for row in cursor.fetchall(): obj2 = entity._get_by_raw_pkval_(row[:pk_len]) item = rentity._get_by_raw_pkval_(row[pk_len:]) items = d.get(obj2) if items is None: items = d[obj2] = set() items.add(item) else: d[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} for obj2, items in iteritems(d): setdata2 = obj2._vals_.get(attr) if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added if phantoms: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 if setdata2.removed: items -= setdata2.removed setdata2 |= items reverse.db_reverse_add(items, obj2) for setdata2 in setdata_list: setdata2.is_fully_loaded = True setdata2.absent = None setdata2.count = len(setdata2) cache.collection_statistics[attr] = counter + 1 return setdata def construct_sql_m2m(attr, batch_size=1, items_count=0): if items_count: assert batch_size == 1 cache_key = -items_count else: cache_key = batch_size cached_sql = attr.cached_load_sql.get(cache_key) if cached_sql is not None: return cached_sql reverse = attr.reverse assert reverse is not None and reverse.is_collection and issubclass(reverse.py_type, Entity) table_name = attr.table assert table_name is not None select_list = [ 'ALL' ] if not attr.symmetric: columns = attr.columns converters = attr.converters rcolumns = reverse.columns rconverters = reverse.converters else: columns = attr.reverse_columns rcolumns = attr.columns converters = rconverters = attr.converters if batch_size > 1: select_list.extend([ 'COLUMN', 'T1', column ] for column in rcolumns) select_list.extend([ 'COLUMN', 'T1', column ] for column in columns) from_list = [ 'FROM', [ 'T1', 'TABLE', table_name ]] database = attr.entity._database_ row_value_syntax = database.provider.translator_cls.row_value_syntax where_list = [ 'WHERE' ] where_list += construct_batchload_criteria_list( 'T1', rcolumns, rconverters, batch_size, row_value_syntax, items_count) if items_count: where_list += construct_batchload_criteria_list( 'T1', columns, converters, items_count, row_value_syntax) sql_ast = [ 'SELECT', select_list, from_list, where_list ] sql, adapter = attr.cached_load_sql[cache_key] = database._ast2sql(sql_ast) return sql, adapter def copy(attr, obj): if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) reverse = attr.reverse if not reverse.is_collection and reverse.pk_offset is None: added = setdata.added or () for item in setdata: if item in added: continue bit = item._bits_except_volatile_[reverse] assert item._wbits_ is not None if not item._wbits_ & bit: item._rbits_ |= bit return set(setdata) @cut_traceback def __get__(attr, obj, cls=None): if obj is None: return attr if obj._status_ in del_statuses: throw_object_was_deleted(obj) rentity = attr.py_type wrapper_class = rentity._get_set_wrapper_subclass_() return wrapper_class(obj, attr) @cut_traceback def __set__(attr, obj, new_items, undo_funcs=None): if isinstance(new_items, SetInstance) and new_items._obj_ is obj and new_items._attr_ is attr: return # after += or -= cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): new_items = attr.validate(new_items, obj) reverse = attr.reverse if not reverse: throw(NotImplementedError) setdata = obj._vals_.get(attr) if setdata is None: if obj._status_ == 'created': setdata = obj._vals_[attr] = SetData() setdata.is_fully_loaded = True setdata.count = 0 else: setdata = attr.load(obj) elif not setdata.is_fully_loaded: setdata = attr.load(obj) if new_items == setdata: return to_add = new_items - setdata to_remove = setdata - new_items is_reverse_call = undo_funcs is not None if not is_reverse_call: undo_funcs = [] try: if not reverse.is_collection: if attr.cascade_delete: for item in to_remove: item._delete_(undo_funcs) else: for item in to_remove: reverse.__set__(item, None, undo_funcs) for item in to_add: reverse.__set__(item, obj, undo_funcs) else: reverse.reverse_remove(to_remove, obj, undo_funcs) reverse.reverse_add(to_add, obj, undo_funcs) except: if not is_reverse_call: for undo_func in reversed(undo_funcs): undo_func() raise setdata.clear() setdata |= new_items if setdata.count is not None: setdata.count = len(new_items) added = setdata.added removed = setdata.removed if to_add: if removed: (to_add, setdata.removed) = (to_add - removed, removed - to_add) if added: added |= to_add else: setdata.added = to_add # added may be None if to_remove: if added: (to_remove, setdata.added) = (to_remove - added, added - to_remove) if removed: removed |= to_remove else: setdata.removed = to_remove # removed may be None cache.modified_collections[attr].add(obj) cache.modified = True def __delete__(attr, obj): throw(NotImplementedError) def reverse_add(attr, objects, item, undo_funcs): undo = [] cache = item._session_cache_ objects_with_modified_collections = cache.modified_collections[attr] for obj in objects: setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() else: assert item not in setdata if setdata.added is None: setdata.added = set() else: assert item not in setdata.added in_removed = setdata.removed and item in setdata.removed was_modified_earlier = obj in objects_with_modified_collections undo.append((obj, in_removed, was_modified_earlier)) setdata.add(item) if setdata.count is not None: setdata.count += 1 if in_removed: setdata.removed.remove(item) else: setdata.added.add(item) objects_with_modified_collections.add(obj) def undo_func(): for obj, in_removed, was_modified_earlier in undo: setdata = obj._vals_[attr] setdata.remove(item) if setdata.count is not None: setdata.count -= 1 if in_removed: setdata.removed.add(item) else: setdata.added.remove(item) if not was_modified_earlier: objects_with_modified_collections.remove(obj) undo_funcs.append(undo_func) def db_reverse_add(attr, objects, item): for obj in objects: setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.is_fully_loaded: throw(UnrepeatableReadError, 'Phantom object %s appeared in collection %s.%s' % (safe_repr(item), safe_repr(obj), attr.name)) setdata.add(item) def reverse_remove(attr, objects, item, undo_funcs): undo = [] cache = item._session_cache_ objects_with_modified_collections = cache.modified_collections[attr] for obj in objects: setdata = obj._vals_.get(attr) assert setdata is not None assert item in setdata if setdata.removed is None: setdata.removed = set() else: assert item not in setdata.removed in_added = setdata.added and item in setdata.added was_modified_earlier = obj in objects_with_modified_collections undo.append((obj, in_added, was_modified_earlier)) objects_with_modified_collections.add(obj) setdata.remove(item) if setdata.count is not None: setdata.count -= 1 if in_added: setdata.added.remove(item) else: setdata.removed.add(item) def undo_func(): for obj, in_removed, was_modified_earlier in undo: setdata = obj._vals_[attr] setdata.add(item) if setdata.count is not None: setdata.count += 1 if in_added: setdata.added.add(item) else: setdata.removed.remove(item) if not was_modified_earlier: objects_with_modified_collections.remove(obj) undo_funcs.append(undo_func) def db_reverse_remove(attr, objects, item): for obj in objects: setdata = obj._vals_[attr] setdata.remove(item) def get_m2m_columns(attr, is_reverse=False): reverse = attr.reverse entity = attr.entity pk_length = len(entity._get_pk_columns_()) provider = entity._database_.provider if attr.symmetric or entity is reverse.entity: if attr._columns_checked: if not attr.symmetric: return attr.columns if not is_reverse: return attr.columns return attr.reverse_columns if not attr.symmetric: assert not reverse._columns_checked if attr.columns: if len(attr.columns) != pk_length: throw(MappingError, 'Invalid number of columns for %s' % reverse) else: attr.columns = provider.get_default_m2m_column_names(entity) attr._columns_checked = True attr.converters = entity._pk_converters_ if attr.symmetric: if not attr.reverse_columns: attr.reverse_columns = [ column + '_2' for column in attr.columns ] elif len(attr.reverse_columns) != pk_length: throw(MappingError, "Invalid number of reverse columns for symmetric attribute %s" % attr) return attr.columns if not is_reverse else attr.reverse_columns else: if not reverse.columns: reverse.columns = [ column + '_2' for column in attr.columns ] reverse._columns_checked = True reverse.converters = entity._pk_converters_ return attr.columns if not is_reverse else reverse.columns if attr._columns_checked: return reverse.columns elif reverse.columns: if len(reverse.columns) != pk_length: throw(MappingError, 'Invalid number of columns for %s' % reverse) else: reverse.columns = provider.get_default_m2m_column_names(entity) reverse.converters = entity._pk_converters_ attr._columns_checked = True return reverse.columns def remove_m2m(attr, removed): assert removed entity = attr.entity database = entity._database_ cached_sql = attr.cached_remove_m2m_sql if cached_sql is None: reverse = attr.reverse where_list = [ 'WHERE' ] if attr.symmetric: columns = attr.columns + attr.reverse_columns converters = attr.converters + attr.converters else: columns = reverse.columns + attr.columns converters = reverse.converters + attr.converters for i, (column, converter) in enumerate(izip(columns, converters)): where_list.append([ converter.EQ, ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ]) from_ast = [ 'FROM', [ None, 'TABLE', attr.table ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) attr.cached_remove_m2m_sql = sql, adapter else: sql, adapter = cached_sql arguments_list = [ adapter(obj._get_raw_pkval_() + robj._get_raw_pkval_()) for obj, robj in removed ] database._exec_sql(sql, arguments_list) def add_m2m(attr, added): assert added entity = attr.entity database = entity._database_ cached_sql = attr.cached_add_m2m_sql if cached_sql is None: reverse = attr.reverse if attr.symmetric: columns = attr.columns + attr.reverse_columns converters = attr.converters + attr.converters else: columns = reverse.columns + attr.columns converters = reverse.converters + attr.converters params = [ [ 'PARAM', (i, None, None), converter ] for i, converter in enumerate(converters) ] sql_ast = [ 'INSERT', attr.table, columns, params ] sql, adapter = database._ast2sql(sql_ast) attr.cached_add_m2m_sql = sql, adapter else: sql, adapter = cached_sql arguments_list = [ adapter(obj._get_raw_pkval_() + robj._get_raw_pkval_()) for obj, robj in added ] database._exec_sql(sql, arguments_list) @cut_traceback @db_session(ddl=True) def drop_table(attr, with_all_data=False): if attr.reverse.is_collection: table_name = attr.table else: table_name = attr.entity._table_ attr.entity._database_._drop_tables([ table_name ], True, with_all_data) def unpickle_setwrapper(obj, attrname, items): attr = getattr(obj.__class__, attrname) wrapper_cls = attr.py_type._get_set_wrapper_subclass_() wrapper = wrapper_cls(obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() setdata.is_fully_loaded = True setdata.absent = None setdata.count = len(setdata) return wrapper class SetIterator(object): def __init__(self, wrapper): self._wrapper = wrapper self._query = None self._iter = None def __iter__(self): return self def next(self): if self._iter is None: self._iter = iter(self._wrapper.copy()) return next(self._iter) __next__ = next def _get_query(self): if self._query is None: self._query = self._wrapper.select() return self._query def _get_type_(self): return QueryType(self._get_query()) def _normalize_var(self, query_type): return query_type, self._get_query() class SetInstance(object): __slots__ = '_obj_', '_attr_', '_attrnames_' _parent_ = None def __init__(wrapper, obj, attr): wrapper._obj_ = obj wrapper._attr_ = attr wrapper._attrnames_ = (attr.name,) def __reduce__(wrapper): return unpickle_setwrapper, (wrapper._obj_, wrapper._attr_.name, wrapper.copy()) @cut_traceback def copy(wrapper): return wrapper._attr_.copy(wrapper._obj_) @cut_traceback def __repr__(wrapper): return '<%s %r.%s>' % (wrapper.__class__.__name__, wrapper._obj_, wrapper._attr_.name) @cut_traceback def __str__(wrapper): cache = wrapper._obj_._session_cache_ if cache is None or not cache.is_alive: content = '...' else: content = ', '.join(imap(str, wrapper)) return '%s([%s])' % (wrapper.__class__.__name__, content) @cut_traceback def __nonzero__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = attr.load(obj) if setdata: return True if not setdata.is_fully_loaded: setdata = attr.load(obj) return bool(setdata) @cut_traceback def is_empty(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.is_fully_loaded: return not setdata elif setdata: return False elif setdata.count is not None: return not setdata.count entity = attr.entity reverse = attr.reverse rentity = reverse.entity database = entity._database_ cached_sql = attr.cached_empty_sql if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = rentity._table_ select_list, attr_offsets = rentity._construct_select_clause_() else: table_name = attr.table select_list = [ 'ALL' ] + [ [ 'COLUMN', None, column ] for column in attr.columns ] attr_offsets = None sql_ast = [ 'SELECT', select_list, [ 'FROM', [ None, 'TABLE', table_name ] ], where_list, [ 'LIMIT', 1 ] ] sql, adapter = database._ast2sql(sql_ast) attr.cached_empty_sql = sql, adapter, attr_offsets else: sql, adapter, attr_offsets = cached_sql arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) if reverse.is_collection: row = cursor.fetchone() if row is not None: loaded_item = rentity._get_by_raw_pkval_(row) setdata.add(loaded_item) reverse.db_reverse_add((loaded_item,), obj) else: rentity._fetch_objects(cursor, attr_offsets) if setdata: return False setdata.is_fully_loaded = True setdata.absent = None setdata.count = 0 return True @cut_traceback def __len__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) return len(setdata) @cut_traceback def count(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ cache = obj._session_cache_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.count is not None: return setdata.count if cache is None or not cache.is_alive: throw_db_session_is_over('read value of', obj, attr) entity = attr.entity reverse = attr.reverse database = entity._database_ cached_sql = attr.cached_count_sql if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = reverse.entity._table_ else: table_name = attr.table sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ None, 'TABLE', table_name ] ], where_list ] sql, adapter = database._ast2sql(sql_ast) attr.cached_count_sql = sql, adapter else: sql, adapter = cached_sql arguments = adapter(obj._get_raw_pkval_()) with cache.flush_disabled(): cursor = database._exec_sql(sql, arguments) setdata.count = cursor.fetchone()[0] if setdata.added: setdata.count += len(setdata.added) if setdata.removed: setdata.count -= len(setdata.removed) return setdata.count @cut_traceback def __iter__(wrapper): return SetIterator(wrapper) @cut_traceback def __eq__(wrapper, other): if isinstance(other, SetInstance): if wrapper._obj_ is other._obj_ and wrapper._attr_ is other._attr_: return True else: other = other.copy() elif not isinstance(other, set): other = set(other) items = wrapper.copy() return items == other @cut_traceback def __ne__(wrapper, other): return not wrapper.__eq__(other) @cut_traceback def __add__(wrapper, new_items): return wrapper.copy().union(new_items) @cut_traceback def __sub__(wrapper, items): return wrapper.copy().difference(items) @cut_traceback def __contains__(wrapper, item): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) if not isinstance(item, attr.py_type): return False if item._session_cache_ is not obj._session_cache_: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') reverse = attr.reverse if not reverse.is_collection: obj2 = item._vals_[reverse] if reverse in item._vals_ else reverse.load(item) wbits = item._wbits_ if wbits is not None: bit = item._bits_except_volatile_[reverse] if not wbits & bit: item._rbits_ |= bit return obj is obj2 setdata = obj._vals_.get(attr) if setdata is not None: if item in setdata: return True if setdata.is_fully_loaded: return False if setdata.absent is not None and item in setdata.absent: return False else: reverse_setdata = item._vals_.get(reverse) if reverse_setdata is not None and reverse_setdata.is_fully_loaded: return obj in reverse_setdata setdata = attr.load(obj, (item,)) if item in setdata: return True if setdata.absent is None: setdata.absent = set() setdata.absent.add(item) return False @cut_traceback def create(wrapper, **kwargs): attr = wrapper._attr_ reverse = attr.reverse if reverse.name in kwargs: throw(TypeError, 'When using %s.%s.create(), %r attribute should not be passed explicitly' % (attr.entity.__name__, attr.name, reverse.name)) kwargs[reverse.name] = wrapper._obj_ item_type = attr.py_type item = item_type(**kwargs) return item @cut_traceback def add(wrapper, new_items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse if not reverse: throw(NotImplementedError) new_items = attr.validate(new_items, obj) if not new_items: return setdata = obj._vals_.get(attr) if setdata is not None: new_items -= setdata if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj, new_items) new_items -= setdata undo_funcs = [] try: if not reverse.is_collection: for item in new_items: reverse.__set__(item, obj, undo_funcs) else: reverse.reverse_add(new_items, obj, undo_funcs) except: for undo_func in reversed(undo_funcs): undo_func() raise setdata |= new_items if setdata.count is not None: setdata.count += len(new_items) added = setdata.added removed = setdata.removed if removed: (new_items, setdata.removed) = (new_items-removed, removed-new_items) if added: added |= new_items else: setdata.added = new_items # added may be None cache.modified_collections[attr].add(obj) cache.modified = True @cut_traceback def __iadd__(wrapper, items): wrapper.add(items) return wrapper @cut_traceback def remove(wrapper, items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse if not reverse: throw(NotImplementedError) items = attr.validate(items, obj) setdata = obj._vals_.get(attr) if setdata is not None and setdata.removed: items -= setdata.removed if not items: return if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj, items) items &= setdata undo_funcs = [] try: if not reverse.is_collection: if attr.cascade_delete: for item in items: item._delete_(undo_funcs) else: for item in items: reverse.__set__(item, None, undo_funcs) else: reverse.reverse_remove(items, obj, undo_funcs) except: for undo_func in reversed(undo_funcs): undo_func() raise setdata -= items if setdata.count is not None: setdata.count -= len(items) added = setdata.added removed = setdata.removed if added: (items, setdata.added) = (items - added, added - items) if removed: removed |= items else: setdata.removed = items # removed may be None cache.modified_collections[attr].add(obj) cache.modified = True @cut_traceback def __isub__(wrapper, items): wrapper.remove(items) return wrapper @cut_traceback def clear(wrapper): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ if cache is None or not obj._session_cache_.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) attr.__set__(obj, ()) @cut_traceback def load(wrapper): wrapper._attr_.load(wrapper._obj_) @cut_traceback def select(wrapper, *args): obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) attr = wrapper._attr_ reverse = attr.reverse query = reverse.entity._select_all() s = 'lambda item: JOIN(obj in item.%s)' if reverse.is_collection else 'lambda item: item.%s == obj' query = query.filter(s % reverse.name, {'obj' : obj, 'JOIN': JOIN}) if args: func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+1) query = query.filter(func, globals, locals) return query filter = select def limit(wrapper, limit=None, offset=None): return wrapper.select().limit(limit, offset) def page(wrapper, pagenum, pagesize=10): return wrapper.select().page(pagenum, pagesize) def order_by(wrapper, *args): return wrapper.select().order_by(*args) def sort_by(wrapper, *args): return wrapper.select().sort_by(*args) def random(wrapper, limit): return wrapper.select().random(limit) def unpickle_multiset(obj, attrnames, items): entity = obj.__class__ for name in attrnames: attr = entity._adict_[name] if attr.reverse: entity = attr.py_type else: entity = None break if entity is None: multiset_cls = Multiset else: multiset_cls = entity._get_multiset_subclass_() return multiset_cls(obj, attrnames, items) class Multiset(object): __slots__ = [ '_obj_', '_attrnames_', '_items_' ] @cut_traceback def __init__(multiset, obj, attrnames, items): multiset._obj_ = obj multiset._attrnames_ = attrnames if type(items) is dict: multiset._items_ = items else: multiset._items_ = utils.distinct(items) def __reduce__(multiset): return unpickle_multiset, (multiset._obj_, multiset._attrnames_, multiset._items_) @cut_traceback def distinct(multiset): return multiset._items_.copy() @cut_traceback def __repr__(multiset): cache = multiset._obj_._session_cache_ if cache is not None and cache.is_alive: size = builtins.sum(itervalues(multiset._items_)) if size == 1: size_str = ' (1 item)' else: size_str = ' (%d items)' % size else: size_str = '' return '<%s %r.%s%s>' % (multiset.__class__.__name__, multiset._obj_, '.'.join(multiset._attrnames_), size_str) @cut_traceback def __str__(multiset): items_str = '{%s}' % ', '.join('%r: %r' % pair for pair in sorted(iteritems(multiset._items_))) return '%s(%s)' % (multiset.__class__.__name__, items_str) @cut_traceback def __nonzero__(multiset): return bool(multiset._items_) @cut_traceback def __len__(multiset): return builtins.sum(multiset._items_.values()) @cut_traceback def __iter__(multiset): for item, cnt in iteritems(multiset._items_): for i in xrange(cnt): yield item @cut_traceback def __eq__(multiset, other): if isinstance(other, Multiset): return multiset._items_ == other._items_ if isinstance(other, dict): return multiset._items_ == other if hasattr(other, 'keys'): return multiset._items_ == dict(other) return multiset._items_ == utils.distinct(other) @cut_traceback def __ne__(multiset, other): return not multiset.__eq__(other) @cut_traceback def __contains__(multiset, item): return item in multiset._items_ ##class List(Collection): pass ##class Dict(Collection): pass ##class Relation(Collection): pass class EntityIter(object): def __init__(self, entity): self.entity = entity def next(self): throw(TypeError, 'Use select(...) function or %s.select(...) method for iteration' % self.entity.__name__) if not PY2: __next__ = next entity_id_counter = itertools.count(1) new_instance_id_counter = itertools.count(1) select_re = re.compile(r'select\b', re.IGNORECASE) lambda_re = re.compile(r'lambda\b') class EntityMeta(type): def __new__(meta, name, bases, cls_dict): if 'Entity' in globals(): if '__slots__' in cls_dict: throw(TypeError, 'Entity classes cannot contain __slots__ variable') cls_dict['__slots__'] = () return super(EntityMeta, meta).__new__(meta, name, bases, cls_dict) @cut_traceback def __init__(entity, name, bases, cls_dict): super(EntityMeta, entity).__init__(name, bases, cls_dict) entity._database_ = None if name == 'Entity': return if not entity.__name__[:1].isupper(): throw(ERDiagramError, 'Entity class name should start with a capital letter. Got: %s' % entity.__name__) databases = set() for base_class in bases: if isinstance(base_class, EntityMeta): database = base_class._database_ if database is None: throw(ERDiagramError, 'Base Entity does not belong to any database') databases.add(database) if not databases: assert False # pragma: no cover elif len(databases) > 1: throw(ERDiagramError, 'With multiple inheritance of entities, all entities must belong to the same database') database = databases.pop() if entity.__name__ in database.entities: throw(ERDiagramError, 'Entity %s already exists' % entity.__name__) assert entity.__name__ not in database.__dict__ if database.schema is not None: throw(ERDiagramError, 'Cannot define entity %r: database mapping has already been generated' % entity.__name__) entity._database_ = database entity._id_ = next(entity_id_counter) direct_bases = [ c for c in entity.__bases__ if issubclass(c, Entity) and c.__name__ != 'Entity' ] entity._direct_bases_ = direct_bases all_bases = entity._all_bases_ = set() entity._subclasses_ = set() for base in direct_bases: all_bases.update(base._all_bases_) all_bases.add(base) for base in all_bases: base._subclasses_.add(entity) if direct_bases: root = entity._root_ = direct_bases[0]._root_ for base in direct_bases[1:]: if base._root_ is not root: throw(ERDiagramError, 'Multiple inheritance graph must be diamond-like. ' "Entity %s inherits from %s and %s entities which don't have common base class." % (name, root.__name__, base._root_.__name__)) if root._discriminator_attr_ is None: assert root._discriminator_ is None Discriminator.create_default_attr(root) else: entity._root_ = entity entity._discriminator_attr_ = None base_attrs = [] base_attrs_dict = {} for base in direct_bases: for a in base._attrs_: prev = base_attrs_dict.get(a.name) if prev is None: base_attrs_dict[a.name] = a base_attrs.append(a) elif prev is not a: throw(ERDiagramError, 'Attribute "%s" clashes with attribute "%s" in derived entity "%s"' % (prev, a, entity.__name__)) entity._base_attrs_ = base_attrs new_attrs = [] for name, attr in items_list(entity.__dict__): if name in base_attrs_dict: throw(ERDiagramError, "Name '%s' hides base attribute %s" % (name,base_attrs_dict[name])) if not isinstance(attr, Attribute): continue if name.startswith('_') and name.endswith('_'): throw(ERDiagramError, 'Attribute name cannot both start and end with underscore. Got: %s' % name) if attr.entity is not None: throw(ERDiagramError, 'Duplicate use of attribute %s in entity %s' % (attr, entity.__name__)) attr._init_(entity, name) new_attrs.append(attr) new_attrs.sort(key=attrgetter('id')) indexes = entity._indexes_ = entity.__dict__.get('_indexes_', []) for attr in new_attrs: if attr.is_unique: indexes.append(Index(attr, is_pk=isinstance(attr, PrimaryKey))) for index in indexes: index._init_(entity) primary_keys = {index.attrs for index in indexes if index.is_pk} if direct_bases: if primary_keys: throw(ERDiagramError, 'Primary key cannot be redefined in derived classes') base_indexes = [] for base in direct_bases: for index in base._indexes_: if index not in base_indexes and index not in indexes: base_indexes.append(index) indexes[:0] = base_indexes primary_keys = {index.attrs for index in indexes if index.is_pk} if len(primary_keys) > 1: throw(ERDiagramError, 'Only one primary key can be defined in each entity class') elif not primary_keys: if hasattr(entity, 'id'): throw(ERDiagramError, "Cannot create default primary key attribute for %s because name 'id' is already in use." " Please create a PrimaryKey attribute for entity %s or rename the 'id' attribute" % (entity.__name__, entity.__name__)) attr = PrimaryKey(int, auto=True) attr.is_implicit = True attr._init_(entity, 'id') entity.id = attr new_attrs.insert(0, attr) pk_attrs = (attr,) index = Index(attr, is_pk=True) indexes.insert(0, index) index._init_(entity) else: pk_attrs = primary_keys.pop() for i, attr in enumerate(pk_attrs): attr.pk_offset = i entity._pk_columns_ = None entity._pk_attrs_ = pk_attrs entity._pk_is_composite_ = len(pk_attrs) > 1 entity._pk_ = pk_attrs if len(pk_attrs) > 1 else pk_attrs[0] entity._keys_ = [ index.attrs for index in indexes if index.is_unique and not index.is_pk ] entity._simple_keys_ = [ key[0] for key in entity._keys_ if len(key) == 1 ] entity._composite_keys_ = [ key for key in entity._keys_ if len(key) > 1 ] entity._new_attrs_ = new_attrs entity._attrs_ = base_attrs + new_attrs entity._adict_ = {attr.name: attr for attr in entity._attrs_} entity._subclass_attrs_ = [] entity._subclass_adict_ = {} for base in entity._all_bases_: for attr in new_attrs: if attr.is_collection: continue prev = base._subclass_adict_.setdefault(attr.name, attr) if prev is not attr: throw(ERDiagramError, 'Attribute %s conflicts with attribute %s because both entities inherit from %s. ' 'To fix this, move attribute definition to base class' % (attr, prev, entity._root_.__name__)) base._subclass_attrs_.append(attr) entity._attrnames_cache_ = {} try: table_name = entity.__dict__['_table_'] except KeyError: entity._table_ = None else: if not isinstance(table_name, basestring): if not isinstance(table_name, (list, tuple)): throw(TypeError, '%s._table_ property must be a string. Got: %r' % (entity.__name__, table_name)) for name_part in table_name: if not isinstance(name_part, basestring):throw(TypeError, 'Each part of table name must be a string. Got: %r' % name_part) entity._table_ = table_name = tuple(table_name) database.entities[entity.__name__] = entity setattr(database, entity.__name__, entity) entity._cached_max_id_sql_ = None entity._find_sql_cache_ = {} entity._load_sql_cache_ = {} entity._batchload_sql_cache_ = {} entity._insert_sql_cache_ = {} entity._update_sql_cache_ = {} entity._delete_sql_cache_ = {} entity._propagation_mixin_ = None entity._set_wrapper_subclass_ = None entity._multiset_subclass_ = None if '_discriminator_' not in entity.__dict__: entity._discriminator_ = None if entity._discriminator_ is not None and not entity._discriminator_attr_: Discriminator.create_default_attr(entity) if entity._discriminator_attr_: entity._discriminator_attr_.process_entity_inheritance(entity) iter_name = entity._default_iter_name_ = ( ''.join(letter for letter in entity.__name__ if letter.isupper()).lower() or entity.__name__ ) for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), ast.Name('.0'), []) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) entity._default_genexpr_ = inner_expr entity._access_rules_ = defaultdict(set) def _initialize_bits_(entity): entity._bits_ = {} entity._bits_except_volatile_ = {} offset_counter = itertools.count() all_bits = all_bits_except_volatile = 0 for attr in entity._attrs_: if attr.is_collection or attr.is_discriminator or attr.pk_offset is not None: bit = 0 elif not attr.columns: bit = 0 else: bit = 1 << next(offset_counter) all_bits |= bit entity._bits_[attr] = bit if attr.is_volatile: bit = 0 all_bits_except_volatile |= bit entity._bits_except_volatile_[attr] = bit entity._all_bits_ = all_bits entity._all_bits_except_volatile_ = all_bits_except_volatile def _resolve_attr_types_(entity): database = entity._database_ for attr in entity._new_attrs_: py_type = attr.py_type if isinstance(py_type, basestring): rentity = database.entities.get(py_type) if rentity is None: throw(ERDiagramError, 'Entity definition %s was not found' % py_type) attr.py_type = py_type = rentity elif isinstance(py_type, types.FunctionType): rentity = py_type() if not isinstance(rentity, EntityMeta): throw(TypeError, 'Invalid type of attribute %s: expected entity class, got %r' % (attr, rentity)) attr.py_type = py_type = rentity if isinstance(py_type, EntityMeta) and py_type.__name__ == 'Entity': throw(TypeError, 'Cannot link attribute %s to abstract Entity class. Use specific Entity subclass instead' % attr) def _link_reverse_attrs_(entity): database = entity._database_ for attr in entity._new_attrs_: py_type = attr.py_type if not isinstance(py_type, EntityMeta): continue entity2 = py_type if entity2._database_ is not database: throw(ERDiagramError, 'Interrelated entities must belong to same database. ' 'Entities %s and %s belongs to different databases' % (entity.__name__, entity2.__name__)) reverse = attr.reverse if isinstance(reverse, basestring): attr2 = getattr(entity2, reverse, None) if attr2 is None: throw(ERDiagramError, 'Reverse attribute %s.%s not found' % (entity2.__name__, reverse)) elif isinstance(reverse, Attribute): attr2 = reverse if attr2.entity is not entity2: throw(ERDiagramError, 'Incorrect reverse attribute %s used in %s' % (attr2, attr)) ### elif reverse is not None: throw(ERDiagramError, "Value of 'reverse' option must be string. Got: %r" % type(reverse)) else: candidates1 = [] candidates2 = [] for attr2 in entity2._new_attrs_: if attr2.py_type not in (entity, entity.__name__): continue reverse2 = attr2.reverse if reverse2 in (attr, attr.name): candidates1.append(attr2) elif not reverse2: if attr2 is attr: continue candidates2.append(attr2) msg = "Ambiguous reverse attribute for %s. Use the 'reverse' parameter for pointing to right attribute" if len(candidates1) > 1: throw(ERDiagramError, msg % attr) elif len(candidates1) == 1: attr2 = candidates1[0] elif len(candidates2) > 1: throw(ERDiagramError, msg % attr) elif len(candidates2) == 1: attr2 = candidates2[0] else: throw(ERDiagramError, 'Reverse attribute for %s not found' % attr) type2 = attr2.py_type if type2 != entity: throw(ERDiagramError, 'Inconsistent reverse attributes %s and %s' % (attr, attr2)) reverse2 = attr2.reverse if reverse2 not in (None, attr, attr.name): throw(ERDiagramError, 'Inconsistent reverse attributes %s and %s' % (attr, attr2)) if attr.is_required and attr2.is_required: throw(ERDiagramError, "At least one attribute of one-to-one relationship %s - %s must be optional" % (attr, attr2)) attr.reverse = attr2 attr2.reverse = attr attr.linked() attr2.linked() def _check_table_options_(entity): if entity._root_ is not entity: if '_table_options_' in entity.__dict__: throw(TypeError, 'Cannot redefine %s options in %s entity' % (entity._root_.__name__, entity.__name__)) elif not hasattr(entity, '_table_options_'): entity._table_options_ = {} def _get_pk_columns_(entity): if entity._pk_columns_ is not None: return entity._pk_columns_ pk_columns = [] pk_converters = [] pk_paths = [] for attr in entity._pk_attrs_: attr_columns = attr.get_columns() attr_col_paths = attr.col_paths attr.pk_columns_offset = len(pk_columns) pk_columns.extend(attr_columns) pk_converters.extend(attr.converters) pk_paths.extend(attr_col_paths) entity._pk_columns_ = pk_columns entity._pk_converters_ = pk_converters entity._pk_nones_ = (None,) * len(pk_columns) entity._pk_paths_ = pk_paths return pk_columns def __iter__(entity): return EntityIter(entity) @cut_traceback def __getitem__(entity, key): if type(key) is not tuple: key = (key,) if len(key) == len(entity._pk_attrs_): kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)} return entity._find_one_(kwargs) if len(key) == len(entity._pk_columns_): return entity._get_by_raw_pkval_(key, from_db=False, seed=False) throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' % (entity.__name__, len(key), len(entity._pk_attrs_))) @cut_traceback def exists(entity, *args, **kwargs): if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).exists() try: obj = entity._find_one_(kwargs) except ObjectNotFound: return False except MultipleObjectsFoundError: return True return True @cut_traceback def get(entity, *args, **kwargs): if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).get() try: return entity._find_one_(kwargs) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_for_update(entity, *args, **kwargs): nowait = kwargs.pop('nowait', False) skip_locked = kwargs.pop('skip_locked', False) if nowait and skip_locked: throw(TypeError, 'nowait and skip_locked options are mutually exclusive') if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1) \ .for_update(nowait, skip_locked).get() try: return entity._find_one_(kwargs, True, nowait, skip_locked) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_by_sql(entity, sql, globals=None, locals=None): objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=cut_traceback_depth+1) # can throw MultipleObjectsFoundError if not objects: return None assert len(objects) == 1 return objects[0] @cut_traceback def select(entity, *args): return entity._query_from_args_(args, kwargs=None, frame_depth=cut_traceback_depth+1) @cut_traceback def select_by_sql(entity, sql, globals=None, locals=None): return entity._find_by_sql_(None, sql, globals, locals, frame_depth=cut_traceback_depth+1) @cut_traceback def select_random(entity, limit): if entity._pk_is_composite_: return entity.select().random(limit) pk = entity._pk_attrs_[0] if not issubclass(pk.py_type, int) or entity._discriminator_ is not None and entity._root_ is not entity: return entity.select().random(limit) database = entity._database_ cache = database._get_cache() if cache.modified: cache.flush() max_id = cache.max_id_cache.get(pk) if max_id is None: max_id_sql = entity._cached_max_id_sql_ if max_id_sql is None: sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', None, [ 'COLUMN', None, pk.column ] ] ], [ 'FROM', [ None, 'TABLE', entity._table_ ] ] ] max_id_sql, adapter = database._ast2sql(sql_ast) entity._cached_max_id_sql_ = max_id_sql cursor = database._exec_sql(max_id_sql) max_id = cursor.fetchone()[0] cache.max_id_cache[pk] = max_id if max_id is None: return [] if max_id <= limit * 2: return entity.select().random(limit) cache_index = cache.indexes[entity._pk_attrs_] result = [] tried_ids = set() found_in_cache = False for i in xrange(5): ids = [] n = (limit - len(result)) * (i+1) for j in xrange(n * 2): id = randint(1, max_id) if id in tried_ids: continue if id in ids: continue obj = cache_index.get(id) if obj is not None: found_in_cache = True tried_ids.add(id) result.append(obj) n -= 1 else: ids.append(id) if len(ids) >= n: break if len(result) >= limit: break if not ids: continue sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(ids), from_seeds=False) arguments = adapter([ (id,) for id in ids ]) cursor = database._exec_sql(sql, arguments) objects = entity._fetch_objects(cursor, attr_offsets) result.extend(objects) tried_ids.update(ids) if len(result) >= limit: break if len(result) < limit: return entity.select().random(limit) result = result[:limit] if entity._subclasses_: seeds = cache.seeds[entity._pk_attrs_] if seeds: for obj in result: if obj in seeds: obj._load_() if found_in_cache: shuffle(result) return result def _find_one_(entity, kwargs, for_update=False, nowait=False, skip_locked=False): if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) avdict = {} get_attr = entity._adict_.get for name, val in iteritems(kwargs): attr = get_attr(name) if attr is None: throw(TypeError, 'Unknown attribute %r' % name) avdict[attr] = attr.validate(val, None, entity, from_db=False) if entity._pk_is_composite_: pkval = tuple(imap(avdict.get, entity._pk_attrs_)) if None in pkval: pkval = None else: pkval = avdict.get(entity._pk_attrs_[0]) for attr in avdict: if attr.is_collection: throw(TypeError, 'Collection attribute %s cannot be specified as search criteria' % attr) obj, unique = entity._find_in_cache_(pkval, avdict, for_update) if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait, skip_locked) if obj is None: throw(ObjectNotFound, entity, pkval) return obj def _find_in_cache_(entity, pkval, avdict, for_update=False): cache = entity._database_._get_cache() cache_indexes = cache.indexes obj = None unique = False if pkval is not None: unique = True obj = cache_indexes[entity._pk_attrs_].get(pkval) if obj is None: for attr in entity._simple_keys_: val = avdict.get(attr) if val is not None: unique = True obj = cache_indexes[attr].get(val) if obj is not None: break if obj is None: for attrs in entity._composite_keys_: get_val = avdict.get vals = tuple(get_val(attr) for attr in attrs) if None in vals: continue unique = True cache_index = cache_indexes.get(attrs) if cache_index is None: continue obj = cache_index.get(vals) if obj is not None: break if obj is None: for attr, val in iteritems(avdict): if val is None: continue reverse = attr.reverse if reverse and not reverse.is_collection: obj = reverse.__get__(val) break if obj is not None: if obj._discriminator_ is not None: if obj._subclasses_: cls = obj.__class__ if not issubclass(entity, cls) and not issubclass(cls, entity): throw(ObjectNotFound, entity, pkval) seeds = cache.seeds[entity._pk_attrs_] if obj in seeds: obj._load_() if not isinstance(obj, entity): throw(ObjectNotFound, entity, pkval) if obj._status_ == 'marked_to_delete': throw(ObjectNotFound, entity, pkval) for attr, val in iteritems(avdict): if val != attr.__get__(obj): throw(ObjectNotFound, entity, pkval) if for_update and obj not in cache.for_update: return None, unique # object is found, but it is not locked entity._set_rbits((obj,), avdict) return obj, unique return None, unique def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False, skip_locked=False): database = entity._database_ query_attrs = {attr: value is None for attr, value in iteritems(avdict)} limit = 2 if not unique else None sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait, skip_locked) arguments = adapter(avdict) if for_update: database._get_cache().immediate = True cursor = database._exec_sql(sql, arguments) objects = entity._fetch_objects(cursor, attr_offsets, 1, for_update, avdict) return objects[0] if objects else None def _find_by_sql_(entity, max_fetch_count, sql, globals, locals, frame_depth): if not isinstance(sql, basestring): throw(TypeError) database = entity._database_ cursor = database._exec_raw_sql(sql, globals, locals, frame_depth+1) col_names = [ column_info[0].upper() for column_info in cursor.description ] attr_offsets = {} used_columns = set() for attr in chain(entity._attrs_with_columns_, entity._subclass_attrs_): offsets = [] for column in attr.columns: try: offset = col_names.index(column.upper()) except ValueError: break offsets.append(offset) used_columns.add(offset) else: attr_offsets[attr] = offsets if len(used_columns) < len(col_names): for i in xrange(len(col_names)): if i not in used_columns: throw(NameError, 'Column %s does not belong to entity %s' % (cursor.description[i][0], entity.__name__)) for attr in entity._pk_attrs_: if attr not in attr_offsets: throw(ValueError, 'Primary key attribue %s was not found in query result set' % attr) objects = entity._fetch_objects(cursor, attr_offsets, max_fetch_count) return objects def _construct_select_clause_(entity, alias=None, distinct=False, query_attrs=(), all_attributes=False): attr_offsets = {} select_list = [ 'DISTINCT' ] if distinct else [ 'ALL' ] root = entity._root_ pc = local.prefetch_context attrs_to_prefetch = pc.attrs_to_prefetch_dict.get(entity, ()) if pc else () for attr in chain(root._attrs_, root._subclass_attrs_): if not all_attributes and not issubclass(attr.entity, entity) \ and not issubclass(entity, attr.entity): continue if attr.is_collection: continue if not attr.columns: continue if not attr.lazy or attr in query_attrs or attr in attrs_to_prefetch: attr_offsets[attr] = offsets = [] for column in attr.columns: offsets.append(len(select_list) - 1) select_list.append([ 'COLUMN', alias, column ]) return select_list, attr_offsets def _construct_discriminator_criteria_(entity, alias=None): discr_attr = entity._discriminator_attr_ if discr_attr is None: return None discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in entity._subclasses_ ] discr_values.append([ 'VALUE', entity._discriminator_]) return [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True): pc = local.prefetch_context attrs_to_prefetch = pc.get_frozen_attrs_to_prefetch(entity) if pc is not None else () query_key = batch_size, attr, from_seeds, attrs_to_prefetch cached_sql = entity._batchload_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(all_attributes=True) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ]] if attr is None: columns = entity._pk_columns_ converters = entity._pk_converters_ else: columns = attr.columns converters = attr.converters row_value_syntax = entity._database_.provider.translator_cls.row_value_syntax criteria_list = construct_batchload_criteria_list( None, columns, converters, batch_size, row_value_syntax, from_seeds=from_seeds) sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] database = entity._database_ sql, adapter = database._ast2sql(sql_ast) cached_sql = sql, adapter, attr_offsets entity._batchload_sql_cache_[query_key] = cached_sql return cached_sql def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False, skip_locked=False): if nowait or skip_locked: assert for_update sorted_query_attrs = tuple(sorted(query_attrs.items())) query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait, skip_locked cached_sql = entity._find_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(query_attrs=query_attrs) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ]] where_list = [ 'WHERE' ] discr_attr = entity._discriminator_attr_ if discr_attr and query_attrs.get(discr_attr) != False: discr_criteria = entity._construct_discriminator_criteria_() if discr_criteria: where_list.append(discr_criteria) for attr, attr_is_none in sorted_query_attrs: if not attr.reverse: if attr_is_none: where_list.append([ 'IS_NULL', [ 'COLUMN', None, attr.column ] ]) else: if len(attr.converters) > 1: throw(NotImplementedError) converter = attr.converters[0] where_list.append([ converter.EQ, [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), converter ] ]) elif not attr.columns: throw(NotImplementedError) else: attr_entity = attr.py_type; assert attr_entity == attr.reverse.entity if attr_is_none: for column in attr.columns: where_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ]) else: for j, (column, converter) in enumerate(izip(attr.columns, attr_entity._pk_converters_)): where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ] else: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked, select_list, from_list, where_list ] if order_by_pk: sql_ast.append([ 'ORDER_BY' ] + [ [ 'COLUMN', None, column ] for column in entity._pk_columns_ ]) if limit is not None: sql_ast.append([ 'LIMIT', limit ]) database = entity._database_ sql, adapter = database._ast2sql(sql_ast) cached_sql = sql, adapter, attr_offsets entity._find_sql_cache_[query_key] = cached_sql return cached_sql def _fetch_objects(entity, cursor, attr_offsets, max_fetch_count=None, for_update=False, used_attrs=()): if max_fetch_count is None: max_fetch_count = options.MAX_FETCH_COUNT if max_fetch_count is not None: rows = cursor.fetchmany(max_fetch_count + 1) if len(rows) == max_fetch_count + 1: if max_fetch_count == 1: throw(MultipleObjectsFoundError, 'Multiple objects were found. Use %s.select(...) to retrieve them' % entity.__name__) throw(TooManyObjectsFoundError, 'Found more then pony.options.MAX_FETCH_COUNT=%d objects' % options.MAX_FETCH_COUNT) else: rows = cursor.fetchall() objects = [] if attr_offsets is None: objects = [ entity._get_by_raw_pkval_(row, for_update) for row in rows ] entity._load_many_(objects) else: for row in rows: real_entity_subclass, pkval, avdict = entity._parse_row_(row, attr_offsets) obj = real_entity_subclass._get_from_identity_map_(pkval, 'loaded', for_update) if obj._status_ in del_statuses: continue obj._db_set_(avdict) objects.append(obj) if used_attrs: entity._set_rbits(objects, used_attrs) return objects def _set_rbits(entity, objects, attrs): rbits_dict = {} get_rbits = rbits_dict.get for obj in objects: wbits = obj._wbits_ if wbits is None: continue rbits = get_rbits(obj.__class__) if rbits is None: rbits = sum(obj._bits_except_volatile_.get(attr, 0) for attr in attrs) rbits_dict[obj.__class__] = rbits obj._rbits_ |= rbits & ~wbits def _parse_row_(entity, row, attr_offsets): discr_attr = entity._discriminator_attr_ if not discr_attr: discr_value = None real_entity_subclass = entity else: discr_offset = attr_offsets[discr_attr][0] discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True) real_entity_subclass = discr_attr.code2cls[discr_value] discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x database = entity._database_ cache = local.db2cache[database] avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) if offsets is None: continue if attr.is_discriminator: avdict[attr] = discr_value else: avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) pkval = tuple(avdict.pop(attr) for attr in entity._pk_attrs_) assert None not in pkval if not entity._pk_is_composite_: pkval = pkval[0] return real_entity_subclass, pkval, avdict def _load_many_(entity, objects): database = entity._database_ cache = database._get_cache() seeds = cache.seeds[entity._pk_attrs_] if not seeds: return objects = {obj for obj in objects if obj in seeds} objects = sorted(objects, key=attrgetter('_pkval_')) max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) while objects: batch = objects[:max_batch_size] objects = objects[max_batch_size:] sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(batch)) arguments = adapter(batch) cursor = database._exec_sql(sql, arguments) result = entity._fetch_objects(cursor, attr_offsets) if len(result) < len(batch): for obj in result: if obj not in batch: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) def _select_all(entity): return Query(entity._default_iter_name_, entity._default_genexpr_, {}, { '.0' : entity }) def _query_from_args_(entity, args, kwargs, frame_depth): if not args and not kwargs: return entity._select_all() func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth+1) if type(func) is types.FunctionType: names = get_lambda_args(func) code_key = id(func.func_code if PY2 else func.__code__) cond_expr, external_names, cells = decompile(func) elif isinstance(func, basestring): code_key = func lambda_ast = string2ast(func) if not isinstance(lambda_ast, ast.Lambda): throw(TypeError, 'Lambda function is expected. Got: %s' % func) names = get_lambda_args(lambda_ast) cond_expr = lambda_ast.code cells = None else: assert False # pragma: no cover if len(names) != 1: throw(TypeError, 'Lambda query requires exactly one parameter name, like %s.select(lambda %s: ...). ' 'Got: %d parameters' % (entity.__name__, entity.__name__[0].lower(), len(names))) name = names[0] if_expr = ast.GenExprIf(cond_expr) for_expr = ast.GenExprFor(ast.AssName(name, 'OP_ASSIGN'), ast.Name('.0'), [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(name), [ for_expr ]) locals = locals.copy() if locals is not None else {} locals['.0'] = entity return Query(code_key, inner_expr, globals, locals, cells) def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=None, obj_to_init=None): cache = entity._database_._get_cache() pk_attrs = entity._pk_attrs_ cache_index = cache.indexes[pk_attrs] if pkval is None: obj = None else: obj = cache_index.get(pkval) if obj is None: pass elif status == 'created': if entity._pk_is_composite_: pkval = ', '.join(str(item) for item in pkval) throw(CacheIndexError, 'Cannot create %s: instance with primary key %s already exists' % (obj.__class__.__name__, pkval)) elif obj.__class__ is entity: pass elif issubclass(obj.__class__, entity): pass elif not issubclass(entity, obj.__class__): throw(TransactionError, 'Unexpected class change from %s to %s for object with primary key %r' % (obj.__class__, entity, obj._pkval_)) elif obj._rbits_ or obj._wbits_: throw(NotImplementedError) else: obj.__class__ = entity if obj is None: with cache.flush_disabled(): obj = obj_to_init if obj_to_init is None: obj = object.__new__(entity) cache.objects.add(obj) obj._pkval_ = pkval obj._status_ = status obj._vals_ = {} obj._dbvals_ = {} obj._save_pos_ = None obj._session_cache_ = cache if pkval is not None: cache_index[pkval] = obj obj._newid_ = None else: obj._newid_ = next(new_instance_id_counter) if obj._pk_is_composite_: pairs = izip(pk_attrs, pkval) else: pairs = ((pk_attrs[0], pkval),) if status == 'loaded': assert undo_funcs is None obj._rbits_ = obj._wbits_ = 0 for attr, val in pairs: obj._vals_[attr] = val if attr.reverse: attr.db_update_reverse(obj, NOT_LOADED, val) cache.seeds[pk_attrs].add(obj) elif status == 'created': assert undo_funcs is not None obj._rbits_ = obj._wbits_ = None for attr, val in pairs: obj._vals_[attr] = val if attr.reverse: attr.update_reverse(obj, NOT_LOADED, val, undo_funcs) cache.for_update.add(obj) else: assert False # pragma: no cover if for_update: assert cache.in_transaction cache.for_update.add(obj) return obj def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True, seed=True): i = 0 pkval = [] for attr in entity._pk_attrs_: if attr.column is not None: val = raw_pkval[i] i += 1 if not attr.reverse: val = attr.validate(val, None, entity, from_db=from_db) else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db, seed=seed) else: if not attr.reverse: throw(NotImplementedError) vals = raw_pkval[i:i+len(attr.columns)] val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db, seed=seed) i += len(attr.columns) pkval.append(val) if not entity._pk_is_composite_: pkval = pkval[0] else: pkval = tuple(pkval) if seed: obj = entity._get_from_identity_map_(pkval, 'loaded', for_update) else: obj = entity[pkval] assert obj._status_ != 'cancelled' return obj def _get_propagation_mixin_(entity): mixin = entity._propagation_mixin_ if mixin is not None: return mixin cls_dict = { '_entity_' : entity } for attr in entity._attrs_: if not attr.reverse: def fget(wrapper, attr=attr): attrnames = wrapper._attrnames_ + (attr.name,) items = [ x for x in (attr.__get__(item) for item in wrapper) if x is not None ] if attr.py_type is Json: return [ item.get_untracked() if isinstance(item, TrackedValue) else item for item in items ] return Multiset(wrapper._obj_, attrnames, items) elif not attr.is_collection: def fget(wrapper, attr=attr): attrnames = wrapper._attrnames_ + (attr.name,) items = [ x for x in (attr.__get__(item) for item in wrapper) if x is not None ] rentity = attr.py_type cls = rentity._get_multiset_subclass_() return cls(wrapper._obj_, attrnames, items) else: def fget(wrapper, attr=attr): cache = attr.entity._database_._get_cache() cache.collection_statistics.setdefault(attr, attr.nplus1_threshold) attrnames = wrapper._attrnames_ + (attr.name,) items = [ subitem for item in wrapper for subitem in attr.__get__(item) ] rentity = attr.py_type cls = rentity._get_multiset_subclass_() return cls(wrapper._obj_, attrnames, items) cls_dict[attr.name] = property(fget) result_cls_name = entity.__name__ + 'SetMixin' result_cls = type(result_cls_name, (object,), cls_dict) entity._propagation_mixin_ = result_cls return result_cls def _get_multiset_subclass_(entity): result_cls = entity._multiset_subclass_ if result_cls is None: mixin = entity._get_propagation_mixin_() cls_name = entity.__name__ + 'Multiset' result_cls = type(cls_name, (Multiset, mixin), {}) entity._multiset_subclass_ = result_cls return result_cls def _get_set_wrapper_subclass_(entity): result_cls = entity._set_wrapper_subclass_ if result_cls is None: mixin = entity._get_propagation_mixin_() cls_name = entity.__name__ + 'Set' result_cls = type(cls_name, (SetInstance, mixin), {}) entity._set_wrapper_subclass_ = result_cls return result_cls @cut_traceback def describe(entity): result = [] parents = ','.join(cls.__name__ for cls in entity.__bases__) result.append('class %s(%s):' % (entity.__name__, parents)) if entity._base_attrs_: result.append('# inherited attrs') result.extend(attr.describe() for attr in entity._base_attrs_) result.append('# attrs introduced in %s' % entity.__name__) result.extend(attr.describe() for attr in entity._new_attrs_) if entity._pk_is_composite_: result.append('PrimaryKey(%s)' % ', '.join(attr.name for attr in entity._pk_attrs_)) return '\n '.join(result) @cut_traceback @db_session(ddl=True) def drop_table(entity, with_all_data=False): entity._database_._drop_tables([ entity._table_ ], True, with_all_data) def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_lazy=False): if only and not isinstance(only, basestring): only = tuple(only) if exclude and not isinstance(exclude, basestring): exclude = tuple(exclude) key = (only, exclude, with_collections, with_lazy) attrs = entity._attrnames_cache_.get(key) if not attrs: attrs = [] append = attrs.append if only: if isinstance(only, basestring): only = only.replace(',', ' ').split() get_attr = entity._adict_.get for attrname in only: attr = get_attr(attrname) if attr is None: throw(AttributeError, 'Entity %s does not have attriute %s' % (entity.__name__, attrname)) else: append(attr) else: for attr in entity._attrs_: if attr.is_collection: if with_collections: append(attr) elif attr.lazy: if with_lazy: append(attr) else: append(attr) if exclude: if isinstance(exclude, basestring): exclude = exclude.replace(',', ' ').split() for attrname in exclude: if attrname not in entity._adict_: throw(AttributeError, 'Entity %s does not have attriute %s' % (entity.__name__, attrname)) attrs = (attr for attr in attrs if attr.name not in exclude) attrs = tuple(attrs) entity._attrnames_cache_[key] = attrs return attrs def populate_criteria_list(criteria_list, columns, converters, operations, params_count=0, table_alias=None, optimistic=False): for column, op, converter in izip(columns, operations, converters): if op == 'IS_NULL': criteria_list.append([ op, [ 'COLUMN', None, column ] ]) else: criteria_list.append([ op, [ 'COLUMN', table_alias, column ], [ 'PARAM', (params_count, None, None), converter, optimistic ] ]) params_count += 1 return params_count statuses = {'created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted'} del_statuses = {'marked_to_delete', 'deleted', 'cancelled'} created_or_deleted_statuses = {'created'} | del_statuses saved_statuses = {'inserted', 'updated', 'deleted'} def throw_object_was_deleted(obj): assert obj._status_ in del_statuses throw(OperationWithDeletedObjectError, '%s was %s' % (safe_repr(obj), obj._status_.replace('_', ' '))) def unpickle_entity(d): entity = d.pop('__class__') cache = entity._database_._get_cache() if not entity._pk_is_composite_: pkval = d.get(entity._pk_attrs_[0].name) else: pkval = tuple(d[attr.name] for attr in entity._pk_attrs_) assert pkval is not None obj = entity._get_from_identity_map_(pkval, 'loaded') if obj._status_ in del_statuses: return obj avdict = {} for attrname, val in iteritems(d): attr = entity._adict_[attrname] if attr.pk_offset is not None: continue avdict[attr] = val obj._db_set_(avdict, unpickling=True) return obj def safe_repr(obj): return Entity.__repr__(obj) def make_proxy(obj): proxy = EntityProxy(obj) return proxy class EntityProxy(object): def __init__(self, obj): entity = obj.__class__ object.__setattr__(self, '_entity_', entity) pkval = obj.get_pk() if pkval is None: cache = obj._session_cache_ if obj._status_ in del_statuses or cache is None or not cache.is_alive: throw(ValueError, 'Cannot make a proxy for %s object: primary key is not specified' % entity.__name__) flush() pkval = obj.get_pk() assert pkval is not None object.__setattr__(self, '_obj_pk_', pkval) def __repr__(self): entity = self._entity_ pkval = self._obj_pk_ pkrepr = ','.join(repr(item) for item in pkval) if isinstance(pkval, tuple) else repr(pkval) return '' % (entity.__name__, pkrepr) def _get_object(self): entity = self._entity_ pkval = self._obj_pk_ cache = entity._database_._get_cache() attrs = entity._pk_attrs_ if attrs in cache.indexes and pkval in cache.indexes[attrs]: obj = cache.indexes[attrs][pkval] else: obj = entity[pkval] return obj def __getattr__(self, name): obj = self._get_object() return getattr(obj, name) def __setattr__(self, name, value): obj = self._get_object() setattr(obj, name, value) def __eq__(self, other): entity = self._entity_ pkval = self._obj_pk_ if isinstance(other, EntityProxy): entity2 = other._entity_ pkval2 = other._obj_pk_ return entity == entity2 and pkval == pkval2 elif isinstance(other, entity): return pkval == other._pkval_ return False def __ne__(self, other): return not self.__eq__(other) class Entity(with_metaclass(EntityMeta)): __slots__ = '_session_cache_', '_status_', '_pkval_', '_newid_', '_dbvals_', '_vals_', '_rbits_', '_wbits_', '_save_pos_', '__weakref__' def __reduce__(obj): if obj._status_ in del_statuses: throw( OperationWithDeletedObjectError, 'Deleted object %s cannot be pickled' % safe_repr(obj)) if obj._status_ in ('created', 'modified'): throw( OrmError, '%s object %s has to be stored in DB before it can be pickled' % (obj._status_.capitalize(), safe_repr(obj))) d = {'__class__' : obj.__class__} for attr, val in iteritems(obj._vals_): if not attr.is_collection: d[attr.name] = val return unpickle_entity, (d,) @cut_traceback def __init__(obj, *args, **kwargs): obj._status_ = None entity = obj.__class__ if args: raise TypeError('%s constructor accept only keyword arguments. Got: %d positional argument%s' % (entity.__name__, len(args), len(args) > 1 and 's' or '')) if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) avdict = {} for name in kwargs: if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) for attr in entity._attrs_: val = kwargs.get(attr.name, DEFAULT) avdict[attr] = attr.validate(val, obj, from_db=False) if entity._pk_is_composite_: pkval = tuple(imap(avdict.get, entity._pk_attrs_)) if None in pkval: pkval = None else: pkval = avdict.get(entity._pk_attrs_[0]) undo_funcs = [] cache = entity._database_._get_cache() cache_indexes = cache.indexes indexes_update = {} with cache.flush_disabled(): for attr in entity._simple_keys_: val = avdict[attr] if val is None: continue if val in cache_indexes[attr]: throw(CacheIndexError, 'Cannot create %s: value %r for key %s already exists' % (entity.__name__, val, attr.name)) indexes_update[attr] = val for attrs in entity._composite_keys_: vals = tuple(avdict[attr] for attr in attrs) if None in vals: continue if vals in cache_indexes[attrs]: attr_names = ', '.join(attr.name for attr in attrs) throw(CacheIndexError, 'Cannot create %s: value %s for composite key (%s) already exists' % (entity.__name__, vals, attr_names)) indexes_update[attrs] = vals try: entity._get_from_identity_map_(pkval, 'created', undo_funcs=undo_funcs, obj_to_init=obj) for attr, val in iteritems(avdict): if attr.pk_offset is not None: continue elif not attr.is_collection: obj._vals_[attr] = val if attr.reverse: attr.update_reverse(obj, None, val, undo_funcs) else: attr.__set__(obj, val, undo_funcs) except: for undo_func in reversed(undo_funcs): undo_func() raise if pkval is not None: cache_indexes[entity._pk_attrs_][pkval] = obj for key, vals in iteritems(indexes_update): cache_indexes[key][vals] = obj objects_to_save = cache.objects_to_save obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True @cut_traceback def get_pk(obj): pkval = obj._get_raw_pkval_() if len(pkval) == 1: return pkval[0] return pkval def _get_raw_pkval_(obj): pkval = obj._pkval_ if not obj._pk_is_composite_: if not obj._pk_attrs_[0].reverse: return (pkval,) else: return pkval._get_raw_pkval_() raw_pkval = [] append, extend = raw_pkval.append, raw_pkval.extend for attr, val in izip(obj._pk_attrs_, pkval): if not attr.reverse: append(val) else: extend(val._get_raw_pkval_()) return tuple(raw_pkval) @cut_traceback def __lt__(entity, other): return entity._cmp_(other) < 0 @cut_traceback def __le__(entity, other): return entity._cmp_(other) <= 0 @cut_traceback def __gt__(entity, other): return entity._cmp_(other) > 0 @cut_traceback def __ge__(entity, other): return entity._cmp_(other) >= 0 def _cmp_(entity, other): if entity is other: return 0 if isinstance(other, Entity): pkval = entity._pkval_ other_pkval = other._pkval_ if pkval is not None: if other_pkval is None: return -1 result = cmp(pkval, other_pkval) else: if other_pkval is not None: return 1 result = cmp(entity._newid_, other._newid_) if result: return result return cmp(id(entity), id(other)) @cut_traceback def __repr__(obj): pkval = obj._pkval_ if pkval is None: return '%s[new:%d]' % (obj.__class__.__name__, obj._newid_) if obj._pk_is_composite_: pkval = ','.join(imap(repr, pkval)) else: pkval = repr(pkval) return '%s[%s]' % (obj.__class__.__name__, pkval) @classmethod def _prefetch_load_all_(entity, objects): objects = sorted(objects, key=entity._get_raw_pkval_) database = entity._database_ cache = database._get_cache() if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) for i in xrange(0, len(objects), max_batch_size): batch = objects[i:i+max_batch_size] sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(batch)) arguments = adapter(batch) cursor = database._exec_sql(sql, arguments) entity._fetch_objects(cursor, attr_offsets) def _load_(obj): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): throw(TransactionError, "Object %s doesn't belong to current transaction" % safe_repr(obj)) seeds = cache.seeds[entity._pk_attrs_] max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) objects = [ obj ] for seed in seeds: if len(objects) >= max_batch_size: break if seed is not obj: objects.append(seed) sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(objects)) arguments = adapter(objects) cursor = database._exec_sql(sql, arguments) objects = entity._fetch_objects(cursor, attr_offsets) if obj not in objects: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) @cut_traceback def load(obj, *attrs): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): throw(TransactionError, "Object %s doesn't belong to current transaction" % safe_repr(obj)) if obj._status_ in created_or_deleted_statuses: return if not attrs: attrs = tuple(attr for attr, bit in iteritems(entity._bits_) if bit and attr not in obj._vals_) else: args = attrs attrs = set() for arg in args: if isinstance(arg, basestring): attr = entity._adict_.get(arg) if attr is None: if not is_ident(arg): throw(ValueError, 'Invalid attribute name: %r' % arg) throw(AttributeError, 'Object %s does not have attribute %r' % (obj, arg)) elif isinstance(arg, Attribute): attr = arg if not isinstance(obj, attr.entity): throw(AttributeError, 'Attribute %s does not belong to object %s' % (attr, obj)) else: throw(TypeError, 'Invalid argument type: %r' % arg) if attr.is_collection: throw(NotImplementedError, 'The load() method does not support collection attributes yet. Got: %s' % attr.name) if entity._bits_[attr] and attr not in obj._vals_: attrs.add(attr) attrs = tuple(sorted(attrs, key=attrgetter('id'))) sql_cache = entity._root_._load_sql_cache_ cached_sql = sql_cache.get(attrs) if cached_sql is None: if entity._discriminator_attr_ is not None: attrs = (entity._discriminator_attr_,) + attrs attrs = entity._pk_attrs_ + attrs attr_offsets = {} select_list = [ 'ALL' ] for attr in attrs: attr_offsets[attr] = offsets = [] for column in attr.columns: offsets.append(len(select_list) - 1) select_list.append([ 'COLUMN', None, column ]) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ]] criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(obj._pk_columns_, obj._pk_converters_)) ] where_list = [ 'WHERE' ] + criteria_list sql_ast = [ 'SELECT', select_list, from_list, where_list ] sql, adapter = database._ast2sql(sql_ast) cached_sql = sql, adapter, attr_offsets sql_cache[attrs] = cached_sql else: sql, adapter, attr_offsets = cached_sql arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) objects = entity._fetch_objects(cursor, attr_offsets) if obj not in objects: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) def _attr_changed_(obj, attr): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) status = obj._status_ wbits = obj._wbits_ bit = obj._bits_[attr] objects_to_save = cache.objects_to_save if wbits is not None and bit: obj._wbits_ |= bit if status != 'modified': assert status in ('loaded', 'inserted', 'updated') assert obj._save_pos_ is None obj._status_ = 'modified' obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True def _db_set_(obj, avdict, unpickling=False): assert obj._status_ not in created_or_deleted_statuses cache = obj._session_cache_ assert cache is not None and cache.is_alive cache.seeds[obj._pk_attrs_].discard(obj) if not avdict: return get_val = obj._vals_.get get_dbval = obj._dbvals_.get rbits = obj._rbits_ wbits = obj._wbits_ for attr, new_dbval in items_list(avdict): assert attr.pk_offset is None assert new_dbval is not NOT_LOADED old_dbval = get_dbval(attr, NOT_LOADED) if old_dbval is not NOT_LOADED: if unpickling or old_dbval == new_dbval or ( not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): del avdict[attr] continue if unpickling: new_vals = avdict new_dbvals = {attr: attr.converters[0].val2dbval(val, obj) if not attr.reverse else val for attr, val in iteritems(avdict)} else: new_dbvals = avdict new_vals = {attr: attr.converters[0].dbval2val(dbval, obj) if not attr.reverse else dbval for attr, dbval in iteritems(avdict)} for attr, new_val in items_list(new_vals): new_dbval = new_dbvals[attr] old_dbval = get_dbval(attr, NOT_LOADED) bit = obj._bits_except_volatile_[attr] if rbits & bit: errormsg = 'Please contact PonyORM developers so they can ' \ 'reproduce your error and fix a bug: support@ponyorm.org' assert old_dbval is not NOT_LOADED, errormsg throw(UnrepeatableReadError, 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) if attr.reverse: attr.db_update_reverse(obj, old_dbval, new_dbval) obj._dbvals_[attr] = new_dbval if wbits & bit: del new_vals[attr] for attr, new_val in iteritems(new_vals): if attr.is_unique: old_val = get_val(attr) if old_val != new_val: cache.db_update_simple_index(obj, attr, old_val, new_val) for attrs in obj._composite_keys_: if any(attr in new_vals for attr in attrs): key_vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! prev_key_vals = tuple(key_vals) for i, attr in enumerate(attrs): if attr in new_vals: key_vals[i] = new_vals[attr] new_key_vals = tuple(key_vals) if prev_key_vals != new_key_vals: cache.db_update_composite_index(obj, attrs, prev_key_vals, new_key_vals) obj._vals_.update(new_vals) def _delete_(obj, undo_funcs=None): status = obj._status_ if status in del_statuses: return is_recursive_call = undo_funcs is not None if not is_recursive_call: undo_funcs = [] cache = obj._session_cache_ assert cache is not None and cache.is_alive with cache.flush_disabled(): get_val = obj._vals_.get undo_list = [] objects_to_save = cache.objects_to_save save_pos = obj._save_pos_ def undo_func(): if obj._status_ == 'marked_to_delete': assert objects_to_save obj2 = objects_to_save.pop() assert obj2 is obj if save_pos is not None: assert objects_to_save[save_pos] is None objects_to_save[save_pos] = obj obj._save_pos_ = save_pos obj._status_ = status for cache_index, old_key in undo_list: cache_index[old_key] = obj undo_funcs.append(undo_func) try: for attr in obj._attrs_: if not attr.is_collection: continue if isinstance(attr, Set): set_wrapper = attr.__get__(obj) if not set_wrapper.__nonzero__(): pass elif attr.cascade_delete: for robj in set_wrapper: robj._delete_(undo_funcs) elif not attr.reverse.is_required: attr.__set__(obj, (), undo_funcs) else: throw(ConstraintError, "Cannot delete object %s, because it has non-empty set of %s, " "and 'cascade_delete' option of %s is not set" % (obj, attr.name, attr)) else: throw(NotImplementedError) for attr in obj._attrs_: if not attr.is_collection: reverse = attr.reverse if not reverse: continue if not reverse.is_collection: val = get_val(attr) if attr in obj._vals_ else attr.load(obj) if val is None: continue if attr.cascade_delete: val._delete_(undo_funcs) elif not reverse.is_required: reverse.__set__(val, None, undo_funcs) else: throw(ConstraintError, "Cannot delete object %s, because it has associated %s, " "and 'cascade_delete' option of %s is not set" % (obj, attr.name, attr)) elif isinstance(reverse, Set): if attr not in obj._vals_: continue val = get_val(attr) if val is None: continue reverse.reverse_remove((val,), obj, undo_funcs) else: throw(NotImplementedError) cache_indexes = cache.indexes for attr in obj._simple_keys_: val = get_val(attr) if val is None: continue cache_index = cache_indexes[attr] obj2 = cache_index.pop(val) assert obj2 is obj undo_list.append((cache_index, val)) for attrs in obj._composite_keys_: vals = tuple(get_val(attr) for attr in attrs) if None in vals: continue cache_index = cache_indexes[attrs] obj2 = cache_index.pop(vals) assert obj2 is obj undo_list.append((cache_index, vals)) if status == 'created': assert save_pos is not None objects_to_save[save_pos] = None obj._save_pos_ = None obj._status_ = 'cancelled' if obj._pkval_ is not None: pk_index = cache_indexes[obj._pk_attrs_] obj2 = pk_index.pop(obj._pkval_) assert obj2 is obj undo_list.append((pk_index, obj._pkval_)) else: if status == 'modified': assert save_pos is not None objects_to_save[save_pos] = None else: assert status in ('loaded', 'inserted', 'updated') assert save_pos is None obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) obj._status_ = 'marked_to_delete' cache.modified = True except: if not is_recursive_call: for undo_func in reversed(undo_funcs): undo_func() raise @cut_traceback def delete(obj): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('delete object', obj) obj._delete_() @cut_traceback def set(obj, **kwargs): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('change object', obj) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): avdict, collection_avdict = obj._keyargs_to_avdicts_(kwargs) status = obj._status_ wbits = obj._wbits_ get_val = obj._vals_.get objects_to_save = cache.objects_to_save if avdict: for attr in avdict: if attr not in obj._vals_ and attr.reverse and not attr.reverse.is_collection: attr.load(obj) # loading of one-to-one relations if wbits is not None: new_wbits = wbits for attr in avdict: new_wbits |= obj._bits_[attr] obj._wbits_ = new_wbits if status != 'modified': assert status in ('loaded', 'inserted', 'updated') assert obj._save_pos_ is None obj._status_ = 'modified' obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True if not collection_avdict: if not any(attr.reverse or attr.is_part_of_unique_index for attr in avdict): obj._vals_.update(avdict) return for attr, value in items_list(avdict): if value == get_val(attr): avdict.pop(attr) undo_funcs = [] undo = [] def undo_func(): obj._status_ = status obj._wbits_ = wbits if status in ('loaded', 'inserted', 'updated'): assert objects_to_save obj2 = objects_to_save.pop() assert obj2 is obj and obj._save_pos_ == len(objects_to_save) obj._save_pos_ = None for cache_index, old_key, new_key in undo: if new_key is not None: del cache_index[new_key] if old_key is not None: cache_index[old_key] = obj try: for attr in obj._simple_keys_: if attr not in avdict: continue new_val = avdict[attr] old_val = get_val(attr) cache.update_simple_index(obj, attr, old_val, new_val, undo) for attrs in obj._composite_keys_: if any(attr in avdict for attr in attrs): vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! prev_vals = tuple(vals) for i, attr in enumerate(attrs): if attr in avdict: vals[i] = avdict[attr] new_vals = tuple(vals) cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) for attr, new_val in iteritems(avdict): if not attr.reverse: continue old_val = get_val(attr) attr.update_reverse(obj, old_val, new_val, undo_funcs) for attr, new_val in iteritems(collection_avdict): attr.__set__(obj, new_val, undo_funcs) except: for undo_func in undo_funcs: undo_func() raise obj._vals_.update(avdict) def _keyargs_to_avdicts_(obj, kwargs): avdict, collection_avdict = {}, {} get_attr = obj._adict_.get for name, new_val in kwargs.items(): attr = get_attr(name) if attr is None: throw(TypeError, 'Unknown attribute %r' % name) new_val = attr.validate(new_val, obj, from_db=False) if attr.is_collection: collection_avdict[attr] = new_val elif attr.pk_offset is None: avdict[attr] = new_val elif obj._vals_.get(attr, new_val) != new_val: throw(TypeError, 'Cannot change value of primary key attribute %s' % attr.name) return avdict, collection_avdict @classmethod def _attrs_with_bit_(entity, attrs, mask=-1): get_bit = entity._bits_.get for attr in attrs: if get_bit(attr) & mask: yield attr def _construct_optimistic_criteria_(obj): optimistic_columns = [] optimistic_converters = [] optimistic_values = [] optimistic_operations = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): converters = attr.converters assert converters optimistic = attr.optimistic if attr.optimistic is not None else converters[0].optimistic if not optimistic: continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) optimistic_converters.extend(attr.converters) values = attr.get_raw_values(dbval) optimistic_values.extend(values) optimistic_operations.extend('IS_NULL' if dbval is None else converter.EQ for converter in converters) return optimistic_operations, optimistic_columns, optimistic_converters, optimistic_values def _save_principal_objects_(obj, dependent_objects): if dependent_objects is None: dependent_objects = [] elif obj in dependent_objects: chain = ' -> '.join(obj2.__class__.__name__ for obj2 in dependent_objects) throw(UnresolvableCyclicDependency, 'Cannot save cyclic chain: ' + chain) dependent_objects.append(obj) status = obj._status_ if status == 'created': attrs = obj._attrs_with_columns_ elif status == 'modified': attrs = obj._attrs_with_bit_(obj._attrs_with_columns_, obj._wbits_) else: assert False # pragma: no cover for attr in attrs: if not attr.reverse: continue val = obj._vals_[attr] if val is not None and val._status_ == 'created': val._save_(dependent_objects) def _update_dbvals_(obj, after_create, new_dbvals): bits = obj._bits_ vals = obj._vals_ dbvals = obj._dbvals_ cache_indexes = obj._session_cache_.indexes for attr in obj._attrs_with_columns_: if not bits.get(attr): continue if attr not in vals: continue val = vals[attr] if attr.is_volatile: if val is not None: if attr.is_unique: cache_indexes[attr].pop(val, None) get_val = vals.get for key, i in attr.composite_keys: keyval = tuple(get_val(attr) for attr in key) cache_indexes[key].pop(keyval, None) elif after_create and val is None: obj._rbits_ &= ~bits[attr] else: if attr in new_dbvals: dbvals[attr] = new_dbvals[attr] continue # Clear value of volatile attribute or null values after create, because the value may be changed in the DB del vals[attr] dbvals.pop(attr, None) def _save_created_(obj): auto_pk = (obj._pkval_ is None) attrs = [] values = [] new_dbvals = {} for attr in obj._attrs_with_columns_: if auto_pk and attr.is_pk: continue val = obj._vals_[attr] if val is not None: attrs.append(attr) if not attr.reverse: assert len(attr.converters) == 1 dbval = attr.converters[0].val2dbval(val, obj) new_dbvals[attr] = dbval values.append(dbval) else: new_dbvals[attr] = val values.extend(attr.get_raw_values(val)) attrs = tuple(attrs) database = obj._database_ cached_sql = obj._insert_sql_cache_.get(attrs) if cached_sql is None: columns = [] converters = [] for attr in attrs: columns.extend(attr.columns) converters.extend(attr.converters) assert len(columns) == len(converters) params = [ [ 'PARAM', (i, None, None), converter ] for i, converter in enumerate(converters) ] entity = obj.__class__ if not columns and database.provider.dialect == 'Oracle': sql_ast = [ 'INSERT', entity._table_, obj._pk_columns_, [ [ 'DEFAULT' ] for column in obj._pk_columns_ ] ] else: sql_ast = [ 'INSERT', entity._table_, columns, params ] if auto_pk: sql_ast.append(entity._pk_columns_[0]) sql, adapter = database._ast2sql(sql_ast) entity._insert_sql_cache_[attrs] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) try: if auto_pk: new_id = database._exec_sql(sql, arguments, returning_id=True, start_transaction=True) else: database._exec_sql(sql, arguments, start_transaction=True) except IntegrityError as e: msg = " ".join(tostring(arg) for arg in e.args) throw(TransactionIntegrityError, 'Object %r cannot be stored in the database. %s: %s' % (obj, e.__class__.__name__, msg), e) except DatabaseError as e: msg = " ".join(tostring(arg) for arg in e.args) throw(UnexpectedError, 'Object %r cannot be stored in the database. %s: %s' % (obj, e.__class__.__name__, msg), e) if auto_pk: pk_attrs = obj._pk_attrs_ cache_index = obj._session_cache_.indexes[pk_attrs] obj2 = cache_index.setdefault(new_id, obj) if obj2 is not obj: throw(TransactionIntegrityError, 'Newly auto-generated id value %s was already used in transaction cache for another object' % new_id) obj._pkval_ = obj._vals_[pk_attrs[0]] = new_id obj._newid_ = None obj._status_ = 'inserted' obj._rbits_ = obj._all_bits_except_volatile_ obj._wbits_ = 0 obj._update_dbvals_(True, new_dbvals) def _save_updated_(obj): update_columns = [] values = [] new_dbvals = {} for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._wbits_): update_columns.extend(attr.columns) val = obj._vals_[attr] if not attr.reverse: assert len(attr.converters) == 1 dbval = attr.converters[0].val2dbval(val, obj) new_dbvals[attr] = dbval values.append(dbval) else: new_dbvals[attr] = val values.extend(attr.get_raw_values(val)) if update_columns: for attr in obj._pk_attrs_: val = obj._vals_[attr] values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ optimistic_session = cache.db_session is None or cache.db_session.optimistic if optimistic_session and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) else: optimistic_columns = optimistic_converters = optimistic_ops = () query_key = tuple(update_columns), tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._update_sql_cache_.get(query_key) if cached_sql is None: update_converters = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._wbits_): update_converters.extend(attr.converters) assert len(update_columns) == len(update_converters) update_params = [ [ 'PARAM', (i, None, None), converter ] for i, converter in enumerate(update_converters) ] params_count = len(update_params) where_list = [ 'WHERE' ] pk_columns = obj._pk_columns_ pk_converters = obj._pk_converters_ params_count = populate_criteria_list(where_list, pk_columns, pk_converters, repeat('EQ'), params_count) if optimistic_columns: populate_criteria_list( where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) sql_ast = [ 'UPDATE', obj._table_, list(izip(update_columns, update_params)), where_list ] sql, adapter = database._ast2sql(sql_ast) obj._update_sql_cache_[query_key] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) if cursor.rowcount == 0 and cache.db_session.optimistic: throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ obj._wbits_ = 0 obj._update_dbvals_(False, new_dbvals) def _save_deleted_(obj): values = [] values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ optimistic_session = cache.db_session is None or cache.db_session.optimistic if optimistic_session and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) else: optimistic_columns = optimistic_converters = optimistic_ops = () query_key = tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._delete_sql_cache_.get(query_key) if cached_sql is None: where_list = [ 'WHERE' ] params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_, repeat('EQ')) if optimistic_columns: populate_criteria_list( where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) from_ast = [ 'FROM', [ None, 'TABLE', obj._table_ ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) obj.__class__._delete_sql_cache_[query_key] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) if cursor.rowcount == 0 and cache.db_session.optimistic: throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) def find_updated_attributes(obj): entity = obj.__class__ attrs_to_select = [] attrs_to_select.extend(entity._pk_attrs_) discr = entity._discriminator_attr_ if discr is not None and discr.pk_offset is None: attrs_to_select.append(discr) for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): optimistic = attr.optimistic if attr.optimistic is not None else attr.converters[0].optimistic if optimistic: attrs_to_select.append(attr) optimistic_converters = [] attr_offsets = {} select_list = [ 'ALL' ] for attr in attrs_to_select: optimistic_converters.extend(attr.converters) attr_offsets[attr] = offsets = [] for columns in attr.columns: select_list.append([ 'COLUMN', None, columns]) offsets.append(len(select_list) - 2) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] pk_columns = entity._pk_columns_ pk_converters = entity._pk_converters_ criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] database = entity._database_ sql, adapter = database._ast2sql(sql_ast) arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) row = cursor.fetchone() if row is None: return "Object %s was deleted outside of current transaction" % safe_repr(obj) real_entity_subclass, pkval, avdict = entity._parse_row_(row, attr_offsets) diff = [] for attr, new_dbval in avdict.items(): old_dbval = obj._dbvals_[attr] converter = attr.converters[0] if old_dbval != new_dbval and ( attr.reverse or not converter.dbvals_equal(old_dbval, new_dbval)): diff.append('%s (%r -> %r)' % (attr.name, old_dbval, new_dbval)) return "Object %s was updated outside of current transaction%s" % ( safe_repr(obj), ('. Changes: %s' % ', '.join(diff) if diff else '')) def _save_(obj, dependent_objects=None): status = obj._status_ if status in ('created', 'modified'): obj._save_principal_objects_(dependent_objects) if status == 'created': obj._save_created_() elif status == 'modified': obj._save_updated_() elif status == 'marked_to_delete': obj._save_deleted_() else: assert False, "_save_() called for object %r with incorrect status %s" % (obj, status) # pragma: no cover assert obj._status_ in saved_statuses cache = obj._session_cache_ assert cache is not None and cache.is_alive cache.saved_objects.append((obj, obj._status_)) objects_to_save = cache.objects_to_save save_pos = obj._save_pos_ if save_pos == len(objects_to_save) - 1: objects_to_save.pop() else: objects_to_save[save_pos] = None obj._save_pos_ = None def flush(obj): if obj._status_ not in ('created', 'modified', 'marked_to_delete'): return assert obj._save_pos_ is not None, 'save_pos is None for %s object' % obj._status_ cache = obj._session_cache_ assert cache is not None and cache.is_alive and not cache.saved_objects with cache.flush_disabled(): obj._before_save_() # should be inside flush_disabled to prevent infinite recursion # TODO: add to documentation that flush is disabled inside before_xxx hooks obj._save_() cache.call_after_save_hooks() def _before_save_(obj): status = obj._status_ if status == 'created': obj.before_insert() elif status == 'modified': obj.before_update() elif status == 'marked_to_delete': obj.before_delete() def before_insert(obj): pass def before_update(obj): pass def before_delete(obj): pass def _after_save_(obj, status): if status == 'inserted': obj.after_insert() elif status == 'updated': obj.after_update() elif status == 'deleted': obj.after_delete() def after_insert(obj): pass def after_update(obj): pass def after_delete(obj): pass @cut_traceback def to_dict(obj, only=None, exclude=None, with_collections=False, with_lazy=False, related_objects=False): cache = obj._session_cache_ if cache is not None and cache.is_alive and cache.modified: cache.flush() attrs = obj.__class__._get_attrs_(only, exclude, with_collections, with_lazy) result = {} for attr in attrs: value = attr.__get__(obj) if attr.is_collection: if related_objects: value = sorted(value) elif len(attr.reverse.entity._pk_columns_) > 1: value = sorted(item._get_raw_pkval_() for item in value) else: value = sorted(item._get_raw_pkval_()[0] for item in value) elif attr.is_relation and not related_objects and value is not None: value = value._get_raw_pkval_() if len(value) == 1: value = value[0] result[attr.name] = value return result def to_json(obj, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return obj._database_.to_json(obj, include, exclude, converter, with_schema, schema_hash) def string2ast(s): result = string2ast_cache.get(s) if result is not None: return result if PY2: if isinstance(s, str): try: s.encode('ascii') except UnicodeDecodeError: throw(TypeError, 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % s) else: s = s.encode('ascii', 'backslashreplace') module_node = parse('(%s)' % s) if not isinstance(module_node, ast.Module): throw(TypeError) stmt_node = module_node.node if not isinstance(stmt_node, ast.Stmt) or len(stmt_node.nodes) != 1: throw(TypeError) discard_node = stmt_node.nodes[0] if not isinstance(discard_node, ast.Discard): throw(TypeError) result = string2ast_cache[s] = discard_node.expr # result = deepcopy(result) # no need for now, but may be needed later return result def get_globals_and_locals(args, kwargs, frame_depth, from_generator=False): args_len = len(args) assert args_len > 0 func = args[0] if from_generator: if not isinstance(func, (basestring, types.GeneratorType)): throw(TypeError, 'The first positional argument must be generator expression or its text source. Got: %r' % func) else: if not isinstance(func, (basestring, types.FunctionType)): throw(TypeError, 'The first positional argument must be lambda function or its text source. Got: %r' % func) if args_len > 1: globals = args[1] if not hasattr(globals, 'keys'): throw(TypeError, 'The second positional arguments should be globals dictionary. Got: %r' % globals) if args_len > 2: locals = args[2] if local is not None and not hasattr(locals, 'keys'): throw(TypeError, 'The third positional arguments should be locals dictionary. Got: %r' % locals) else: locals = {} if type(func) is types.GeneratorType: locals = locals.copy() locals.update(func.gi_frame.f_locals) if len(args) > 3: throw(TypeError, 'Excess positional argument%s: %s' % (len(args) > 4 and 's' or '', ', '.join(imap(repr, args[3:])))) else: locals = {} if frame_depth is not None: locals.update(sys._getframe(frame_depth+1).f_locals) if type(func) is types.GeneratorType: globals = func.gi_frame.f_globals locals.update(func.gi_frame.f_locals) elif frame_depth is not None: globals = sys._getframe(frame_depth+1).f_globals if kwargs: throw(TypeError, 'Keyword arguments cannot be specified together with positional arguments') return func, globals, locals def make_query(args, frame_depth, left_join=False): gen, globals, locals = get_globals_and_locals( args, kwargs=None, frame_depth=frame_depth+1 if frame_depth is not None else None, from_generator=True) if isinstance(gen, types.GeneratorType): tree, external_names, cells = decompile(gen) code_key = id(gen.gi_frame.f_code) elif isinstance(gen, basestring): tree = string2ast(gen) if not isinstance(tree, ast.GenExpr): throw(TypeError, 'Source code should represent generator. Got: %s' % gen) code_key = gen cells = None else: assert False return Query(code_key, tree.code, globals, locals, cells, left_join) @cut_traceback def select(*args): return make_query(args, frame_depth=cut_traceback_depth+1) @cut_traceback def left_join(*args): return make_query(args, frame_depth=cut_traceback_depth+1, left_join=True) @cut_traceback def get(*args): return make_query(args, frame_depth=cut_traceback_depth+1).get() @cut_traceback def exists(*args): return make_query(args, frame_depth=cut_traceback_depth+1).exists() @cut_traceback def delete(*args): return make_query(args, frame_depth=cut_traceback_depth+1).delete() def make_aggrfunc(std_func): def aggrfunc(*args, **kwargs): if not args: return std_func(**kwargs) arg = args[0] if type(arg) is types.GeneratorType: try: iterator = arg.gi_frame.f_locals['.0'] except: return std_func(*args, **kwargs) if isinstance(iterator, EntityIter): return getattr(select(arg), std_func.__name__)(*args[1:], **kwargs) return std_func(*args, **kwargs) aggrfunc.__name__ = std_func.__name__ return aggrfunc count = make_aggrfunc(utils.count) sum = make_aggrfunc(builtins.sum) min = make_aggrfunc(builtins.min) max = make_aggrfunc(builtins.max) avg = make_aggrfunc(utils.avg) group_concat = make_aggrfunc(utils.group_concat) distinct = make_aggrfunc(utils.distinct) def JOIN(expr): return expr def desc(expr): if isinstance(expr, Attribute): return expr.desc if isinstance(expr, DescWrapper): return expr.attr if isinstance(expr, int_types): return -expr if isinstance(expr, basestring): return 'desc(%s)' % expr return expr def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() for name, cell in cells.items(): try: locals[name] = cell.cell_contents except ValueError: throw(NameError, 'Free variable `%s` referenced before assignment in enclosing scope' % name) vars = {} vartypes = HashableDict() for src, extractor in iteritems(extractors): varkey = filter_num, src, code_key try: value = extractor(globals, locals) except Exception as cause: raise ExprEvalError(src, cause) if isinstance(value, types.GeneratorType): value = make_query((value,), frame_depth=None) if isinstance(value, QueryResultIterator): qr = value._query_result value = qr if not qr._items else tuple(qr._items[value._position:]) if isinstance(value, QueryResult) and value._items: value = tuple(value._items) if isinstance(value, (Query, QueryResult, SetIterator)): query = value._get_query() vars.update(query._vars) vartypes.update(query._translator.vartypes) if src == 'None' and value is not None: throw(TranslationError) if src == 'True' and value is not True: throw(TranslationError) if src == 'False' and value is not False: throw(TranslationError) try: vartypes[varkey], value = normalize(value) except TypeError: if not isinstance(value, dict): unsupported = False try: value = tuple(value) except: unsupported = True else: unsupported = True if unsupported: typename = type(value).__name__ if src == '.0': throw(TypeError, 'Query cannot iterate over anything but entity class or another query') throw(TypeError, 'Expression `%s` has unsupported type %r' % (src, typename)) vartypes[varkey], value = normalize(value) vars[varkey] = value return vars, vartypes def unpickle_query(query_result): return query_result class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) tree, extractors = create_extractors(code_key, tree, globals, locals, special_functions, const_functions) filter_num = 0 vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter varkey = filter_num, node.src, code_key origin = vars[varkey] if isinstance(origin, Query): prev_query = origin elif isinstance(origin, QueryResult): prev_query = origin._query elif isinstance(origin, QueryResultIterator): prev_query = origin._query_result._query elif isinstance(origin, SetIterator): prev_query = origin._query else: prev_query = None if not isinstance(origin, EntityMeta): if node.src == '.0': throw(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) database = origin._database_ if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) if prev_query is not None: database = prev_query._translator.database filter_num = prev_query._filter_num + 1 vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) query._filter_num = filter_num database.provider.normalize_vars(vars, vartypes) query._code_key = code_key query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database translator, vars = query._get_translator(query._key, vars) query._vars = vars if translator is None: pickled_tree = pickle_ast(tree) tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls try: translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) except UseAnotherTranslator as e: translator = e.translator name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) try: translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), left_join=True, optimize=name_path) except UseAnotherTranslator as e: translator = e.translator except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree if translator.can_be_cached: database._translator_cache[query._key] = translator query._translator = translator query._filters = () query._next_kwarg_id = 0 query._for_update = query._nowait = query._skip_locked = False query._distinct = None query._prefetch = False query._prefetch_context = PrefetchContext(query._database) def _get_query(query): return query def _get_type_(query): return QueryType(query) def _normalize_var(query, query_type): return query_type, query def _clone(query, **kwargs): new_query = object.__new__(Query) new_query.__dict__.update(query.__dict__) new_query.__dict__.update(kwargs) return new_query def __reduce__(query): return unpickle_query, (query._fetch(),) def _get_translator(query, query_key, vars): new_vars = vars.copy() database = query._database translator = database._translator_cache.get(query_key) all_func_vartypes = {} if translator is not None: if translator.func_extractors_map: for func, func_extractors in iteritems(translator.func_extractors_map): func_id = id(func.func_code if PY2 else func.__code__) func_filter_num = translator.filter_num, 'func', func_id func_vars, func_vartypes = extract_vars( func_id, func_filter_num, func_extractors, func.__globals__, {}, func.__closure__) # todo closures database.provider.normalize_vars(func_vars, func_vartypes) new_vars.update(func_vars) all_func_vartypes.update(func_vartypes) if all_func_vartypes != translator.func_vartypes: return None, vars.copy() for key, attrname in iteritems(translator.getattr_values): assert key in new_vars if attrname != new_vars[key]: del database._translator_cache[query_key] return None, vars.copy() return translator, new_vars def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type attrs_to_prefetch_dict = query._prefetch_context.attrs_to_prefetch_dict if isinstance(expr_type, EntityMeta) and attrs_to_prefetch_dict: attrs_to_prefetch = tuple(sorted(attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () sql_key = HashableDict( query._key, vartypes=HashableDict(query._translator.vartypes), getattr_values=HashableDict(translator.getattr_values), limit=limit, offset=offset, distinct=query._distinct, aggr_func=(aggr_func_name, aggr_func_distinct, sep), for_update=query._for_update, nowait=query._nowait, skip_locked=query._skip_locked, inner_join_syntax=options.INNER_JOIN_SYNTAX, attrs_to_prefetch=attrs_to_prefetch ) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( limit, offset, query._distinct, aggr_func_name, aggr_func_distinct, sep, query._for_update, query._nowait, query._skip_locked) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) cache_entry = sql, adapter, attr_offsets database._constructed_sql_cache[sql_key] = cache_entry else: sql, adapter, attr_offsets = cache_entry arguments = adapter(query._vars) if query._translator.query_result_is_cacheable: arguments_key = HashableDict(arguments) if type(arguments) is dict else arguments try: hash(arguments_key) except: query_key = None # arguments are unhashable else: query_key = HashableDict(sql_key, arguments_key=arguments_key) else: query_key = None return sql, arguments, attr_offsets, query_key def get_sql(query): sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments() return sql def _actual_fetch(query, limit=None, offset=None): translator = query._translator with query._prefetch_context: sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) database = query._database cache = database._get_cache() if query._for_update: cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results items = cache.query_results.get(query_key) if items is None: cursor = database._exec_sql(sql, arguments) if isinstance(translator.expr_type, EntityMeta): entity = translator.expr_type items = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, used_attrs=translator.get_used_attrs()) elif len(translator.row_layout) == 1: func, slice_or_offset, src = translator.row_layout[0] items = list(starmap(func, cursor.fetchall())) else: items = [ tuple(func(sql_row[slice_or_offset]) for func, slice_or_offset, src in translator.row_layout) for sql_row in cursor.fetchall() ] for i, t in enumerate(translator.expr_type): if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in items) if query_key is not None: cache.query_results[query_key] = items else: stats = database._dblocal.stats stat = stats.get(sql) if stat is not None: stat.cache_count += 1 else: stats[sql] = QueryStat(sql) if query._prefetch: query._do_prefetch(items) return items @cut_traceback def prefetch(query, *args): query = query._clone(_prefetch_context=query._prefetch_context.copy()) query._prefetch = True prefetch_context = query._prefetch_context for arg in args: if isinstance(arg, EntityMeta): entity = arg if query._database is not entity._database_: throw(TypeError, 'Entity %s belongs to different database and cannot be prefetched' % entity.__name__) prefetch_context.entities_to_prefetch.add(entity) elif isinstance(arg, Attribute): attr = arg entity = attr.entity if query._database is not entity._database_: throw(TypeError, 'Entity of attribute %s belongs to different database and cannot be prefetched' % attr) if isinstance(attr.py_type, EntityMeta) or attr.lazy: prefetch_context.attrs_to_prefetch_dict[entity].add(attr) else: throw(TypeError, 'Argument of prefetch() query method must be entity class or attribute. ' 'Got: %r' % arg) return query def _do_prefetch(query, query_result): expr_type = query._translator.expr_type all_objects = set() objects_to_process = set() objects_to_prefetch = set() if isinstance(expr_type, EntityMeta): objects_to_process.update(query_result) all_objects.update(query_result) elif type(expr_type) is tuple: obj_indexes = [ i for i, t in enumerate(expr_type) if isinstance(t, EntityMeta) ] if obj_indexes: for row in query_result: objects_to_prefetch.update(row[i] for i in obj_indexes) all_objects.update(objects_to_prefetch) prefetch_context = local.prefetch_context assert prefetch_context collection_prefetch_dict = defaultdict(set) objects_to_prefetch_dict = defaultdict(set) while objects_to_process or objects_to_prefetch: for obj in objects_to_process: entity = obj.__class__ relations_to_prefetch = prefetch_context.get_relations_to_prefetch(entity) for attr in relations_to_prefetch: if attr.is_collection: collection_prefetch_dict[attr].add(obj) else: obj2 = attr.get(obj) if obj2 is not None and obj2 not in all_objects: all_objects.add(obj2) objects_to_prefetch.add(obj2) next_objects_to_process = set() for attr, objects in collection_prefetch_dict.items(): items = attr.prefetch_load_all(objects) if attr.reverse.is_collection: objects_to_prefetch.update(items) else: next_objects_to_process.update(item for item in items if item not in all_objects) collection_prefetch_dict.clear() for obj in objects_to_prefetch: objects_to_prefetch_dict[obj.__class__._root_].add(obj) objects_to_prefetch.clear() for entity, objects in objects_to_prefetch_dict.items(): next_objects_to_process.update(objects) entity._prefetch_load_all_(objects) objects_to_prefetch_dict.clear() objects_to_process = next_objects_to_process @cut_traceback def show(query, width=None, stream=None): query._fetch().show(width, stream) @cut_traceback def get(query): objects = query[:2] if not objects: return None if len(objects) > 1: throw(MultipleObjectsFoundError, 'Multiple objects were found. Use select(...) to retrieve them') return objects[0] @cut_traceback def first(query): translator = query._translator if translator.order: pass elif type(translator.expr_type) is tuple: query = query.order_by(*[i+1 for i in xrange(len(query._translator.expr_type))]) else: query = query.order_by(1) objects = query.without_distinct()[:1] if not objects: return None return objects[0] @cut_traceback def without_distinct(query): return query._clone(_distinct=False) @cut_traceback def distinct(query): return query._clone(_distinct=True) @cut_traceback def exists(query): objects = query[:1] return bool(objects) @cut_traceback def delete(query, bulk=None): if not bulk: if not isinstance(query._translator.expr_type, EntityMeta): throw(TypeError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(query._translator.tree.expr)) objects = query._actual_fetch() for obj in objects: obj._delete_() return len(objects) translator = query._translator sql_key = HashableDict(query._key, sql_command='DELETE') database = query._database cache = database._get_cache() cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast = translator.construct_delete_sql_ast() cache_entry = database.provider.ast2sql(sql_ast) database._constructed_sql_cache[sql_key] = cache_entry sql, adapter = cache_entry arguments = adapter(query._vars) cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results cursor = database._exec_sql(sql, arguments) cache.query_results.clear() return cursor.rowcount @cut_traceback def __len__(query): return len(query._actual_fetch()) @cut_traceback def __iter__(query): return iter(query._fetch(lazy=True)) @cut_traceback def order_by(query, *args): return query._order_by('order_by', *args) @cut_traceback def sort_by(query, *args): return query._order_by('sort_by', *args) def _order_by(query, method_name, *args): if not args: throw(TypeError, '%s() method requires at least one argument' % method_name) if args[0] is None: if len(args) > 1: throw(TypeError, 'When first argument of %s() method is None, it must be the only argument' % method_name) tup = (('without_order',),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: new_translator = query._translator.without_order() query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) if isinstance(args[0], (basestring, types.FunctionType)): func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+2) return query._process_lambda(func, globals, locals, order_by=True) if isinstance(args[0], RawSQL): raw = args[0] return query.order_by(lambda: raw) attributes = numbers = False for arg in args: if isinstance(arg, int_types): numbers = True elif isinstance(arg, (Attribute, DescWrapper)): attributes = True else: throw(TypeError, "order_by() method receive an argument of invalid type: %r" % arg) if numbers and attributes: throw(TypeError, 'order_by() method receive invalid combination of arguments') tup = (('order_by_numbers' if numbers else 'order_by_attributes', args),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) else: new_translator = query._translator.order_by_attributes(args) query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) def _process_lambda(query, func, globals, locals, order_by=False, original_names=False): prev_translator = query._translator argnames = () if isinstance(func, basestring): func_id = func func_ast = string2ast(func) if isinstance(func_ast, ast.Lambda): argnames = get_lambda_args(func_ast) func_ast = func_ast.code cells = None elif type(func) is types.FunctionType: argnames = get_lambda_args(func) func_id = id(func.func_code if PY2 else func.__code__) func_ast, external_names, cells = decompile(func) elif not order_by: throw(TypeError, 'Argument of filter() method must be a lambda functon or its text. Got: %r' % func) else: assert False # pragma: no cover if argnames: if original_names: for name in argnames: if name not in prev_translator.namespace: throw(TypeError, 'Lambda argument `%s` does not correspond to any variable in original query' % name) else: expr_type = prev_translator.expr_type expr_count = len(expr_type) if type(expr_type) is tuple else 1 if len(argnames) != expr_count: throw(TypeError, 'Incorrect number of lambda arguments. ' 'Expected: %d, got: %d' % (expr_count, len(argnames))) else: original_names = True new_filter_num = query._filter_num + 1 func_ast, extractors = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.namespace) if extractors: vars, vartypes = extract_vars(func_id, new_filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) new_vars = query._vars.copy() new_vars.update(vars) else: new_vars, vartypes = query._vars, HashableDict() tup = (('order_by' if order_by else 'where' if original_names else 'filter', func_id, vartypes),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + (('apply_lambda', func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize new_translator = prev_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) translator_cls = prev_translator.__class__ try: new_translator = translator_cls( tree_copy, None, prev_translator.original_code_key, prev_translator.original_filter_num, prev_translator.extractors, None, prev_translator.vartypes.copy(), left_join=True, optimize=name_path) except UseAnotherTranslator: assert False new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator return query._clone(_filter_num=new_filter_num, _vars=new_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): for tup in query._filters: method_name, args = tup[0], tup[1:] translator_method = getattr(translator, method_name) translator = translator_method(*args) return translator @cut_traceback def filter(query, *args, **kwargs): if args: if isinstance(args[0], RawSQL): raw = args[0] return query.filter(lambda: raw) func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) return query._process_lambda(func, globals, locals, order_by=False) if not kwargs: return query entity = query._translator.expr_type if not isinstance(entity, EntityMeta): throw(TypeError, 'Keyword arguments are not allowed: since query result type is not an entity, filter() method can accept only lambda') return query._apply_kwargs(kwargs) @cut_traceback def where(query, *args, **kwargs): if args: if isinstance(args[0], RawSQL): raw = args[0] return query.where(lambda: raw) func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) return query._process_lambda(func, globals, locals, order_by=False, original_names=True) if not kwargs: return query if len(query._translator.tree.quals) > 1: throw(TypeError, 'Keyword arguments are not allowed: query iterates over more than one entity') return query._apply_kwargs(kwargs, original_names=True) def _apply_kwargs(query, kwargs, original_names=False): translator = query._translator if original_names: tablerefs = translator.sqlquery.tablerefs alias = translator.tree.quals[0].assign.name tableref = tablerefs[alias] entity = tableref.entity else: entity = translator.expr_type get_attr = entity._adict_.get filterattrs = [] value_dict = {} next_id = query._next_kwarg_id for attrname, val in sorted(iteritems(kwargs)): attr = get_attr(attrname) if attr is None: throw(AttributeError, 'Entity %s does not have attribute %s' % (entity.__name__, attrname)) if attr.is_collection: throw(TypeError, '%s attribute %s cannot be used as a keyword argument for filtering' % (attr.__class__.__name__, attr)) val = attr.validate(val, None, entity, from_db=False) id = next_id next_id += 1 filterattrs.append((attr, id, val is None)) value_dict[id] = val filterattrs = tuple(filterattrs) tup = (('apply_kwfilters', filterattrs, original_names),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_vars = query._vars.copy() new_vars.update(value_dict) new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: new_translator = translator.apply_kwfilters(filterattrs, original_names) query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, _next_kwarg_id=next_id, _vars=new_vars) @cut_traceback def __getitem__(query, key): if not isinstance(key, slice): throw(TypeError, 'If you want apply index to a query, convert it to list first') step = key.step if step is not None and step != 1: throw(TypeError, "Parameter 'step' of slice object is not allowed here") start = key.start if start is None: start = 0 elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") stop = key.stop if stop is None: if not start: return query._fetch() else: return query._fetch(limit=None, offset=start) if start >= stop: return query._fetch(limit=0) return query._fetch(limit=stop-start, offset=start) def _fetch(query, limit=None, offset=None, lazy=False): return QueryResult(query, limit, offset, lazy=lazy) @cut_traceback def fetch(query, limit=None, offset=None): return query._fetch(limit, offset) @cut_traceback def limit(query, limit=None, offset=None): return query._fetch(limit, offset, lazy=True) @cut_traceback def page(query, pagenum, pagesize=10): offset = (pagenum - 1) * pagesize return query._fetch(pagesize, offset, lazy=True) def _aggregate(query, aggr_func_name, distinct=None, sep=None): translator = query._translator sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments( aggr_func_name=aggr_func_name, aggr_func_distinct=distinct, sep=sep) cache = query._database._get_cache() try: result = cache.query_results[query_key] except KeyError: cursor = query._database._exec_sql(sql, arguments) row = cursor.fetchone() if row is not None: result = row[0] else: result = None if result is None and aggr_func_name == 'SUM': result = 0 if result is None: pass elif aggr_func_name == 'COUNT': pass else: if aggr_func_name == 'AVG': expr_type = float elif aggr_func_name == 'GROUP_CONCAT': expr_type = basestring else: expr_type = translator.expr_type provider = query._database.provider converter = provider.get_converter_by_py_type(expr_type) result = converter.sql2py(result) if query_key is not None: cache.query_results[query_key] = result return result @cut_traceback def sum(query, distinct=None): return query._aggregate('SUM', distinct) @cut_traceback def avg(query, distinct=None): return query._aggregate('AVG', distinct) @cut_traceback def group_concat(query, sep=None, distinct=None): if sep is not None: if not isinstance(sep, basestring): throw(TypeError, '`sep` option for `group_concat` should be of type str. Got: %s' % type(sep).__name__) return query._aggregate('GROUP_CONCAT', distinct, sep) @cut_traceback def min(query): return query._aggregate('MIN') @cut_traceback def max(query): return query._aggregate('MAX') @cut_traceback def count(query, distinct=None): return query._aggregate('COUNT', distinct) @cut_traceback def for_update(query, nowait=False, skip_locked=False): if nowait and skip_locked: throw(TypeError, 'nowait and skip_locked options are mutually exclusive') return query._clone(_for_update=True, _nowait=nowait, _skip_locked=skip_locked) def random(query, limit): return query.order_by('random()')[:limit] def to_json(query, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return query._database.to_json(query[:], include, exclude, converter, with_schema, schema_hash) class QueryResultIterator(object): __slots__ = '_query_result', '_position' def __init__(self, query_result): self._query_result = query_result self._position = 0 def _get_type_(self): if self._position != 0: throw(NotImplementedError, 'Cannot use partially exhausted iterator, please convert to list') return self._query_result._get_type_() def _normalize_var(self, query_type): if self._position != 0: throw(NotImplementedError) return self._query_result._normalize_var(query_type) def next(self): qr = self._query_result if qr._items is None: qr._items = qr._query._actual_fetch(qr._limit, qr._offset) if self._position >= len(qr._items): raise StopIteration item = qr._items[self._position] self._position += 1 return item __next__ = next def __length_hint__(self): return len(self._query_result) - self._position def make_query_result_method_error_stub(name, title=None): def func(self, *args, **kwargs): throw(TypeError, 'In order to do %s, cast QueryResult to list first' % (title or name)) return func class QueryResult(object): __slots__ = '_query', '_limit', '_offset', '_items', '_expr_type', '_col_names' def __init__(self, query, limit, offset, lazy): translator = query._translator self._query = query self._limit = limit self._offset = offset self._items = None if lazy else self._query._actual_fetch(limit, offset) self._expr_type = translator.expr_type self._col_names = translator.col_names def _get_query(self): return self._query def _get_type_(self): if self._items is None: return QueryType(self._query, self._limit, self._offset) item_type = self._query._translator.expr_type return tuple(item_type for item in self._items) def _normalize_var(self, query_type): if self._items is None: return query_type, self._query items = tuple(normalize(item) for item in self._items) item_type = self._query._translator.expr_type return tuple(item_type for item in items), items def _get_items(self): if self._items is None: self._items = self._query._actual_fetch(self._limit, self._offset) return self._items def __getstate__(self): return self._get_items(), self._limit, self._offset, self._expr_type, self._col_names def __setstate__(self, state): self._query = None self._items, self._limit, self._offset, self._expr_type, self._col_names = state def __repr__(self): if self._items is not None: return self.__str__() return '' % hex(id(self)) def __str__(self): return repr(self._get_items()) def __iter__(self): return QueryResultIterator(self) def __len__(self): if self._items is None: self._items = self._query._actual_fetch(self._limit, self._offset) return len(self._items) def __getitem__(self, key): if self._items is None: self._items = self._query._actual_fetch(self._limit, self._offset) return self._items[key] def __contains__(self, item): return item in self._get_items() def index(self, item): return self._get_items().index(item) def _other_items(self, other): return other._get_items() if isinstance(other, QueryResult) else other def __eq__(self, other): return self._get_items() == self._other_items(other) def __ne__(self, other): return self._get_items() != self._other_items(other) def __lt__(self, other): return self._get_items() < self._other_items(other) def __le__(self, other): return self._get_items() <= self._other_items(other) def __gt__(self, other): return self._get_items() > self._other_items(other) def __ge__(self, other): return self._get_items() >= self._other_items(other) def __reversed__(self): return reversed(self._get_items()) def reverse(self): self._get_items().reverse() def sort(self, *args, **kwargs): self._get_items().sort(*args, **kwargs) def shuffle(self): shuffle(self._get_items()) @cut_traceback def show(self, width=None, stream=None): if stream is None: stream = sys.stdout def writeln(s): stream.write(s) stream.write('\n') if self._items is None: self._items = self._query._actual_fetch(self._limit, self._offset) if not width: width = options.CONSOLE_WIDTH max_columns = width // 5 expr_type = self._expr_type col_names = self._col_names def to_str(x): return tostring(x).replace('\n', ' ') if isinstance(expr_type, EntityMeta): entity = expr_type col_names = [ attr.name for attr in entity._attrs_ if not attr.is_collection and not attr.lazy ][:max_columns] if len(col_names) == 1: col_name = col_names[0] row_maker = lambda obj: (getattr(obj, col_name),) else: row_maker = attrgetter(*col_names) rows = [tuple(to_str(value) for value in row_maker(obj)) for obj in self._items] elif len(col_names) == 1: rows = [(to_str(obj),) for obj in self._items] else: rows = [tuple(to_str(value) for value in row) for row in self._items] remaining_columns = {} for col_num, colname in enumerate(col_names): if not rows: max_len = len(colname) else: max_len = max(len(colname), max(len(row[col_num]) for row in rows)) remaining_columns[col_num] = max_len width_dict = {} available_width = width - len(col_names) + 1 while remaining_columns: base_len = (available_width - len(remaining_columns) + 1) // len(remaining_columns) for col_num, max_len in remaining_columns.items(): if max_len <= base_len: width_dict[col_num] = max_len del remaining_columns[col_num] available_width -= max_len break else: break if remaining_columns: base_len = available_width // len(remaining_columns) for col_num, max_len in remaining_columns.items(): width_dict[col_num] = base_len writeln(strjoin('|', (strcut(colname, width_dict[i]) for i, colname in enumerate(col_names)))) writeln(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) for row in rows: writeln(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) stream.flush() def to_json(self, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return self._query._database.to_json(self, include, exclude, converter, with_schema, schema_hash) def __add__(self, other): result = [] result.extend(self) result.extend(other) return result def __radd__(self, other): result = [] result.extend(other) result.extend(self) return result def to_list(self): return list(self) __setitem__ = make_query_result_method_error_stub('__setitem__', 'item assignment') __delitem__ = make_query_result_method_error_stub('__delitem__', 'item deletion') __iadd__ = make_query_result_method_error_stub('__iadd__', '+=') __imul__ = make_query_result_method_error_stub('__imul__', '*=') __mul__ = make_query_result_method_error_stub('__mul__', '*') __rmul__ = make_query_result_method_error_stub('__rmul__', '*') append = make_query_result_method_error_stub('append', 'append') clear = make_query_result_method_error_stub('clear', 'clear') extend = make_query_result_method_error_stub('extend', 'extend') insert = make_query_result_method_error_stub('insert', 'insert') pop = make_query_result_method_error_stub('pop', 'pop') remove = make_query_result_method_error_stub('remove', 'remove') def strcut(s, width): if len(s) <= width: return s + ' ' * (width - len(s)) else: return s[:width-3] + '...' @cut_traceback def show(entity): x = entity if isinstance(x, EntityMeta): print(x.describe()) elif isinstance(x, Entity): print('instance of ' + x.__class__.__name__) # width = options.CONSOLE_WIDTH # for attr in x._attrs_: # if attr.is_collection or attr.lazy: continue # value = str(attr.__get__(x)).replace('\n', ' ') # print(' %s: %s' % (attr.name, strcut(value, width-len(attr.name)-4))) # print() QueryResult([ x ], None, x.__class__, None).show() elif isinstance(x, (basestring, types.GeneratorType)): select(x).show() elif hasattr(x, 'show'): x.show() else: from pprint import pprint pprint(x) special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr} const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbapiprovider.py0000666000000000000000000011321000000000000015704 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, basestring, unicode, buffer, int_types, iteritems import os, re, json from decimal import Decimal, InvalidOperation from datetime import datetime, date, time, timedelta from uuid import uuid4, UUID import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedArray, Json, QueryType, Array class DBException(Exception): def __init__(exc, original_exc, *args): args = args or getattr(original_exc, 'args', ()) Exception.__init__(exc, *args) exc.original_exc = original_exc # Exception inheritance layout of DBAPI 2.0-compatible provider: # # Exception # Warning # Error # InterfaceError # DatabaseError # DataError # OperationalError # IntegrityError # InternalError # ProgrammingError # NotSupportedError class Warning(DBException): pass class Error(DBException): pass class InterfaceError(Error): pass class DatabaseError(Error): pass class DataError(DatabaseError): pass class OperationalError(DatabaseError): pass class IntegrityError(DatabaseError): pass class InternalError(DatabaseError): pass class ProgrammingError(DatabaseError): pass class NotSupportedError(DatabaseError): pass @decorator def wrap_dbapi_exceptions(func, provider, *args, **kwargs): dbapi_module = provider.dbapi_module try: if provider.dialect != 'SQLite': return func(provider, *args, **kwargs) else: provider.local_exceptions.keep_traceback = True try: return func(provider, *args, **kwargs) finally: provider.local_exceptions.keep_traceback = False except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) except dbapi_module.ProgrammingError as e: if provider.dialect == 'PostgreSQL': msg = str(e) if msg.startswith('operator does not exist:') and ' json ' in msg: msg += ' (Note: use column type `jsonb` instead of `json`)' raise ProgrammingError(e, msg, *e.args[1:]) raise ProgrammingError(e) except dbapi_module.InternalError as e: raise InternalError(e) except dbapi_module.IntegrityError as e: raise IntegrityError(e) except dbapi_module.OperationalError as e: if provider.dialect == 'SQLite': provider.restore_exception() raise OperationalError(e) except dbapi_module.DataError as e: raise DataError(e) except dbapi_module.DatabaseError as e: raise DatabaseError(e) except dbapi_module.InterfaceError as e: if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': throw(InterfaceError, e, 'MySQL server misconfiguration') raise InterfaceError(e) except dbapi_module.Error as e: raise Error(e) except dbapi_module.Warning as e: raise Warning(e) def unexpected_args(attr, args): throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format( len(args) > 1 and 's' or '', attr, ', '.join(repr(arg) for arg in args)) ) version_re = re.compile('[0-9\.]+') def get_version_tuple(s): m = version_re.match(s) if m is not None: components = m.group(0).split('.') return tuple(int(component) for component in components) return None class DBAPIProvider(object): paramstyle = 'qmark' quote_char = '"' max_params_count = 999 max_name_len = 128 table_if_not_exists_syntax = True index_if_not_exists_syntax = True max_time_precision = default_time_precision = 6 uint64_support = False # SQLite and PostgreSQL does not limit varchar max length. varchar_default_max_len = None dialect = None dbapi_module = None dbschema_cls = None translator_cls = None sqlbuilder_cls = None array_converter_cls = None name_before_table = 'schema_name' default_schema_name = None fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' } def __init__(provider, *args, **kwargs): pool_mockup = kwargs.pop('pony_pool_mockup', None) call_on_connect = kwargs.pop('pony_call_on_connect', None) if pool_mockup: provider.pool = pool_mockup else: provider.pool = provider.get_pool(*args, **kwargs) connection, is_new_connection = provider.connect() if call_on_connect: call_on_connect(connection) provider.inspect_connection(connection) provider.release(connection) @wrap_dbapi_exceptions def inspect_connection(provider, connection): pass def normalize_name(provider, name): return name[:provider.max_name_len] def get_default_entity_table_name(provider, entity): return provider.normalize_name(entity.__name__) def get_default_m2m_table_name(provider, attr, reverse): if attr.symmetric: assert reverse is attr name = attr.entity.__name__ + '_' + attr.name else: name = attr.entity.__name__ + '_' + reverse.entity.__name__ return provider.normalize_name(name) def get_default_column_names(provider, attr, reverse_pk_columns=None): normalize_name = provider.normalize_name if reverse_pk_columns is None: return [ normalize_name(attr.name) ] elif len(reverse_pk_columns) == 1: return [ normalize_name(attr.name) ] else: prefix = attr.name + '_' return [ normalize_name(prefix + column) for column in reverse_pk_columns ] def get_default_m2m_column_names(provider, entity): normalize_name = provider.normalize_name columns = entity._get_pk_columns_() if len(columns) == 1: return [ normalize_name(entity.__name__.lower()) ] else: prefix = entity.__name__.lower() + '_' return [ normalize_name(prefix + column) for column in columns ] def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False): if is_pk: index_name = 'pk_%s' % provider.base_name(table_name) else: if is_unique: template = 'unq_%(tname)s__%(cnames)s' elif m2m: template = 'idx_%(tname)s' else: template = 'idx_%(tname)s__%(cnames)s' index_name = template % dict(tname=provider.base_name(table_name), cnames='_'.join(name for name in column_names)) return provider.normalize_name(index_name.lower()) def get_default_fk_name(provider, child_table_name, parent_table_name, child_column_names): fk_name = 'fk_%s__%s' % (provider.base_name(child_table_name), '__'.join(child_column_names)) return provider.normalize_name(fk_name.lower()) def split_table_name(provider, table_name): if isinstance(table_name, basestring): return provider.default_schema_name, table_name if not table_name: throw(TypeError, 'Invalid table name: %r' % table_name) if len(table_name) != 2: size = len(table_name) throw(TypeError, '%s qualified table name must have two components: ' '%s and table_name. Got %d component%s: %s' % (provider.dialect, provider.name_before_table, size, 's' if size != 1 else '', table_name)) return table_name[0], table_name[1] def base_name(provider, name): if not isinstance(name, basestring): assert type(name) is tuple name = name[-1] assert isinstance(name, basestring) return name def quote_name(provider, name): quote_char = provider.quote_char if isinstance(name, basestring): name = name.replace(quote_char, quote_char+quote_char) return quote_char + name + quote_char return '.'.join(provider.quote_name(item) for item in name) def format_table_name(provider, name): return provider.quote_name(name) def normalize_vars(provider, vars, vartypes): for key, value in iteritems(vars): vartype = vartypes[key] if isinstance(vartype, QueryType): vartypes[key], vars[key] = value._normalize_var(vartype) def ast2sql(provider, ast): builder = provider.sqlbuilder_cls(provider, ast) return builder.sql, builder.adapter def should_reconnect(provider, exc): return False @wrap_dbapi_exceptions def connect(provider): return provider.pool.connect() @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): pass @wrap_dbapi_exceptions def commit(provider, connection, cache=None): core = pony.orm.core if core.local.debug: core.log_orm('COMMIT') connection.commit() if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def rollback(provider, connection, cache=None): core = pony.orm.core if core.local.debug: core.log_orm('ROLLBACK') connection.rollback() if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def release(provider, connection, cache=None): core = pony.orm.core if cache is not None and cache.db_session is not None and cache.db_session.ddl: provider.drop(connection, cache) else: if core.local.debug: core.log_orm('RELEASE CONNECTION') provider.pool.release(connection) @wrap_dbapi_exceptions def drop(provider, connection, cache=None): core = pony.orm.core if core.local.debug: core.log_orm('CLOSE CONNECTION') provider.pool.drop(connection) if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def disconnect(provider): core = pony.orm.core if core.local.debug: core.log_orm('DISCONNECT') provider.pool.disconnect() @wrap_dbapi_exceptions def execute(provider, cursor, sql, arguments=None, returning_id=False): if type(arguments) is list: assert arguments and not returning_id cursor.executemany(sql, arguments) else: if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) if returning_id: return cursor.lastrowid converter_classes = [] def _get_converter_type_by_py_type(provider, py_type): if isinstance(py_type, type): for t, converter_cls in provider.converter_classes: if issubclass(py_type, t): return converter_cls if issubclass(py_type, Array): converter_cls = provider.array_converter_cls if converter_cls is None: throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) return converter_cls if isinstance(py_type, RawSQLType): return Converter # for cases like select(raw_sql(...) for x in X) throw(TypeError, 'No database converter found for type %s' % py_type) def get_converter_by_py_type(provider, py_type): converter_cls = provider._get_converter_type_by_py_type(py_type) return converter_cls(provider, py_type) def get_converter_by_attr(provider, attr): py_type = attr.py_type converter_cls = provider._get_converter_type_by_py_type(py_type) return converter_cls(provider, py_type, attr) def get_pool(provider, *args, **kwargs): return Pool(provider.dbapi_module, *args, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): throw(NotImplementedError) def index_exists(provider, connection, table_name, index_name, case_sensitive=True): throw(NotImplementedError) def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): throw(NotImplementedError) def table_has_data(provider, connection, table_name): cursor = connection.cursor() cursor.execute('SELECT 1 FROM %s LIMIT 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def disable_fk_checks(provider, connection): pass def enable_fk_checks(provider, connection, prev_state): pass def drop_table(provider, connection, table_name): cursor = connection.cursor() sql = 'DROP TABLE %s' % provider.quote_name(table_name) cursor.execute(sql) class Pool(localbase): forked_connections = [] def __init__(pool, dbapi_module, *args, **kwargs): # called separately in each thread pool.dbapi_module = dbapi_module pool.args = args pool.kwargs = kwargs pool.con = pool.pid = None def connect(pool): pid = os.getpid() if pool.con is not None and pool.pid != pid: pool.forked_connections.append((pool.con, pool.pid)) pool.con = pool.pid = None core = pony.orm.core is_new_connection = False if pool.con is None: if core.local.debug: core.log_orm('GET NEW CONNECTION') is_new_connection = True pool._connect() pool.pid = pid elif core.local.debug: core.log_orm('GET CONNECTION FROM THE LOCAL POOL') return pool.con, is_new_connection def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) def release(pool, con): assert con is pool.con try: con.rollback() except: pool.drop(con) raise def drop(pool, con): assert con is pool.con, (con, pool.con) pool.con = None con.close() def disconnect(pool): con = pool.con pool.con = None if con is not None: con.close() class Converter(object): EQ = 'EQ' NE = 'NE' optimistic = True def __deepcopy__(converter, memo): return converter # Converter instances are "immutable" def __init__(converter, provider, py_type, attr=None): converter.provider = provider converter.py_type = py_type converter.attr = attr if attr is None: return kwargs = attr.kwargs.copy() converter.init(kwargs) for option in kwargs: throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) def init(converter, kwargs): attr = converter.attr if attr and attr.args: unexpected_args(attr, attr.args) def validate(converter, val, obj=None): return val def py2sql(converter, val): return val def sql2py(converter, val): return val def val2dbval(self, val, obj=None): return val def dbval2val(self, dbval, obj=None): return dbval def dbvals_equal(self, x, y): return x == y def get_sql_type(converter, attr=None): if attr is not None and attr.sql_type is not None: return attr.sql_type attr = converter.attr if attr.sql_type is not None: assert len(attr.columns) == 1 return converter.get_fk_type(attr.sql_type) if attr is not None and attr.reverse and not attr.is_collection: i = attr.converters.index(converter) rentity = attr.reverse.entity rpk_converters = rentity._pk_converters_ assert rpk_converters is not None and len(attr.converters) == len(rpk_converters) rconverter = rpk_converters[i] return rconverter.sql_type() return converter.sql_type() def get_fk_type(converter, sql_type): fk_types = converter.provider.fk_types if sql_type.isupper(): return fk_types.get(sql_type, sql_type) sql_type = sql_type.upper() return fk_types.get(sql_type, sql_type).lower() class NoneConverter(Converter): # used for raw_sql() parameters only def __init__(converter, provider, py_type, attr=None): if attr is not None: throw(TypeError, 'Attribute %s has invalid type NoneType' % attr) Converter.__init__(converter, provider, py_type) def get_sql_type(converter, attr=None): assert False def get_fk_type(converter, sql_type): assert False class BoolConverter(Converter): def validate(converter, val, obj=None): return bool(val) def sql2py(converter, val): return bool(val) def sql_type(converter): return "BOOLEAN" class StrConverter(Converter): def __init__(converter, provider, py_type, attr=None): converter.max_len = None converter.db_encoding = None Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr max_len = kwargs.pop('max_len', None) if len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) elif attr.args: if max_len is not None: throw(TypeError, 'Max length option specified twice: as a positional argument and as a `max_len` named argument') max_len = attr.args[0] if issubclass(attr.py_type, (LongStr, LongUnicode)): if max_len is not None: throw(TypeError, 'Max length is not supported for CLOBs') elif max_len is None: max_len = converter.provider.varchar_default_max_len elif not isinstance(max_len, int_types): throw(TypeError, 'Max length argument must be int. Got: %r' % max_len) converter.max_len = max_len converter.db_encoding = kwargs.pop('db_encoding', None) converter.autostrip = kwargs.pop('autostrip', True) def validate(converter, val, obj=None): if PY2 and isinstance(val, str): val = val.decode('ascii') elif not isinstance(val, unicode): throw(TypeError, 'Value type for attribute %s must be %s. Got: %r' % (converter.attr, unicode.__name__, type(val))) if converter.autostrip: val = val.strip() max_len = converter.max_len val_len = len(val) if max_len and val_len > max_len: throw(ValueError, 'Value for attribute %s is too long. Max length is %d, value length is %d' % (converter.attr, max_len, val_len)) return val def sql_type(converter): if converter.max_len: return 'VARCHAR(%d)' % converter.max_len return 'TEXT' class IntConverter(Converter): signed_types = {None: 'INTEGER', 8: 'TINYINT', 16: 'SMALLINT', 24: 'MEDIUMINT', 32: 'INTEGER', 64: 'BIGINT'} unsigned_types = None def init(converter, kwargs): Converter.init(converter, kwargs) attr = converter.attr min_val = kwargs.pop('min', None) if min_val is not None and not isinstance(min_val, int_types): throw(TypeError, "'min' argument for attribute %s must be int. Got: %r" % (attr, min_val)) max_val = kwargs.pop('max', None) if max_val is not None and not isinstance(max_val, int_types): throw(TypeError, "'max' argument for attribute %s must be int. Got: %r" % (attr, max_val)) size = kwargs.pop('size', None) if size is None: if attr.py_type.__name__ == 'long': deprecated(9, "Attribute %s: 'long' attribute type is deprecated. " "Please use 'int' type with size=64 option instead" % attr) attr.py_type = int size = 64 elif attr.py_type.__name__ == 'long': throw(TypeError, "Attribute %s: 'size' option cannot be used with long type. Please use int type instead" % attr) elif not isinstance(size, int_types): throw(TypeError, "'size' option for attribute %s must be of int type. Got: %r" % (attr, size)) elif size not in (8, 16, 24, 32, 64): throw(TypeError, "incorrect value of 'size' option for attribute %s. " "Should be 8, 16, 24, 32 or 64. Got: %d" % (attr, size)) unsigned = kwargs.pop('unsigned', False) if unsigned is not None and not isinstance(unsigned, bool): throw(TypeError, "'unsigned' option for attribute %s must be of bool type. Got: %r" % (attr, unsigned)) if size == 64 and unsigned and not converter.provider.uint64_support: throw(TypeError, 'Attribute %s: %s provider does not support unsigned bigint type' % (attr, converter.provider.dialect)) if unsigned is not None and size is None: size = 32 lowest = highest = None if size: highest = highest = 2 ** size - 1 if unsigned else 2 ** (size - 1) - 1 lowest = 0 if unsigned else -(2 ** (size - 1)) if highest is not None and max_val is not None and max_val > highest: throw(ValueError, "'max' argument should be less or equal to %d because of size=%d and unsigned=%s. " "Got: %d" % (highest, size, max_val, unsigned)) if lowest is not None and min_val is not None and min_val < lowest: throw(ValueError, "'min' argument should be greater or equal to %d because of size=%d and unsigned=%s. " "Got: %d" % (lowest, size, min_val, unsigned)) converter.min_val = min_val or lowest converter.max_val = max_val or highest converter.size = size converter.unsigned = unsigned def validate(converter, val, obj=None): if isinstance(val, int_types): pass elif hasattr(val, '__index__'): val = val.__index__() elif isinstance(val, basestring): try: val = int(val) except ValueError: throw(ValueError, 'Value type for attribute %s must be int. Got string %r' % (converter.attr, val)) else: throw(TypeError, 'Value type for attribute %s must be int. Got: %r' % (converter.attr, type(val))) if converter.min_val and val < converter.min_val: throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' % (val, converter.attr, converter.min_val)) if converter.max_val and val > converter.max_val: throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' % (val, converter.attr, converter.max_val)) return val def sql2py(converter, val): return int(val) def sql_type(converter): if not converter.unsigned: return converter.signed_types.get(converter.size) if converter.unsigned_types is None: return converter.signed_types.get(converter.size) + ' UNSIGNED' return converter.unsigned_types.get(converter.size) class RealConverter(Converter): EQ = 'FLOAT_EQ' NE = 'FLOAT_NE' # The tolerance is necessary for Oracle, because it has different representation of float numbers. # For other databases the default tolerance is set because the precision can be lost during # Python -> JavaScript -> Python conversion default_tolerance = 1e-14 optimistic = False def init(converter, kwargs): Converter.init(converter, kwargs) min_val = kwargs.pop('min', None) if min_val is not None: try: min_val = float(min_val) except ValueError: throw(TypeError, "Invalid value for 'min' argument for attribute %s: %r" % (converter.attr, min_val)) max_val = kwargs.pop('max', None) if max_val is not None: try: max_val = float(max_val) except ValueError: throw(TypeError, "Invalid value for 'max' argument for attribute %s: %r" % (converter.attr, max_val)) converter.min_val = min_val converter.max_val = max_val converter.tolerance = kwargs.pop('tolerance', converter.default_tolerance) def validate(converter, val, obj=None): try: val = float(val) except ValueError: throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) if converter.min_val and val < converter.min_val: throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' % (val, converter.attr, converter.min_val)) if converter.max_val and val > converter.max_val: throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' % (val, converter.attr, converter.max_val)) return val def dbvals_equal(converter, x, y): tolerance = converter.tolerance if tolerance is None or x is None or y is None: return x == y denominator = max(abs(x), abs(y)) if not denominator: return True diff = abs(x-y) / denominator return diff <= tolerance def sql2py(converter, val): return float(val) def sql_type(converter): return 'REAL' class DecimalConverter(Converter): def __init__(converter, provider, py_type, attr=None): converter.exp = None # for the case when attr is None Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr args = attr.args if len(args) > 2: throw(TypeError, 'Too many positional parameters for Decimal ' '(expected: precision and scale), got: %s' % args) if args: precision = args[0] else: precision = kwargs.pop('precision', 12) if not isinstance(precision, int_types): throw(TypeError, "'precision' positional argument for attribute %s must be int. Got: %r" % (attr, precision)) if precision <= 0: throw(TypeError, "'precision' positional argument for attribute %s must be positive. Got: %r" % (attr, precision)) if len(args) == 2: scale = args[1] else: scale = kwargs.pop('scale', 2) if not isinstance(scale, int_types): throw(TypeError, "'scale' positional argument for attribute %s must be int. Got: %r" % (attr, scale)) if scale <= 0: throw(TypeError, "'scale' positional argument for attribute %s must be positive. Got: %r" % (attr, scale)) if scale > precision: throw(ValueError, "'scale' must be less or equal 'precision'") converter.precision = precision converter.scale = scale converter.exp = Decimal(10) ** -scale min_val = kwargs.pop('min', None) if min_val is not None: try: min_val = Decimal(min_val) except TypeError: throw(TypeError, "Invalid value for 'min' argument for attribute %s: %r" % (attr, min_val)) max_val = kwargs.pop('max', None) if max_val is not None: try: max_val = Decimal(max_val) except TypeError: throw(TypeError, "Invalid value for 'max' argument for attribute %s: %r" % (attr, max_val)) converter.min_val = min_val converter.max_val = max_val def validate(converter, val, obj=None): if isinstance(val, float): s = str(val) if float(s) != val: s = repr(val) val = Decimal(s) try: val = Decimal(val) except InvalidOperation as exc: throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) if converter.min_val is not None and val < converter.min_val: throw(ValueError, 'Value %r of attr %s is less than the minimum allowed value %r' % (val, converter.attr, converter.min_val)) if converter.max_val is not None and val > converter.max_val: throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' % (val, converter.attr, converter.max_val)) return val def sql2py(converter, val): return Decimal(val) def sql_type(converter): return 'DECIMAL(%d, %d)' % (converter.precision, converter.scale) class BlobConverter(Converter): def validate(converter, val, obj=None): if isinstance(val, buffer): return val if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) def sql2py(converter, val): if not isinstance(val, buffer): try: val = buffer(val) except: pass elif PY2 and converter.attr is not None and converter.attr.is_part_of_unique_index: try: hash(val) except TypeError: val = buffer(val) return val def sql_type(converter): return 'BLOB' class DateConverter(Converter): def validate(converter, val, obj=None): if isinstance(val, datetime): return val.date() if isinstance(val, date): return val if isinstance(val, basestring): return str2date(val) throw(TypeError, "Attribute %r: expected type is 'date'. Got: %r" % (converter.attr, val)) def sql2py(converter, val): if not isinstance(val, date): throw(ValueError, 'Value of unexpected type received from database: instead of date got %s' % type(val)) return val def sql_type(converter): return 'DATE' class ConverterWithMicroseconds(Converter): def __init__(converter, provider, py_type, attr=None): converter.precision = None # for the case when attr is None Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr args = attr.args if len(args) > 1: throw(TypeError, 'Too many positional parameters for attribute %s. ' 'Expected: precision, got: %r' % (attr, args)) provider = attr.entity._database_.provider if args: precision = args[0] if 'precision' in kwargs: throw(TypeError, 'Precision for attribute %s has both positional and keyword value' % attr) else: precision = kwargs.pop('precision', provider.default_time_precision) if not isinstance(precision, int) or not 0 <= precision <= 6: throw(ValueError, 'Precision value of attribute %s must be between 0 and 6. Got: %r' % (attr, precision)) if precision > provider.max_time_precision: throw(ValueError, 'Precision value (%d) of attribute %s exceeds max datetime precision (%d) of %s %s' % (precision, attr, provider.max_time_precision, provider.dialect, provider.server_version)) converter.precision = precision def round_microseconds_to_precision(converter, microseconds, precision): # returns None if no change is required if not precision: result = 0 elif precision < 6: rounding = 10 ** (6-precision) result = (microseconds // rounding) * rounding else: return None return result if result != microseconds else None def sql_type(converter): attr = converter.attr precision = converter.precision if not attr or precision == attr.entity._database_.provider.default_time_precision: return converter.sql_type_name return converter.sql_type_name + '(%d)' % precision class TimeConverter(ConverterWithMicroseconds): sql_type_name = 'TIME' def validate(converter, val, obj=None): if isinstance(val, time): pass elif isinstance(val, basestring): val = str2time(val) else: throw(TypeError, "Attribute %r: expected type is 'time'. Got: %r" % (converter.attr, val)) mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision) if mcs is not None: val = val.replace(microsecond=mcs) return val def sql2py(converter, val): if not isinstance(val, time): throw(ValueError, 'Value of unexpected type received from database: instead of time got %s' % type(val)) return val class TimedeltaConverter(ConverterWithMicroseconds): sql_type_name = 'INTERVAL' def validate(converter, val, obj=None): if isinstance(val, timedelta): pass elif isinstance(val, basestring): val = str2timedelta(val) else: throw(TypeError, "Attribute %r: expected type is 'timedelta'. Got: %r" % (converter.attr, val)) mcs = converter.round_microseconds_to_precision(val.microseconds, converter.precision) if mcs is not None: val = timedelta(val.days, val.seconds, mcs) return val def sql2py(converter, val): if not isinstance(val, timedelta): throw(ValueError, 'Value of unexpected type received from database: instead of time got %s' % type(val)) return val class DatetimeConverter(ConverterWithMicroseconds): sql_type_name = 'DATETIME' def validate(converter, val, obj=None): if isinstance(val, datetime): pass elif isinstance(val, basestring): val = str2datetime(val) else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val)) mcs = converter.round_microseconds_to_precision(val.microsecond, converter.precision) if mcs is not None: val = val.replace(microsecond=mcs) return val def sql2py(converter, val): if not isinstance(val, datetime): throw(ValueError, 'Value of unexpected type received from database: instead of datetime got %s' % type(val)) return val class UuidConverter(Converter): def __init__(converter, provider, py_type, attr=None): if attr is not None and attr.auto: attr.auto = False if not attr.default: attr.default = uuid4 Converter.__init__(converter, provider, py_type, attr) def validate(converter, val, obj=None): if isinstance(val, UUID): return val if isinstance(val, buffer): return UUID(bytes=val) if isinstance(val, basestring): if len(val) == 16: return UUID(bytes=val) return UUID(hex=val) if isinstance(val, int): return UUID(int=val) if converter.attr is not None: throw(ValueError, 'Value type of attribute %s must be UUID. Got: %r' % (converter.attr, type(val))) else: throw(ValueError, 'Expected UUID value, got: %r' % type(val)) def py2sql(converter, val): return buffer(val.bytes) sql2py = validate def sql_type(converter): return "UUID" class JsonConverter(Converter): json_kwargs = {} class JsonEncoder(json.JSONEncoder): def default(converter, obj): if isinstance(obj, Json): return obj.wrapped return json.JSONEncoder.default(converter, obj) def validate(converter, val, obj=None): if obj is None or converter.attr is None: return val if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: return val return TrackedValue.make(obj, converter.attr, val) def val2dbval(converter, val, obj=None): return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs) def dbval2val(converter, dbval, obj=None): if isinstance(dbval, (int, bool, float, type(None))): return dbval val = json.loads(dbval) if obj is None: return val return TrackedValue.make(obj, converter.attr, val) def dbvals_equal(converter, x, y): if x == y: return True # optimization if isinstance(x, basestring): x = json.loads(x) if isinstance(y, basestring): y = json.loads(y) return x == y def sql_type(converter): return "JSON" class ArrayConverter(Converter): array_types = { int: ('int', IntConverter), unicode: ('text', StrConverter), float: ('real', RealConverter) } def __init__(converter, provider, py_type, attr=None): Converter.__init__(converter, provider, py_type, attr) converter.item_converter = converter.array_types[converter.py_type.item_type][1] def validate(converter, val, obj=None): if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: return val if isinstance(val, basestring) or not hasattr(val, '__len__'): items = [val] else: items = list(val) item_type = converter.py_type.item_type if item_type == float: item_type = (float, int) for i, v in enumerate(items): if PY2 and isinstance(v, str): v = v.decode('ascii') if not isinstance(v, item_type): if hasattr(v, '__index__'): items[i] = v.__index__() else: throw(TypeError, 'Cannot store %s item in array of %s' % (type(v).__name__, converter.py_type.item_type.__name__)) if obj is None or converter.attr is None: return items return TrackedArray(obj, converter.attr, items) def dbval2val(converter, dbval, obj=None): if obj is None or dbval is None: return dbval return TrackedArray(obj, converter.attr, dbval) def val2dbval(converter, val, obj=None): return list(val) def sql_type(converter): return '%s[]' % converter.array_types[converter.py_type.item_type][0] ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1571864710.413453 pony-0.7.11/pony/orm/dbproviders/0000777000000000000000000000000000000000000015025 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/dbproviders/__init__.py0000666000000000000000000000000000000000000017124 0ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbproviders/mysql.py0000666000000000000000000003676000000000000016560 0ustar0000000000000000from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types import json from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID NoneType = type(None) import warnings warnings.filterwarnings('ignore', '^Table.+already exists$', Warning, '^pony\\.orm\\.dbapiprovider$') try: import MySQLdb as mysql_module from MySQLdb import string_literal import MySQLdb.converters as mysql_converters from MySQLdb.constants import FIELD_TYPE, FLAG, CLIENT mysql_module_name = 'MySQLdb' except ImportError: try: import pymysql as mysql_module except ImportError: raise ImportError('In order to use PonyORM with MySQL please install MySQLdb or pymysql') from pymysql.converters import escape_str as string_literal import pymysql.converters as mysql_converters from pymysql.constants import FIELD_TYPE, FLAG, CLIENT mysql_module_name = 'pymysql' from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator, TranslationError from pony.orm.sqlbuilding import Value, Param, SQLBuilder, join from pony.utils import throw from pony.converting import str2timedelta, timedelta2str class MySQLColumn(dbschema.Column): auto_template = '%(type)s PRIMARY KEY AUTO_INCREMENT' class MySQLSchema(dbschema.DBSchema): dialect = 'MySQL' inline_fk_syntax = False column_class = MySQLColumn class MySQLTranslator(SQLTranslator): dialect = 'MySQL' json_path_wildcard_syntax = True class MySQLValue(Value): __slots__ = [] def __unicode__(self): value = self.value if isinstance(value, timedelta): if value.microseconds: return "INTERVAL '%s' HOUR_MICROSECOND" % timedelta2str(value) return "INTERVAL '%s' HOUR_SECOND" % timedelta2str(value) return Value.__unicode__(self) if not PY2: __str__ = __unicode__ class MySQLBuilder(SQLBuilder): dialect = 'MySQL' value_class = MySQLValue def CONCAT(builder, *args): return 'concat(', join(', ', imap(builder, args)), ')' def TRIM(builder, expr, chars=None): if chars is None: return 'trim(', builder(expr), ')' return 'trim(both ', builder(chars), ' from ' ,builder(expr), ')' def LTRIM(builder, expr, chars=None): if chars is None: return 'ltrim(', builder(expr), ')' return 'trim(leading ', builder(chars), ' from ' ,builder(expr), ')' def RTRIM(builder, expr, chars=None): if chars is None: return 'rtrim(', builder(expr), ')' return 'trim(trailing ', builder(chars), ' from ' ,builder(expr), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS SIGNED)' def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS DOUBLE)' def TO_STR(builder, expr): return 'CAST(', builder(expr), ' AS CHAR)' def YEAR(builder, expr): return 'year(', builder(expr), ')' def MONTH(builder, expr): return 'month(', builder(expr), ')' def DAY(builder, expr): return 'day(', builder(expr), ')' def HOUR(builder, expr): return 'hour(', builder(expr), ')' def MINUTE(builder, expr): return 'minute(', builder(expr), ')' def SECOND(builder, expr): return 'second(', builder(expr), ')' def DATE_ADD(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], time): return 'ADDTIME(', builder(expr), ', ', builder(delta), ')' return 'ADDDATE(', builder(expr), ', ', builder(delta), ')' def DATE_SUB(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], time): return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' return 'SUBDATE(', builder(expr), ', ', builder(delta), ')' def DATE_DIFF(builder, expr1, expr2): return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' def DATETIME_ADD(builder, expr, delta): return builder.DATE_ADD(expr, delta) def DATETIME_SUB(builder, expr, delta): return builder.DATE_SUB(expr, delta) def DATETIME_DIFF(builder, expr1, expr2): return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' def JSON_QUERY(builder, expr, path): path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'json_extract(', builder(expr), ', ', path_sql, ')' def JSON_VALUE(builder, expr, path, type): path_sql, has_params, has_wildcards = builder.build_json_path(path) result = 'json_extract(', builder(expr), ', ', path_sql, ')' if type is NoneType: return 'NULLIF(', result, ", CAST('null' as JSON))" if type in (bool, int): return 'CAST(', result, ' AS SIGNED)' if type is float: return 'CAST(', result, ' AS DOUBLE)' return 'json_unquote(', result, ')' def JSON_NONZERO(builder, expr): return 'COALESCE(CAST(', builder(expr), ''' as CHAR), 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' def JSON_EQ(builder, left, right): return '(', builder(left), ' = CAST(', builder(right), ' AS JSON))' def JSON_NE(builder, left, right): return '(', builder(left), ' != CAST(', builder(right), ' AS JSON))' def JSON_CONTAINS(builder, expr, path, key): key_sql = builder(key) if isinstance(key_sql, Value): wrapped_key = builder.value_class(builder.paramstyle, json.dumps([ key_sql.value ])) elif isinstance(key_sql, Param): wrapped_key = builder.make_composite_param( (key_sql.paramkey,), [key_sql], builder.wrap_param_to_json_array) else: assert False expr_sql = builder(expr) result = [ '(json_contains(', expr_sql, ', ', wrapped_key ] path_sql, has_params, has_wildcards = builder.build_json_path(path) if has_wildcards: throw(TranslationError, 'Wildcards are not allowed in json_contains()') path_with_key_sql, _, _ = builder.build_json_path(path + [key]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result @classmethod def wrap_param_to_json_array(cls, values): return json.dumps(values) def JSON_PARAM(builder, expr): return 'CAST(', builder(expr), ' AS JSON)' class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): result = 'VARCHAR(%d)' % converter.max_len if converter.max_len else 'LONGTEXT' if converter.db_encoding: result += ' CHARACTER SET %s' % converter.db_encoding return result class MySQLRealConverter(dbapiprovider.RealConverter): def sql_type(converter): return 'DOUBLE' class MySQLBlobConverter(dbapiprovider.BlobConverter): def sql_type(converter): return 'LONGBLOB' class MySQLTimeConverter(dbapiprovider.TimeConverter): def sql2py(converter, val): if isinstance(val, timedelta): # MySQLdb returns timedeltas instead of times total_seconds = val.days * (24 * 60 * 60) + val.seconds if 0 <= total_seconds <= 24 * 60 * 60: minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) return time(hours, minutes, seconds, val.microseconds) elif not isinstance(val, time): throw(ValueError, 'Value of unexpected type received from database%s: instead of time or timedelta got %s' % ('for attribute %s' % converter.attr if converter.attr else '', type(val))) return val class MySQLTimedeltaConverter(dbapiprovider.TimedeltaConverter): sql_type_name = 'TIME' class MySQLUuidConverter(dbapiprovider.UuidConverter): def sql_type(converter): return 'BINARY(16)' class MySQLJsonConverter(dbapiprovider.JsonConverter): EQ = 'JSON_EQ' NE = 'JSON_NE' def init(self, kwargs): if self.provider.server_version < (5, 7, 8): version = '.'.join(imap(str, self.provider.server_version)) raise NotImplementedError("MySQL %s has no JSON support" % version) class MySQLProvider(DBAPIProvider): dialect = 'MySQL' paramstyle = 'format' quote_char = "`" max_name_len = 64 max_params_count = 10000 table_if_not_exists_syntax = True index_if_not_exists_syntax = False max_time_precision = default_time_precision = 0 varchar_default_max_len = 255 uint64_support = True dbapi_module = mysql_module dbschema_cls = MySQLSchema translator_cls = MySQLTranslator sqlbuilder_cls = MySQLBuilder fk_types = { 'SERIAL' : 'BIGINT UNSIGNED' } converter_classes = [ (NoneType, dbapiprovider.NoneConverter), (bool, dbapiprovider.BoolConverter), (basestring, MySQLStrConverter), (int_types, dbapiprovider.IntConverter), (float, MySQLRealConverter), (Decimal, dbapiprovider.DecimalConverter), (datetime, dbapiprovider.DatetimeConverter), (date, dbapiprovider.DateConverter), (time, MySQLTimeConverter), (timedelta, MySQLTimedeltaConverter), (UUID, MySQLUuidConverter), (buffer, MySQLBlobConverter), (ormtypes.Json, MySQLJsonConverter), ] def normalize_name(provider, name): return name[:provider.max_name_len].lower() @wrap_dbapi_exceptions def inspect_connection(provider, connection): cursor = connection.cursor() cursor.execute('select version()') row = cursor.fetchone() assert row is not None provider.server_version = get_version_tuple(row[0]) if provider.server_version >= (5, 6, 4): provider.max_time_precision = 6 cursor.execute('select database()') provider.default_schema_name = cursor.fetchone()[0] cursor.execute('set session group_concat_max_len = 4294967295') def should_reconnect(provider, exc): return isinstance(exc, mysql_module.OperationalError) and exc.args[0] in (2006, 2013) def get_pool(provider, *args, **kwargs): if 'conv' not in kwargs: conv = mysql_converters.conversions.copy() if mysql_module_name == 'MySQLdb': conv[FIELD_TYPE.BLOB] = [(FLAG.BINARY, buffer)] else: if PY2: def encode_buffer(val, encoders=None): return string_literal(str(val), encoders) conv[buffer] = encode_buffer def encode_timedelta(val, encoders=None): return string_literal(timedelta2str(val), encoders) conv[timedelta] = encode_timedelta conv[FIELD_TYPE.TIMESTAMP] = str2datetime conv[FIELD_TYPE.DATETIME] = str2datetime conv[FIELD_TYPE.TIME] = str2timedelta kwargs['conv'] = conv if 'charset' not in kwargs: kwargs['charset'] = 'utf8' kwargs['client_flag'] = kwargs.get('client_flag', 0) | CLIENT.FOUND_ROWS return Pool(mysql_module, *args, **kwargs) @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction db_session = cache.db_session if db_session is not None and db_session.ddl: cursor = connection.cursor() cursor.execute("SHOW VARIABLES LIKE 'foreign_key_checks'") fk = cursor.fetchone() if fk is not None: fk = (fk[1] == 'ON') if fk: sql = 'SET foreign_key_checks = 0' if core.local.debug: log_orm(sql) cursor.execute(sql) cache.saved_fk_state = bool(fk) cache.in_transaction = True cache.immediate = True if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' if core.local.debug: log_orm(sql) cursor.execute(sql) cache.in_transaction = True @wrap_dbapi_exceptions def release(provider, connection, cache=None): if cache is not None: db_session = cache.db_session if db_session is not None and db_session.ddl and cache.saved_fk_state: try: cursor = connection.cursor() sql = 'SET foreign_key_checks = 1' if core.local.debug: log_orm(sql) cursor.execute(sql) except: provider.pool.drop(connection) raise DBAPIProvider.release(provider, connection, cache) def table_exists(provider, connection, table_name, case_sensitive=True): db_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() if case_sensitive: sql = 'SELECT table_name FROM information_schema.tables ' \ 'WHERE table_schema=%s and table_name=%s' else: sql = 'SELECT table_name FROM information_schema.tables ' \ 'WHERE table_schema=%s and UPPER(table_name)=UPPER(%s)' cursor.execute(sql, [ db_name, table_name ]) row = cursor.fetchone() return row[0] if row is not None else None def index_exists(provider, connection, table_name, index_name, case_sensitive=True): db_name, table_name = provider.split_table_name(table_name) if case_sensitive: sql = 'SELECT index_name FROM information_schema.statistics ' \ 'WHERE table_schema=%s and table_name=%s and index_name=%s' else: sql = 'SELECT index_name FROM information_schema.statistics ' \ 'WHERE table_schema=%s and table_name=%s and UPPER(index_name)=UPPER(%s)' cursor = connection.cursor() cursor.execute(sql, [ db_name, table_name, index_name ]) row = cursor.fetchone() return row[0] if row is not None else None def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): db_name, table_name = provider.split_table_name(table_name) if case_sensitive: sql = 'SELECT constraint_name FROM information_schema.table_constraints ' \ 'WHERE table_schema=%s and table_name=%s ' \ "and constraint_type='FOREIGN KEY' and constraint_name=%s" else: sql = 'SELECT constraint_name FROM information_schema.table_constraints ' \ 'WHERE table_schema=%s and table_name=%s ' \ "and constraint_type='FOREIGN KEY' and UPPER(constraint_name)=UPPER(%s)" cursor = connection.cursor() cursor.execute(sql, [ db_name, table_name, fk_name ]) row = cursor.fetchone() return row[0] if row is not None else None provider_cls = MySQLProvider def str2datetime(s): if 19 < len(s) < 26: s += '000000'[:26-len(s)] s = s.replace('-', ' ').replace(':', ' ').replace('.', ' ').replace('T', ' ') try: return datetime(*imap(int, s.split())) except ValueError: return None # for incorrect values like 0000-00-00 00:00:00 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbproviders/oracle.py0000666000000000000000000006631400000000000016656 0ustar0000000000000000from __future__ import absolute_import from pony.py23compat import PY2, iteritems, basestring, unicode, buffer, int_types import os os.environ["NLS_LANG"] = "AMERICAN_AMERICA.UTF8" import re from datetime import datetime, date, time, timedelta from decimal import Decimal from uuid import UUID import cx_Oracle from pony.orm import core, dbapiprovider, sqltranslation from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column from pony.orm.ormtypes import Json from pony.orm.sqlbuilding import SQLBuilder from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw, is_ident NoneType = type(None) class OraTable(Table): def get_objects_to_create(table, created_tables=None): result = Table.get_objects_to_create(table, created_tables) for column in table.column_list: if column.is_pk == 'auto': sequence_name = column.converter.attr.kwargs.get('sequence_name') sequence = OraSequence(table, sequence_name) trigger = OraTrigger(table, column, sequence) result.extend((sequence, trigger)) break return result class OraSequence(DBObject): typename = 'Sequence' def __init__(sequence, table, name=None): sequence.table = table table_name = table.name if name is not None: sequence.name = name elif isinstance(table_name, basestring): sequence.name = table_name + '_SEQ' else: sequence.name = tuple(table_name[:-1]) + (table_name[0] + '_SEQ',) def exists(sequence, provider, connection, case_sensitive=True): if case_sensitive: sql = 'SELECT sequence_name FROM all_sequences ' \ 'WHERE sequence_owner = :so and sequence_name = :sn' else: sql = 'SELECT sequence_name FROM all_sequences ' \ 'WHERE sequence_owner = :so and upper(sequence_name) = upper(:sn)' owner_name, sequence_name = provider.split_table_name(sequence.name) cursor = connection.cursor() cursor.execute(sql, dict(so=owner_name, sn=sequence_name)) row = cursor.fetchone() return row[0] if row is not None else None def get_create_command(sequence): schema = sequence.table.schema seq_name = schema.provider.quote_name(sequence.name) return schema.case('CREATE SEQUENCE %s NOCACHE') % seq_name trigger_template = """ CREATE TRIGGER %s BEFORE INSERT ON %s FOR EACH ROW BEGIN IF :new.%s IS NULL THEN SELECT %s.nextval INTO :new.%s FROM DUAL; END IF; END;""".strip() class OraTrigger(DBObject): typename = 'Trigger' def __init__(trigger, table, column, sequence): trigger.table = table trigger.column = column trigger.sequence = sequence table_name = table.name if not isinstance(table_name, basestring): table_name = table_name[-1] trigger.name = table_name + '_BI' # Before Insert def exists(trigger, provider, connection, case_sensitive=True): if case_sensitive: sql = 'SELECT trigger_name FROM all_triggers ' \ 'WHERE table_name = :tbn AND table_owner = :o ' \ 'AND trigger_name = :trn AND owner = :o' else: sql = 'SELECT trigger_name FROM all_triggers ' \ 'WHERE table_name = :tbn AND table_owner = :o ' \ 'AND upper(trigger_name) = upper(:trn) AND owner = :o' owner_name, table_name = provider.split_table_name(trigger.table.name) cursor = connection.cursor() cursor.execute(sql, dict(tbn=table_name, trn=trigger.name, o=owner_name)) row = cursor.fetchone() return row[0] if row is not None else None def get_create_command(trigger): schema = trigger.table.schema quote_name = schema.provider.quote_name trigger_name = quote_name(trigger.name) table_name = quote_name(trigger.table.name) column_name = quote_name(trigger.column.name) seq_name = quote_name(trigger.sequence.name) return schema.case(trigger_template) % (trigger_name, table_name, column_name, seq_name, column_name) class OraColumn(Column): auto_template = None class OraSchema(DBSchema): dialect = 'Oracle' table_class = OraTable column_class = OraColumn class OraNoneMonad(sqltranslation.NoneMonad): def __init__(monad, value=None): assert value in (None, '') sqltranslation.ConstMonad.__init__(monad, None) class OraConstMonad(sqltranslation.ConstMonad): @staticmethod def new(value): if value == '': value = None return sqltranslation.ConstMonad.new(value) class OraTranslator(sqltranslation.SQLTranslator): dialect = 'Oracle' rowid_support = True json_path_wildcard_syntax = True json_values_are_comparable = False NoneMonad = OraNoneMonad ConstMonad = OraConstMonad class OraBuilder(SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): result = SQLBuilder.INSERT(builder, table_name, columns, values) if returning is not None: result.extend((' RETURNING ', builder.quote_name(returning), ' INTO :new_id')) return result def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent nowait = ' NOWAIT' if nowait else '' skip_locked = ' SKIP LOCKED' if skip_locked else '' last_section = sections[-1] if last_section[0] != 'LIMIT': return builder.SELECT(*sections), 'FOR UPDATE', nowait, skip_locked, '\n' from_section = sections[1] assert from_section[0] == 'FROM' if len(from_section) > 2: throw(NotImplementedError, 'Table joins are not supported for Oracle queries which have both FOR UPDATE and ROWNUM') order_by_section = None for section in sections: if section[0] == 'ORDER_BY': order_by_section = section table_ast = from_section[1] assert len(table_ast) == 3 and table_ast[1] == 'TABLE' table_alias = table_ast[0] rowid = [ 'COLUMN', table_alias, 'ROWID' ] sql_ast = [ 'SELECT', sections[0], [ 'FROM', table_ast ], [ 'WHERE', [ 'IN', rowid, ('SELECT', [ 'ROWID', ['AS', rowid, 'row-id' ] ]) + sections[1:] ] ] ] if order_by_section: sql_ast.append(order_by_section) result = builder(sql_ast) return result, 'FOR UPDATE', nowait, skip_locked, '\n' def SELECT(builder, *sections): prev_suppress_aliases = builder.suppress_aliases builder.suppress_aliases = False try: last_section = sections[-1] limit = offset = None if last_section[0] == 'LIMIT': limit = last_section[1] if len(last_section) > 2: offset = last_section[2] sections = sections[:-1] result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent if sections[0][0] == 'ROWID': indent0 = builder.indent_spaces x = 't."row-id"' else: indent0 = '' x = 't.*' if not limit and not offset: pass elif not offset: result = [ indent0, 'SELECT * FROM (\n' ] builder.indent += 1 result.extend(builder._subquery(*sections)) builder.indent -= 1 result.extend((indent, ') WHERE ROWNUM <= %d\n' % limit)) else: indent2 = indent + builder.indent_spaces result = [ indent0, 'SELECT %s FROM (\n' % x, indent2, 'SELECT t.*, ROWNUM "row-num" FROM (\n' ] builder.indent += 2 result.extend(builder._subquery(*sections)) builder.indent -= 2 if limit is None: result.append('%s) t\n' % indent2) result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) else: result.append('%s) t WHERE ROWNUM <= %d\n' % (indent2, limit + offset)) result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' return result finally: builder.suppress_aliases = prev_suppress_aliases def ROWID(builder, *expr_list): return builder.ALL(*expr_list) def LIMIT(builder, limit, offset=None): assert False # pragma: no cover def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS NUMBER)' def TO_STR(builder, expr): return 'TO_CHAR(', builder(expr), ')' def DATE(builder, expr): return 'TRUNC(', builder(expr), ')' def RANDOM(builder): return 'dbms_random.value' def MOD(builder, a, b): return 'MOD(', builder(a), ', ', builder(b), ')' def DATE_ADD(builder, expr, delta): return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def DATE_DIFF(builder, expr1, expr2): return builder(expr1), ' - ', builder(expr2) def DATETIME_ADD(builder, expr, delta): return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def DATETIME_DIFF(builder, expr1, expr2): return builder(expr1), ' - ', builder(expr2) def build_json_path(builder, path): path_sql, has_params, has_wildcards = SQLBuilder.build_json_path(builder, path) if has_params: throw(TranslationError, "Oracle doesn't allow parameters in JSON paths") return path_sql, has_params, has_wildcards def JSON_QUERY(builder, expr, path): expr_sql = builder(expr) path_sql, has_params, has_wildcards = builder.build_json_path(path) if has_wildcards: return 'JSON_QUERY(', expr_sql, ', ', path_sql, ' WITH WRAPPER)' return 'REGEXP_REPLACE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '(^\\[|\\]$)', '')" json_value_type_mapping = {bool: 'NUMBER', int: 'NUMBER', float: 'NUMBER'} def JSON_VALUE(builder, expr, path, type): if type is Json: return builder.JSON_QUERY(expr, path) path_sql, has_params, has_wildcards = builder.build_json_path(path) type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') return 'JSON_VALUE(', builder(expr), ', ', path_sql, ' RETURNING ', type_name, ')' def JSON_NONZERO(builder, expr): return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) path_sql, has_params, has_wildcards = builder.build_json_path(path) path_with_key_sql, _, _ = builder.build_json_path(path + [ key ]) expr_sql = builder(expr) result = 'JSON_EXISTS(', expr_sql, ', ', path_with_key_sql, ')' if json_item_re.match(key[1]): item = r'"([^"]|\\")*"' list_start = r'\[\s*(%s\s*,\s*)*' % item list_end = r'\s*(,\s*%s\s*)*\]' % item pattern = r'%s"%s"%s' % (list_start, key[1], list_end) if has_wildcards: sublist = r'\[[^]]*\]' item_or_sublist = '(%s|%s)' % (item, sublist) wrapper_list_start = r'^\[\s*(%s\s*,\s*)*' % item_or_sublist wrapper_list_end = r'\s*(,\s*%s\s*)*\]$' % item_or_sublist pattern = r'%s%s%s' % (wrapper_list_start, pattern, wrapper_list_end) result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '%s')" % pattern else: pattern = '^%s$' % pattern result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, "), '%s')" % pattern return result def JSON_ARRAY_LENGTH(builder, value): throw(TranslationError, 'Oracle does not provide `length` function for JSON arrays') def GROUP_CONCAT(builder, distinct, expr, sep=None): assert distinct in (None, True, False) if distinct and builder.provider.server_version >= (19,): distinct = 'DISTINCT ' else: distinct = '' result = 'LISTAGG(', distinct, builder(expr) if sep is not None: result = result, ', ', builder(sep) else: result = result, ", ','" return result, ') WITHIN GROUP(ORDER BY 1)' json_item_re = re.compile('[\w\s]*') class OraBoolConverter(dbapiprovider.BoolConverter): if not PY2: def py2sql(converter, val): # Fixes cx_Oracle 5.1.3 Python 3 bug: # "DatabaseError: OCI-22062: invalid input string [True]" return int(val) def sql2py(converter, val): return bool(val) # TODO: True/False, T/F, Y/N, Yes/No, etc. def sql_type(converter): return "NUMBER(1)" class OraStrConverter(dbapiprovider.StrConverter): def validate(converter, val, obj=None): if val == '': return None return dbapiprovider.StrConverter.validate(converter, val) def sql2py(converter, val): if isinstance(val, cx_Oracle.LOB): val = val.read() if PY2: val = val.decode('utf8') return val def sql_type(converter): # TODO: Add support for NVARCHAR2 and NCLOB datatypes if converter.max_len: return 'VARCHAR2(%d CHAR)' % converter.max_len return 'CLOB' class OraIntConverter(dbapiprovider.IntConverter): signed_types = {None: 'NUMBER(38)', 8: 'NUMBER(3)', 16: 'NUMBER(5)', 24: 'NUMBER(7)', 32: 'NUMBER(10)', 64: 'NUMBER(19)'} unsigned_types = {None: 'NUMBER(38)', 8: 'NUMBER(3)', 16: 'NUMBER(5)', 24: 'NUMBER(8)', 32: 'NUMBER(10)', 64: 'NUMBER(20)'} def init(self, kwargs): dbapiprovider.IntConverter.init(self, kwargs) sequence_name = kwargs.pop('sequence_name', None) if sequence_name is not None and not (self.attr.auto and self.attr.is_pk): throw(TypeError, "Parameter 'sequence_name' can be used only for PrimaryKey attributes with auto=True") class OraRealConverter(dbapiprovider.RealConverter): # Note that Oracle has differnet representation of float numbers def sql_type(converter): return 'NUMBER' class OraDecimalConverter(dbapiprovider.DecimalConverter): def sql_type(converter): return 'NUMBER(%d, %d)' % (converter.precision, converter.scale) class OraBlobConverter(dbapiprovider.BlobConverter): def sql2py(converter, val): return buffer(val.read()) class OraDateConverter(dbapiprovider.DateConverter): def sql2py(converter, val): if isinstance(val, datetime): return val.date() if not isinstance(val, date): throw(ValueError, 'Value of unexpected type received from database: instead of date got %s', type(val)) return val class OraTimeConverter(dbapiprovider.TimeConverter): sql_type_name = 'INTERVAL DAY(0) TO SECOND' def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimeConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type converter.precision = 0 def sql2py(converter, val): if isinstance(val, timedelta): total_seconds = val.days * (24 * 60 * 60) + val.seconds if 0 <= total_seconds <= 24 * 60 * 60: minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) return time(hours, minutes, seconds, val.microseconds) elif not isinstance(val, time): throw(ValueError, 'Value of unexpected type received from database%s: instead of time or timedelta got %s' % ('for attribute %s' % converter.attr if converter.attr else '', type(val))) return val def py2sql(converter, val): return timedelta(hours=val.hour, minutes=val.minute, seconds=val.second, microseconds=val.microsecond) class OraTimedeltaConverter(dbapiprovider.TimedeltaConverter): sql_type_name = 'INTERVAL DAY TO SECOND' def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimedeltaConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type converter.precision = 0 class OraDatetimeConverter(dbapiprovider.DatetimeConverter): sql_type_name = 'TIMESTAMP' class OraUuidConverter(dbapiprovider.UuidConverter): def sql_type(converter): return 'RAW(16)' class OraJsonConverter(dbapiprovider.JsonConverter): json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} optimistic = False # CLOBs cannot be compared with strings, and TO_CHAR(CLOB) returns first 4000 chars only def sql2py(converter, dbval): if hasattr(dbval, 'read'): dbval = dbval.read() return dbapiprovider.JsonConverter.sql2py(converter, dbval) def sql_type(converter): return 'CLOB' class OraProvider(DBAPIProvider): dialect = 'Oracle' paramstyle = 'named' max_name_len = 30 table_if_not_exists_syntax = False index_if_not_exists_syntax = False varchar_default_max_len = 1000 uint64_support = True dbapi_module = cx_Oracle dbschema_cls = OraSchema translator_cls = OraTranslator sqlbuilder_cls = OraBuilder name_before_table = 'owner' converter_classes = [ (NoneType, dbapiprovider.NoneConverter), (bool, OraBoolConverter), (basestring, OraStrConverter), (int_types, OraIntConverter), (float, OraRealConverter), (Decimal, OraDecimalConverter), (datetime, OraDatetimeConverter), (date, OraDateConverter), (time, OraTimeConverter), (timedelta, OraTimedeltaConverter), (UUID, OraUuidConverter), (buffer, OraBlobConverter), (Json, OraJsonConverter), ] @wrap_dbapi_exceptions def inspect_connection(provider, connection): cursor = connection.cursor() cursor.execute('SELECT version FROM product_component_version ' "WHERE product LIKE 'Oracle Database %'") provider.server_version = get_version_tuple(cursor.fetchone()[0]) cursor.execute("SELECT sys_context( 'userenv', 'current_schema' ) FROM DUAL") provider.default_schema_name = cursor.fetchone()[0] def should_reconnect(provider, exc): reconnect_error_codes = ( 3113, # ORA-03113: end-of-file on communication channel 3114, # ORA-03114: not connected to ORACLE ) return isinstance(exc, cx_Oracle.OperationalError) \ and exc.args[0].code in reconnect_error_codes def normalize_name(provider, name): return name[:provider.max_name_len].upper() def normalize_vars(provider, vars, vartypes): DBAPIProvider.normalize_vars(provider, vars, vartypes) for key, value in iteritems(vars): if value == '': vars[key] = None vartypes[key] = NoneType @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction db_session = cache.db_session if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' if core.local.debug: log_orm(sql) cursor.execute(sql) cache.immediate = True if db_session is not None and (db_session.serializable or db_session.ddl): cache.in_transaction = True @wrap_dbapi_exceptions def execute(provider, cursor, sql, arguments=None, returning_id=False): if type(arguments) is list: assert arguments and not returning_id set_input_sizes(cursor, arguments[0]) cursor.executemany(sql, arguments) else: if arguments is not None: set_input_sizes(cursor, arguments) if returning_id: var = cursor.var(cx_Oracle.STRING, 40, cursor.arraysize, outconverter=int) arguments['new_id'] = var if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) value = var.getvalue() if isinstance(value, list): assert len(value) == 1 value = value[0] return value if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) def get_pool(provider, *args, **kwargs): user = password = dsn = None if len(args) == 1: conn_str = args[0] if '/' in conn_str: user, tail = conn_str.split('/', 1) if '@' in tail: password, dsn = tail.split('@', 1) if None in (user, password, dsn): throw(ValueError, "Incorrect connection string (must be in form of 'user/password@dsn')") elif len(args) == 2: user, password = args elif len(args) == 3: user, password, dsn = args elif args: throw(ValueError, 'Invalid number of positional arguments') def setdefault(kwargs, key, value): kwargs_value = kwargs.setdefault(key, value) if value is not None and value != kwargs_value: throw(ValueError, 'Ambiguous value for ' + key) setdefault(kwargs, 'user', user) setdefault(kwargs, 'password', password) setdefault(kwargs, 'dsn', dsn) kwargs.setdefault('threaded', True) kwargs.setdefault('min', 1) kwargs.setdefault('max', 10) kwargs.setdefault('increment', 1) return OraPool(**kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): owner_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() if case_sensitive: sql = 'SELECT table_name FROM all_tables WHERE owner = :o AND table_name = :tn' else: sql = 'SELECT table_name FROM all_tables WHERE owner = :o AND upper(table_name) = upper(:tn)' cursor.execute(sql, dict(o=owner_name, tn=table_name)) row = cursor.fetchone() return row[0] if row is not None else None def index_exists(provider, connection, table_name, index_name, case_sensitive=True): owner_name, table_name = provider.split_table_name(table_name) if not isinstance(index_name, basestring): throw(NotImplementedError) if case_sensitive: sql = 'SELECT index_name FROM all_indexes WHERE owner = :o ' \ 'AND index_name = :i AND table_owner = :o AND table_name = :t' else: sql = 'SELECT index_name FROM all_indexes WHERE owner = :o ' \ 'AND upper(index_name) = upper(:i) AND table_owner = :o AND table_name = :t' cursor = connection.cursor() cursor.execute(sql, dict(o=owner_name, i=index_name, t=table_name)) row = cursor.fetchone() return row[0] if row is not None else None def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): owner_name, table_name = provider.split_table_name(table_name) if not isinstance(fk_name, basestring): throw(NotImplementedError) if case_sensitive: sql = "SELECT constraint_name FROM user_constraints WHERE constraint_type = 'R' " \ 'AND table_name = :tn AND constraint_name = :cn AND owner = :o' else: sql = "SELECT constraint_name FROM user_constraints WHERE constraint_type = 'R' " \ 'AND table_name = :tn AND upper(constraint_name) = upper(:cn) AND owner = :o' cursor = connection.cursor() cursor.execute(sql, dict(tn=table_name, cn=fk_name, o=owner_name)) row = cursor.fetchone() return row[0] if row is not None else None def table_has_data(provider, connection, table_name): cursor = connection.cursor() cursor.execute('SELECT 1 FROM %s WHERE ROWNUM = 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def drop_table(provider, connection, table_name): cursor = connection.cursor() sql = 'DROP TABLE %s CASCADE CONSTRAINTS' % provider.quote_name(table_name) cursor.execute(sql) provider_cls = OraProvider def to_int_or_decimal(val): val = val.replace(',', '.') if '.' in val: return Decimal(val) return int(val) def to_decimal(val): return Decimal(val.replace(',', '.')) def output_type_handler(cursor, name, defaultType, size, precision, scale): if defaultType == cx_Oracle.NUMBER: if scale == 0: if precision: return cursor.var(cx_Oracle.STRING, 40, cursor.arraysize, outconverter=int) return cursor.var(cx_Oracle.STRING, 40, cursor.arraysize, outconverter=to_int_or_decimal) if scale != -127: return cursor.var(cx_Oracle.STRING, 100, cursor.arraysize, outconverter=to_decimal) elif defaultType in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): return cursor.var(unicode, size, cursor.arraysize) # from cx_Oracle example return None class OraPool(object): forked_pools = [] def __init__(pool, **kwargs): pool.kwargs = kwargs pool.cx_pool = cx_Oracle.SessionPool(**kwargs) pool.pid = os.getpid() def connect(pool): pid = os.getpid() if pool.pid != pid: pool.forked_pools.append((pool.cx_pool, pool.pid)) pool.cx_pool = cx_Oracle.SessionPool(**pool.kwargs) pool.pid = os.getpid() if core.local.debug: log_orm('GET CONNECTION') con = pool.cx_pool.acquire() con.outputtypehandler = output_type_handler return con, True def release(pool, con): pool.cx_pool.release(con) def drop(pool, con): pool.cx_pool.drop(con) def disconnect(pool): pass def get_inputsize(arg): if isinstance(arg, datetime): return cx_Oracle.TIMESTAMP return None def set_input_sizes(cursor, arguments): if type(arguments) is dict: input_sizes = {} for name, arg in iteritems(arguments): size = get_inputsize(arg) if size is not None: input_sizes[name] = size cursor.setinputsizes(**input_sizes) elif type(arguments) is tuple: input_sizes = map(get_inputsize, arguments) cursor.setinputsizes(*input_sizes) else: assert False, type(arguments) # pragma: no cover ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbproviders/postgres.py0000666000000000000000000003244100000000000017251 0ustar0000000000000000from __future__ import absolute_import from pony.py23compat import PY2, basestring, unicode, buffer, int_types from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID try: import psycopg2 except ImportError: try: from psycopg2cffi import compat except ImportError: raise ImportError('In order to use PonyORM with PostgreSQL please install psycopg2 or psycopg2cffi') else: compat.register() from psycopg2 import extensions import psycopg2.extras psycopg2.extras.register_uuid() psycopg2.extras.register_default_json(loads=lambda x: x) psycopg2.extras.register_default_jsonb(loads=lambda x: x) from pony.orm import core, dbschema, dbapiprovider, sqltranslation, ormtypes from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator from pony.orm.sqlbuilding import Value, SQLBuilder, join from pony.converting import timedelta2str from pony.utils import is_ident NoneType = type(None) class PGColumn(dbschema.Column): auto_template = 'SERIAL PRIMARY KEY' class PGSchema(dbschema.DBSchema): dialect = 'PostgreSQL' column_class = PGColumn class PGTranslator(SQLTranslator): dialect = 'PostgreSQL' class PGValue(Value): __slots__ = [] def __unicode__(self): value = self.value if isinstance(value, bool): return value and 'true' or 'false' return Value.__unicode__(self) if not PY2: __str__ = __unicode__ class PGSQLBuilder(SQLBuilder): dialect = 'PostgreSQL' value_class = PGValue def INSERT(builder, table_name, columns, values, returning=None): if not values: result = [ 'INSERT INTO ', builder.quote_name(table_name) ,' DEFAULT VALUES' ] else: result = SQLBuilder.INSERT(builder, table_name, columns, values) if returning is not None: result.extend([' RETURNING ', builder.quote_name(returning) ]) return result def TO_INT(builder, expr): return '(', builder(expr), ')::int' def TO_STR(builder, expr): return '(', builder(expr), ')::text' def TO_REAL(builder, expr): return '(', builder(expr), ')::double precision' def DATE(builder, expr): return '(', builder(expr), ')::date' def RANDOM(builder): return 'random()' def DATE_ADD(builder, expr, delta): return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def DATE_DIFF(builder, expr1, expr2): return builder(expr1), ' - ', builder(expr2) def DATETIME_ADD(builder, expr, delta): return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def DATETIME_DIFF(builder, expr1, expr2): return builder(expr1), ' - ', builder(expr2) def eval_json_path(builder, values): result = [] for value in values: if isinstance(value, int): result.append(str(value)) elif isinstance(value, basestring): result.append(value if is_ident(value) else '"%s"' % value.replace('"', '\\"')) else: assert False, value return '{%s}' % ','.join(result) def JSON_QUERY(builder, expr, path): path_sql, has_params, has_wildcards = builder.build_json_path(path) return '(', builder(expr), " #> ", path_sql, ')' json_value_type_mapping = {bool: 'boolean', int: 'int', float: 'real'} def JSON_VALUE(builder, expr, path, type): if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) path_sql, has_params, has_wildcards = builder.build_json_path(path) sql = '(', builder(expr), " #>> ", path_sql, ')' type_name = builder.json_value_type_mapping.get(type, 'text') return sql if type_name == 'text' else (sql, '::', type_name) def JSON_NONZERO(builder, expr): return 'coalesce(', builder(expr), ", 'null'::jsonb) NOT IN (" \ "'null'::jsonb, 'false'::jsonb, '0'::jsonb, '\"\"'::jsonb, '[]'::jsonb, '{}'::jsonb)" def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): return (builder.JSON_QUERY(expr, path) if path else builder(expr)), ' ? ', builder(key) def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' def GROUP_CONCAT(builder, distinct, expr, sep=None): assert distinct in (None, True, False) result = distinct and 'string_agg(distinct ' or 'string_agg(', builder(expr), '::text' if sep is not None: result = result, ', ', builder(sep) else: result = result, ", ','" return result, ')' def ARRAY_INDEX(builder, col, index): return builder(col), '[', builder(index), ']' def ARRAY_CONTAINS(builder, key, not_in, col): if not_in: return builder(key), ' <> ALL(', builder(col), ')' return builder(key), ' = ANY(', builder(col), ')' def ARRAY_SUBSET(builder, array1, not_in, array2): result = builder(array1), ' <@ ', builder(array2) if not_in: result = 'NOT (', result, ')' return result def ARRAY_LENGTH(builder, array): return 'COALESCE(ARRAY_LENGTH(', builder(array), ', 1), 0)' def ARRAY_SLICE(builder, array, start, stop): return builder(array), '[', builder(start) if start else '', ':', builder(stop) if stop else '', ']' def MAKE_ARRAY(builder, *items): return 'ARRAY[', join(', ', (builder(item) for item in items)), ']' class PGStrConverter(dbapiprovider.StrConverter): if PY2: def py2sql(converter, val): return val.encode('utf-8') def sql2py(converter, val): if isinstance(val, unicode): return val return val.decode('utf-8') class PGIntConverter(dbapiprovider.IntConverter): signed_types = {None: 'INTEGER', 8: 'SMALLINT', 16: 'SMALLINT', 24: 'INTEGER', 32: 'INTEGER', 64: 'BIGINT'} unsigned_types = {None: 'INTEGER', 8: 'SMALLINT', 16: 'INTEGER', 24: 'INTEGER', 32: 'BIGINT'} class PGRealConverter(dbapiprovider.RealConverter): def sql_type(converter): return 'DOUBLE PRECISION' class PGBlobConverter(dbapiprovider.BlobConverter): def sql_type(converter): return 'BYTEA' class PGTimedeltaConverter(dbapiprovider.TimedeltaConverter): sql_type_name = 'INTERVAL DAY TO SECOND' class PGDatetimeConverter(dbapiprovider.DatetimeConverter): sql_type_name = 'TIMESTAMP' class PGUuidConverter(dbapiprovider.UuidConverter): def py2sql(converter, val): return val class PGJsonConverter(dbapiprovider.JsonConverter): def sql_type(self): return "JSONB" class PGArrayConverter(dbapiprovider.ArrayConverter): array_types = { int: ('int', PGIntConverter), unicode: ('text', PGStrConverter), float: ('double precision', PGRealConverter) } class PGPool(Pool): def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) if 'client_encoding' not in pool.kwargs: pool.con.set_client_encoding('UTF8') def release(pool, con): assert con is pool.con try: con.rollback() con.autocommit = True cursor = con.cursor() cursor.execute('DISCARD ALL') con.autocommit = False except: pool.drop(con) raise class PGProvider(DBAPIProvider): dialect = 'PostgreSQL' paramstyle = 'pyformat' max_name_len = 63 max_params_count = 10000 index_if_not_exists_syntax = False dbapi_module = psycopg2 dbschema_cls = PGSchema translator_cls = PGTranslator sqlbuilder_cls = PGSQLBuilder array_converter_cls = PGArrayConverter default_schema_name = 'public' fk_types = { 'SERIAL' : 'INTEGER', 'BIGSERIAL' : 'BIGINT' } def normalize_name(provider, name): return name[:provider.max_name_len].lower() @wrap_dbapi_exceptions def inspect_connection(provider, connection): provider.server_version = connection.server_version provider.table_if_not_exists_syntax = provider.server_version >= 90100 def should_reconnect(provider, exc): return isinstance(exc, psycopg2.OperationalError) and exc.pgcode is None def get_pool(provider, *args, **kwargs): return PGPool(provider.dbapi_module, *args, **kwargs) @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate and connection.autocommit: connection.autocommit = False if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') db_session = cache.db_session if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' if core.local.debug: log_orm(sql) cursor.execute(sql) elif not cache.immediate and not connection.autocommit: connection.autocommit = True if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') if db_session is not None and (db_session.serializable or db_session.ddl): cache.in_transaction = True @wrap_dbapi_exceptions def execute(provider, cursor, sql, arguments=None, returning_id=False): if PY2 and isinstance(sql, unicode): sql = sql.encode('utf8') if type(arguments) is list: assert arguments and not returning_id cursor.executemany(sql, arguments) else: if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) if returning_id: return cursor.fetchone()[0] def table_exists(provider, connection, table_name, case_sensitive=True): schema_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() if case_sensitive: sql = 'SELECT tablename FROM pg_catalog.pg_tables ' \ 'WHERE schemaname = %s AND tablename = %s' else: sql = 'SELECT tablename FROM pg_catalog.pg_tables ' \ 'WHERE schemaname = %s AND lower(tablename) = lower(%s)' cursor.execute(sql, (schema_name, table_name)) row = cursor.fetchone() return row[0] if row is not None else None def index_exists(provider, connection, table_name, index_name, case_sensitive=True): schema_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() if case_sensitive: sql = 'SELECT indexname FROM pg_catalog.pg_indexes ' \ 'WHERE schemaname = %s AND tablename = %s AND indexname = %s' else: sql = 'SELECT indexname FROM pg_catalog.pg_indexes ' \ 'WHERE schemaname = %s AND tablename = %s AND lower(indexname) = lower(%s)' cursor.execute(sql, [ schema_name, table_name, index_name ]) row = cursor.fetchone() return row[0] if row is not None else None def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): schema_name, table_name = provider.split_table_name(table_name) if case_sensitive: sql = 'SELECT con.conname FROM pg_class cls ' \ 'JOIN pg_namespace ns ON cls.relnamespace = ns.oid ' \ 'JOIN pg_constraint con ON con.conrelid = cls.oid ' \ 'WHERE ns.nspname = %s AND cls.relname = %s ' \ "AND con.contype = 'f' AND con.conname = %s" else: sql = 'SELECT con.conname FROM pg_class cls ' \ 'JOIN pg_namespace ns ON cls.relnamespace = ns.oid ' \ 'JOIN pg_constraint con ON con.conrelid = cls.oid ' \ 'WHERE ns.nspname = %s AND cls.relname = %s ' \ "AND con.contype = 'f' AND lower(con.conname) = lower(%s)" cursor = connection.cursor() cursor.execute(sql, [ schema_name, table_name, fk_name ]) row = cursor.fetchone() return row[0] if row is not None else None def drop_table(provider, connection, table_name): cursor = connection.cursor() sql = 'DROP TABLE %s CASCADE' % provider.quote_name(table_name) cursor.execute(sql) converter_classes = [ (NoneType, dbapiprovider.NoneConverter), (bool, dbapiprovider.BoolConverter), (basestring, PGStrConverter), (int_types, PGIntConverter), (float, PGRealConverter), (Decimal, dbapiprovider.DecimalConverter), (datetime, PGDatetimeConverter), (date, dbapiprovider.DateConverter), (time, dbapiprovider.TimeConverter), (timedelta, PGTimedeltaConverter), (UUID, PGUuidConverter), (buffer, PGBlobConverter), (ormtypes.Json, PGJsonConverter), ] provider_cls = PGProvider ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbproviders/sqlite.py0000666000000000000000000006267700000000000016722 0ustar0000000000000000from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode import os.path, sys, re, json import sqlite3 as sqlite from decimal import Decimal from datetime import datetime, date, time, timedelta from random import random from time import strptime from threading import Lock from uuid import UUID from binascii import hexlify from functools import wraps from pony.orm import core, dbschema, dbapiprovider from pony.orm.core import log_orm from pony.orm.ormtypes import Json, TrackedArray from pony.orm.sqltranslation import SQLTranslator, StringExprMonad from pony.orm.sqlbuilding import SQLBuilder, Value, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \ cut_traceback_depth class SqliteExtensionUnavailable(Exception): pass NoneType = type(None) class SQLiteForeignKey(dbschema.ForeignKey): def get_create_command(foreign_key): assert False # pragma: no cover class SQLiteSchema(dbschema.DBSchema): dialect = 'SQLite' named_foreign_keys = False fk_class = SQLiteForeignKey def make_overriden_string_func(sqlop): def func(translator, monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator return StringExprMonad(monad.type, [ sqlop, sql[0] ]) func.__name__ = sqlop return func class SQLiteTranslator(SQLTranslator): dialect = 'SQLite' sqlite_version = sqlite.sqlite_version_info row_value_syntax = False rowid_support = True StringMixin_UPPER = make_overriden_string_func('PY_UPPER') StringMixin_LOWER = make_overriden_string_func('PY_LOWER') class SQLiteValue(Value): __slots__ = [] def __unicode__(self): value = self.value if isinstance(value, datetime): return self.quote_str(datetime2timestamp(value)) if isinstance(value, date): return self.quote_str(str(value)) if isinstance(value, timedelta): return repr(value.total_seconds() / (24 * 60 * 60)) return Value.__unicode__(self) if not PY2: __str__ = __unicode__ class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' least_func_name = 'min' greatest_func_name = 'max' value_class = SQLiteValue def __init__(builder, provider, ast): builder.json1_available = provider.json1_available SQLBuilder.__init__(builder, provider, ast) def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent return builder.SELECT(*sections) def INSERT(builder, table_name, columns, values, returning=None): if not values: return 'INSERT INTO %s DEFAULT VALUES' % builder.quote_name(table_name) return SQLBuilder.INSERT(builder, table_name, columns, values, returning) def TODAY(builder): return "date('now', 'localtime')" def NOW(builder): return "datetime('now', 'localtime')" def YEAR(builder, expr): return 'cast(substr(', builder(expr), ', 1, 4) as integer)' def MONTH(builder, expr): return 'cast(substr(', builder(expr), ', 6, 2) as integer)' def DAY(builder, expr): return 'cast(substr(', builder(expr), ', 9, 2) as integer)' def HOUR(builder, expr): return 'cast(substr(', builder(expr), ', 12, 2) as integer)' def MINUTE(builder, expr): return 'cast(substr(', builder(expr), ', 15, 2) as integer)' def SECOND(builder, expr): return 'cast(substr(', builder(expr), ', 18, 2) as integer)' def datetime_add(builder, funcname, expr, td): assert isinstance(td, timedelta) modifiers = [] seconds = td.seconds + td.days * 24 * 3600 sign = '+' if seconds > 0 else '-' seconds = abs(seconds) if seconds >= (24 * 3600): days = seconds // (24 * 3600) modifiers.append(", '%s%d days'" % (sign, days)) seconds -= days * 24 * 3600 if seconds >= 3600: hours = seconds // 3600 modifiers.append(", '%s%d hours'" % (sign, hours)) seconds -= hours * 3600 if seconds >= 60: minutes = seconds // 60 modifiers.append(", '%s%d minutes'" % (sign, minutes)) seconds -= minutes * 60 if seconds: modifiers.append(", '%s%d seconds'" % (sign, seconds)) if not modifiers: return builder(expr) return funcname, '(', builder(expr), modifiers, ')' def DATE_ADD(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): return builder.datetime_add('date', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): return builder.datetime_add('date', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' def DATE_DIFF(builder, expr1, expr2): return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def DATETIME_ADD(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): return builder.datetime_add('datetime', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): return builder.datetime_add('datetime', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' def DATETIME_DIFF(builder, expr1, expr2): return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def RANDOM(builder): return 'rand()' # return '(random() / 9223372036854775807.0 + 1.0) / 2.0' PY_UPPER = make_unary_func('py_upper') PY_LOWER = make_unary_func('py_lower') def FLOAT_EQ(builder, a, b): a, b = builder(a), builder(b) return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' def FLOAT_NE(builder, a, b): a, b = builder(a), builder(b) return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' def JSON_QUERY(builder, expr, path): fname = 'json_extract' if builder.json1_available else 'py_json_extract' path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))' json_value_type_mapping = {unicode: 'text', bool: 'integer', int: 'integer', float: 'real'} def JSON_VALUE(builder, expr, path, type): func_name = 'json_extract' if builder.json1_available else 'py_json_extract' path_sql, has_params, has_wildcards = builder.build_json_path(path) type_name = builder.json_value_type_mapping.get(type) result = func_name, '(', builder(expr), ', ', path_sql, ')' if type_name is not None: result = 'CAST(', result, ' as ', type_name, ')' return result def JSON_NONZERO(builder, expr): return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length' return func_name, '(', builder(value), ')' def JSON_CONTAINS(builder, expr, path, key): path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_contains(', builder(expr), ', ', path_sql, ', ', builder(key), ')' def ARRAY_INDEX(builder, col, index): return 'py_array_index(', builder(col), ', ', builder(index), ')' def ARRAY_CONTAINS(builder, key, not_in, col): return ('NOT ' if not_in else ''), 'py_array_contains(', builder(col), ', ', builder(key), ')' def ARRAY_SUBSET(builder, array1, not_in, array2): return ('NOT ' if not_in else ''), 'py_array_subset(', builder(array2), ', ', builder(array1), ')' def ARRAY_LENGTH(builder, array): return 'py_array_length(', builder(array), ')' def ARRAY_SLICE(builder, array, start, stop): return 'py_array_slice(', builder(array), ', ', \ builder(start) if start else 'null', ',',\ builder(stop) if stop else 'null', ')' def MAKE_ARRAY(builder, *items): return 'py_make_array(', join(', ', (builder(item) for item in items)), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): attr = converter.attr if attr is not None and attr.auto: return 'INTEGER' # Only this type can have AUTOINCREMENT option return dbapiprovider.IntConverter.sql_type(converter) class SQLiteDecimalConverter(dbapiprovider.DecimalConverter): inf = Decimal('infinity') neg_inf = Decimal('-infinity') NaN = Decimal('NaN') def sql2py(converter, val): try: val = Decimal(str(val)) except: return val exp = converter.exp if exp is not None: val = val.quantize(exp) return val def py2sql(converter, val): if type(val) is not Decimal: val = Decimal(val) exp = converter.exp if exp is not None: if val in (converter.inf, converter.neg_inf, converter.NaN): throw(ValueError, 'Cannot store %s Decimal value in database' % val) val = val.quantize(exp) return str(val) class SQLiteDateConverter(dbapiprovider.DateConverter): def sql2py(converter, val): try: time_tuple = strptime(val[:10], '%Y-%m-%d') return date(*time_tuple[:3]) except: return val def py2sql(converter, val): return val.strftime('%Y-%m-%d') class SQLiteTimeConverter(dbapiprovider.TimeConverter): def sql2py(converter, val): try: if len(val) <= 8: dt = datetime.strptime(val, '%H:%M:%S') else: dt = datetime.strptime(val, '%H:%M:%S.%f') return dt.time() except: return val def py2sql(converter, val): return val.isoformat() class SQLiteTimedeltaConverter(dbapiprovider.TimedeltaConverter): def sql2py(converter, val): return timedelta(days=val) def py2sql(converter, val): return val.days + (val.seconds + val.microseconds / 1000000.0) / 86400.0 class SQLiteDatetimeConverter(dbapiprovider.DatetimeConverter): def sql2py(converter, val): try: return timestamp2datetime(val) except: return val def py2sql(converter, val): return datetime2timestamp(val) class SQLiteJsonConverter(dbapiprovider.JsonConverter): json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} def dumps(items): return json.dumps(items, **SQLiteJsonConverter.json_kwargs) class SQLiteArrayConverter(dbapiprovider.ArrayConverter): array_types = { int: ('int', SQLiteIntConverter), unicode: ('text', dbapiprovider.StrConverter), float: ('real', dbapiprovider.RealConverter) } def dbval2val(converter, dbval, obj=None): if not dbval: return None items = json.loads(dbval) if obj is None: return items return TrackedArray(obj, converter.attr, items) def val2dbval(converter, val, obj=None): return dumps(val) class LocalExceptions(localbase): def __init__(self): self.exc_info = None self.keep_traceback = False local_exceptions = LocalExceptions() def keep_exception(func): @wraps(func) def new_func(*args): local_exceptions.exc_info = None try: return func(*args) except Exception: local_exceptions.exc_info = sys.exc_info() if not local_exceptions.keep_traceback: local_exceptions.exc_info = local_exceptions.exc_info[:2] + (None,) raise finally: local_exceptions.keep_traceback = False return new_func class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' local_exceptions = local_exceptions max_name_len = 1024 dbapi_module = sqlite dbschema_cls = SQLiteSchema translator_cls = SQLiteTranslator sqlbuilder_cls = SQLiteBuilder array_converter_cls = SQLiteArrayConverter name_before_table = 'db_name' server_version = sqlite.sqlite_version_info converter_classes = [ (NoneType, dbapiprovider.NoneConverter), (bool, dbapiprovider.BoolConverter), (basestring, dbapiprovider.StrConverter), (int_types, SQLiteIntConverter), (float, dbapiprovider.RealConverter), (Decimal, SQLiteDecimalConverter), (datetime, SQLiteDatetimeConverter), (date, SQLiteDateConverter), (time, SQLiteTimeConverter), (timedelta, SQLiteTimedeltaConverter), (UUID, dbapiprovider.UuidConverter), (buffer, dbapiprovider.BlobConverter), (Json, SQLiteJsonConverter) ] def __init__(provider, *args, **kwargs): DBAPIProvider.__init__(provider, *args, **kwargs) provider.pre_transaction_lock = Lock() provider.transaction_lock = Lock() @wrap_dbapi_exceptions def inspect_connection(provider, conn): DBAPIProvider.inspect_connection(provider, conn) provider.json1_available = provider.check_json1(conn) def restore_exception(provider): if provider.local_exceptions.exc_info is not None: try: reraise(*provider.local_exceptions.exc_info) finally: provider.local_exceptions.exc_info = None def acquire_lock(provider): provider.pre_transaction_lock.acquire() try: provider.transaction_lock.acquire() finally: provider.pre_transaction_lock.release() def release_lock(provider): provider.transaction_lock.release() @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate: provider.acquire_lock() try: cursor = connection.cursor() db_session = cache.db_session if db_session is not None and db_session.ddl: cursor.execute('PRAGMA foreign_keys') fk = cursor.fetchone() if fk is not None: fk = fk[0] if fk: sql = 'PRAGMA foreign_keys = false' if core.local.debug: log_orm(sql) cursor.execute(sql) cache.saved_fk_state = bool(fk) assert cache.immediate if cache.immediate: sql = 'BEGIN IMMEDIATE TRANSACTION' if core.local.debug: log_orm(sql) cursor.execute(sql) cache.in_transaction = True elif core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') finally: if cache.immediate and not cache.in_transaction: provider.release_lock() def commit(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction try: DBAPIProvider.commit(provider, connection, cache) finally: if in_transaction: cache.in_transaction = False provider.release_lock() def rollback(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction try: DBAPIProvider.rollback(provider, connection, cache) finally: if in_transaction: cache.in_transaction = False provider.release_lock() def drop(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction try: DBAPIProvider.drop(provider, connection, cache) finally: if in_transaction: cache.in_transaction = False provider.release_lock() @wrap_dbapi_exceptions def release(provider, connection, cache=None): if cache is not None: db_session = cache.db_session if db_session is not None and db_session.ddl and cache.saved_fk_state: try: cursor = connection.cursor() sql = 'PRAGMA foreign_keys = true' if core.local.debug: log_orm(sql) cursor.execute(sql) except: provider.pool.drop(connection) raise DBAPIProvider.release(provider, connection, cache) def get_pool(provider, filename, create_db=False, **kwargs): if filename != ':memory:': # When relative filename is specified, it is considered # not relative to cwd, but to user module where # Database instance is created # the list of frames: # 7 - user code: db = Database(...) # 6 - cut_traceback decorator wrapper # 5 - cut_traceback decorator # 4 - pony.orm.Database.__init__() / .bind() # 3 - pony.orm.Database._bind() # 2 - pony.dbapiprovider.DBAPIProvider.__init__() # 1 - SQLiteProvider.__init__() # 0 - pony.dbproviders.sqlite.get_pool() filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) return SQLitePool(filename, create_db, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): return provider._exists(connection, table_name, None, case_sensitive) def index_exists(provider, connection, table_name, index_name, case_sensitive=True): return provider._exists(connection, table_name, index_name, case_sensitive) def _exists(provider, connection, table_name, index_name=None, case_sensitive=True): db_name, table_name = provider.split_table_name(table_name) if db_name is None: catalog_name = 'sqlite_master' else: catalog_name = (db_name, 'sqlite_master') catalog_name = provider.quote_name(catalog_name) cursor = connection.cursor() if index_name is not None: sql = "SELECT name FROM %s WHERE type='index' AND name=?" % catalog_name if not case_sensitive: sql += ' COLLATE NOCASE' cursor.execute(sql, [ index_name ]) else: sql = "SELECT name FROM %s WHERE type='table' AND name=?" % catalog_name if not case_sensitive: sql += ' COLLATE NOCASE' cursor.execute(sql, [ table_name ]) row = cursor.fetchone() return row[0] if row is not None else None def fk_exists(provider, connection, table_name, fk_name): assert False # pragma: no cover def check_json1(provider, connection): cursor = connection.cursor() sql = ''' select json('{"this": "is", "a": ["test"]}')''' try: cursor.execute(sql) return True except sqlite.OperationalError: return False provider_cls = SQLiteProvider def _text_factory(s): return s.decode('utf8', 'replace') def make_string_function(name, base_func): def func(value): if value is None: return None t = type(value) if t is not unicode: if t is buffer: value = hexlify(value).decode('ascii') else: value = unicode(value) result = base_func(value) return result func.__name__ = name return func py_upper = make_string_function('py_upper', unicode.upper) py_lower = make_string_function('py_lower', unicode.lower) def py_json_unwrap(value): # [null,some-value] -> some-value if value is None: return None assert value.startswith('[null,'), value return value[6:-1] path_cache = {} json_path_re = re.compile(r'\[(-?\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE) def _parse_path(path): if path in path_cache: return path_cache[path] keys = None if isinstance(path, basestring) and path.startswith('$'): keys = [] pos = 1 path_len = len(path) while pos < path_len: match = json_path_re.match(path, pos) if match is not None: g1, g2, g3 = match.groups() keys.append(int(g1) if g1 else g2 or g3) pos = match.end() else: keys = None break else: keys = tuple(keys) path_cache[path] = keys return keys def _traverse(obj, keys): if keys is None: return None list_or_dict = (list, dict) for key in keys: if type(obj) not in list_or_dict: return None try: obj = obj[key] except (KeyError, IndexError): return None return obj def _extract(expr, *paths): expr = json.loads(expr) if isinstance(expr, basestring) else expr result = [] for path in paths: keys = _parse_path(path) result.append(_traverse(expr, keys)) return result[0] if len(paths) == 1 else result def py_json_extract(expr, *paths): result = _extract(expr, *paths) if type(result) in (list, dict): result = json.dumps(result, **SQLiteJsonConverter.json_kwargs) return result def py_json_query(expr, path, with_wrapper): result = _extract(expr, path) if type(result) not in (list, dict): if not with_wrapper: return None result = [result] return json.dumps(result, **SQLiteJsonConverter.json_kwargs) def py_json_value(expr, path): result = _extract(expr, path) return result if type(result) not in (list, dict) else None def py_json_contains(expr, path, key): expr = json.loads(expr) if isinstance(expr, basestring) else expr keys = _parse_path(path) expr = _traverse(expr, keys) return type(expr) in (list, dict) and key in expr def py_json_nonzero(expr, path): expr = json.loads(expr) if isinstance(expr, basestring) else expr keys = _parse_path(path) expr = _traverse(expr, keys) return bool(expr) def py_json_array_length(expr, path=None): expr = json.loads(expr) if isinstance(expr, basestring) else expr if path: keys = _parse_path(path) expr = _traverse(expr, keys) return len(expr) if type(expr) is list else 0 def wrap_array_func(func): @wraps(func) def new_func(array, *args): if array is None: return None array = json.loads(array) return func(array, *args) return new_func @wrap_array_func def py_array_index(array, index): try: return array[index] except IndexError: return None @wrap_array_func def py_array_contains(array, item): return item in array @wrap_array_func def py_array_subset(array, items): if items is None: return None items = json.loads(items) return set(items).issubset(set(array)) @wrap_array_func def py_array_length(array): return len(array) @wrap_array_func def py_array_slice(array, start, stop): return dumps(array[start:stop]) def py_make_array(*items): return dumps(items) class SQLitePool(Pool): def __init__(pool, filename, create_db, **kwargs): # called separately in each thread pool.filename = filename pool.create_db = create_db pool.kwargs = kwargs pool.con = None def _connect(pool): filename = pool.filename if filename != ':memory:' and not pool.create_db and not os.path.exists(filename): throw(IOError, "Database file is not found: %r" % filename) pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs) con.text_factory = _text_factory def create_function(name, num_params, func): func = keep_exception(func) con.create_function(name, num_params, func) create_function('power', 2, pow) create_function('rand', 0, random) create_function('py_upper', 1, py_upper) create_function('py_lower', 1, py_lower) create_function('py_json_unwrap', 1, py_json_unwrap) create_function('py_json_extract', -1, py_json_extract) create_function('py_json_contains', 3, py_json_contains) create_function('py_json_nonzero', 2, py_json_nonzero) create_function('py_json_array_length', -1, py_json_array_length) create_function('py_array_index', 2, py_array_index) create_function('py_array_contains', 2, py_array_contains) create_function('py_array_subset', 2, py_array_subset) create_function('py_array_length', 1, py_array_length) create_function('py_array_slice', 3, py_array_slice) create_function('py_make_array', -1, py_make_array) if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') con.execute('PRAGMA case_sensitive_like = true') def disconnect(pool): if pool.filename != ':memory:': Pool.disconnect(pool) def drop(pool, con): if pool.filename != ':memory:': Pool.drop(pool, con) else: con.rollback() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/dbschema.py0000666000000000000000000004667200000000000014641 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import itervalues, basestring, int_types from operator import attrgetter from pony.orm import core from pony.orm.core import log_sql, DBSchemaError, MappingError from pony.utils import throw class DBSchema(object): dialect = None inline_fk_syntax = True named_foreign_keys = True def __init__(schema, provider, uppercase=True): schema.provider = provider schema.tables = {} schema.constraints = {} schema.indent = ' ' schema.command_separator = ';\n\n' schema.uppercase = uppercase schema.names = {} def column_list(schema, columns): quote_name = schema.provider.quote_name return '(%s)' % ', '.join(quote_name(column.name) for column in columns) def case(schema, s): if schema.uppercase: return s.upper().replace('%S', '%s') \ .replace(')S', ')s').replace('%R', '%r').replace(')R', ')r') else: return s.lower() def add_table(schema, table_name, entity=None): return schema.table_class(table_name, schema, entity) def order_tables_to_create(schema): tables = [] created_tables = set() split = schema.provider.split_table_name tables_to_create = sorted(itervalues(schema.tables), key=lambda table: split(table.name)) while tables_to_create: for table in tables_to_create: if table.parent_tables.issubset(created_tables): created_tables.add(table) tables_to_create.remove(table) break else: table = tables_to_create.pop() tables.append(table) return tables def generate_create_script(schema): created_tables = set() commands = [] for table in schema.order_tables_to_create(): for db_object in table.get_objects_to_create(created_tables): commands.append(db_object.get_create_command()) return schema.command_separator.join(commands) def create_tables(schema, provider, connection): created_tables = set() for table in schema.order_tables_to_create(): for db_object in table.get_objects_to_create(created_tables): base_name = provider.base_name(db_object.name) name = db_object.exists(provider, connection, case_sensitive=False) if name is None: db_object.create(provider, connection) elif name != base_name: quote_name = schema.provider.quote_name n1, n2 = quote_name(db_object.name), quote_name(name) tn1, tn2 = db_object.typename, db_object.typename.lower() throw(DBSchemaError, '%s %s cannot be created, because %s %s ' \ '(with a different letter case) already exists in the database. ' \ 'Try to delete %s %s first.' % (tn1, n1, tn2, n2, n2, tn2)) def check_tables(schema, provider, connection): cursor = connection.cursor() split = provider.split_table_name for table in sorted(itervalues(schema.tables), key=lambda table: split(table.name)): alias = provider.base_name(table.name) sql_ast = [ 'SELECT', [ 'ALL', ] + [ [ 'COLUMN', alias, column.name ] for column in table.column_list ], [ 'FROM', [ alias, 'TABLE', table.name ] ], [ 'WHERE', [ 'EQ', [ 'VALUE', 0 ], [ 'VALUE', 1 ] ] ] ] sql, adapter = provider.ast2sql(sql_ast) if core.local.debug: log_sql(sql) provider.execute(cursor, sql) class DBObject(object): def create(table, provider, connection): sql = table.get_create_command() if core.local.debug: log_sql(sql) cursor = connection.cursor() provider.execute(cursor, sql) class Table(DBObject): typename = 'Table' def __init__(table, name, schema, entity=None): if name in schema.tables: throw(DBSchemaError, "Table %r already exists in database schema" % name) if name in schema.names: throw(DBSchemaError, "Table %r cannot be created, name is already in use" % name) schema.tables[name] = table schema.names[name] = table table.schema = schema table.name = name table.column_list = [] table.column_dict = {} table.indexes = {} table.pk_index = None table.foreign_keys = {} table.parent_tables = set() table.child_tables = set() table.entities = set() table.options = {} if entity is not None: table.entities.add(entity) table.options = entity._table_options_ table.m2m = set() def __repr__(table): return '' % table.schema.provider.format_table_name(table.name) def add_entity(table, entity): for e in table.entities: if e._root_ is not entity._root_: throw(MappingError, "Entities %s and %s cannot be mapped to table %s " "because they don't belong to the same hierarchy" % (e, entity, table.name)) assert '_table_options_' not in entity.__dict__ table.entities.add(entity) def exists(table, provider, connection, case_sensitive=True): return provider.table_exists(connection, table.name, case_sensitive) def get_create_command(table): schema = table.schema case = schema.case provider = schema.provider quote_name = provider.quote_name if_not_exists = False # provider.table_if_not_exists_syntax and provider.index_if_not_exists_syntax cmd = [] if not if_not_exists: cmd.append(case('CREATE TABLE %s (') % quote_name(table.name)) else: cmd.append(case('CREATE TABLE IF NOT EXISTS %s (') % quote_name(table.name)) for column in table.column_list: cmd.append(schema.indent + column.get_sql() + ',') if len(table.pk_index.columns) > 1: cmd.append(schema.indent + table.pk_index.get_sql() + ',') indexes = [ index for index in itervalues(table.indexes) if not index.is_pk and index.is_unique and len(index.columns) > 1 ] for index in indexes: assert index.name is not None indexes.sort(key=attrgetter('name')) for index in indexes: cmd.append(schema.indent+index.get_sql() + ',') if not schema.named_foreign_keys: for foreign_key in sorted(itervalues(table.foreign_keys), key=lambda fk: fk.name): if schema.inline_fk_syntax and len(foreign_key.child_columns) == 1: continue cmd.append(schema.indent+foreign_key.get_sql() + ',') cmd[-1] = cmd[-1][:-1] cmd.append(')') for name, value in sorted(table.options.items()): option = table.format_option(name, value) if option: cmd.append(option) return '\n'.join(cmd) def format_option(table, name, value): if value is True: return name if value is False: return None return '%s %s' % (name, value) def get_objects_to_create(table, created_tables=None): if created_tables is None: created_tables = set() created_tables.add(table) result = [ table ] indexes = [ index for index in itervalues(table.indexes) if not index.is_pk and not index.is_unique ] for index in indexes: assert index.name is not None indexes.sort(key=attrgetter('name')) result.extend(indexes) schema = table.schema if schema.named_foreign_keys: for foreign_key in sorted(itervalues(table.foreign_keys), key=lambda fk: fk.name): if foreign_key.parent_table not in created_tables: continue result.append(foreign_key) for child_table in table.child_tables: if child_table not in created_tables: continue for foreign_key in sorted(itervalues(child_table.foreign_keys), key=lambda fk: fk.name): if foreign_key.parent_table is not table: continue result.append(foreign_key) return result def add_column(table, column_name, sql_type, converter, is_not_null=None, sql_default=None): return table.schema.column_class(column_name, table, sql_type, converter, is_not_null, sql_default) def add_index(table, index_name, columns, is_pk=False, is_unique=None, m2m=False): assert index_name is not False if index_name is True: index_name = None if index_name is None and not is_pk: provider = table.schema.provider index_name = provider.get_default_index_name(table.name, (column.name for column in columns), is_pk=is_pk, is_unique=is_unique, m2m=m2m) index = table.indexes.get(columns) if index and index.name == index_name and index.is_pk == is_pk and index.is_unique == is_unique: return index return table.schema.index_class(index_name, table, columns, is_pk, is_unique) def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None, on_delete=False): if fk_name is None: provider = table.schema.provider child_column_names = tuple(column.name for column in child_columns) fk_name = provider.get_default_fk_name(table.name, parent_table.name, child_column_names) return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name, on_delete) class Column(object): auto_template = '%(type)s PRIMARY KEY AUTOINCREMENT' def __init__(column, name, table, sql_type, converter, is_not_null=None, sql_default=None): if name in table.column_dict: throw(DBSchemaError, "Column %r already exists in table %r" % (name, table.name)) table.column_dict[name] = column table.column_list.append(column) column.table = table column.name = name column.sql_type = sql_type column.converter = converter column.is_not_null = is_not_null column.sql_default = sql_default column.is_pk = False column.is_pk_part = False column.is_unique = False def __repr__(column): return '' % (column.table.name, column.name) def get_sql(column): table = column.table schema = table.schema quote_name = schema.provider.quote_name case = schema.case result = [] append = result.append append(quote_name(column.name)) if column.is_pk == 'auto' and column.auto_template and column.converter.py_type in int_types: append(case(column.auto_template % dict(type=column.sql_type))) else: append(case(column.sql_type)) if column.is_pk: if schema.dialect == 'SQLite': append(case('NOT NULL')) append(case('PRIMARY KEY')) else: if column.is_unique: append(case('UNIQUE')) if column.is_not_null: append(case('NOT NULL')) if column.sql_default not in (None, True, False): append(case('DEFAULT')) append(column.sql_default) if schema.inline_fk_syntax and not schema.named_foreign_keys: foreign_key = table.foreign_keys.get((column,)) if foreign_key is not None: parent_table = foreign_key.parent_table append(case('REFERENCES')) append(quote_name(parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) if foreign_key.on_delete: append('ON DELETE %s' % foreign_key.on_delete) return ' '.join(result) class Constraint(DBObject): def __init__(constraint, name, schema): if name is not None: assert name not in schema.names if name in schema.constraints: throw(DBSchemaError, "Constraint with name %r already exists" % name) schema.names[name] = constraint schema.constraints[name] = constraint constraint.schema = schema constraint.name = name class DBIndex(Constraint): typename = 'Index' def __init__(index, name, table, columns, is_pk=False, is_unique=None): assert len(columns) > 0 for column in columns: if column.table is not table: throw(DBSchemaError, "Column %r does not belong to table %r and cannot be part of its index" % (column.name, table.name)) if columns in table.indexes: if len(columns) == 1: throw(DBSchemaError, "Index for column %r already exists" % columns[0].name) else: throw(DBSchemaError, "Index for columns (%s) already exists" % ', '.join(repr(column.name) for column in columns)) if is_pk: if table.pk_index is not None: throw(DBSchemaError, 'Primary key for table %r is already defined' % table.name) table.pk_index = index if is_unique is None: is_unique = True elif not is_unique: throw(DBSchemaError, "Incompatible combination of is_unique=False and is_pk=True") elif is_unique is None: is_unique = False schema = table.schema if name is not None and name in schema.names: throw(DBSchemaError, 'Index %s cannot be created, name is already in use' % name) Constraint.__init__(index, name, schema) for column in columns: column.is_pk = column.is_pk or (len(columns) == 1 and is_pk) column.is_pk_part = column.is_pk_part or bool(is_pk) column.is_unique = column.is_unique or (is_unique and len(columns) == 1) table.indexes[columns] = index index.table = table index.columns = columns index.is_pk = is_pk index.is_unique = is_unique def exists(index, provider, connection, case_sensitive=True): return provider.index_exists(connection, index.table.name, index.name, case_sensitive) def get_sql(index): return index._get_create_sql(inside_table=True) def get_create_command(index): return index._get_create_sql(inside_table=False) def _get_create_sql(index, inside_table): schema = index.schema case = schema.case quote_name = schema.provider.quote_name cmd = [] append = cmd.append if not inside_table: if index.is_pk: throw(DBSchemaError, 'Primary key index cannot be defined outside of table definition') append(case('CREATE')) if index.is_unique: append(case('UNIQUE')) append(case('INDEX')) # if schema.provider.index_if_not_exists_syntax: # append(case('IF NOT EXISTS')) append(quote_name(index.name)) append(case('ON')) append(quote_name(index.table.name)) converter = index.columns[0].converter if isinstance(converter.py_type, core.Array) and converter.provider.dialect == 'PostgreSQL': append(case('USING GIN')) else: if index.name: append(case('CONSTRAINT')) append(quote_name(index.name)) if index.is_pk: append(case('PRIMARY KEY')) elif index.is_unique: append(case('UNIQUE')) else: append(case('INDEX')) append(schema.column_list(index.columns)) return ' '.join(cmd) class ForeignKey(Constraint): typename = 'Foreign key' def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name, on_delete): schema = parent_table.schema if schema is not child_table.schema: throw(DBSchemaError, 'Parent and child tables of foreign_key cannot belong to different schemata') for column in parent_columns: if column.table is not parent_table: throw(DBSchemaError, 'Column %r does not belong to table %r' % (column.name, parent_table.name)) for column in child_columns: if column.table is not child_table: throw(DBSchemaError, 'Column %r does not belong to table %r' % (column.name, child_table.name)) if len(parent_columns) != len(child_columns): throw(DBSchemaError, 'Foreign key columns count do not match') if child_columns in child_table.foreign_keys: if len(child_columns) == 1: throw(DBSchemaError, 'Foreign key for column %r already defined' % child_columns[0].name) else: throw(DBSchemaError, 'Foreign key for columns (%s) already defined' % ', '.join(repr(column.name) for column in child_columns)) if name is not None and name in schema.names: throw(DBSchemaError, 'Foreign key %s cannot be created, name is already in use' % name) Constraint.__init__(foreign_key, name, schema) child_table.foreign_keys[child_columns] = foreign_key if child_table is not parent_table: child_table.parent_tables.add(parent_table) parent_table.child_tables.add(child_table) foreign_key.parent_table = parent_table foreign_key.parent_columns = parent_columns foreign_key.child_table = child_table foreign_key.child_columns = child_columns foreign_key.on_delete = on_delete if index_name is not False: child_columns_len = len(child_columns) if all(columns[:child_columns_len] != child_columns for columns in child_table.indexes): child_table.add_index(index_name, child_columns, is_pk=False, is_unique=False, m2m=bool(child_table.m2m)) def exists(foreign_key, provider, connection, case_sensitive=True): return provider.fk_exists(connection, foreign_key.child_table.name, foreign_key.name, case_sensitive) def get_sql(foreign_key): return foreign_key._get_create_sql(inside_table=True) def get_create_command(foreign_key): return foreign_key._get_create_sql(inside_table=False) def _get_create_sql(foreign_key, inside_table): schema = foreign_key.schema case = schema.case quote_name = schema.provider.quote_name cmd = [] append = cmd.append if not inside_table: append(case('ALTER TABLE')) append(quote_name(foreign_key.child_table.name)) append(case('ADD')) if schema.named_foreign_keys and foreign_key.name: append(case('CONSTRAINT')) append(quote_name(foreign_key.name)) append(case('FOREIGN KEY')) append(schema.column_list(foreign_key.child_columns)) append(case('REFERENCES')) append(quote_name(foreign_key.parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) if foreign_key.on_delete: append(case('ON DELETE %s' % foreign_key.on_delete)) return ' '.join(cmd) DBSchema.table_class = Table DBSchema.column_class = Column DBSchema.index_class = DBIndex DBSchema.fk_class = ForeignKey ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/decompiling.py0000666000000000000000000007034700000000000015361 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, xrange, PY37, PYPY import sys, types, inspect from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op from opcode import hasconst, hasname, hasjrel, haslocal, hascompare, hasfree from collections import defaultdict from pony.thirdparty.compiler import ast, parse from pony.utils import throw, get_codeobject_id ##ast.And.__repr__ = lambda self: "And(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) ##ast.Or.__repr__ = lambda self: "Or(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) class DecompileError(NotImplementedError): pass ast_cache = {} def decompile(x): cells = {} t = type(x) if t is types.CodeType: codeobject = x elif t is types.GeneratorType: codeobject = x.gi_frame.f_code elif t is types.FunctionType: codeobject = x.func_code if PY2 else x.__code__ if PY2: if x.func_closure: cells = dict(izip(codeobject.co_freevars, x.func_closure)) else: if x.__closure__: cells = dict(izip(codeobject.co_freevars, x.__closure__)) else: throw(TypeError) key = get_codeobject_id(codeobject) result = ast_cache.get(key) if result is None: decompiler = Decompiler(codeobject) result = decompiler.ast, decompiler.external_names ast_cache[key] = result return result + (cells,) def simplify(clause): if isinstance(clause, ast.And): if len(clause.nodes) == 1: result = clause.nodes[0] else: return clause elif isinstance(clause, ast.Or): if len(clause.nodes) == 1: result = ast.Not(clause.nodes[0]) else: return clause else: return clause if getattr(result, 'endpos', 0) < clause.endpos: result.endpos = clause.endpos return result class InvalidQuery(Exception): pass def binop(node_type, args_holder=tuple): def method(decompiler): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() return node_type(args_holder((oper1, oper2))) return method if not PY2: ord = lambda x: x class Decompiler(object): def __init__(decompiler, code, start=0, end=None): decompiler.code = code decompiler.start = decompiler.pos = start if end is None: end = len(code.co_code) decompiler.end = end decompiler.stack = [] decompiler.jump_map = defaultdict(list) decompiler.targets = {} decompiler.ast = None decompiler.names = set() decompiler.assnames = set() decompiler.conditions_end = 0 decompiler.instructions = [] decompiler.instructions_map = {} decompiler.or_jumps = set() decompiler.get_instructions() decompiler.analyze_jumps() decompiler.decompile() decompiler.ast = decompiler.stack.pop() decompiler.external_names = decompiler.names - decompiler.assnames assert not decompiler.stack, decompiler.stack def get_instructions(decompiler): PY36 = sys.version_info >= (3, 6) before_yield = True code = decompiler.code co_code = code.co_code free = code.co_cellvars + code.co_freevars decompiler.abs_jump_to_top = decompiler.for_iter_pos = -1 while decompiler.pos < decompiler.end: i = decompiler.pos op = ord(code.co_code[i]) if PY36: extended_arg = 0 oparg = ord(code.co_code[i+1]) while op == EXTENDED_ARG: extended_arg = (extended_arg | oparg) << 8 i += 2 op = ord(code.co_code[i]) oparg = ord(code.co_code[i+1]) oparg = None if op < HAVE_ARGUMENT else oparg | extended_arg i += 2 else: i += 1 if op >= HAVE_ARGUMENT: oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 i += 2 if op == EXTENDED_ARG: op = ord(code.co_code[i]) i += 1 oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + oparg * 65536 i += 2 if op >= HAVE_ARGUMENT: if op in hasconst: arg = [code.co_consts[oparg]] elif op in hasname: arg = [code.co_names[oparg]] elif op in hasjrel: arg = [i + oparg] elif op in haslocal: arg = [code.co_varnames[oparg]] elif op in hascompare: arg = [cmp_op[oparg]] elif op in hasfree: arg = [free[oparg]] else: arg = [oparg] else: arg = [] opname = opnames[op].replace('+', '_') if opname == 'FOR_ITER': decompiler.for_iter_pos = decompiler.pos if opname == 'JUMP_ABSOLUTE' and arg[0] == decompiler.for_iter_pos: decompiler.abs_jump_to_top = decompiler.pos if before_yield: if 'JUMP' in opname: endpos = arg[0] if endpos < decompiler.pos: decompiler.conditions_end = i decompiler.jump_map[endpos].append(decompiler.pos) decompiler.instructions_map[decompiler.pos] = len(decompiler.instructions) decompiler.instructions.append((decompiler.pos, i, opname, arg)) if opname == 'YIELD_VALUE': before_yield = False decompiler.pos = i def analyze_jumps(decompiler): if PYPY: targets = decompiler.jump_map.pop(decompiler.abs_jump_to_top, []) decompiler.jump_map[decompiler.for_iter_pos] = targets for i, (x, y, opname, arg) in enumerate(decompiler.instructions): if 'JUMP' in opname: target = arg[0] if target == decompiler.abs_jump_to_top: decompiler.instructions[i] = (x, y, opname, [decompiler.for_iter_pos]) decompiler.conditions_end = y i = decompiler.instructions_map[decompiler.conditions_end] while i > 0: pos, next_pos, opname, arg = decompiler.instructions[i] if pos in decompiler.jump_map: for jump_start_pos in decompiler.jump_map[pos]: if jump_start_pos > pos: continue for or_jump_start_pos in decompiler.or_jumps: if pos > or_jump_start_pos > jump_start_pos: break # And jump else: decompiler.or_jumps.add(jump_start_pos) i -= 1 def decompile(decompiler): for pos, next_pos, opname, arg in decompiler.instructions: if pos in decompiler.targets: decompiler.process_target(pos) method = getattr(decompiler, opname, None) if method is None: throw(DecompileError('Unsupported operation: %s' % opname)) decompiler.pos = pos decompiler.next_pos = next_pos x = method(*arg) if x is not None: decompiler.stack.append(x) def pop_items(decompiler, size): if not size: return () result = decompiler.stack[-size:] decompiler.stack[-size:] = [] return result def store(decompiler, node): stack = decompiler.stack if not stack: stack.append(node); return top = stack[-1] if isinstance(top, (ast.AssTuple, ast.AssList)) and len(top.nodes) < top.count: top.nodes.append(node) if len(top.nodes) == top.count: decompiler.store(stack.pop()) elif isinstance(top, ast.GenExprFor): assert top.assign is None top.assign = node else: stack.append(node) BINARY_POWER = binop(ast.Power) BINARY_MULTIPLY = binop(ast.Mul) BINARY_DIVIDE = binop(ast.Div) BINARY_FLOOR_DIVIDE = binop(ast.FloorDiv) BINARY_ADD = binop(ast.Add) BINARY_SUBTRACT = binop(ast.Sub) BINARY_LSHIFT = binop(ast.LeftShift) BINARY_RSHIFT = binop(ast.RightShift) BINARY_AND = binop(ast.Bitand, list) BINARY_XOR = binop(ast.Bitxor, list) BINARY_OR = binop(ast.Bitor, list) BINARY_TRUE_DIVIDE = BINARY_DIVIDE BINARY_MODULO = binop(ast.Mod) def BINARY_SUBSCR(decompiler): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() if isinstance(oper2, ast.Sliceobj) and len(oper2.nodes) == 2: a, b = oper2.nodes a = None if isinstance(a, ast.Const) and a.value == None else a b = None if isinstance(b, ast.Const) and b.value == None else b return ast.Slice(oper1, 'OP_APPLY', a, b) elif isinstance(oper2, ast.Tuple): return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) else: return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) def BUILD_CONST_KEY_MAP(decompiler, length): keys = decompiler.stack.pop() assert isinstance(keys, ast.Const) keys = [ ast.Const(key) for key in keys.value ] values = decompiler.pop_items(length) pairs = list(izip(keys, values)) return ast.Dict(pairs) def BUILD_LIST(decompiler, size): return ast.List(decompiler.pop_items(size)) def BUILD_MAP(decompiler, length): if sys.version_info < (3, 5): return ast.Dict(()) data = decompiler.pop_items(2 * length) # [key1, value1, key2, value2, ...] it = iter(data) pairs = list(izip(it, it)) # [(key1, value1), (key2, value2), ...] return ast.Dict(tuple(pairs)) def BUILD_SET(decompiler, size): return ast.Set(decompiler.pop_items(size)) def BUILD_SLICE(decompiler, size): return ast.Sliceobj(decompiler.pop_items(size)) def BUILD_TUPLE(decompiler, size): return ast.Tuple(decompiler.pop_items(size)) def BUILD_STRING(decompiler, count): values = list(reversed([decompiler.stack.pop() for _ in range(count)])) return ast.JoinedStr(values) def CALL_FUNCTION(decompiler, argc, star=None, star2=None): pop = decompiler.stack.pop kwarg, posarg = divmod(argc, 256) args = [] for i in xrange(kwarg): arg = pop() key = pop().value args.append(ast.Keyword(key, arg)) for i in xrange(posarg): args.append(pop()) args.reverse() return decompiler._call_function(args, star, star2) def _call_function(decompiler, args, star=None, star2=None): tos = decompiler.stack.pop() if isinstance(tos, ast.GenExpr): assert len(args) == 1 and star is None and star2 is None genexpr = tos qual = genexpr.code.quals[0] assert isinstance(qual.iter, ast.Name) assert qual.iter.name in ('.0', '[outmost-iterable]') qual.iter = args[0] return genexpr else: return ast.CallFunc(tos, args, star, star2) def CALL_FUNCTION_VAR(decompiler, argc): return decompiler.CALL_FUNCTION(argc, decompiler.stack.pop()) def CALL_FUNCTION_KW(decompiler, argc): if sys.version_info < (3, 6): return decompiler.CALL_FUNCTION(argc, star2=decompiler.stack.pop()) keys = decompiler.stack.pop() assert isinstance(keys, ast.Const) keys = keys.value values = decompiler.pop_items(argc) assert len(keys) <= len(values) args = values[:-len(keys)] for key, value in izip(keys, values[-len(keys):]): args.append(ast.Keyword(key, value)) return decompiler._call_function(args) def CALL_FUNCTION_VAR_KW(decompiler, argc): star2 = decompiler.stack.pop() star = decompiler.stack.pop() return decompiler.CALL_FUNCTION(argc, star, star2) def CALL_FUNCTION_EX(decompiler, argc): star2 = None if argc: if argc != 1: throw(DecompileError) star2 = decompiler.stack.pop() star = decompiler.stack.pop() return decompiler._call_function([], star, star2) def CALL_METHOD(decompiler, argc): pop = decompiler.stack.pop args = [] if argc >= 256: kwargc = argc // 256 argc = argc % 256 for i in range(kwargc): v = pop() k = pop() assert isinstance(k, ast.Const) k = k.value # ast.Name(k.value) args.append(ast.Keyword(k, v)) for i in range(argc): args.append(pop()) args.reverse() method = pop() return ast.CallFunc(method, args) def COMPARE_OP(decompiler, op): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() return ast.Compare(oper1, [(op, oper2)]) def DUP_TOP(decompiler): return decompiler.stack[-1] def FOR_ITER(decompiler, endpos): assign = None iter = decompiler.stack.pop() ifs = [] return ast.GenExprFor(assign, iter, ifs) def FORMAT_VALUE(decompiler, flags): if flags in (0, 1, 2, 3): value = decompiler.stack.pop() return ast.Str(value, flags) elif flags == 4: fmt_spec = decompiler.stack.pop() value = decompiler.stack.pop() return ast.FormattedValue(value, fmt_spec) def GET_ITER(decompiler): pass def JUMP_IF_FALSE(decompiler, endpos): return decompiler.conditional_jump(endpos, False) JUMP_IF_FALSE_OR_POP = JUMP_IF_FALSE def JUMP_IF_TRUE(decompiler, endpos): return decompiler.conditional_jump(endpos, True) JUMP_IF_TRUE_OR_POP = JUMP_IF_TRUE def conditional_jump(decompiler, endpos, if_true): if PY37 or PYPY: return decompiler.conditional_jump_new(endpos, if_true) return decompiler.conditional_jump_old(endpos, if_true) def conditional_jump_old(decompiler, endpos, if_true): i = decompiler.next_pos if i in decompiler.targets: decompiler.process_target(i) expr = decompiler.stack.pop() clausetype = ast.Or if if_true else ast.And clause = clausetype([expr]) clause.endpos = endpos decompiler.targets.setdefault(endpos, clause) return clause def conditional_jump_new(decompiler, endpos, if_true): expr = decompiler.stack.pop() if decompiler.pos >= decompiler.conditions_end: clausetype = ast.Or if if_true else ast.And elif decompiler.pos in decompiler.or_jumps: clausetype = ast.Or if not if_true: expr = ast.Not(expr) else: clausetype = ast.And if if_true: expr = ast.Not(expr) decompiler.stack.append(expr) if decompiler.next_pos in decompiler.targets: decompiler.process_target(decompiler.next_pos) expr = decompiler.stack.pop() clause = clausetype([ expr ]) clause.endpos = endpos decompiler.targets.setdefault(endpos, clause) return clause def process_target(decompiler, pos, partial=False): if pos is None: limit = None elif partial: limit = decompiler.targets.get(pos, None) else: limit = decompiler.targets.pop(pos, None) top = decompiler.stack.pop() while True: top = simplify(top) if top is limit: break if isinstance(top, ast.GenExprFor): break if not decompiler.stack: break top2 = decompiler.stack[-1] if isinstance(top2, ast.GenExprFor): break if partial and hasattr(top2, 'endpos') and top2.endpos == pos: break if isinstance(top2, (ast.And, ast.Or)): if top2.__class__ == top.__class__: top2.nodes.extend(top.nodes) else: top2.nodes.append(top) elif isinstance(top2, ast.IfExp): # Python 2.5 top2.else_ = top if hasattr(top, 'endpos'): top2.endpos = top.endpos if decompiler.targets.get(top.endpos) is top: decompiler.targets[top.endpos] = top2 else: throw(DecompileError('Expression is too complex to decompile, try to pass query as string, e.g. select("x for x in Something")')) top2.endpos = max(top2.endpos, getattr(top, 'endpos', 0)) top = decompiler.stack.pop() decompiler.stack.append(top) def JUMP_FORWARD(decompiler, endpos): i = decompiler.next_pos # next instruction decompiler.process_target(i, True) then = decompiler.stack.pop() decompiler.process_target(i, False) test = decompiler.stack.pop() if_exp = ast.IfExp(simplify(test), simplify(then), None) if_exp.endpos = endpos decompiler.targets.setdefault(endpos, if_exp) if decompiler.targets.get(endpos) is then: decompiler.targets[endpos] = if_exp return if_exp def LIST_APPEND(decompiler, offset=None): throw(InvalidQuery('Use generator expression (... for ... in ...) ' 'instead of list comprehension [... for ... in ...] inside query')) def LOAD_ATTR(decompiler, attr_name): return ast.Getattr(decompiler.stack.pop(), attr_name) def LOAD_CLOSURE(decompiler, freevar): decompiler.names.add(freevar) return ast.Name(freevar) def LOAD_CONST(decompiler, const_value): return ast.Const(const_value) def LOAD_DEREF(decompiler, freevar): decompiler.names.add(freevar) return ast.Name(freevar) def LOAD_FAST(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) def LOAD_GLOBAL(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) def LOAD_METHOD(decompiler, methname): return decompiler.LOAD_ATTR(methname) LOOKUP_METHOD = LOAD_METHOD # For PyPy def LOAD_NAME(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) def MAKE_CLOSURE(decompiler, argc): if PY2: decompiler.stack[-2:-1] = [] # ignore freevars else: decompiler.stack[-3:-2] = [] # ignore freevars return decompiler.MAKE_FUNCTION(argc) def MAKE_FUNCTION(decompiler, argc): defaults = [] flags = 0 if sys.version_info >= (3, 6): qualname = decompiler.stack.pop() tos = decompiler.stack.pop() if argc & 0x08: func_closure = decompiler.stack.pop() if argc & 0x04: annotations = decompiler.stack.pop() if argc & 0x02: kwonly_defaults = decompiler.stack.pop() if argc & 0x01: defaults = decompiler.stack.pop() throw(DecompileError) else: if not PY2: qualname = decompiler.stack.pop() tos = decompiler.stack.pop() if argc: defaults = [ decompiler.stack.pop() for i in range(argc) ] defaults.reverse() codeobject = tos.value func_decompiler = Decompiler(codeobject) # decompiler.names.update(decompiler.names) ??? if codeobject.co_varnames[:1] == ('.0',): return func_decompiler.ast # generator argnames, varargs, keywords = inspect.getargs(codeobject) if varargs: argnames.append(varargs) flags |= inspect.CO_VARARGS if keywords: argnames.append(keywords) flags |= inspect.CO_VARKEYWORDS return ast.Lambda(argnames, defaults, flags, func_decompiler.ast) POP_JUMP_IF_FALSE = JUMP_IF_FALSE POP_JUMP_IF_TRUE = JUMP_IF_TRUE def POP_TOP(decompiler): pass def RETURN_VALUE(decompiler): if decompiler.next_pos != decompiler.end: throw(DecompileError) expr = decompiler.stack.pop() return simplify(expr) def ROT_TWO(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() decompiler.stack.append(tos) decompiler.stack.append(tos1) def ROT_THREE(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() tos2 = decompiler.stack.pop() decompiler.stack.append(tos) decompiler.stack.append(tos2) decompiler.stack.append(tos1) def SETUP_LOOP(decompiler, endpos): pass def SLICE_0(decompiler): return ast.Slice(decompiler.stack.pop(), 'OP_APPLY', None, None) def SLICE_1(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() return ast.Slice(tos1, 'OP_APPLY', tos, None) def SLICE_2(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() return ast.Slice(tos1, 'OP_APPLY', None, tos) def SLICE_3(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() tos2 = decompiler.stack.pop() return ast.Slice(tos2, 'OP_APPLY', tos1, tos) def STORE_ATTR(decompiler, attrname): decompiler.store(ast.AssAttr(decompiler.stack.pop(), attrname, 'OP_ASSIGN')) def STORE_DEREF(decompiler, freevar): decompiler.assnames.add(freevar) decompiler.store(ast.AssName(freevar, 'OP_ASSIGN')) def STORE_FAST(decompiler, varname): if varname.startswith('_['): throw(InvalidQuery('Use generator expression (... for ... in ...) ' 'instead of list comprehension [... for ... in ...] inside query')) decompiler.assnames.add(varname) decompiler.store(ast.AssName(varname, 'OP_ASSIGN')) def STORE_MAP(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() tos2 = decompiler.stack[-1] if not isinstance(tos2, ast.Dict): assert False # pragma: no cover if tos2.items == (): tos2.items = [] tos2.items.append((tos, tos1)) def STORE_SUBSCR(decompiler): tos = decompiler.stack.pop() tos1 = decompiler.stack.pop() tos2 = decompiler.stack.pop() if not isinstance(tos1, ast.Dict): assert False # pragma: no cover if tos1.items == (): tos1.items = [] tos1.items.append((tos, tos2)) def UNARY_POSITIVE(decompiler): return ast.UnaryAdd(decompiler.stack.pop()) def UNARY_NEGATIVE(decompiler): return ast.UnarySub(decompiler.stack.pop()) def UNARY_NOT(decompiler): return ast.Not(decompiler.stack.pop()) def UNARY_CONVERT(decompiler): return ast.Backquote(decompiler.stack.pop()) def UNARY_INVERT(decompiler): return ast.Invert(decompiler.stack.pop()) def UNPACK_SEQUENCE(decompiler, count): ass_tuple = ast.AssTuple([]) ass_tuple.count = count return ass_tuple def YIELD_VALUE(decompiler): expr = decompiler.stack.pop() fors = [] while decompiler.stack: decompiler.process_target(None) top = decompiler.stack.pop() if not isinstance(top, (ast.GenExprFor)): cond = ast.GenExprIf(top) top = decompiler.stack.pop() assert isinstance(top, ast.GenExprFor) top.ifs.append(cond) fors.append(top) else: fors.append(top) fors.reverse() return ast.GenExpr(ast.GenExprInner(simplify(expr), fors)) test_lines = """ (a and b if c and d else e and f for i in T if (A and B if C and D else E and F)) (a for b in T) (a for b, c in T) (a for b in T1 for c in T2) (a for b in T1 for c in T2 for d in T3) (a for b in T if f) (a for b in T if f and h) (a for b in T if f and h or t) (a for b in T if f == 5 and r or t) (a for b in T if f and r and t) # (a for b in T if f == 5 and +r or not t) # (a for b in T if -t and ~r or `f`) (a for b in T if x and not y and z) (a for b in T if not x and y) (a for b in T if not x and y and z) (a for b in T if not x and y or z) #FIXME! (a**2 for b in T if t * r > y / 3) (a + 2 for b in T if t + r > y // 3) (a[2,v] for b in T if t - r > y[3]) ((a + 2) * 3 for b in T if t[r, e] > y[3, r * 4, t]) (a<<2 for b in T if t>>e > r & (y & u)) (a|b for c in T1 if t^e > r | (y & (u & (w % z)))) ([a, b, c] for d in T) ([a, b, 4] for d in T if a[4, b] > b[1,v,3]) ((a, b, c) for d in T) ({} for d in T) ({'a' : x, 'b' : y} for a, b in T) (({'a' : x, 'b' : y}, {'c' : x1, 'd' : 1}) for a, b, c, d in T) ([{'a' : x, 'b' : y}, {'c' : x1, 'd' : 1}] for a, b, c, d in T) (a[1:2] for b in T) (a[:2] for b in T) (a[2:] for b in T) (a[:] for b in T) (a[1:2:3] for b in T) (a[1:2, 3:4] for b in T) (a[2:4:6,6:8] for a, y in T) (a.b.c for d.e.f.g in T) # (a.b.c for d[g] in T) ((s,d,w) for t in T if (4 != x.a or a*3 > 20) and a * 2 < 5) ([s,d,w] for t in T if (4 != x.amount or amount * 3 > 20 or amount * 2 < 5) and amount*8 == 20) ([s,d,w] for t in T if (4 != x.a or a*3 > 20 or a*2 < 5 or 4 == 5) and a * 8 == 20) (s for s in T if s.a > 20 and (s.x.y == 123 or 'ABC' in s.p.q.r)) (a for b in T1 if c > d for e in T2 if f < g) (func1(a, a.attr, x=123) for s in T) # (func1(a, a.attr, *args) for s in T) # (func1(a, a.attr, x=123, **kwargs) for s in T) (func1(a, b, a.attr1, a.b.c, x=123, y='foo') for s in T) # (func1(a, b, a.attr1, a.b.c, x=123, y='foo', **kwargs) for s in T) # (func(a, a.attr, keyarg=123) for a in T if a.method(x, *args, **kwargs) == 4) ((x or y) and (p or q) for a in T if (a or b) and (c or d)) (x.y for x in T if (a and (b or (c and d))) or X) (a for a in T1 if a in (b for b in T2)) (a for a in T1 if a in (b for b in T2 if b == a)) (a for a in T1 if a in (b for b in T2)) (a for a in T1 if a in select(b for b in T2)) (a for a in T1 if a in (b for b in T2 if b in (c for c in T3 if c == a))) (a for a in T1 if a > x and a in (b for b in T1 if b < y) and a < z) """ ## should throw InvalidQuery due to using [] inside of a query ## (a for a in T1 if a in [b for b in T2 if b in [(c, d) for c in T3]]) ## examples of conditional expressions ## (a if b else c for x in T) ## (x for x in T if (d if e else f)) ## (a if b else c for x in T if (d if e else f)) ## (a and b or c and d if x and y or p and q else r and n or m and k for i in T) ## (i for i in T if (a and b or c and d if x and y or p and q else r and n or m and k)) ## (a and b or c and d if x and y or p and q else r and n or m and k for i in T if (A and B or C and D if X and Y or P and Q else R and N or M and K)) def test(): import sys if sys.version[:3] > '2.4': outmost_iterable_name = '.0' else: outmost_iterable_name = '[outmost-iterable]' import dis for line in test_lines.split('\n'): if not line or line.isspace(): continue line = line.strip() if line.startswith('#'): continue code = compile(line, '', 'eval').co_consts[0] ast1 = parse(line).node.nodes[0].expr ast1.code.quals[0].iter.name = outmost_iterable_name try: ast2 = Decompiler(code).ast except Exception as e: print() print(line) print() print(ast1) print() dis.dis(code) raise if str(ast1) != str(ast2): print() print(line) print() print(ast1) print() print(ast2) print() dis.dis(code) break else: print('OK: %s' % line) else: print('Done!') if __name__ == '__main__': test() ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1571864710.524732 pony-0.7.11/pony/orm/examples/0000777000000000000000000000000000000000000014320 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/examples/__init__.py0000666000000000000000000000000000000000000016417 0ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/examples/bottle_example.py0000666000000000000000000000632700000000000017706 0ustar0000000000000000from __future__ import absolute_import, print_function from bottle import default_app, install, route, request, redirect, run, template # Import eStore model http://editor.ponyorm.com/user/pony/eStore from pony.orm.examples.estore import * from pony.orm.integration.bottle_plugin import PonyPlugin # After the plugin is installed each request will be processed # in a separate database session. Once the HTTP request processing # is finished the plugin does the following: # * commit the changes to the database (or rollback if an exception happened) # * clear the transaction cache # * return the database connection to the connection pool install(PonyPlugin()) @route('/') @route('/products/') def all_products(): # Get the list of all products from the database products = select(p for p in Product) return template('''

List of products

''', products=products) @route('/products/:id/') def show_product(id): # Get the instance of the Product entity by the primary key p = Product[id] # You can traverse entity relationship attributes inside the template # In this examples it is many-to-many relationship p.categories # Since the data were not loaded into the cache yet, # it will result in a separate SQL query. return template('''

{{ p.name }}

Price: {{ p.price }}

Product categories:

    %for c in p.categories:
  • {{ c.name }} %end
Edit product info Return to all products ''', p=p) @route('/products/:id/edit/') def edit_product(id): # Get the instance of the Product entity and display its attributes p = Product[id] return template('''
Product name:
Product price:

Discard changes

Return to all products ''', p=p) @route('/products/:id/edit/', method='POST') def save_product(id): # Get the instance of the Product entity p = Product[id] # Update the attributes with the new values p.name = request.forms.get('name') p.price = request.forms.get('price') # We might put the commit() command here, but it is not necessary # because PonyPlugin will take care of this. redirect("/products/%d/" % p.id) # The Bottle's redirect function raises the HTTPResponse exception. # Normally PonyPlugin closes the session with rollback # if a callback function raises an exception. But in this case # PonyPlugin understands that this exception is not the error # and closes the session with commit. run(debug=True, host='localhost', port=8080, reloader=True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/examples/compositekeys.py0000666000000000000000000000667400000000000017605 0ustar0000000000000000from __future__ import absolute_import from datetime import date from pony.orm.core import * db = Database('sqlite', 'complex.sqlite', create_db=True) class Group(db.Entity): dept = Required('Department') year = Required(int) spec = Required(int) students = Set('Student') courses = Set('Course') lessons = Set('Lesson', columns=['building', 'number', 'dt']) PrimaryKey(dept, year, spec) class Department(db.Entity): number = PrimaryKey(int) faculty = Required('Faculty') name = Required(str) groups = Set(Group) teachers = Set('Teacher') class Faculty(db.Entity): number = PrimaryKey(int) name = Required(str) depts = Set(Department) class Student(db.Entity): name = Required(str) group = Required(Group) dob = Optional(date) grades = Set('Grade') PrimaryKey(name, group) class Grade(db.Entity): student = Required(Student, columns=['student_name', 'dept', 'year', 'spec']) task = Required('Task') date = Required(date) value = Required(int) PrimaryKey(student, task) class Task(db.Entity): course = Required('Course') type = Required(str) number = Required(int) descr = Optional(str) grades = Set(Grade) PrimaryKey(course, type, number) class Course(db.Entity): subject = Required('Subject') semester = Required(int) groups = Set(Group) tasks = Set(Task) lessons = Set('Lesson') teachers = Set('Teacher') PrimaryKey(subject, semester) class Subject(db.Entity): name = PrimaryKey(str) descr = Optional(str) courses = Set(Course) class Room(db.Entity): building = Required(str) number = Required(str) floor = Optional(int) schedules = Set('Lesson') PrimaryKey(building, number) class Teacher(db.Entity): dept = Required(Department) name = Required(str) courses = Set(Course) lessons = Set('Lesson') class Lesson(db.Entity): _table_ = 'Schedule' groups = Set(Group) course = Required(Course) room = Required(Room) teacher = Required(Teacher) date = Required(date) PrimaryKey(room, date) composite_key(teacher, date) db.generate_mapping(create_tables=True) def test_queries(): select(grade for grade in Grade if grade.task.type == 'Lab')[:] select(grade for grade in Grade if grade.task.descr.startswith('Intermediate'))[:] select(grade for grade in Grade if grade.task.course.semester == 2)[:] select(grade for grade in Grade if grade.task.course.subject.name == 'Math')[:] select(grade for grade in Grade if 'elementary' in grade.task.course.subject.descr.lower())[:] select(grade for grade in Grade if 'elementary' in grade.task.course.subject.descr.lower() and grade.task.descr.startswith('Intermediate'))[:] select(grade for grade in Grade if grade.task.descr.startswith('Intermediate') and 'elementary' in grade.task.course.subject.descr.lower())[:] select(s for s in Student if s.group.dept.faculty.name == 'Abc')[:] select(g for g in Group if avg(g.students.grades.value) > 4)[:] select(g for g in Group if avg(g.students.grades.value) > 4 and max(g.students.grades.date) < date(2011, 3, 2))[:] select(g for g in Group if '4-A' in g.lessons.room.number)[:] select(g for g in Group if 1 in g.lessons.room.floor)[:] select(t for t in Teacher if t not in t.courses.groups.lessons.teacher)[:] sql_debug(True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/examples/demo.py0000666000000000000000000000457600000000000015632 0ustar0000000000000000from __future__ import absolute_import, print_function from decimal import Decimal from pony.orm import * db = Database("sqlite", "demo.sqlite", create_db=True) class Customer(db.Entity): id = PrimaryKey(int, auto=True) name = Required(str) email = Required(str, unique=True) orders = Set("Order") class Order(db.Entity): id = PrimaryKey(int, auto=True) total_price = Required(Decimal) customer = Required(Customer) items = Set("OrderItem") class Product(db.Entity): id = PrimaryKey(int, auto=True) name = Required(str) price = Required(Decimal) items = Set("OrderItem") class OrderItem(db.Entity): quantity = Required(int, default=1) order = Required(Order) product = Required(Product) PrimaryKey(order, product) sql_debug(True) db.generate_mapping(create_tables=True) def populate_database(): c1 = Customer(name='John Smith', email='john@example.com') c2 = Customer(name='Matthew Reed', email='matthew@example.com') c3 = Customer(name='Chuan Qin', email='chuanqin@example.com') c4 = Customer(name='Rebecca Lawson', email='rebecca@example.com') c5 = Customer(name='Oliver Blakey', email='oliver@example.com') p1 = Product(name='Kindle Fire HD', price=Decimal('284.00')) p2 = Product(name='Apple iPad with Retina Display', price=Decimal('478.50')) p3 = Product(name='SanDisk Cruzer 16 GB USB Flash Drive', price=Decimal('9.99')) p4 = Product(name='Kingston DataTraveler 16GB USB 2.0', price=Decimal('9.98')) p5 = Product(name='Samsung 840 Series 120GB SATA III SSD', price=Decimal('98.95')) p6 = Product(name='Crucial m4 256GB SSD SATA 6Gb/s', price=Decimal('188.67')) o1 = Order(customer=c1, total_price=Decimal('292.00')) OrderItem(order=o1, product=p1) OrderItem(order=o1, product=p4, quantity=2) o2 = Order(customer=c1, total_price=Decimal('478.50')) OrderItem(order=o2, product=p2) o3 = Order(customer=c2, total_price=Decimal('680.50')) OrderItem(order=o3, product=p2) OrderItem(order=o3, product=p4, quantity=2) OrderItem(order=o3, product=p6) o4 = Order(customer=c3, total_price=Decimal('99.80')) OrderItem(order=o4, product=p4, quantity=10) o5 = Order(customer=c4, total_price=Decimal('722.00')) OrderItem(order=o5, product=p1) OrderItem(order=o5, product=p2) commit() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/examples/estore.py0000666000000000000000000002517600000000000016206 0ustar0000000000000000from __future__ import absolute_import, print_function from decimal import Decimal from datetime import datetime from pony.converting import str2datetime from pony.orm import * db = Database("sqlite", "estore.sqlite", create_db=True) class Customer(db.Entity): email = Required(str, unique=True) password = Required(str) name = Required(str) country = Required(str) address = Required(str) cart_items = Set("CartItem") orders = Set("Order") class Product(db.Entity): id = PrimaryKey(int, auto=True) name = Required(str) categories = Set("Category") description = Optional(str) picture = Optional(buffer) price = Required(Decimal) quantity = Required(int) cart_items = Set("CartItem") order_items = Set("OrderItem") class CartItem(db.Entity): quantity = Required(int) customer = Required(Customer) product = Required(Product) class OrderItem(db.Entity): quantity = Required(int) price = Required(Decimal) order = Required("Order") product = Required(Product) PrimaryKey(order, product) class Order(db.Entity): id = PrimaryKey(int, auto=True) state = Required(str) date_created = Required(datetime) date_shipped = Optional(datetime) date_delivered = Optional(datetime) total_price = Required(Decimal) customer = Required(Customer) items = Set(OrderItem) class Category(db.Entity): name = Required(str, unique=True) products = Set(Product) sql_debug(True) db.generate_mapping(create_tables=True) # Order states CREATED = 'CREATED' SHIPPED = 'SHIPPED' DELIVERED = 'DELIVERED' CANCELLED = 'CANCELLED' @db_session def populate_database(): c1 = Customer(email='john@example.com', password='***', name='John Smith', country='USA', address='address 1') c2 = Customer(email='matthew@example.com', password='***', name='Matthew Reed', country='USA', address='address 2') c3 = Customer(email='chuanqin@example.com', password='***', name='Chuan Qin', country='China', address='address 3') c4 = Customer(email='rebecca@example.com', password='***', name='Rebecca Lawson', country='USA', address='address 4') c5 = Customer(email='oliver@example.com', password='***', name='Oliver Blakey', country='UK', address='address 5') tablets = Category(name='Tablets') flash_drives = Category(name='USB Flash Drives') ssd = Category(name='Solid State Drives') storage = Category(name='Data Storage') p1 = Product(name='Kindle Fire HD', price=Decimal('284.00'), quantity=120, description='Amazon tablet for web, movies, music, apps, ' 'games, reading and more', categories=[tablets]) p2 = Product(name='Apple iPad with Retina Display MD513LL/A (16GB, Wi-Fi, White)', price=Decimal('478.50'), quantity=180, description='iPad with Retina display now features an A6X chip, ' 'FaceTime HD camera, and faster Wi-Fi', categories=[tablets]) p3 = Product(name='SanDisk Cruzer 16 GB USB Flash Drive', price=Decimal('9.99'), quantity=400, description='Take it all with you on reliable ' 'SanDisk USB flash drive', categories=[flash_drives, storage]) p4 = Product(name='Kingston Digital DataTraveler SE9 16GB USB 2.0', price=Decimal('9.98'), quantity=350, description='Convenient - small, capless and pocket-sized ' 'for easy transportability', categories=[flash_drives, storage]) p5 = Product(name='Samsung 840 Series 2.5 inch 120GB SATA III SSD', price=Decimal('98.95'), quantity=0, description='Enables you to boot up your computer ' 'in as little as 15 seconds', categories=[ssd, storage]) p6 = Product(name='Crucial m4 256GB 2.5-Inch SSD SATA 6Gb/s CT256M4SSD2', price=Decimal('188.67'), quantity=60, description='The award-winning SSD delivers ' 'powerful performance gains for SATA 6Gb/s systems', categories=[ssd, storage]) CartItem(customer=c1, product=p1, quantity=1) CartItem(customer=c1, product=p2, quantity=1) CartItem(customer=c2, product=p5, quantity=2) o1 = Order(customer=c1, total_price=Decimal('292.00'), state=DELIVERED, date_created=str2datetime('2012-10-20 15:22:00'), date_shipped=str2datetime('2012-10-21 11:34:00'), date_delivered=str2datetime('2012-10-26 17:23:00')) OrderItem(order=o1, product=p1, price=Decimal('274.00'), quantity=1) OrderItem(order=o1, product=p4, price=Decimal('9.98'), quantity=2) o2 = Order(customer=c1, total_price=Decimal('478.50'), state=DELIVERED, date_created=str2datetime('2013-01-10 09:40:00'), date_shipped=str2datetime('2013-01-10 14:03:00'), date_delivered=str2datetime('2013-01-13 11:57:00')) OrderItem(order=o2, product=p2, price=Decimal('478.50'), quantity=1) o3 = Order(customer=c2, total_price=Decimal('680.50'), state=DELIVERED, date_created=str2datetime('2012-11-03 12:10:00'), date_shipped=str2datetime('2012-11-04 11:47:00'), date_delivered=str2datetime('2012-11-07 18:55:00')) OrderItem(order=o3, product=p2, price=Decimal('478.50'), quantity=1) OrderItem(order=o3, product=p4, price=Decimal('9.98'), quantity=2) OrderItem(order=o3, product=p6, price=Decimal('199.00'), quantity=1) o4 = Order(customer=c3, total_price=Decimal('99.80'), state=SHIPPED, date_created=str2datetime('2013-03-11 19:33:00'), date_shipped=str2datetime('2013-03-12 09:40:00')) OrderItem(order=o4, product=p4, price=Decimal('9.98'), quantity=10) o5 = Order(customer=c4, total_price=Decimal('722.00'), state=CREATED, date_created=str2datetime('2013-03-15 23:15:00')) OrderItem(order=o5, product=p1, price=Decimal('284.00'), quantity=1) OrderItem(order=o5, product=p2, price=Decimal('478.50'), quantity=1) @db_session def test_queries(): print('All USA customers') print() result = select(c for c in Customer if c.country == 'USA')[:] print(result) print() print('The number of customers for each country') print() result = select((c.country, count(c)) for c in Customer)[:] print(result) print() print('Max product price') print() result = max(p.price for p in Product) print(result) print() print('Max SSD price') print() result = max(p.price for p in Product for cat in p.categories if cat.name == 'Solid State Drives') print(result) print() print('Three most expensive products:') print() result = select(p for p in Product).order_by(desc(Product.price))[:3] print(result) print() print('Out of stock products') print() result = select(p for p in Product if p.quantity == 0)[:] print(result) print() print('Most popular product') print() result = select(p for p in Product).order_by(lambda p: desc(sum(p.order_items.quantity))).first() print(result) print() print('Products that have never been ordered') print() result = select(p for p in Product if not p.order_items)[:] print(result) print() print('Customers who made several orders') print() result = select(c for c in Customer if count(c.orders) > 1)[:] print(result) print() print('Three most valuable customers') print() result = select(c for c in Customer).order_by(lambda c: desc(sum(c.orders.total_price)))[:3] print(result) print() print('Customers whose orders were shipped') print() result = select(c for c in Customer if SHIPPED in c.orders.state)[:] print(result) print() print('The same query with the INNER JOIN instead of IN') print() result = select(c for c in Customer if JOIN(SHIPPED in c.orders.state))[:] print(result) print() print('Customers with no orders') print() result = select(c for c in Customer if not c.orders)[:] print(result) print() print('The same query with the LEFT JOIN instead of NOT EXISTS') print() result = left_join(c for c in Customer for o in c.orders if o is None)[:] print(result) print() print('Customers which ordered several different tablets') print() result = select(c for c in Customer for p in c.orders.items.product if 'Tablets' in p.categories.name and count(p) > 1)[:] print(result) print() print('Customers which ordered several products from the same category') print() result = select((customer, category.name) for customer in Customer for product in customer.orders.items.product for category in product.categories if count(product) > 1)[:] print(result) print() print('Customers which ordered several products from the same category in the same order') print() result = select((customer, order, category.name) for customer in Customer for order in customer.orders for product in order.items.product for category in product.categories if count(product) > 1)[:] print(result) print() print('Products whose price varies over time') print() result = select(p.name for p in Product if count(p.order_items.price) > 1)[:] print(result) print() print('The same query, but with min and max price for each product') print() result = select((p.name, min(p.order_items.price), max(p.order_items.price)) for p in Product if count(p.order_items.price) > 1)[:] print(result) print() print('Orders with a discount (order total price < sum of order item prices)') print() result = select(o for o in Order if o.total_price < sum(o.items.price * o.items.quantity))[:] print(result) print() if __name__ == '__main__': with db_session: if Customer.select().first() is None: populate_database() test_queries() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1539903198.0 pony-0.7.11/pony/orm/examples/inheritance1.py0000666000000000000000000000453700000000000017255 0ustar0000000000000000from __future__ import absolute_import, print_function from decimal import Decimal from datetime import date from pony import options options.CUT_TRACEBACK = False from pony.orm.core import * sql_debug(False) db = Database('sqlite', 'inheritance1.sqlite', create_db=True) class Person(db.Entity): id = PrimaryKey(int, auto=True) name = Required(str) dob = Optional(date) ssn = Required(str, unique=True) class Student(Person): group = Required("Group") mentor = Optional("Teacher") attend_courses = Set("Course") class Teacher(Person): teach_courses = Set("Course") apprentices = Set("Student") salary = Required(Decimal) class Assistant(Student, Teacher): pass class Professor(Teacher): position = Required(str) class Group(db.Entity): number = PrimaryKey(int) students = Set("Student") class Course(db.Entity): name = Required(str) semester = Required(int) students = Set(Student) teachers = Set(Teacher) PrimaryKey(name, semester) db.generate_mapping(create_tables=True) @db_session def populate_database(): if Person.select().first(): return # already populated p = Person(name='Person1', ssn='SSN1') g = Group(number=123) prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) s2 = Student(name='Student2', group=g, ssn='SSN3') commit() def show_all_persons(): for obj in Person.select(): print(obj) for attr in obj._attrs_: print(attr.name, "=", attr.__get__(obj)) print() if __name__ == '__main__': populate_database() # show_all_persons() sql_debug(True) with db_session: s1 = Student.get(name='Student1') if s1 is None: print('Student1 not found') else: mentor = s1.mentor print(mentor.name, 'is mentor of Student1') print('Is he assistant?', isinstance(mentor, Assistant)) print() for s in Student.select(lambda s: s.mentor.salary == 1000): print(s.name) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636028.0 pony-0.7.11/pony/orm/examples/test_numbers.py0000666000000000000000000000423600000000000017411 0ustar0000000000000000from __future__ import absolute_import, print_function from pony.orm.core import * db = Database() class Numbers(db.Entity): _table_ = "Numbers" id = PrimaryKey(int, auto=True) int8 = Required(int, size=8) # TINYINT int16 = Required(int, size=16) # SMALLINT int24 = Required(int, size=24) # MEDIUMINT int32 = Required(int, size=32) # INTEGER int64 = Required(int, size=64) # BIGINT uint8 = Required(int, size=8, unsigned=True) # TINYINT UNSIGNED uint16 = Required(int, size=16, unsigned=True) # SMALLINT UNSIGNED uint24 = Required(int, size=24, unsigned=True) # MEDIUMINT UNSIGNED uint32 = Required(int, size=32, unsigned=True) # INTEGER UNSIGNED # uint64 = Required(int, size=64, unsigned=True) # BIGINT UNSIGNED, supported by MySQL and Oracle sql_debug(True) # Output all SQL queries to stdout db.bind('sqlite', 'test_numbers.sqlite', create_db=True) #db.bind('mysql', host="localhost", user="pony", passwd="pony", db="test_numbers") #db.bind('postgres', user='pony', password='pony', host='localhost', database='test_numbers') #db.bind('oracle', 'test_numbers/pony@localhost') db.drop_table("Numbers", if_exists=True, with_all_data=True) db.generate_mapping(create_tables=True) @db_session def populate_database(): lo = Numbers(int8=-128, int16=-32768, int24=-8388608, int32=-2147483648, int64=-9223372036854775808, uint8=0, uint16=0, uint24=0, uint32=0) #, uint64=0) hi = Numbers(int8=127, int16=32767, int24=8388607, int32=2147483647, int64=9223372036854775807, uint8=255, uint16=65535, uint24=16777215, uint32=4294967295) # uint64=18446744073709551615) commit() @db_session def test_data(): for n in Numbers.select(): print(n.id, n.int8, n.int16, n.int24, n.int32, n.int64, n.uint8, n.uint16, n.uint24, n.uint32) #, n.uint64) if __name__ == '__main__': populate_database() test_data() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1559481756.0 pony-0.7.11/pony/orm/examples/tmp.py0000666000000000000000000001547100000000000015502 0ustar0000000000000000from collections import namedtuple from pony.orm import * class DataAccess: """ API implementation for a relation database access. """ @db_session def select(self, query, query_args=None, classname=None, output_as_dict=False): """ select() method uses the query string provided for retrieving information :param query: Query to be executed. Parameter to be passed had to be prefixed with $ sign. e.g: "select id, name from Person where name = $x" :param classname: :param query_args: Arguments that need to be passed to the query. e.g: {"x" : "Susan"} :param output_as_dict: When true, output will be returned as list of dict objects :e.g: [{"x" : "Susan"}], if below condition is satisfied, classname is None and "*" is not provided in select query for columns :return: """ data = self.db.select(query, query_args) # if classname or output as dict is true # creation of key values is done # else list of tuples is returned if classname or output_as_dict: column_names = query.split('select ', 1)[1].split('from')[0].strip() # if * is provided in the column names list # return the tuple list if column_names.find("*") == -1: return data column_names_list = column_names.split(',') data_as_dict_list = [] # loop over list of tuples # create a dict item # with column names from select list by index and # value as tuple element by index for item in data: dict_item = {} for idx, elem in enumerate(item): dict_item.update({column_names_list[idx].strip(): elem}) data_as_dict_list.append(dict_item) if classname: data_class_list = [] for idx, dict_item in enumerate(data_as_dict_list): data_as_dict_list.append(namedtuple(classname, dict_item.keys())(*dict_item.values())) return data_class_list else: return data_as_dict_list else: return data @db_session def create(self, query, callback=None, query_args=None): """ create() method takes the query string for creating table :param query: :param callback: :param query_args: Arguments that need to be passed to the query. e.g: {"x" : "Susan"} :return: """ data = self.db.execute(query, query_args) return data @db_session def insert(self, query, callback=None, query_args=None): """ insert() method for inserting an item :param query: :param callback: :param query_args: :return: """ data = self.db.execute(query, query_args) return data @db_session def update(self, query, callback=None, query_args=None): """ update() method for updating information of a given item :param query: :param callback: :param query_args: :return: """ data = self.db.execute(query, query_args) return data @db_session def delete(self, query, callback=None, query_args=None): """ delete() method for deleting an item :param query: :param callback: :param query_args: :return: """ data = self.db.execute(query, query_args) return data @db_session def execute(self, query, classname=None, query_args=None): """ execute() method for executing a procedure/function :param query: :param classname: :param query_args: :return: """ data = self.db.execute(query, query_args) return data @db_session def execute_procedure(self, procedure_name_params, procedure_args=None): """ execute_procedure() method for executing a procedure :param procedure_name_params: name of a procedure and if any IN params e.g: myprocedure($x, $y). ALL IN parameters had to be prefixed with $ symbol :param procedure_args: any parameters that had to be passed to procedure {'x': 10, 'y': 20} :return: """ data = self.db.execute("call " + procedure_name_params, procedure_args) return data @db_session def execute_function(self, function_name_params, function_args=None): """ execute_function() method for executing a function :param function_name_params: name of a function and if any IN params e.g: myfunction($x, $y). ALL IN parameters had to be prefixed with $ symbol :param function_args: any parameters that had to be passed to function {'x': 10, 'y': 20} :return: """ data = self.db.execute("select " + function_name_params, function_args) return data @db_session def begin_transaction(self): """ begin_transaction() method for initiating transaction :return: """ self.db.commit(False) @db_session def end_transaction(self): """ end_transaction() method for end a transaction :return: """ self.db.commit(True) @db_session def rollback(self): """ rollback() rollback a transaction :return: """ self.db.rollback() def __init__(self, provider, hostname, **kwargs): """ Initializes FTP, SFTP setup :param provider: [MySQL, Oracle, PostgreSQL] only these providers are supported as of now :param hostname: :param kwargs: username - username to access database password - password of the user db - database instance """ self.provider = provider self.hostname = hostname self.username = kwargs.get('username') self.password = kwargs.get('password') self.database = kwargs.get('database') if self.provider not in SupportedDatabase.__members__: db_list = ','.join(list(SupportedDatabase.__members__)) raise Exception( provider + ', is not supported at this time. Following databases are only supported : ' + db_list) self.db = Database() self.db.bind(provider=SupportedDatabase[self.provider].value, user=self.username, password=self.password, host=self.hostname, database=self.database) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1559487898.0 pony-0.7.11/pony/orm/examples/tmp2.py0000666000000000000000000000117400000000000015557 0ustar0000000000000000from pony import orm from pony.orm import db_session orm.sql_debug(True) db = orm.Database() class A(db.Entity): friends = orm.Set("B", reverse="friends") #, index="foo", reverse_index="bar") class B(db.Entity): friends = orm.Set("A", reverse="friends") db.bind(provider='sqlite', filename=":memory:", create_db=True) db.generate_mapping(check_tables=False) db.drop_all_tables(with_all_data=True) db.create_tables() if __name__ == "__main__": with db_session: c1 = A() c2 = B(friends=c1) with db_session: print(A.select()[:]) A.select().delete(bulk=True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1561221149.0 pony-0.7.11/pony/orm/examples/tmp3.py0000666000000000000000000000134600000000000015561 0ustar0000000000000000 db.bind(provider='postgres', user=args.username, password=args.password, host=args.host ,port=args.port, database='image_labeling') db.generate_mapping(create_tables=False) with db_session: unit = WorkUnit[437] assignedworkitems = list(select(item for item in Assignedworkitem if item.workunitid.unitid == unit.unitid)) for img in assignedworkitems: assignedworkitems += list(select( item for item in Assignedworkitem if item.workitemid == img.workitemid and str(item.workunitid.associatedlabels) == str(img.workunitid.associatedlabels) and item not in assignedworkitems )) print(assignedworkitems) unit.updatedate=datetime.now()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/examples/university1.py0000666000000000000000000001432000000000000017174 0ustar0000000000000000from __future__ import absolute_import, print_function from decimal import Decimal from datetime import date from pony.orm.core import * db = Database() class Department(db.Entity): number = PrimaryKey(int, auto=True) name = Required(str, unique=True) groups = Set("Group") courses = Set("Course") class Group(db.Entity): number = PrimaryKey(int) major = Required(str) dept = Required("Department") students = Set("Student") class Course(db.Entity): name = Required(str) semester = Required(int) lect_hours = Required(int) lab_hours = Required(int) credits = Required(int) dept = Required(Department) students = Set("Student") PrimaryKey(name, semester) class Student(db.Entity): # _table_ = "public", "Students" # Schema support id = PrimaryKey(int, auto=True) name = Required(str) dob = Required(date) tel = Optional(str) picture = Optional(buffer, lazy=True) gpa = Required(float, default=0) group = Required(Group) courses = Set(Course) sql_debug(True) # Output all SQL queries to stdout params = dict( sqlite=dict(provider='sqlite', filename='university1.sqlite', create_db=True), mysql=dict(provider='mysql', host="localhost", user="pony", passwd="pony", db="pony"), postgres=dict(provider='postgres', user='pony', password='pony', host='localhost', database='pony'), oracle=dict(provider='oracle', user='c##pony', password='pony', dsn='localhost/orcl') ) db.bind(**params['sqlite']) db.generate_mapping(create_tables=True) @db_session def populate_database(): if select(s for s in Student).count() > 0: return d1 = Department(name="Department of Computer Science") d2 = Department(name="Department of Mathematical Sciences") d3 = Department(name="Department of Applied Physics") c1 = Course(name="Web Design", semester=1, dept=d1, lect_hours=30, lab_hours=30, credits=3) c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, lect_hours=40, lab_hours=20, credits=4) c3 = Course(name="Linear Algebra", semester=1, dept=d2, lect_hours=30, lab_hours=30, credits=4) c4 = Course(name="Statistical Methods", semester=2, dept=d2, lect_hours=50, lab_hours=25, credits=5) c5 = Course(name="Thermodynamics", semester=2, dept=d3, lect_hours=25, lab_hours=40, credits=4) c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, lect_hours=40, lab_hours=30, credits=5) g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d1) g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) g105 = Group(number=105, major='B.E in Electronics', dept=d3) g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) s1 = Student(name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, courses=[c1, c2, c4, c6]) s2 = Student(name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, courses=[c1, c3, c4, c5]) s3 = Student(name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, courses=[c3, c5, c6]) s4 = Student(name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, courses=[c1, c4, c5, c6]) s5 = Student(name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, courses=[c1, c2, c4, c6]) s6 = Student(name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, courses=[c1, c2, c5]) s7 = Student(name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, courses=[c1, c3, c5, c6]) commit() def print_students(students): for s in students: print(s.name) print() @db_session def test_queries(): students = select(s for s in Student) print_students(students) students = select(s for s in Student if s.gpa > 3.4 and s.dob.year == 1990) print_students(students) students = select(s for s in Student if len(s.courses) < 4) print_students(students) students = select(s for s in Student if len(c for c in s.courses if c.dept.number == 1) < 4) print_students(students) students = select(s for s in Student if s.name.startswith("M")) print_students(students) students = select(s for s in Student if "Smith" in s.name) print_students(students) students = select(s for s in Student if "Web Design" in s.courses.name) print_students(students) print('Average GPA is', avg(s.gpa for s in Student)) print() students = select(s for s in Student if sum(c.credits for c in s.courses) < 15) print_students(students) students = select(s for s in Student if s.group.major == "B.E. in Computer Engineering") print_students(students) students = select(s for s in Student if s.group.dept.name == "Department of Computer Science") print_students(students) students = select(s for s in Student).order_by(Student.name) print_students(students) students = select(s for s in Student).order_by(Student.name)[2:4] print_students(students) students = select(s for s in Student).order_by(Student.name.desc) print_students(students) students = select(s for s in Student) \ .order_by(Student.group, Student.name.desc) print_students(students) students = select(s for s in Student if s.group.dept.name == "Department of Computer Science" and s.gpa > 3.5 and len(s.courses) > 3) print_students(students) ##if __name__ == '__main__': ## populate_database() ## test_queries() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571677707.0 pony-0.7.11/pony/orm/examples/university2.py0000666000000000000000000001104600000000000017177 0ustar0000000000000000from __future__ import absolute_import, print_function from pony.orm.core import * from decimal import Decimal from datetime import date db = Database() class Faculty(db.Entity): _table_ = 'Faculties' number = PrimaryKey(int) name = Required(str, unique=True) departments = Set('Department') class Department(db.Entity): _table_ = 'Departments' number = PrimaryKey(int) name = Required(str, unique=True) faculty = Required(Faculty) teachers = Set('Teacher') majors = Set('Major') groups = Set('Group') class Group(db.Entity): _table_ = 'Groups' number = PrimaryKey(int) grad_year = Required(int) department = Required(Department, column='dep') lessons = Set('Lesson', columns=['day_of_week', 'meeting_time', 'classroom_number', 'building']) students = Set('Student') class Student(db.Entity): _table_ = 'Students' name = Required(str) scholarship = Required(Decimal, 10, 2, default=Decimal('0.0')) group = Required(Group) grades = Set('Grade') class Major(db.Entity): _table_ = 'Majors' name = PrimaryKey(str) department = Required(Department) courses = Set('Course') class Subject(db.Entity): _table_ = 'Subjects' name = PrimaryKey(str) courses = Set('Course') teachers = Set('Teacher') class Course(db.Entity): _table_ = 'Courses' major = Required(Major) subject = Required(Subject) semester = Required(int) composite_key(major, subject, semester) lect_hours = Required(int) pract_hours = Required(int) credit = Required(int) lessons = Set('Lesson') grades = Set('Grade') class Lesson(db.Entity): _table_ = 'Lessons' day_of_week = Required(int) meeting_time = Required(int) classroom = Required('Classroom') PrimaryKey(day_of_week, meeting_time, classroom) course = Required(Course) teacher = Required('Teacher') groups = Set(Group) class Grade(db.Entity): _table_ = 'Grades' student = Required(Student) course = Required(Course) PrimaryKey(student, course) teacher = Required('Teacher') date = Required(date) value = Required(str) class Teacher(db.Entity): _table_ = 'Teachers' name = Required(str) degree = Optional(str) department = Required(Department) subjects = Set(Subject) lessons = Set(Lesson) grades = Set(Grade) class Building(db.Entity): _table_ = 'Buildings' number = PrimaryKey(str) description = Optional(str) classrooms = Set('Classroom') class Classroom(db.Entity): _table_ = 'Classrooms' building = Required(Building) number = Required(str) PrimaryKey(building, number) description = Optional(str) lessons = Set(Lesson) db.bind('sqlite', 'university2.sqlite', create_db=True) #db.bind('mysql', host='localhost', user='pony', passwd='pony', db='university2') #db.bind('postgres', user='pony', password='pony', host='localhost', database='university2') #db.bind('oracle', 'university2/pony@localhost') db.generate_mapping(create_tables=True) sql_debug(True) def test_queries(): # very simple query select(s for s in Student)[:] # one condition select(s for s in Student if s.scholarship > 0)[:] # multiple conditions select(s for s in Student if s.scholarship > 0 and s.group.number == 4142)[:] # no join here - attribute can be found in table Students select(s for s in Student if s.group.number == 4142)[:] # automatic join of two tables because grad_year is stored in table Groups select(s for s in Student if s.group.grad_year == 2011)[:] # still two tables are joined select(s for s in Student if s.group.department.number == 44)[:] # automatic join of tree tables select(s for s in Student if s.group.department.name == 'Ancient Philosophy')[:] # manual join of tables will produce equivalent query select(s for s in Student for g in Group if s.group == g and g.department.name == 'Ancient Philosophy')[:] # join two tables by composite foreign key select(c for c in Classroom for l in Lesson if l.classroom == c and l.course.subject.name == 'Physics')[:] # Lessons will be joined with Buildings directly without Classrooms select(s for s in Subject for l in Lesson if s == l.course.subject and l.classroom.building.description == 'some description')[:] # just another example of join of many tables select(c for c in Course if c.major.department.faculty.number == 4)[:] ././@PaxHeader0000000000000000000000000000003300000000000011451 xustar000000000000000027 mtime=1571864710.536708 pony-0.7.11/pony/orm/integration/0000777000000000000000000000000000000000000015025 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/integration/__init__.py0000666000000000000000000000000000000000000017124 0ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/integration/bottle_plugin.py0000666000000000000000000000066200000000000020252 0ustar0000000000000000from __future__ import absolute_import, print_function, division from bottle import HTTPResponse, HTTPError from pony.orm.core import db_session def is_allowed_exception(e): return isinstance(e, HTTPResponse) and not isinstance(e, HTTPError) class PonyPlugin(object): name = 'pony' api = 2 def apply(self, callback, route): return db_session(allowed_exceptions=is_allowed_exception)(callback) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/ormtypes.py0000666000000000000000000003456500000000000014753 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types, iteritems import sys, types, weakref from decimal import Decimal from datetime import date, time, datetime, timedelta from functools import wraps, WRAPPER_ASSIGNMENTS from uuid import UUID from pony.utils import throw, parse_expr, deref_proxy NoneType = type(None) class LongStr(str): lazy = True if PY2: class LongUnicode(unicode): lazy = True else: LongUnicode = LongStr class SetType(object): __slots__ = 'item_type' def __deepcopy__(self, memo): return self # SetType instances are "immutable" def __init__(self, item_type): self.item_type = item_type def __eq__(self, other): return type(other) is SetType and self.item_type == other.item_type def __ne__(self, other): return type(other) is not SetType or self.item_type != other.item_type def __hash__(self): return hash(self.item_type) + 1 class FuncType(object): __slots__ = 'func' def __deepcopy__(self, memo): return self # FuncType instances are "immutable" def __init__(self, func): self.func = func def __eq__(self, other): return type(other) is FuncType and self.func == other.func def __ne__(self, other): return type(other) is not FuncType or self.func != other.func def __hash__(self): return hash(self.func) + 1 def __repr__(self): return 'FuncType(%s at %d)' % (self.func.__name__, id(self.func)) class MethodType(object): __slots__ = 'obj', 'func' def __deepcopy__(self, memo): return self # MethodType instances are "immutable" def __init__(self, method): if PY2: self.obj = method.im_self self.func = method.im_func else: self.obj = method.__self__ self.func = method.__func__ def __eq__(self, other): return type(other) is MethodType and self.obj == other.obj and self.func == other.func def __ne__(self, other): return type(other) is not MethodType or self.obj != other.obj or self.func != other.func def __hash__(self): return hash(self.obj) ^ hash(self.func) raw_sql_cache = {} def parse_raw_sql(sql): result = raw_sql_cache.get(sql) if result is not None: return result assert isinstance(sql, basestring) and len(sql) > 0 items = [] codes = [] pos = 0 while True: try: i = sql.index('$', pos) except ValueError: items.append(sql[pos:]) break items.append(sql[pos:i]) if sql[i+1] == '$': items.append('$') pos = i+2 else: try: expr, _ = parse_expr(sql, i+1) except ValueError: raise ValueError(sql[i:]) pos = i+1 + len(expr) if expr.endswith(';'): expr = expr[:-1] code = compile(expr, '', 'eval') # expr correction check codes.append(code) items.append((expr, code)) result = tuple(items), tuple(codes) raw_sql_cache[sql] = result return result def raw_sql(sql, result_type=None): globals = sys._getframe(1).f_globals locals = sys._getframe(1).f_locals return RawSQL(sql, globals, locals, result_type) class RawSQL(object): def __deepcopy__(self, memo): assert False # should not attempt to deepcopy RawSQL instances, because of locals/globals def __init__(self, sql, globals=None, locals=None, result_type=None): self.sql = sql self.items, self.codes = parse_raw_sql(sql) self.types, self.values = normalize(tuple(eval(code, globals, locals) for code in self.codes)) self.result_type = result_type def _get_type_(self): return RawSQLType(self.sql, self.items, self.types, self.result_type) class RawSQLType(object): def __deepcopy__(self, memo): return self # RawSQLType instances are "immutable" def __init__(self, sql, items, types, result_type): self.sql = sql self.items = items self.types = types self.result_type = result_type def __hash__(self): return hash(self.sql) ^ hash(self.types) def __eq__(self, other): return type(other) is RawSQLType and self.sql == other.sql and self.types == other.types def __ne__(self, other): return not self.__eq__(other) class QueryType(object): def __init__(self, query, limit=None, offset=None): self.query_key = query._key self.translator = query._translator self.limit = limit self.offset = offset def __hash__(self): result = hash(self.query_key) if self.limit is not None: result ^= hash(self.limit + 3) if self.offset is not None: result ^= hash(self.offset) return result def __eq__(self, other): return type(other) is QueryType and self.query_key == other.query_key \ and self.limit == other.limit and self.offset == other.offset def __ne__(self, other): return not self.__eq__(other) def normalize(value): value = deref_proxy(value) t = type(value) if t is tuple: item_types, item_values = [], [] for item in value: item_type, item_value = normalize(item) item_values.append(item_value) item_types.append(item_type) return tuple(item_types), tuple(item_values) if t.__name__ == 'EntityMeta': return SetType(value), value if t.__name__ == 'EntityIter': entity = value.entity return SetType(entity), entity if PY2 and isinstance(value, str): try: value.decode('ascii') except UnicodeDecodeError: throw(TypeError, 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % value) else: return unicode, value elif isinstance(value, unicode): return unicode, value if t in function_types: return FuncType(value), value if t is types.MethodType: return MethodType(value), value if hasattr(value, '_get_type_'): return value._get_type_(), value return normalize_type(t), value def normalize_type(t): tt = type(t) if tt is tuple: return tuple(normalize_type(item) for item in t) if not isinstance(t, type): return t assert t.__name__ != 'EntityMeta' if tt.__name__ == 'EntityMeta': return t if t is NoneType: return t t = type_normalization_dict.get(t, t) if t in primitive_types: return t if t in (slice, type(Ellipsis)): return t if issubclass(t, basestring): return unicode if issubclass(t, (dict, Json)): return Json if issubclass(t, Array): return t throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { (int, float): float, (int, Decimal): Decimal, (date, datetime): datetime, (bool, int): int, (bool, float): float, (bool, Decimal): Decimal } coercions.update(((t2, t1), t3) for ((t1, t2), t3) in items_list(coercions)) def coerce_types(t1, t2): if t1 == t2: return t1 is_set_type = False if type(t1) is SetType: is_set_type = True t1 = t1.item_type if type(t2) is SetType: is_set_type = True t2 = t2.item_type result = coercions.get((t1, t2)) if result is not None and is_set_type: result = SetType(result) return result def are_comparable_types(t1, t2, op='=='): # types must be normalized already! tt1 = type(t1) tt2 = type(t2) t12 = {t1, t2} if Json in t12 and t12 < {Json, str, unicode, int, bool, float}: return True if op in ('in', 'not in'): if tt2 is RawSQLType: return True if tt2 is not SetType: return False op = '==' t2 = t2.item_type tt2 = type(t2) if op in ('is', 'is not'): return t1 is not None and t2 is NoneType if tt1 is tuple: if not tt2 is tuple: return False if len(t1) != len(t2): return False for item1, item2 in izip(t1, t2): if not are_comparable_types(item1, item2): return False return True if tt1 is RawSQLType or tt2 is RawSQLType: return True if op in ('==', '<>', '!='): if t1 is NoneType and t2 is NoneType: return False if t1 is NoneType or t2 is NoneType: return True if t1 in primitive_types: if t1 is t2: return True if (t1, t2) in coercions: return True if tt1 is not type or tt2 is not type: return False if issubclass(t1, int_types) and issubclass(t2, basestring): return True if issubclass(t2, int_types) and issubclass(t1, basestring): return True return False if tt1.__name__ == tt2.__name__ == 'EntityMeta': return t1._root_ is t2._root_ return False if t1 is t2 and t1 in comparable_types: return True return (t1, t2) in coercions class TrackedValue(object): def __init__(self, obj, attr): self.obj_ref = weakref.ref(obj) self.attr = attr @classmethod def make(cls, obj, attr, value): if isinstance(value, dict): return TrackedDict(obj, attr, value) if isinstance(value, list): return TrackedList(obj, attr, value) return value def _changed_(self): obj = self.obj_ref() if obj is not None: obj._attr_changed_(self.attr) def get_untracked(self): assert False, 'Abstract method' # pragma: no cover def tracked_method(func): @wraps(func, assigned=('__name__', '__doc__') if PY2 else WRAPPER_ASSIGNMENTS) def new_func(self, *args, **kwargs): obj = self.obj_ref() attr = self.attr if obj is not None: args = tuple(TrackedValue.make(obj, attr, arg) for arg in args) if kwargs: kwargs = {key: TrackedValue.make(obj, attr, value) for key, value in iteritems(kwargs)} result = func(self, *args, **kwargs) self._changed_() return result return new_func class TrackedDict(TrackedValue, dict): def __init__(self, obj, attr, value): TrackedValue.__init__(self, obj, attr) dict.__init__(self, {key: self.make(obj, attr, val) for key, val in iteritems(value)}) def __reduce__(self): return dict, (dict(self),) __setitem__ = tracked_method(dict.__setitem__) __delitem__ = tracked_method(dict.__delitem__) _update = tracked_method(dict.update) def update(self, *args, **kwargs): args = [ arg if isinstance(arg, dict) else dict(arg) for arg in args ] return self._update(*args, **kwargs) setdefault = tracked_method(dict.setdefault) pop = tracked_method(dict.pop) popitem = tracked_method(dict.popitem) clear = tracked_method(dict.clear) def get_untracked(self): return {key: val.get_untracked() if isinstance(val, TrackedValue) else val for key, val in self.items()} class TrackedList(TrackedValue, list): def __init__(self, obj, attr, value): TrackedValue.__init__(self, obj, attr) list.__init__(self, (self.make(obj, attr, val) for val in value)) def __reduce__(self): return list, (list(self),) __setitem__ = tracked_method(list.__setitem__) __delitem__ = tracked_method(list.__delitem__) extend = tracked_method(list.extend) append = tracked_method(list.append) pop = tracked_method(list.pop) remove = tracked_method(list.remove) insert = tracked_method(list.insert) reverse = tracked_method(list.reverse) sort = tracked_method(list.sort) if PY2: __setslice__ = tracked_method(list.__setslice__) else: clear = tracked_method(list.clear) def get_untracked(self): return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] def validate_item(item_type, item): if PY2 and isinstance(item, str): item = item.decode('ascii') if not isinstance(item, item_type): if item_type is not unicode and hasattr(item, '__index__'): return item.__index__() throw(TypeError, 'Cannot store %r item in array of %r' % (type(item).__name__, item_type.__name__)) return item class TrackedArray(TrackedList): def __init__(self, obj, attr, value): TrackedList.__init__(self, obj, attr, value) self.item_type = attr.py_type.item_type def extend(self, items): items = [validate_item(self.item_type, item) for item in items] TrackedList.extend(self, items) def append(self, item): item = validate_item(self.item_type, item) TrackedList.append(self, item) def insert(self, index, item): item = validate_item(self.item_type, item) TrackedList.insert(self, index, item) def __setitem__(self, index, item): item = validate_item(self.item_type, item) TrackedList.__setitem__(self, index, item) def __contains__(self, item): if not isinstance(item, basestring) and hasattr(item, '__iter__'): return all(it in set(self) for it in item) return list.__contains__(self, item) class Json(object): """A wrapper over a dict or list """ @classmethod def default_empty_value(cls): return {} def __init__(self, wrapped): self.wrapped = wrapped def __repr__(self): return '' % self.wrapped class Array(object): item_type = None # Should be overridden in subclass @classmethod def default_empty_value(cls): return [] class IntArray(Array): item_type = int class StrArray(Array): item_type = unicode class FloatArray(Array): item_type = float numeric_types = {bool, int, float, Decimal} comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID, IntArray, StrArray, FloatArray} primitive_types = comparable_types | {buffer} function_types = {type, types.FunctionType, types.BuiltinFunctionType} type_normalization_dict = { long : int } if PY2 else {} array_types = { int: IntArray, float: FloatArray, unicode: StrArray } ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/serialization.py0000666000000000000000000001172000000000000015732 0ustar0000000000000000from pony.py23compat import iteritems import json from datetime import date, datetime from decimal import Decimal from collections import defaultdict from pony.orm.core import Entity, TransactionError from pony.utils import cut_traceback, throw class Bag(object): def __init__(bag, database): bag.database = database bag.session_cache = None bag.entity_configs = {} bag.objects = defaultdict(set) bag.vars = {} bag.dicts = defaultdict(dict) @cut_traceback def config(bag, entity, only=None, exclude=None, with_collections=True, with_lazy=False, related_objects=True): if bag.database.entities.get(entity.__name__) is not entity: throw(TypeError, 'Entity %s does not belong to database %r' % (entity.__name__, bag.database)) attrs = entity._get_attrs_(only, exclude, with_collections, with_lazy) bag.entity_configs[entity] = attrs, related_objects return attrs, related_objects @cut_traceback def put(bag, x): if isinstance(x, Entity): bag._put_object(x) else: try: x = list(x) except: throw(TypeError, 'Entity instance or a sequence of instances expected. Got: %r' % x) for item in x: if not isinstance(item, Entity): throw(TypeError, 'Entity instance or a sequence of instances expected. Got: %r' % item) bag._put_object(item) def _put_object(bag, obj): entity = obj.__class__ if bag.database.entities.get(entity.__name__) is not entity: throw(TypeError, 'Entity %s does not belong to database %r' % (entity.__name__, bag.database)) cache = bag.session_cache if cache is None: cache = bag.session_cache = obj._session_cache_ elif obj._session_cache_ is not cache: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') bag.objects[entity].add(obj) def _reduce_composite_pk(bag, pk): return ','.join(str(item).replace('*', '**').replace(',', '*,') for item in pk) @cut_traceback def to_dict(bag): bag.dicts.clear() for entity, objects in iteritems(bag.objects): for obj in objects: dicts = bag.dicts[entity] if obj not in dicts: bag._process_object(obj) result = defaultdict(dict) for entity, dicts in iteritems(bag.dicts): composite_pk = len(entity._pk_columns_) > 1 for obj, d in iteritems(dicts): pk = obj._get_raw_pkval_() if composite_pk: pk = bag._reduce_composite_pk(pk) else: pk = pk[0] result[entity.__name__][pk] = d bag.dicts.clear() return result def _process_object(bag, obj, process_related=True): entity = obj.__class__ try: attrs, related_objects = bag.entity_configs[entity] except KeyError: attrs, related_objects = bag.config(entity) process_related_objects = process_related and related_objects d = {} for attr in attrs: value = attr.__get__(obj) if attr.is_collection: if not process_related: continue if process_related_objects: for related_obj in value: if related_obj not in bag.dicts: bag._process_object(related_obj, process_related=False) if attr.reverse.entity._pk_is_composite_: value = sorted(bag._reduce_composite_pk(item._get_raw_pkval_()) for item in value) else: value = sorted(item._get_raw_pkval_()[0] for item in value) elif attr.is_relation: if value is not None: if process_related_objects: bag._process_object(value, process_related=False) value = value._get_raw_pkval_() if len(value) == 1: value = value[0] d[attr.name] = value bag.dicts[entity][obj] = d @cut_traceback def to_json(bag): return json.dumps(bag.to_dict(), default=json_converter, indent=2, sort_keys=True) def to_dict(objects): if isinstance(objects, Entity): objects = [ objects ] objects = iter(objects) try: first_object = next(objects) except StopIteration: return {} if not isinstance(first_object, Entity): throw(TypeError, 'Entity instance or a sequence of instances expected. Got: %r' % first_object) database = first_object._database_ bag = Bag(database) bag.put(first_object) bag.put(objects) return dict(bag.to_dict()) def to_json(objects): return json.dumps(to_dict(objects), default=json_converter, indent=2, sort_keys=True) def json_converter(x): if isinstance(x, (datetime, date, Decimal)): return str(x) raise TypeError(x) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/sqlbuilding.py0000666000000000000000000006766600000000000015416 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, imap, itervalues, basestring, unicode, buffer, int_types from operator import attrgetter from decimal import Decimal from datetime import date, datetime, timedelta from binascii import hexlify from pony import options from pony.utils import datetime2timestamp, throw, is_ident from pony.converting import timedelta2str from pony.orm.ormtypes import RawSQL, Json class AstError(Exception): pass class Param(object): __slots__ = 'style', 'id', 'paramkey', 'converter', 'optimistic' def __init__(param, paramstyle, paramkey, converter=None, optimistic=False): param.style = paramstyle param.id = None param.paramkey = paramkey param.converter = converter param.optimistic = optimistic def eval(param, values): varkey, i, j = param.paramkey value = values[varkey] if i is not None: t = type(value) if t is tuple: value = value[i] elif t is RawSQL: value = value.values[i] elif hasattr(value, '_get_items'): value = value._get_items()[i] else: assert False, t if j is not None: assert type(type(value)).__name__ == 'EntityMeta' value = value._get_raw_pkval_()[j] converter = param.converter if value is not None and converter is not None: if converter.attr is None: value = converter.val2dbval(value) value = converter.py2sql(value) return value def __unicode__(param): paramstyle = param.style if paramstyle == 'qmark': return u'?' elif paramstyle == 'format': return u'%s' elif paramstyle == 'numeric': return u':%d' % param.id elif paramstyle == 'named': return u':p%d' % param.id elif paramstyle == 'pyformat': return u'%%(p%d)s' % param.id else: throw(NotImplementedError) if not PY2: __str__ = __unicode__ def __repr__(param): return '%s(%r)' % (param.__class__.__name__, param.paramkey) class CompositeParam(Param): __slots__ = 'items', 'func' def __init__(param, paramstyle, paramkey, items, func): for item in items: assert isinstance(item, (Param, Value)), item Param.__init__(param, paramstyle, paramkey) param.items = items param.func = func def eval(param, values): args = [ item.eval(values) if isinstance(item, Param) else item.value for item in param.items ] return param.func(args) class Value(object): __slots__ = 'paramstyle', 'value' def __init__(self, paramstyle, value): self.paramstyle = paramstyle self.value = value def __unicode__(self): value = self.value if value is None: return 'null' if isinstance(value, bool): return value and '1' or '0' if isinstance(value, basestring): return self.quote_str(value) if isinstance(value, datetime): return 'TIMESTAMP ' + self.quote_str(datetime2timestamp(value)) if isinstance(value, date): return 'DATE ' + self.quote_str(str(value)) if isinstance(value, timedelta): return "INTERVAL '%s' HOUR TO SECOND" % timedelta2str(value) if PY2: if isinstance(value, (int, long, float, Decimal)): return str(value) if isinstance(value, buffer): return "X'%s'" % hexlify(value) else: if isinstance(value, (int, float, Decimal)): return str(value) if isinstance(value, bytes): return "X'%s'" % hexlify(value).decode('ascii') assert False, repr(value) # pragma: no cover if not PY2: __str__ = __unicode__ def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.value) def quote_str(self, s): if self.paramstyle in ('format', 'pyformat'): s = s.replace('%', '%%') return "'%s'" % s.replace("'", "''") def flat(tree): stack = [ tree ] result = [] stack_pop = stack.pop stack_extend = stack.extend result_append = result.append while stack: x = stack_pop() if isinstance(x, basestring): result_append(x) else: try: stack_extend(reversed(x)) except TypeError: result_append(x) return result def flat_conditions(conditions): result = [] for condition in conditions: if condition[0] == 'AND': result.extend(flat_conditions(condition[1:])) else: result.append(condition) return result def join(delimiter, items): items = iter(items) try: result = [ next(items) ] except StopIteration: return [] for item in items: result.append(delimiter) result.append(item) return result def move_conditions_from_inner_join_to_where(sections): new_sections = list(sections) for i, section in enumerate(sections): if section[0] == 'FROM': new_from_list = [ 'FROM' ] + [ list(item) for item in section[1:] ] new_sections[i] = new_from_list if len(sections) > i+1 and sections[i+1][0] == 'WHERE': new_where_list = list(sections[i+1]) new_sections[i+1] = new_where_list else: new_where_list = [ 'WHERE' ] new_sections.insert(i+1, new_where_list) break else: return sections for join in new_from_list[2:]: if join[1] in ('TABLE', 'SELECT') and len(join) == 4: new_where_list.append(join.pop()) return new_sections def make_binary_op(symbol, default_parentheses=False): def binary_op(builder, expr1, expr2, parentheses=None): if parentheses is None: parentheses = default_parentheses if parentheses: return '(', builder(expr1), symbol, builder(expr2), ')' return builder(expr1), symbol, builder(expr2) return binary_op def make_unary_func(symbol): def unary_func(builder, expr): return '%s(' % symbol, builder(expr), ')' return unary_func def indentable(method): def new_method(builder, *args, **kwargs): result = method(builder, *args, **kwargs) if builder.indent <= 1: return result return builder.indent_spaces * (builder.indent-1), result new_method.__name__ = method.__name__ return new_method class SQLBuilder(object): dialect = None param_class = Param composite_param_class = CompositeParam value_class = Value indent_spaces = " " * 4 least_func_name = 'least' greatest_func_name = 'greatest' def __init__(builder, provider, ast): builder.provider = provider builder.quote_name = provider.quote_name builder.paramstyle = paramstyle = provider.paramstyle builder.ast = ast builder.indent = 0 builder.keys = {} builder.inner_join_syntax = options.INNER_JOIN_SYNTAX builder.suppress_aliases = False builder.result = flat(builder(ast)) params = tuple(x for x in builder.result if isinstance(x, Param)) layout = [] for i, param in enumerate(params): if param.id is None: param.id = i + 1 layout.append(param.paramkey) builder.layout = layout builder.sql = u''.join(imap(unicode, builder.result)).rstrip('\n') if paramstyle in ('qmark', 'format'): def adapter(values): return tuple(param.eval(values) for param in params) elif paramstyle == 'numeric': def adapter(values): return tuple(param.eval(values) for param in params) elif paramstyle in ('named', 'pyformat'): def adapter(values): return {'p%d' % param.id: param.eval(values) for param in params} else: throw(NotImplementedError, paramstyle) builder.params = params builder.adapter = adapter def __call__(builder, ast): if isinstance(ast, basestring): throw(AstError, 'An SQL AST list was expected. Got string: %r' % ast) symbol = ast[0] if not isinstance(symbol, basestring): throw(AstError, 'Invalid node name in AST: %r' % ast) method = getattr(builder, symbol, None) if method is None: throw(AstError, 'Method not found: %s' % symbol) try: return method(*ast[1:]) except TypeError: raise ## traceback = sys.exc_info()[2] ## if traceback.tb_next is None: ## del traceback ## throw(AstError, 'Invalid data for method %s: %r' ## % (symbol, ast[1:])) ## else: ## del traceback ## raise def INSERT(builder, table_name, columns, values, returning=None): return [ 'INSERT INTO ', builder.quote_name(table_name), ' (', join(', ', [builder.quote_name(column) for column in columns ]), ') VALUES (', join(', ', [builder(value) for value in values]), ')' ] def DEFAULT(builder): return 'DEFAULT' def UPDATE(builder, table_name, pairs, where=None): return [ 'UPDATE ', builder.quote_name(table_name), '\nSET ', join(', ', [ (builder.quote_name(name), ' = ', builder(param)) for name, param in pairs]), where and [ '\n', builder(where) ] or [] ] def DELETE(builder, alias, from_ast, where=None): builder.indent += 1 if alias is not None: assert isinstance(alias, basestring) if not where: return 'DELETE ', alias, ' ', builder(from_ast) return 'DELETE ', alias, ' ', builder(from_ast), builder(where) else: assert from_ast[0] == 'FROM' and len(from_ast) == 2 and from_ast[1][1] == 'TABLE' alias = from_ast[1][0] if alias is not None: builder.suppress_aliases = True if not where: return 'DELETE ', builder(from_ast) return 'DELETE ', builder(from_ast), builder(where) def _subquery(builder, *sections): builder.indent += 1 if not builder.inner_join_syntax: sections = move_conditions_from_inner_join_to_where(sections) result = [ builder(s) for s in sections ] builder.indent -= 1 return result def SELECT(builder, *sections): prev_suppress_aliases = builder.suppress_aliases builder.suppress_aliases = False try: result = builder._subquery(*sections) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' return result finally: builder.suppress_aliases = prev_suppress_aliases def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent result = builder.SELECT(*sections) nowait = ' NOWAIT' if nowait else '' skip_locked = ' SKIP LOCKED' if skip_locked else '' return result, 'FOR UPDATE', nowait, skip_locked, '\n' def EXISTS(builder, *sections): result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent return 'EXISTS (\n', indent, 'SELECT 1\n', result, indent, ')' def NOT_EXISTS(builder, *sections): return 'NOT ', builder.EXISTS(*sections) @indentable def ALL(builder, *expr_list): exprs = [ builder(e) for e in expr_list ] return 'SELECT ', join(', ', exprs), '\n' @indentable def DISTINCT(builder, *expr_list): exprs = [ builder(e) for e in expr_list ] return 'SELECT DISTINCT ', join(', ', exprs), '\n' @indentable def AGGREGATES(builder, *expr_list): exprs = [ builder(e) for e in expr_list ] return 'SELECT ', join(', ', exprs), '\n' def AS(builder, expr, alias): return builder(expr), ' AS ', builder.quote_name(alias) def compound_name(builder, name_parts): return '.'.join(p and builder.quote_name(p) or '' for p in name_parts) def sql_join(builder, join_type, sources): indent = builder.indent_spaces * (builder.indent-1) indent2 = indent + builder.indent_spaces indent3 = indent2 + builder.indent_spaces result = [ indent, 'FROM '] for i, source in enumerate(sources): if len(source) == 3: alias, kind, x = source join_cond = None elif len(source) == 4: alias, kind, x, join_cond = source else: throw(AstError, 'Invalid source in FROM section: %r' % source) if i > 0: if join_cond is None: result.append(', ') else: result += [ '\n', indent, ' %s JOIN ' % join_type ] if builder.suppress_aliases: alias = None elif alias is not None: alias = builder.quote_name(alias) if kind == 'TABLE': if isinstance(x, basestring): result.append(builder.quote_name(x)) else: result.append(builder.compound_name(x)) if alias is not None: result += ' ', alias # Oracle does not support 'AS' here elif kind == 'SELECT': if alias is None: throw(AstError, 'Subquery in FROM section must have an alias') result += builder.SELECT(*x), ' ', alias # Oracle does not support 'AS' here else: throw(AstError, 'Invalid source kind in FROM section: %r' % kind) if join_cond is not None: result += [ '\n', indent2, 'ON ', builder(join_cond) ] result.append('\n') return result def FROM(builder, *sources): return builder.sql_join('INNER', sources) def INNER_JOIN(builder, *sources): builder.inner_join_syntax = True return builder.sql_join('INNER', sources) @indentable def LEFT_JOIN(builder, *sources): return builder.sql_join('LEFT', sources) def WHERE(builder, *conditions): if not conditions: return '' conditions = flat_conditions(conditions) indent = builder.indent_spaces * (builder.indent-1) result = [ indent, 'WHERE ' ] extend = result.extend extend((builder(conditions[0]), '\n')) for condition in conditions[1:]: extend((indent, ' AND ', builder(condition), '\n')) return result def HAVING(builder, *conditions): if not conditions: return '' conditions = flat_conditions(conditions) indent = builder.indent_spaces * (builder.indent-1) result = [ indent, 'HAVING ' ] extend = result.extend extend((builder(conditions[0]), '\n')) for condition in conditions[1:]: extend((indent, ' AND ', builder(condition), '\n')) return result @indentable def GROUP_BY(builder, *expr_list): exprs = [ builder(e) for e in expr_list ] return 'GROUP BY ', join(', ', exprs), '\n' @indentable def UNION(builder, kind, *sections): return 'UNION ', kind, '\n', builder.SELECT(*sections) @indentable def INTERSECT(builder, *sections): return 'INTERSECT\n', builder.SELECT(*sections) @indentable def EXCEPT(builder, *sections): return 'EXCEPT\n', builder.SELECT(*sections) @indentable def ORDER_BY(builder, *order_list): result = [ 'ORDER BY ' ] result.extend(join(', ', [ builder(expr) for expr in order_list ])) result.append('\n') return result def DESC(builder, expr): return builder(expr), ' DESC' @indentable def LIMIT(builder, limit, offset=None): if limit is None: limit = 'null' else: assert isinstance(limit, int_types) assert offset is None or isinstance(offset, int) if offset: return 'LIMIT %s OFFSET %d\n' % (limit, offset) else: return 'LIMIT %s\n' % limit def COLUMN(builder, table_alias, col_name): if builder.suppress_aliases or not table_alias: return [ '%s' % builder.quote_name(col_name) ] return [ '%s.%s' % (builder.quote_name(table_alias), builder.quote_name(col_name)) ] def PARAM(builder, paramkey, converter=None, optimistic=False): return builder.make_param(builder.param_class, paramkey, converter, optimistic) def make_param(builder, param_class, paramkey, *args): keys = builder.keys param = keys.get(paramkey) if param is None: param = param_class(builder.paramstyle, paramkey, *args) keys[paramkey] = param return param def make_composite_param(builder, paramkey, items, func): return builder.make_param(builder.composite_param_class, paramkey, items, func) def STAR(builder, table_alias): return builder.quote_name(table_alias), '.*' def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): return builder.value_class(builder.paramstyle, value) def AND(builder, *cond_list): cond_list = [ builder(condition) for condition in cond_list ] return join(' AND ', cond_list) def OR(builder, *cond_list): cond_list = [ builder(condition) for condition in cond_list ] return '(', join(' OR ', cond_list), ')' def NOT(builder, condition): return 'NOT (', builder(condition), ')' def POW(builder, expr1, expr2): return 'power(', builder(expr1), ', ', builder(expr2), ')' EQ = make_binary_op(' = ') NE = make_binary_op(' <> ') LT = make_binary_op(' < ') LE = make_binary_op(' <= ') GT = make_binary_op(' > ') GE = make_binary_op(' >= ') ADD = make_binary_op(' + ', True) SUB = make_binary_op(' - ', True) MUL = make_binary_op(' * ', True) DIV = make_binary_op(' / ', True) FLOORDIV = make_binary_op(' / ', True) def MOD(builder, a, b): symbol = ' %% ' if builder.paramstyle in ('format', 'pyformat') else ' % ' return '(', builder(a), symbol, builder(b), ')' def FLOAT_EQ(builder, a, b): a, b = builder(a), builder(b) return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' def FLOAT_NE(builder, a, b): a, b = builder(a), builder(b) return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' def CONCAT(builder, *args): return '(', join(' || ', imap(builder, args)), ')' def NEG(builder, expr): return '-(', builder(expr), ')' def IS_NULL(builder, expr): return builder(expr), ' IS NULL' def IS_NOT_NULL(builder, expr): return builder(expr), ' IS NOT NULL' def LIKE(builder, expr, template, escape=None): result = builder(expr), ' LIKE ', builder(template) if escape: result = result + (' ESCAPE ', builder(escape)) return result def NOT_LIKE(builder, expr, template, escape=None): result = builder(expr), ' NOT LIKE ', builder(template) if escape: result = result + (' ESCAPE ', builder(escape)) return result def BETWEEN(builder, expr1, expr2, expr3): return builder(expr1), ' BETWEEN ', builder(expr2), ' AND ', builder(expr3) def NOT_BETWEEN(builder, expr1, expr2, expr3): return builder(expr1), ' NOT BETWEEN ', builder(expr2), ' AND ', builder(expr3) def IN(builder, expr1, x): if not x: return '0 = 1' if len(x) >= 1 and x[0] == 'SELECT': return builder(expr1), ' IN ', builder(x) expr_list = [ builder(expr) for expr in x ] return builder(expr1), ' IN (', join(', ', expr_list), ')' def NOT_IN(builder, expr1, x): if not x: return '1 = 1' if len(x) >= 1 and x[0] == 'SELECT': return builder(expr1), ' NOT IN ', builder(x) expr_list = [ builder(expr) for expr in x ] return builder(expr1), ' NOT IN (', join(', ', expr_list), ')' def COUNT(builder, distinct, *expr_list): assert distinct in (None, True, False) if not distinct: if not expr_list: return ['COUNT(*)'] return 'COUNT(', join(', ', imap(builder, expr_list)), ')' if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') if len(expr_list) == 1: return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' if builder.dialect == 'PostgreSQL': return 'COUNT(DISTINCT ', builder.ROW(*expr_list), ')' elif builder.dialect == 'MySQL': return 'COUNT(DISTINCT ', join(', ', imap(builder, expr_list)), ')' # Oracle and SQLite queries translated to completely different subquery syntax else: throw(NotImplementedError) # This line must not be executed def SUM(builder, distinct, expr): assert distinct in (None, True, False) return distinct and 'coalesce(SUM(DISTINCT ' or 'coalesce(SUM(', builder(expr), '), 0)' def AVG(builder, distinct, expr): assert distinct in (None, True, False) return distinct and 'AVG(DISTINCT ' or 'AVG(', builder(expr), ')' def GROUP_CONCAT(builder, distinct, expr, sep=None): assert distinct in (None, True, False) result = distinct and 'GROUP_CONCAT(DISTINCT ' or 'GROUP_CONCAT(', builder(expr) if sep is not None: if builder.provider.dialect == 'MySQL': result = result, ' SEPARATOR ', builder(sep) else: result = result, ', ', builder(sep) return result, ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') LENGTH = make_unary_func('length') ABS = make_unary_func('abs') def COALESCE(builder, *args): if len(args) < 2: assert False # pragma: no cover return 'coalesce(', join(', ', imap(builder, args)), ')' def MIN(builder, distinct, *args): assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MIN' else: fname = builder.least_func_name return fname, '(', join(', ', imap(builder, args)), ')' def MAX(builder, distinct, *args): assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MAX' else: fname = builder.greatest_func_name return fname, '(', join(', ', imap(builder, args)), ')' def SUBSTR(builder, expr, start, len=None): if len is None: return 'substr(', builder(expr), ', ', builder(start), ')' return 'substr(', builder(expr), ', ', builder(start), ', ', builder(len), ')' def CASE(builder, expr, cases, default=None): if expr is None and default is not None and default[0] == 'CASE' and default[1] is None: cases2, default2 = default[2:] return builder.CASE(None, tuple(cases) + tuple(cases2), default2) result = [ 'case' ] if expr is not None: result.append(' ') result.extend(builder(expr)) for condition, expr in cases: result.extend((' when ', builder(condition), ' then ', builder(expr))) if default is not None: result.extend((' else ', builder(default))) result.append(' end') return result def TRIM(builder, expr, chars=None): if chars is None: return 'trim(', builder(expr), ')' return 'trim(', builder(expr), ', ', builder(chars), ')' def LTRIM(builder, expr, chars=None): if chars is None: return 'ltrim(', builder(expr), ')' return 'ltrim(', builder(expr), ', ', builder(chars), ')' def RTRIM(builder, expr, chars=None): if chars is None: return 'rtrim(', builder(expr), ')' return 'rtrim(', builder(expr), ', ', builder(chars), ')' def REPLACE(builder, str, from_, to): return 'replace(', builder(str), ', ', builder(from_), ', ', builder(to), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS integer)' def TO_STR(builder, expr): return 'CAST(', builder(expr), ' AS text)' def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS real)' def TODAY(builder): return 'CURRENT_DATE' def NOW(builder): return 'CURRENT_TIMESTAMP' def DATE(builder, expr): return 'DATE(', builder(expr) ,')' def YEAR(builder, expr): return 'EXTRACT(YEAR FROM ', builder(expr), ')' def MONTH(builder, expr): return 'EXTRACT(MONTH FROM ', builder(expr), ')' def DAY(builder, expr): return 'EXTRACT(DAY FROM ', builder(expr), ')' def HOUR(builder, expr): return 'EXTRACT(HOUR FROM ', builder(expr), ')' def MINUTE(builder, expr): return 'EXTRACT(MINUTE FROM ', builder(expr), ')' def SECOND(builder, expr): return 'EXTRACT(SECOND FROM ', builder(expr), ')' def RANDOM(builder): return 'RAND()' def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] def build_json_path(builder, path): empty_slice = slice(None, None, None) has_params = False has_wildcards = False items = [ builder(element) for element in path ] for item in items: if isinstance(item, Param): has_params = True elif isinstance(item, Value): value = item.value if value is Ellipsis or value == empty_slice: has_wildcards = True else: assert isinstance(value, (int, basestring)), value else: assert False, item if has_params: paramkey = tuple(item.paramkey if isinstance(item, Param) else None if type(item.value) is slice else item.value for item in items) path_sql = builder.make_composite_param(paramkey, items, builder.eval_json_path) else: result_value = builder.eval_json_path(item.value for item in items) path_sql = builder.value_class(builder.paramstyle, result_value) return path_sql, has_params, has_wildcards @classmethod def eval_json_path(cls, values): result = ['$'] append = result.append empty_slice = slice(None, None, None) for value in values: if isinstance(value, int): append('[%d]' % value) elif isinstance(value, basestring): append('.' + value if is_ident(value) else '."%s"' % value.replace('"', '\\"')) elif value is Ellipsis: append('.*') elif value == empty_slice: append('[*]') else: assert False, value return ''.join(result) def JSON_QUERY(builder, expr, path): throw(NotImplementedError) def JSON_VALUE(builder, expr, path, type): throw(NotImplementedError) def JSON_NONZERO(builder, expr): throw(NotImplementedError) def JSON_CONCAT(builder, left, right): throw(NotImplementedError) def JSON_CONTAINS(builder, expr, path, key): throw(NotImplementedError) def JSON_ARRAY_LENGTH(builder, value): throw(NotImplementedError) def JSON_PARAM(builder, expr): return builder(expr) def ARRAY_INDEX(builder, col, index): throw(NotImplementedError) def ARRAY_CONTAINS(builder, key, not_in, col): throw(NotImplementedError) def ARRAY_SUBSET(builder, array1, not_in, array2): throw(NotImplementedError) def ARRAY_LENGTH(builder, array): throw(NotImplementedError) def ARRAY_SLICE(builder, array, start, stop): throw(NotImplementedError) def MAKE_ARRAY(builder, *items): throw(NotImplementedError) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/sqlsymbols.py0000666000000000000000000000206100000000000015263 0ustar0000000000000000symbols = [ 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'SELECT_FOR_UPDATE', 'FROM', 'INNER_JOIN', 'LEFT_JOIN', 'WHERE', 'GROUP_BY', 'HAVING', 'UNION', 'INTERSECT', 'EXCEPT', 'ORDER_BY', 'LIMIT', 'ASC', 'DESC', 'DISTINCT', 'ALL', 'AGGREGATES', 'AS', 'COUNT', 'SUM', 'MIN', 'MAX', 'AVG', 'TABLE', 'COLUMN', 'PARAM', 'VALUE', 'AND', 'OR', 'NOT', 'EQ', 'NE', 'LT', 'LE', 'GT', 'GE', 'IS_NULL', 'IS_NOT_NULL', 'LIKE', 'NOT_LIKE', 'BETWEEN', 'NOT_BETWEEN', 'IN', 'NOT_IN', 'EXISTS', 'NOT_EXISTS', 'ROW', 'ADD', 'SUB', 'MUL', 'DIV', 'POW', 'NEG', 'ABS', 'UPPER', 'LOWER', 'CONCAT', 'STRIN', 'LIKE', 'SUBSTR', 'LENGTH', 'TRIM', 'LTRIM', 'RTRIM', 'REPLACE', 'CASE', 'COALESCE', 'TO_INT', 'RANDOM', 'DATE', 'YEAR', 'MONTH', 'DAY', 'HOUR', 'MINUTE', 'SECOND', 'TODAY', 'NOW', 'DATE_ADD', 'DATE_SUB', 'DATETIME_ADD', 'DATETIME_SUB' ] globals().update((s, s) for s in symbols) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/sqltranslation.py0000666000000000000000000052357100000000000016147 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass, int_types import types, sys, re, itertools, inspect from decimal import Decimal from datetime import date, time, datetime, timedelta from random import random from copy import deepcopy from functools import update_wrapper from uuid import UUID from pony.thirdparty.compiler import ast from pony import options, utils from pony.utils import localbase, is_ident, throw, reraise, copy_ast, between, concat, coalesce from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors from pony.orm.decompiling import decompile, DecompileError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, raw_sql, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ Json, QueryType, Array, array_types from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ special_functions, const_functions, extract_vars, Query, UseAnotherTranslator NoneType = type(None) def check_comparable(left_monad, right_monad, op='=='): t1, t2 = left_monad.type, right_monad.type if t1 == 'METHOD': raise_forgot_parentheses(left_monad) if t2 == 'METHOD': raise_forgot_parentheses(right_monad) if not are_comparable_types(t1, t2, op): if op in ('in', 'not in') and isinstance(t2, SetType): t2 = t2.item_type throw(IncomparableTypesError, t1, t2) class IncomparableTypesError(TypeError): def __init__(exc, type1, type2): msg = 'Incomparable types %r and %r in expression: {EXPR}' % (type2str(type1), type2str(type2)) TypeError.__init__(exc, msg) exc.type1 = type1 exc.type2 = type2 def sqland(items): if not items: return [] if len(items) == 1: return items[0] result = [ 'AND' ] for item in items: if item[0] == 'AND': result.extend(item[1:]) else: result.append(item) return result def sqlor(items): if not items: return [] if len(items) == 1: return items[0] result = [ 'OR' ] for item in items: if item[0] == 'OR': result.extend(item[1:]) else: result.append(item) return result def join_tables(alias1, alias2, columns1, columns2): assert len(columns1) == len(columns2) return sqland([ [ 'EQ', [ 'COLUMN', alias1, c1 ], [ 'COLUMN', alias2, c2 ] ] for c1, c2 in izip(columns1, columns2) ]) def type2str(t): if type(t) is tuple: return 'list' if type(t) is SetType: return 'Set of ' + type2str(t.item_type) try: return t.__name__ except: return str(t) class Local(localbase): def __init__(local): local.translators = [] @property def translator(self): return local.translators[-1] local = Local() class SQLTranslator(ASTTranslator): dialect = None row_value_syntax = True json_path_wildcard_syntax = False json_values_are_comparable = True rowid_support = False def __enter__(translator): local.translators.append(translator) def __exit__(translator, exc_type, exc_val, exc_tb): t = local.translators.pop() if isinstance(exc_val, UseAnotherTranslator): assert t is exc_val.translator else: assert t is translator def default_post(translator, node): throw(NotImplementedError) # pragma: no cover def dispatch(translator, node): if hasattr(node, 'monad'): return # monad already assigned somehow if not getattr(node, 'external', False) or getattr(node, 'constant', False): return ASTTranslator.dispatch(translator, node) # default route translator.call(translator.__class__.dispatch_external, node) def dispatch_external(translator, node): varkey = translator.filter_num, node.src, translator.code_key t = translator.root_translator.vartypes[varkey] tt = type(t) if t is NoneType: monad = ConstMonad.new(None) elif tt is SetType: if isinstance(t.item_type, EntityMeta): monad = EntityMonad(t.item_type) else: throw(NotImplementedError) # pragma: no cover elif tt is QueryType: prev_translator = deepcopy(t.translator) prev_translator.parent = translator prev_translator.injected = True if translator.database is not prev_translator.database: throw(TranslationError, 'Mixing queries from different databases') monad = QuerySetMonad(prev_translator) if t.limit is not None or t.offset is not None: monad = monad.call_limit(t.limit, t.offset) elif tt is FuncType: func = t.func func_monad_class = translator.registered_functions.get(func) if func_monad_class is not None: monad = func_monad_class(func) else: monad = HybridFuncMonad(t, func.__name__) elif tt is MethodType: obj, func = t.obj, t.func if isinstance(obj, EntityMeta): entity_monad = EntityMonad(obj) if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) monad = MethodMonad(entity_monad, func.__name__) elif node.src == 'random': # For PyPy monad = FuncRandomMonad(t) else: throw(NotImplementedError) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False monad = ConstMonad.new(value) elif tt is tuple: params = [] is_array = False if t and translator.database.provider.array_converter_cls is not None: types = set(t) if len(types) == 1 and unicode in types: item_type = unicode is_array = True else: item_type = int for type_ in types: if type_ is float: item_type = float if type_ not in (float, int) or not hasattr(type_, '__index__'): break else: is_array = True for i, item_type in enumerate(t): if item_type is NoneType: throw(TypeError, 'Expression `%s` should not contain None values' % node.src) param = ParamMonad.new(item_type, (varkey, i, None)) params.append(param) monad = ListMonad(params) if is_array: array_type = array_types.get(item_type, None) monad = ArrayParamMonad(array_type, (varkey, None, None), list_monad=monad) elif isinstance(t, RawSQLType): monad = RawSQLMonad(t, varkey) else: monad = ParamMonad.new(t, (varkey, None, None)) node.monad = monad monad.node = node monad.aggregated = monad.nogroup = False def call(translator, method, node): try: monad = method(translator, node) except Exception: exc_class, exc, tb = sys.exc_info() try: if not exc.args: exc.args = (ast2src(node),) else: msg = exc.args[0] if isinstance(msg, basestring) and '{EXPR}' in msg: msg = msg.replace('{EXPR}', ast2src(node)) exc.args = (msg,) + exc.args[1:] reraise(exc_class, exc, tb) finally: del exc, tb else: if monad is None: return node.monad = monad monad.node = node if not hasattr(monad, 'aggregated'): for child in node.getChildNodes(): m = getattr(child, 'monad', None) if m and getattr(m, 'aggregated', False): monad.aggregated = True break else: monad.aggregated = False if not hasattr(monad, 'nogroup'): for child in node.getChildNodes(): m = getattr(child, 'monad', None) if m and getattr(m, 'nogroup', False): monad.nogroup = True break else: monad.nogroup = False if monad.aggregated: translator.aggregated = True if monad.nogroup: if isinstance(monad, ListMonad): pass elif isinstance(monad, AndMonad): pass else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad def __init__(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): local.translators.append(translator) try: translator.init(tree, parent_translator, code_key, filter_num, extractors, vars, vartypes, left_join, optimize) except UseAnotherTranslator as e: assert local.translators t = local.translators.pop() assert t is e.translator raise else: assert local.translators t = local.translators.pop() assert t is translator def init(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): this = translator assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) translator.can_be_cached = True translator.parent = parent_translator translator.injected = False if parent_translator is None: translator.root_translator = translator translator.database = None translator.sqlquery = SqlQuery(translator, left_join=left_join) assert code_key is not None and filter_num is not None translator.code_key = translator.original_code_key = code_key translator.filter_num = translator.original_filter_num = filter_num else: translator.root_translator = parent_translator.root_translator translator.database = parent_translator.database translator.sqlquery = SqlQuery(translator, parent_translator.sqlquery, left_join=left_join) assert code_key is None and filter_num is None translator.code_key = parent_translator.code_key translator.filter_num = parent_translator.filter_num translator.original_code_key = translator.original_filter_num = None translator.extractors = extractors translator.vars = vars translator.vartypes = vartypes translator.namespace_stack = [{}] if not parent_translator else [ parent_translator.namespace.copy() ] translator.func_extractors_map = {} translator.getattr_values = {} translator.func_vartypes = {} translator.left_join = left_join translator.optimize = optimize translator.from_optimized = False translator.optimization_failed = False translator.distinct = False translator.conditions = translator.sqlquery.conditions translator.having_conditions = [] translator.order = [] translator.limit = translator.offset = None translator.inside_order_by = False translator.aggregated = False if not optimize else True translator.hint_join = False translator.query_result_is_cacheable = True translator.aggregated_subquery_paths = set() for i, qual in enumerate(tree.quals): assign = qual.assign if isinstance(assign, ast.AssTuple): ass_names = tuple(assign.nodes) elif isinstance(assign, ast.AssName): ass_names = (assign,) else: throw(NotImplementedError, ast2src(assign)) for ass_name in ass_names: if not isinstance(ass_name, ast.AssName): throw(NotImplementedError, ast2src(ass_name)) if ass_name.flags != 'OP_ASSIGN': throw(TypeError, ast2src(ass_name)) names = tuple(ass_name.name for ass_name in ass_names) for name in names: if name in translator.namespace and name in translator.sqlquery.tablerefs: throw(TranslationError, 'Duplicate name: %r' % name) if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) name = names[0] if len(names) == 1 else None def check_name_is_single(): if len(names) > 1: throw(TypeError, 'Single variable name expected. Got: %s' % ast2src(assign)) database = entity = None node = qual.iter monad = getattr(node, 'monad', None) if monad: # Lambda was encountered inside generator check_name_is_single() assert parent_translator and i == 0 entity = monad.type.item_type if isinstance(monad, EntityMonad): tableref = TableRef(translator.sqlquery, name, entity) translator.sqlquery.tablerefs[name] = tableref elif isinstance(monad, AttrSetMonad): translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) tableref = monad.tableref else: assert False # pragma: no cover translator.namespace[name] = ObjectIterMonad(tableref, entity) elif node.external: varkey = translator.filter_num, node.src, translator.code_key iterable = translator.root_translator.vartypes[varkey] if isinstance(iterable, SetType): check_name_is_single() entity = iterable.item_type if not isinstance(entity, EntityMeta): throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) if i > 0: if translator.left_join: throw(TranslationError, 'Collection expected inside left join query. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) translator.distinct = True tableref = TableRef(translator.sqlquery, name, entity) translator.sqlquery.tablerefs[name] = tableref tableref.make_join() translator.namespace[name] = node.monad = ObjectIterMonad(tableref, entity) elif isinstance(iterable, QueryType): prev_translator = deepcopy(iterable.translator) prev_limit = iterable.limit prev_offset = iterable.offset database = prev_translator.database try: translator.process_query_qual(prev_translator, prev_limit, prev_offset, names, try_extend_prev_query=not i) except UseAnotherTranslator as e: assert local.translators and local.translators[-1] is translator translator = e.translator local.translators[-1] = translator else: throw(TranslationError, 'Inside declarative query, iterator must be entity or query. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) else: translator.dispatch(node) monad = node.monad if isinstance(monad, QuerySetMonad): subtranslator = monad.subtranslator database = subtranslator.database try: translator.process_query_qual(subtranslator, monad.limit, monad.offset, names) except UseAnotherTranslator: assert False else: check_name_is_single() attr_names = [] while isinstance(monad, (AttrMonad, AttrSetMonad)) and monad.parent is not None: attr_names.append(monad.attr.name) monad = monad.parent attr_names.reverse() if not isinstance(monad, ObjectIterMonad): throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) name_path = monad.tableref.alias # or name_path, it is the same parent_tableref = monad.tableref parent_entity = parent_tableref.entity last_index = len(attr_names) - 1 for j, attrname in enumerate(attr_names): attr = parent_entity._adict_.get(attrname) if attr is None: throw(AttributeError, attrname) entity = attr.py_type if not isinstance(entity, EntityMeta): throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) can_affect_distinct = None if attr.is_collection: if not isinstance(attr, Set): throw(NotImplementedError, ast2src(qual.iter)) reverse = attr.reverse if reverse.is_collection: if not isinstance(reverse, Set): throw(NotImplementedError, ast2src(qual.iter)) translator.distinct = True elif parent_tableref.alias != tree.quals[i-1].assign.name: translator.distinct = True else: can_affect_distinct = True if j == last_index: name_path = name else: name_path += '-' + attr.name tableref = translator.sqlquery.add_tableref(name_path, parent_tableref, attr) tableref.make_join(pk_only=True) if j == last_index: translator.namespace[name] = ObjectIterMonad(tableref, tableref.entity) if can_affect_distinct is not None: tableref.can_affect_distinct = can_affect_distinct parent_tableref = tableref parent_entity = entity if database is None: assert entity is not None database = entity._database_ assert database.schema is not None if translator.database is None: translator.database = database elif translator.database is not database: throw(TranslationError, 'All entities in a query must belong to the same database') for if_ in qual.ifs: assert isinstance(if_, ast.GenExprIf) translator.dispatch(if_) if isinstance(if_.monad, AndMonad): cond_monads = if_.monad.operands else: cond_monads = [ if_.monad ] for m in cond_monads: if not getattr(m, 'aggregated', False): translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) translator.dispatch(tree.expr) assert not translator.hint_join monad = tree.expr.monad if isinstance(monad, ParamMonad): throw(TranslationError, "External parameter '%s' cannot be used as query result" % ast2src(tree.expr)) translator.expr_monads = monad.items if isinstance(monad, ListMonad) else [ monad ] translator.groupby_monads = None expr_type = monad.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): entity = expr_type translator.expr_type = entity monad.orderby_columns = list(xrange(1, len(entity._pk_columns_)+1)) if monad.aggregated: throw(TranslationError) if isinstance(monad, QuerySetMonad): throw(NotImplementedError) elif isinstance(monad, ObjectMixin): tableref = monad.tableref elif isinstance(monad, AttrSetMonad): tableref = monad.make_tableref(translator.sqlquery) else: assert False # pragma: no cover if translator.aggregated: translator.groupby_monads = [ monad ] else: translator.distinct |= monad.requires_distinct() translator.tableref = tableref pk_only = parent_translator is not None or translator.aggregated alias, pk_columns = tableref.make_join(pk_only=pk_only) translator.alias = alias translator.expr_columns = [ [ 'COLUMN', alias, column ] for column in pk_columns ] translator.row_layout = None translator.col_names = [ attr.name for attr in entity._attrs_ if not attr.is_collection and not attr.lazy ] else: translator.alias = None expr_monads = translator.expr_monads if len(expr_monads) > 1: translator.expr_type = tuple(m.type for m in expr_monads) # ????? expr_columns = [] for m in expr_monads: expr_columns.extend(m.getsql()) translator.expr_columns = expr_columns else: translator.expr_type = monad.type translator.expr_columns = monad.getsql() if translator.aggregated: translator.groupby_monads = [ m for m in expr_monads if not m.aggregated and not m.nogroup ] else: expr_set = set() for m in expr_monads: if isinstance(m, ObjectIterMonad): expr_set.add(m.tableref.name_path) elif isinstance(m, AttrMonad) and isinstance(m.parent, ObjectIterMonad): expr_set.add((m.parent.tableref.name_path, m.attr)) for tr in translator.sqlquery.tablerefs.values(): if tr.entity is None: continue if not tr.can_affect_distinct: continue if tr.name_path in expr_set: continue if any((tr.name_path, attr) not in expr_set for attr in tr.entity._pk_attrs_): translator.distinct = True break row_layout = [] offset = 0 provider = translator.database.provider for m in expr_monads: if m.disable_distinct: translator.distinct = False expr_type = m.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): next_offset = offset + len(expr_type._pk_columns_) def func(values, constructor=expr_type._get_by_raw_pkval_): if None in values: return None return constructor(values) row_layout.append((func, slice(offset, next_offset), ast2src(m.node))) m.orderby_columns = list(xrange(offset+1, next_offset+1)) offset = next_offset else: converter = provider.get_converter_by_py_type(expr_type) def func(value, converter=converter): if value is None: return None value = converter.sql2py(value) value = converter.dbval2val(value) return value row_layout.append((func, offset, ast2src(m.node))) m.orderby_columns = (offset+1,) if not m.disable_ordering else () offset += 1 translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] if translator.aggregated: translator.distinct = False translator.vars = None if translator is not this: raise UseAnotherTranslator(translator) @property def namespace(translator): return translator.namespace_stack[-1] def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False aggr_path = next(iter(translator.aggregated_subquery_paths)) for tableref in translator.sqlquery.tablerefs.values(): if tableref.joined and not aggr_path.startswith(tableref.name_path): return False return aggr_path def process_query_qual(translator, prev_translator, prev_limit, prev_offset, names, try_extend_prev_query=False): sqlquery = translator.sqlquery tablerefs = sqlquery.tablerefs expr_types = prev_translator.expr_type if not isinstance(expr_types, tuple): expr_types = (expr_types,) expr_count = len(expr_types) if expr_count > 1 and len(names) == 1: throw(NotImplementedError, 'Please unpack a tuple of (%s) in for-loop to individual variables (like: "for x, y in ...")' % (', '.join(ast2src(m.node) for m in prev_translator.expr_monads))) elif expr_count > len(names): throw(TranslationError, 'Not enough values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' % (', '.join(names), ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), len(names), expr_count)) elif expr_count < len(names): throw(TranslationError, 'Too many values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' % (', '.join(names), ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), len(names), expr_count)) if try_extend_prev_query: if prev_translator.aggregated: pass elif prev_translator.left_join: pass else: assert translator.parent is None assert prev_translator.vars is None prev_translator.code_key = translator.code_key prev_translator.filter_num = translator.filter_num prev_translator.extractors.update(translator.extractors) prev_translator.vars = translator.vars prev_translator.vartypes.update(translator.vartypes) prev_translator.left_join = translator.left_join prev_translator.optimize = translator.optimize prev_translator.namespace_stack = [ {name: expr for name, expr in izip(names, prev_translator.expr_monads)} ] prev_translator.limit, prev_translator.offset = combine_limit_and_offset( prev_translator.limit, prev_translator.offset, prev_limit, prev_offset) raise UseAnotherTranslator(prev_translator) if len(names) == 1 and isinstance(prev_translator.expr_type, EntityMeta) \ and not prev_translator.aggregated and not prev_translator.distinct: name = names[0] entity = prev_translator.expr_type [expr_monad] = prev_translator.expr_monads entity_alias = expr_monad.tableref.alias subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, star=entity_alias) tableref = StarTableRef(sqlquery, name, entity, subquery_ast) tablerefs[name] = tableref tableref.make_join() translator.namespace[name] = ObjectIterMonad(tableref, entity) else: aliases = [] aliases_dict = {} for name, base_expr_monad in izip(names, prev_translator.expr_monads): t = base_expr_monad.type if isinstance(t, EntityMeta): t_aliases = [] for suffix in t._pk_paths_: alias = '%s-%s' % (name, suffix) t_aliases.append(alias) aliases.extend(t_aliases) aliases_dict[base_expr_monad] = t_aliases else: aliases.append(name) aliases_dict[base_expr_monad] = name subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, aliases=aliases) tableref = ExprTableRef(sqlquery, 't', subquery_ast, names, aliases) for name in names: tablerefs[name] = tableref tableref.make_join() for name, base_expr_monad in izip(names, prev_translator.expr_monads): t = base_expr_monad.type if isinstance(t, EntityMeta): columns = aliases_dict[base_expr_monad] expr_tableref = ExprJoinedTableRef(sqlquery, tableref, columns, name, t) expr_monad = ObjectIterMonad(expr_tableref, t) else: column = aliases_dict[base_expr_monad] expr_ast = ['COLUMN', tableref.alias, column] expr_monad = ExprMonad.new(t, expr_ast, base_expr_monad.nullable) assert name not in translator.namespace translator.namespace[name] = expr_monad def construct_subquery_ast(translator, limit=None, offset=None, aliases=None, star=None, distinct=None, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast( limit, offset, distinct, is_not_null_checks=is_not_null_checks) assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' select_ast = subquery_ast[1][:] assert select_ast[0] in ('ALL', 'DISTINCT', 'AGGREGATES'), select_ast if aliases: assert not star and len(aliases) == len(select_ast) - 1 for i, alias in enumerate(aliases): expr = select_ast[i+1] if expr[0] == 'AS': expr = expr[1] select_ast[i+1] = [ 'AS', expr, alias ] elif star is not None: assert isinstance(star, basestring) for section in subquery_ast: assert section[0] not in ('GROUP_BY', 'HAVING'), subquery_ast select_ast[1:] = [ [ 'STAR', star ] ] from_ast = subquery_ast[2][:] assert from_ast[0] in ('FROM', 'LEFT_JOIN') if len(subquery_ast) == 3: where_ast = [ 'WHERE' ] other_ast = [] elif subquery_ast[3][0] != 'WHERE': where_ast = [ 'WHERE' ] other_ast = subquery_ast[3:] else: where_ast = subquery_ast[3][:] other_ast = subquery_ast[4:] if len(from_ast[1]) == 4: outer_conditions = from_ast[1][-1] from_ast[1] = from_ast[1][:-1] if outer_conditions[0] == 'AND': where_ast[1:1] = outer_conditions[1:] else: where_ast.insert(1, outer_conditions) return [ 'SELECT', select_ast, from_ast, where_ast ] + other_ast def construct_sql_ast(translator, limit=None, offset=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, for_update=False, nowait=False, skip_locked=False, is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct ast_transformer = lambda ast: ast if for_update: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked ] translator.query_result_is_cacheable = False else: sql_ast = [ 'SELECT' ] select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + translator.expr_columns if aggr_func_name: expr_type = translator.expr_type if isinstance(expr_type, EntityMeta): if aggr_func_name == 'GROUP_CONCAT': if expr_type._pk_is_composite_: throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") elif aggr_func_name != 'COUNT': throw(TypeError, 'Attribute should be specified for %r aggregate function' % aggr_func_name.lower()) elif isinstance(expr_type, tuple): if aggr_func_name != 'COUNT': throw(TypeError, 'Single attribute should be specified for %r aggregate function' % aggr_func_name.lower()) else: if aggr_func_name in ('SUM', 'AVG') and expr_type not in numeric_types: throw(TypeError, '%r is valid for numeric attributes only' % aggr_func_name.lower()) assert len(translator.expr_columns) == 1 aggr_ast = None if translator.groupby_monads or ( aggr_func_name == 'COUNT' and distinct and isinstance(translator.expr_type, EntityMeta) and len(translator.expr_columns) > 1): outer_alias = 't' if aggr_func_name == 'COUNT' and not aggr_func_distinct: outer_aggr_ast = [ 'COUNT', None ] else: assert len(translator.expr_columns) == 1 expr_ast = translator.expr_columns[0] if expr_ast[0] == 'COLUMN': outer_alias, column_name = expr_ast[1:] outer_aggr_ast = [aggr_func_name, aggr_func_distinct, ['COLUMN', outer_alias, column_name]] if aggr_func_name == 'GROUP_CONCAT' and sep is not None: outer_aggr_ast.append(['VALUE', sep]) else: select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + [ [ 'AS', expr_ast, 'expr' ] ] outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', 't', 'expr' ] ] if aggr_func_name == 'GROUP_CONCAT' and sep is not None: outer_aggr_ast.append(['VALUE', sep]) def ast_transformer(ast): return [ 'SELECT', [ 'AGGREGATES', outer_aggr_ast ], [ 'FROM', [ outer_alias, 'SELECT', ast[1:] ] ] ] else: if aggr_func_name == 'COUNT': if isinstance(expr_type, (tuple, EntityMeta)) and not distinct and not aggr_func_distinct: aggr_ast = [ 'COUNT', aggr_func_distinct ] else: aggr_ast = [ 'COUNT', True if aggr_func_distinct is None else aggr_func_distinct, translator.expr_columns[0] ] else: aggr_ast = [ aggr_func_name, aggr_func_distinct, translator.expr_columns[0] ] if aggr_func_name == 'GROUP_CONCAT' and sep is not None: aggr_ast.append(['VALUE', sep]) if aggr_ast: select_ast = [ 'AGGREGATES', aggr_ast ] elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: select_ast, attr_offsets = translator.expr_type._construct_select_clause_( translator.alias, distinct, translator.tableref.used_attrs) sql_ast.append(select_ast) sql_ast.append(translator.sqlquery.from_ast) conditions = translator.conditions[:] having_conditions = translator.having_conditions[:] if is_not_null_checks: for monad in translator.expr_monads: if isinstance(monad, ObjectIterMonad): pass elif not monad.nullable: pass else: notnull_conditions = [ [ 'IS_NOT_NULL', column_ast ] for column_ast in monad.getsql() ] if monad.aggregated: having_conditions.extend(notnull_conditions) else: conditions.extend(notnull_conditions) if conditions: sql_ast.append([ 'WHERE' ] + conditions) if translator.groupby_monads: group_by = [ 'GROUP_BY' ] for m in translator.groupby_monads: group_by.extend(m.getsql()) sql_ast.append(group_by) else: group_by = None if having_conditions: if not group_by: throw(TranslationError, 'In order to use aggregated functions such as SUM(), COUNT(), etc., ' 'query must have grouping columns (i.e. resulting non-aggregated values)') sql_ast.append([ 'HAVING' ] + having_conditions) if translator.order and not aggr_func_name: sql_ast.append([ 'ORDER_BY' ] + translator.order) limit, offset = combine_limit_and_offset(translator.limit, translator.offset, limit, offset) if limit is not None or offset is not None: assert not aggr_func_name provider = translator.database.provider if limit is None: if provider.dialect == 'SQLite': limit = -1 elif provider.dialect == 'MySQL': limit = 18446744073709551615 limit_section = [ 'LIMIT', limit ] if offset: limit_section.append(offset) sql_ast.append(limit_section) sql_ast = ast_transformer(sql_ast) return sql_ast, attr_offsets def construct_delete_sql_ast(translator): entity = translator.expr_type expr_monad = translator.tree.expr.monad if not isinstance(entity, EntityMeta): throw(TranslationError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(translator.tree.expr)) force_in = False if translator.groupby_monads: force_in = True else: assert not translator.having_conditions tableref = expr_monad.tableref from_ast = translator.sqlquery.from_ast if from_ast[0] != 'FROM': force_in = True if not force_in and len(from_ast) == 2 and not translator.sqlquery.used_from_subquery: sql_ast = [ 'DELETE', None, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) elif not force_in and translator.dialect == 'MySQL': sql_ast = [ 'DELETE', tableref.alias, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) else: delete_from_ast = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] if len(entity._pk_columns_) == 1: inner_expr = expr_monad.getsql() outer_expr = [ 'COLUMN', None, entity._pk_columns_[0] ] elif translator.rowid_support: inner_expr = [ [ 'COLUMN', tableref.alias, 'ROWID' ] ] outer_expr = [ 'COLUMN', None, 'ROWID' ] elif translator.row_value_syntax: inner_expr = expr_monad.getsql() outer_expr = [ 'ROW' ] + [ [ 'COLUMN', None, column_name ] for column_name in entity._pk_columns_ ] else: throw(NotImplementedError) subquery_ast = [ 'SELECT', [ 'ALL' ] + inner_expr, from_ast ] if translator.conditions: subquery_ast.append([ 'WHERE' ] + translator.conditions) delete_where_ast = [ 'WHERE', [ 'IN', outer_expr, subquery_ast ] ] sql_ast = [ 'DELETE', None, delete_from_ast, delete_where_ast ] return sql_ast def get_used_attrs(translator): if isinstance(translator.expr_type, EntityMeta) and not translator.aggregated and not translator.optimize: return translator.tableref.used_attrs return () def without_order(translator): translator = deepcopy(translator) translator.order = [] return translator def order_by_numbers(translator, numbers): if 0 in numbers: throw(ValueError, 'Numeric arguments of order_by() method must be non-zero') translator = deepcopy(translator) order = translator.order = translator.order[:] # only order will be changed expr_monads = translator.expr_monads new_order = [] for i in numbers: try: monad = expr_monads[abs(i)-1] except IndexError: if len(expr_monads) > 1: throw(IndexError, "Invalid index of order_by() method: %d " "(query result is list of tuples with only %d elements in each)" % (i, len(expr_monads))) else: throw(IndexError, "Invalid index of order_by() method: %d " "(query result is single list of elements and has only one 'column')" % i) for pos in monad.orderby_columns: new_order.append(i < 0 and [ 'DESC', [ 'VALUE', pos ] ] or [ 'VALUE', pos ]) order[:0] = new_order return translator def order_by_attributes(translator, attrs): entity = translator.expr_type if not isinstance(entity, EntityMeta): throw(NotImplementedError, 'Ordering by attributes is limited to queries which return simple list of objects. ' 'Try use other forms of ordering (by tuple element numbers or by full-blown lambda expr).') translator = deepcopy(translator) order = translator.order = translator.order[:] # only order will be changed alias = translator.alias new_order = [] for x in attrs: if isinstance(x, DescWrapper): attr = x.attr desc_wrapper = lambda column: [ 'DESC', column ] elif isinstance(x, Attribute): attr = x desc_wrapper = lambda column: column else: assert False, x # pragma: no cover if entity._adict_.get(attr.name) is not attr: throw(TypeError, 'Attribute %s does not belong to entity %s' % (attr, entity.__name__)) if attr.is_collection: throw(TypeError, 'Collection attribute %s cannot be used for ordering' % attr) for column in attr.columns: new_order.append(desc_wrapper([ 'COLUMN', alias, column])) order[:0] = new_order return translator def apply_kwfilters(translator, filterattrs, original_names=False): translator = deepcopy(translator) with translator: if original_names: object_monad = translator.tree.quals[0].iter.monad assert isinstance(object_monad.type, EntityMeta) else: object_monad = translator.tree.expr.monad if not isinstance(object_monad.type, EntityMeta): throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') monads = [] none_monad = NoneMonad() for attr, id, is_none in filterattrs: attr_monad = object_monad.getattr(attr.name) if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) else: param_monad = ParamMonad.new(attr.py_type, (id, None, None)) monads.append(CmpMonad('==', attr_monad, param_monad)) for m in monads: translator.conditions.extend(m.getsql()) return translator def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): translator = deepcopy(translator) func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) translator.code_key = func_id translator.filter_num = filter_num translator.extractors.update(extractors) translator.vars = vars translator.vartypes = translator.vartypes.copy() # make HashableDict mutable again translator.vartypes.update(vartypes) if not original_names: assert argnames namespace = {name: monad for name, monad in izip(argnames, translator.expr_monads)} elif argnames: namespace = {name: translator.namespace[name] for name in argnames} else: namespace = None if namespace is not None: translator.namespace_stack.append(namespace) with translator: try: translator.dispatch(func_ast) if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes else: nodes = (func_ast,) if order_by: translator.inside_order_by = True new_order = [] for node in nodes: if isinstance(node.monad, SetMixin): t = node.monad.type.item_type if isinstance(type(t), type): t = t.__name__ throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' % (t, ast2src(node))) new_order.extend(node.monad.getsql()) translator.order[:0] = new_order translator.inside_order_by = False else: for node in nodes: monad = node.monad if isinstance(monad, AndMonad): cond_monads = monad.operands else: cond_monads = [ monad ] for m in cond_monads: if not m.aggregated: translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) translator.vars = None return translator finally: if namespace is not None: ns = translator.namespace_stack.pop() assert ns is namespace def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ try: subtranslator = translator_cls(inner_tree, translator) except UseAnotherTranslator: assert False return QuerySetMonad(subtranslator) def postGenExprIf(translator, node): monad = node.test.monad if monad.type is not bool: monad = monad.nonzero() return monad def preCompare(translator, node): monads = [] ops = node.ops left = node.expr translator.dispatch(left) # op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '==' # | 'in' | 'not in' | 'is' | 'is not' for op, right in node.ops: translator.dispatch(right) if op.endswith('in'): monad = right.monad.contains(left.monad, op == 'not in') else: monad = left.monad.cmp(op, right.monad) if not hasattr(monad, 'aggregated'): monad.aggregated = getattr(left.monad, 'aggregated', False) or getattr(right.monad, 'aggregated', False) if not hasattr(monad, 'nogroup'): monad.nogroup = getattr(left.monad, 'nogroup', False) or getattr(right.monad, 'nogroup', False) if monad.aggregated and monad.nogroup: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: {EXPR}') monads.append(monad) left = right if len(monads) == 1: return monads[0] return AndMonad(monads) def postConst(translator, node): value = node.value if type(value) is frozenset: value = tuple(sorted(value)) if type(value) is not tuple: return ConstMonad.new(value) else: return ListMonad([ ConstMonad.new(item) for item in value ]) def postEllipsis(translator, node): return ConstMonad.new(Ellipsis) def postList(translator, node): return ListMonad([ item.monad for item in node.nodes ]) def postTuple(translator, node): return ListMonad([ item.monad for item in node.nodes ]) def postName(translator, node): monad = translator.resolve_name(node.name) assert monad is not None return monad def resolve_name(translator, name): if name not in translator.namespace: throw(TranslationError, 'Name %s is not found in %s' % (name, translator.namespace)) monad = translator.namespace[name] assert isinstance(monad, Monad) if monad.translator is not translator: monad.translator.sqlquery.used_from_subquery = True return monad def postAdd(translator, node): return node.left.monad + node.right.monad def postSub(translator, node): return node.left.monad - node.right.monad def postMul(translator, node): return node.left.monad * node.right.monad def postDiv(translator, node): return node.left.monad / node.right.monad def postFloorDiv(translator, node): return node.left.monad // node.right.monad def postMod(translator, node): return node.left.monad % node.right.monad def postPower(translator, node): return node.left.monad ** node.right.monad def postUnarySub(translator, node): return -node.expr.monad def postGetattr(translator, node): return node.expr.monad.getattr(node.attrname) def postAnd(translator, node): return AndMonad([ subnode.monad for subnode in node.nodes ]) def postOr(translator, node): return OrMonad([ subnode.monad for subnode in node.nodes ]) def postBitor(translator, node): left, right = (subnode.monad for subnode in node.nodes) return left | right def postBitand(translator, node): left, right = (subnode.monad for subnode in node.nodes) return left & right def postBitxor(translator, node): left, right = (subnode.monad for subnode in node.nodes) return left ^ right def postNot(translator, node): return node.expr.monad.negate() def preCallFunc(translator, node): if node.star_args is not None: throw(NotImplementedError, '*%s is not supported' % ast2src(node.star_args)) if node.dstar_args is not None: throw(NotImplementedError, '**%s is not supported' % ast2src(node.dstar_args)) func_node = node.node if isinstance(func_node, ast.CallFunc): if isinstance(func_node.node, ast.Name) and func_node.node.name == 'getattr': return if not isinstance(func_node, (ast.Name, ast.Getattr)): throw(NotImplementedError) if len(node.args) > 1: return if not node.args: return arg = node.args[0] if isinstance(arg, ast.GenExpr): translator.dispatch(func_node) func_monad = func_node.monad translator.dispatch(arg) query_set_monad = arg.monad return func_monad(query_set_monad) if not isinstance(arg, ast.Lambda): return lambda_expr = arg translator.dispatch(func_node) method_monad = func_node.monad if not isinstance(method_monad, MethodMonad): throw(NotImplementedError) entity_monad = method_monad.parent if not isinstance(entity_monad, (EntityMonad, AttrSetMonad)): throw(NotImplementedError) entity = entity_monad.type.item_type method_name = method_monad.attrname if method_name not in ('select', 'filter', 'exists'): throw(TypeError) if len(lambda_expr.argnames) != 1: throw(TypeError) if lambda_expr.varargs: throw(TypeError) if lambda_expr.kwargs: throw(TypeError) if lambda_expr.defaults: throw(TypeError) iter_name = lambda_expr.argnames[0] cond_expr = lambda_expr.code if_expr = ast.GenExprIf(cond_expr) name_ast = ast.Name(entity.__name__) name_ast.monad = entity_monad for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ try: subtranslator = translator_cls(inner_expr, translator) except UseAnotherTranslator: assert False monad = QuerySetMonad(subtranslator) if method_name == 'exists': monad = monad.nonzero() return monad def postCallFunc(translator, node): args = [] kwargs = {} for arg in node.args: if isinstance(arg, ast.Keyword): kwargs[arg.name] = arg.expr.monad else: args.append(arg.monad) func_monad = node.node.monad return func_monad(*args, **kwargs) def postKeyword(translator, node): pass # this node will be processed by postCallFunc def postSubscript(translator, node): assert node.flags == 'OP_APPLY' assert isinstance(node.subs, list) if len(node.subs) > 1: for x in node.subs: if isinstance(x, ast.Sliceobj): throw(TypeError) key = ListMonad([ item.monad for item in node.subs ]) return node.expr.monad[key] sub = node.subs[0] if isinstance(sub, ast.Sliceobj): start, stop, step = (sub.nodes+[None])[:3] if start is not None: start = start.monad if isinstance(start, NoneMonad): start = None if stop is not None: stop = stop.monad if isinstance(stop, NoneMonad): stop = None if step is not None: step = step.monad if isinstance(step, NoneMonad): step = None return node.expr.monad[start:stop:step] else: return node.expr.monad[sub.monad] def postSlice(translator, node): assert node.flags == 'OP_APPLY' expr_monad = node.expr.monad upper = node.upper if upper is not None: upper = upper.monad if isinstance(upper, NoneMonad): upper = None lower = node.lower if lower is not None: lower = lower.monad if isinstance(lower, NoneMonad): lower = None return expr_monad[lower:upper] def postSliceobj(translator, node): pass def postIfExp(translator, node): test_monad, then_monad, else_monad = node.test.monad, node.then.monad, node.else_.monad if test_monad.type is not bool: test_monad = test_monad.nonzero() result_type = coerce_types(then_monad.type, else_monad.type) test_sql, then_sql, else_sql = test_monad.getsql()[0], then_monad.getsql(), else_monad.getsql() if len(then_sql) == 1: then_sql, else_sql = then_sql[0], else_sql[0] elif not translator.row_value_syntax: throw(NotImplementedError) else: then_sql, else_sql = [ 'ROW' ] + then_sql, [ 'ROW' ] + else_sql expr = [ 'CASE', None, [ [ test_sql, then_sql ] ], else_sql ] result = ExprMonad.new(result_type, expr, nullable=test_monad.nullable or then_monad.nullable or else_monad.nullable) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result def postStr(translator, node): val_monad = node.value.monad if isinstance(val_monad, StringMixin): return val_monad sql = ['TO_STR', val_monad.getsql()[0] ] return StringExprMonad(unicode, sql, nullable=val_monad.nullable) def postJoinedStr(translator, node): nullable = False for subnode in node.values: assert isinstance(subnode.monad, StringMixin), (subnode.monad, subnode) if subnode.monad.nullable: nullable = True sql = [ 'CONCAT' ] + [ value.monad.getsql()[0] for value in node.values ] return StringExprMonad(unicode, sql, nullable=nullable) def postFormattedValue(translator, node): throw(NotImplementedError, 'You cannot set width and precision markers in query') def combine_limit_and_offset(limit, offset, limit2, offset2): assert limit is None or limit >= 0 assert limit2 is None or limit2 >= 0 if offset2 is not None: if limit is not None: limit = max(0, limit - offset2) offset = (offset or 0) + offset2 if limit2 is not None: if limit is not None: limit = min(limit, limit2) else: limit = limit2 if limit == 0: offset = None return limit, offset def coerce_monads(m1, m2, for_comparison=False): result_type = coerce_types(m1.type, m2.type) if result_type in numeric_types and bool in (m1.type, m2.type) and ( result_type is not bool or not for_comparison): translator = m1.translator if translator.dialect == 'PostgreSQL': if result_type is bool: result_type = int if m1.type is bool: new_m1 = NumericExprMonad(int, [ 'TO_INT', m1.getsql()[0] ], nullable=m1.nullable) new_m1.aggregated = m1.aggregated m1 = new_m1 if m2.type is bool: new_m2 = NumericExprMonad(int, [ 'TO_INT', m2.getsql()[0] ], nullable=m2.nullable) new_m2.aggregated = m2.aggregated m2 = new_m2 return result_type, m1, m2 max_alias_length = 30 class SqlQuery(object): def __init__(sqlquery, translator, parent_sqlquery=None, left_join=False): sqlquery.translator = translator sqlquery.parent_sqlquery = parent_sqlquery sqlquery.left_join = left_join sqlquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] sqlquery.conditions = [] sqlquery.outer_conditions = [] sqlquery.tablerefs = {} if parent_sqlquery is None: sqlquery.alias_counters = {} sqlquery.expr_counter = itertools.count(1) else: sqlquery.alias_counters = parent_sqlquery.alias_counters.copy() sqlquery.expr_counter = parent_sqlquery.expr_counter sqlquery.used_from_subquery = False def get_tableref(sqlquery, name_path): tableref = sqlquery.tablerefs.get(name_path) parent_sqlquery = sqlquery.parent_sqlquery if tableref is None and parent_sqlquery: tableref = parent_sqlquery.get_tableref(name_path) if tableref is not None: parent_sqlquery.used_from_subquery = True return tableref def add_tableref(sqlquery, name_path, parent_tableref, attr): assert name_path not in sqlquery.tablerefs if parent_tableref.sqlquery is not sqlquery: parent_tableref.sqlquery.used_from_subquery = True tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) sqlquery.tablerefs[name_path] = tableref return tableref def make_alias(sqlquery, name): name = name[:max_alias_length-3].lower() i = sqlquery.alias_counters.setdefault(name, 0) + 1 alias = name if i == 1 and name != 't' else '%s-%d' % (name, i) sqlquery.alias_counters[name] = i return alias def join_table(sqlquery, parent_alias, alias, table_name, join_cond): new_item = [alias, 'TABLE', table_name, join_cond] from_ast = sqlquery.from_ast for i in xrange(1, len(from_ast)): if from_ast[i][0] == parent_alias: for j in xrange(i+1, len(from_ast)): if len(from_ast[j]) < 4: # item without join condition from_ast.insert(j, new_item) return from_ast.append(new_item) class TableRef(object): def __init__(tableref, sqlquery, name, entity): tableref.sqlquery = sqlquery tableref.alias = sqlquery.make_alias(name) tableref.name_path = tableref.alias tableref.entity = entity tableref.joined = False tableref.can_affect_distinct = True tableref.used_attrs = set() def make_join(tableref, pk_only=False): entity = tableref.entity if not tableref.joined: sqlquery = tableref.sqlquery sqlquery.from_ast.append([ tableref.alias, 'TABLE', entity._table_ ]) if entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) assert discr_criteria is not None sqlquery.conditions.append(discr_criteria) tableref.joined = True return tableref.alias, entity._pk_columns_ class ExprTableRef(TableRef): def __init__(tableref, sqlquery, name, subquery_ast, expr_names, expr_aliases): TableRef.__init__(tableref, sqlquery, name, None) tableref.subquery_ast = subquery_ast tableref.expr_names = expr_names tableref.expr_aliases = expr_aliases def make_join(tableref, pk_only=False): assert tableref.subquery_ast[0] == 'SELECT' if not tableref.joined: sqlquery = tableref.sqlquery sqlquery.from_ast.append([tableref.alias, 'SELECT', tableref.subquery_ast[1:]]) tableref.joined = True return tableref.alias, None class StarTableRef(TableRef): def __init__(tableref, sqlquery, name, entity, subquery_ast): TableRef.__init__(tableref, sqlquery, name, entity) tableref.subquery_ast = subquery_ast def make_join(tableref, pk_only=False): entity = tableref.entity assert tableref.subquery_ast[0] == 'SELECT' if not tableref.joined: sqlquery = tableref.sqlquery sqlquery.from_ast.append([ tableref.alias, 'SELECT', tableref.subquery_ast[1:] ]) if entity._discriminator_attr_: # ??? discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) assert discr_criteria is not None sqlquery.conditions.append(discr_criteria) tableref.joined = True return tableref.alias, entity._pk_columns_ class ExprJoinedTableRef(object): def __init__(tableref, sqlquery, parent_tableref, parent_columns, name, entity): tableref.sqlquery = sqlquery tableref.parent_tableref = parent_tableref tableref.parent_columns = parent_columns tableref.name = tableref.name_path = name tableref.entity = entity tableref.alias = None tableref.joined = False tableref.can_affect_distinct = False tableref.used_attrs = set() def make_join(tableref, pk_only=False): entity = tableref.entity if tableref.joined: return tableref.alias, tableref.pk_columns sqlquery = tableref.sqlquery parent_alias, left_pk_columns = tableref.parent_tableref.make_join() if pk_only: tableref.alias = parent_alias tableref.pk_columns = tableref.parent_columns return tableref.alias, tableref.pk_columns tableref.alias = sqlquery.make_alias(tableref.name) tableref.pk_columns = entity._pk_columns_ join_cond = join_tables(parent_alias, tableref.alias, tableref.parent_columns, tableref.pk_columns) sqlquery.join_table(parent_alias, tableref.alias, entity._table_, join_cond) tableref.joined = True return tableref.alias, tableref.pk_columns class JoinedTableRef(object): def __init__(tableref, sqlquery, name_path, parent_tableref, attr): tableref.sqlquery = sqlquery tableref.name_path = name_path tableref.var_name = name_path if is_ident(name_path) else None tableref.alias = None tableref.optimized = None tableref.parent_tableref = parent_tableref tableref.attr = attr tableref.entity = attr.py_type assert isinstance(tableref.entity, EntityMeta) tableref.joined = False tableref.can_affect_distinct = False tableref.used_attrs = set() def make_join(tableref, pk_only=False): entity = tableref.entity if tableref.joined: if pk_only or not tableref.optimized: return tableref.alias, tableref.pk_columns sqlquery = tableref.sqlquery attr = tableref.attr parent_pk_only = attr.pk_offset is not None or attr.is_collection parent_alias, left_pk_columns = tableref.parent_tableref.make_join(parent_pk_only) left_entity = attr.entity pk_columns = entity._pk_columns_ if not attr.is_collection: if not attr.columns: # one-to-one relationship with foreign key column on the right side reverse = attr.reverse assert reverse.columns and not reverse.is_collection rentity = reverse.entity pk_columns = rentity._pk_columns_ alias = sqlquery.make_alias(tableref.var_name or rentity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, reverse.columns) else: # one-to-one or many-to-one relationship with foreign key column on the left side if attr.pk_offset is not None: offset = attr.pk_columns_offset left_columns = left_pk_columns[offset:offset+len(attr.columns)] else: left_columns = attr.columns if pk_only: tableref.alias = parent_alias tableref.pk_columns = left_columns tableref.optimized = True # tableref.joined = True return parent_alias, left_columns alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) elif not attr.reverse.is_collection: # many-to-one relationship alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, attr.reverse.columns) else: # many-to-many relationship right_m2m_columns = attr.reverse_columns if attr.symmetric else attr.columns if not tableref.joined: m2m_table = attr.table m2m_alias = sqlquery.make_alias('t') reverse_columns = attr.columns if attr.symmetric else attr.reverse.columns m2m_join_cond = join_tables(parent_alias, m2m_alias, left_pk_columns, reverse_columns) sqlquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) if pk_only: tableref.alias = m2m_alias tableref.pk_columns = right_m2m_columns tableref.optimized = True tableref.joined = True return m2m_alias, tableref.pk_columns elif tableref.optimized: assert not pk_only m2m_alias = tableref.alias alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(m2m_alias, alias, right_m2m_columns, pk_columns) if not pk_only and entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(alias) assert discr_criteria is not None join_cond.append(discr_criteria) translator = tableref.sqlquery.translator.root_translator if translator.optimize == tableref.name_path and translator.from_optimized and tableref.sqlquery is translator.sqlquery: pass else: sqlquery.join_table(parent_alias, alias, entity._table_, join_cond) tableref.alias = alias tableref.pk_columns = pk_columns tableref.optimized = False tableref.joined = True return tableref.alias, pk_columns def wrap_monad_method(cls_name, func): overrider_name = '%s_%s' % (cls_name, func.__name__) def wrapper(monad, *args, **kwargs): method = getattr(monad.translator, overrider_name, func) return method(monad, *args, **kwargs) return update_wrapper(wrapper, func) class MonadMeta(type): def __new__(meta, cls_name, bases, cls_dict): for name, func in cls_dict.items(): if not isinstance(func, types.FunctionType): continue if name in ('__new__', '__init__'): continue cls_dict[name] = wrap_monad_method(cls_name, func) return super(MonadMeta, meta).__new__(meta, cls_name, bases, cls_dict) class MonadMixin(with_metaclass(MonadMeta)): pass class Monad(with_metaclass(MonadMeta)): disable_distinct = False disable_ordering = False def __init__(monad, type, nullable=True): monad.node = None monad.translator = local.translator monad.type = type monad.nullable = nullable monad.mixin_init() def mixin_init(monad): pass def cmp(monad, op, monad2): return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) def nonzero(monad): return CmpMonad('is not', monad, NoneMonad()) def negate(monad): return NotMonad(monad) def getattr(monad, attrname): try: property_method = getattr(monad, 'attr_' + attrname) except AttributeError: if not hasattr(monad, 'call_' + attrname): throw(AttributeError, '%r object has no attribute %r: {EXPR}' % (type2str(monad.type), attrname)) return MethodMonad(monad, attrname) return property_method() def len(monad): throw(TypeError) def count(monad, distinct=None): distinct = distinct_from_monad(distinct, default=True) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr = monad.getsql() if monad.type is bool: expr = [ 'CASE', None, [ [ expr[0], [ 'VALUE', 1 ] ] ], [ 'VALUE', None ] ] distinct = None elif len(expr) == 1: expr = expr[0] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr expr = [ 'CASE', None, [ [ [ 'IS_NULL', row ], [ 'VALUE', None ] ] ], row ] # elif translator.dialect == 'PostgreSQL': # another way # alias, pk_columns = monad.tableref.make_join(pk_only=False) # expr = [ 'COLUMN', alias, 'ctid' ] elif translator.dialect in ('SQLite', 'Oracle'): alias, pk_columns = monad.tableref.make_join(pk_only=False) expr = [ 'COLUMN', alias, 'ROWID' ] # elif translator.row_value_syntax == True: # doesn't work in MySQL # expr = ['ROW'] + expr else: throw(NotImplementedError, '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) result = ExprMonad.new(int, [ 'COUNT', distinct, expr ], nullable=False) result.aggregated = True return result def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr_type = monad.type # if isinstance(expr_type, SetType): expr_type = expr_type.item_type if func_name in ('SUM', 'AVG'): if expr_type not in numeric_types: if expr_type is Json: monad = monad.to_real() else: throw(TypeError, "Function '%s' expects argument of numeric type, got %r in {EXPR}" % (func_name, type2str(expr_type))) elif func_name in ('MIN', 'MAX'): if expr_type not in comparable_types: throw(TypeError, "Function '%s' cannot be applied to type %r in {EXPR}" % (func_name, type2str(expr_type))) elif func_name == 'GROUP_CONCAT': if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover expr = monad.getsql() if len(expr) == 1: expr = expr[0] elif translator.row_value_syntax: expr = ['ROW'] + expr else: throw(NotImplementedError, '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR} ' '(you can suggest us how to write SQL for this query)' % translator.dialect) if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': result_type = unicode else: result_type = expr_type if distinct is None: distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, expr ] if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) result = ExprMonad.new(result_type, aggr_ast, nullable=True) result.aggregated = True return result def __call__(monad, *args, **kwargs): throw(TypeError) def __getitem__(monad, key): throw(TypeError) def __add__(monad, monad2): throw(TypeError) def __sub__(monad, monad2): throw(TypeError) def __mul__(monad, monad2): throw(TypeError) def __truediv__(monad, monad2): throw(TypeError) def __floordiv__(monad, monad2): throw(TypeError) def __pow__(monad, monad2): throw(TypeError) def __neg__(monad): throw(TypeError) def __or__(monad): throw(TypeError) def __and__(monad): throw(TypeError) def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad def to_int(monad): return NumericExprMonad(int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) def to_str(monad): return StringExprMonad(unicode, [ 'TO_STR', monad.getsql()[0] ], nullable=monad.nullable) def to_real(monad): return NumericExprMonad(float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) def distinct_from_monad(distinct, default=None): if distinct is None: return default if isinstance(distinct, NumericConstMonad) and isinstance(distinct.value, bool): return distinct.value throw(TypeError, '`distinct` value should be True or False. Got: %s' % ast2src(distinct.node)) class RawSQLMonad(Monad): def __init__(monad, rawtype, varkey, nullable=True): if rawtype.result_type is None: type = rawtype else: type = normalize_type(rawtype.result_type) Monad.__init__(monad, type, nullable=nullable) monad.rawtype = rawtype monad.varkey = varkey def contains(monad, item, not_in=False): translator = monad.translator expr = item.getsql() if len(expr) == 1: expr = expr[0] elif translator.row_value_syntax == True: expr = ['ROW'] + expr else: throw(TranslationError, '%s database provider does not support tuples. Got: {EXPR} ' % translator.dialect) op = 'NOT_IN' if not_in else 'IN' sql = [ op, expr, monad.getsql() ] return BoolExprMonad(sql, nullable=item.nullable) def nonzero(monad): return monad def getsql(monad, sqlquery=None): provider = monad.translator.database.provider rawtype = monad.rawtype result = [] types = enumerate(rawtype.types) for item in monad.rawtype.items: if isinstance(item, basestring): result.append(item) else: expr, code = item i, param_type = next(types) param_converter = provider.get_converter_by_py_type(param_type) result.append(['PARAM', (monad.varkey, i, None), param_converter]) return [ [ 'RAWSQL', result ] ] typeerror_re_1 = re.compile(r'\(\) takes (no|(?:exactly|at (?:least|most)))(?: (\d+))? arguments \((\d+) given\)') typeerror_re_2 = re.compile(r'\(\) takes from (\d+) to (\d+) positional arguments but (\d+) were given') def reraise_improved_typeerror(exc, func_name, orig_func_name): if not exc.args: throw(exc) msg = exc.args[0] if not msg.startswith(func_name): throw(exc) msg = msg[len(func_name):] match = typeerror_re_1.match(msg) if match: what, takes, given = match.groups() takes, given = int(takes), int(given) if takes: what = '%s %d' % (what, takes-1) plural = 's' if takes > 2 else '' new_msg = '%s() takes %s argument%s (%d given)' % (orig_func_name, what, plural, given-1) exc.args = (new_msg,) throw(exc) match = typeerror_re_2.match(msg) if match: start, end, given = match.groups() start, end, given = int(start)-1, int(end)-1, int(given)-1 if not start: plural = 's' if end > 1 else '' new_msg = '%s() takes at most %d argument%s (%d given)' % (orig_func_name, end, plural, given) else: new_msg = '%s() takes from %d to %d arguments (%d given)' % (orig_func_name, start, end, given) exc.args = (new_msg,) throw(exc) exc.args = (orig_func_name + msg,) throw(exc) def raise_forgot_parentheses(monad): assert monad.type == 'METHOD' throw(TranslationError, 'You seems to forgot parentheses after %s' % ast2src(monad.node)) class MethodMonad(Monad): def __init__(monad, parent, attrname): Monad.__init__(monad, 'METHOD', nullable=False) monad.parent = parent monad.attrname = attrname def getattr(monad, attrname): raise_forgot_parentheses(monad) def __call__(monad, *args, **kwargs): method = getattr(monad.parent, 'call_' + monad.attrname) try: return method(*args, **kwargs) except TypeError as exc: reraise_improved_typeerror(exc, method.__name__, monad.attrname) def contains(monad, item, not_in=False): raise_forgot_parentheses(monad) def nonzero(monad): raise_forgot_parentheses(monad) def negate(monad): raise_forgot_parentheses(monad) def aggregate(monad, func_name, distinct=None, sep=None): raise_forgot_parentheses(monad) def __getitem__(monad, key): raise_forgot_parentheses(monad) def __add__(monad, monad2): raise_forgot_parentheses(monad) def __sub__(monad, monad2): raise_forgot_parentheses(monad) def __mul__(monad, monad2): raise_forgot_parentheses(monad) def __truediv__(monad, monad2): raise_forgot_parentheses(monad) def __floordiv__(monad, monad2): raise_forgot_parentheses(monad) def __pow__(monad, monad2): raise_forgot_parentheses(monad) def __neg__(monad): raise_forgot_parentheses(monad) def abs(monad): raise_forgot_parentheses(monad) class EntityMonad(Monad): def __init__(monad, entity): Monad.__init__(monad, SetType(entity)) translator = monad.translator if translator.database is None: translator.database = entity._database_ elif translator.database is not entity._database_: throw(TranslationError, 'All entities in a query must belong to the same database') def __getitem__(monad, *args): throw(NotImplementedError) class ListMonad(Monad): def __init__(monad, items): Monad.__init__(monad, tuple(item.type for item in items)) monad.items = items def contains(monad, x, not_in=False): if isinstance(x.type, SetType): throw(TypeError, "Type of `%s` is '%s'. Expression `{EXPR}` is not supported" % (ast2src(x.node), type2str(x.type))) for item in monad.items: check_comparable(x, item) left_sql = x.getsql() if len(left_sql) == 1: if not_in: sql = [ 'NOT_IN', left_sql[0], [ item.getsql()[0] for item in monad.items ] ] else: sql = [ 'IN', left_sql[0], [ item.getsql()[0] for item in monad.items ] ] elif not_in: sql = sqland([ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) return BoolExprMonad(sql, nullable=x.nullable or any(item.nullable for item in monad.items)) def getsql(monad, sqlquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] class BufferMixin(MonadMixin): pass class UuidMixin(MonadMixin): pass _binop_errmsg = 'Unsupported operand types %r and %r for operation %r in expression: {EXPR}' def make_numeric_binop(op, sqlop): def numeric_binop(monad, monad2): if isinstance(monad2, (AttrSetMonad, NumericSetExprMonad)): return NumericSetExprMonad(op, sqlop, monad, monad2) if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) result_type, monad, monad2 = coerce_monads(monad, monad2) if result_type is None: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql()[0] right_sql = monad2.getsql()[0] return NumericExprMonad(result_type, [ sqlop, left_sql, right_sql ]) numeric_binop.__name__ = sqlop return numeric_binop class NumericMixin(MonadMixin): def mixin_init(monad): assert monad.type in numeric_types, monad.type __add__ = make_numeric_binop('+', 'ADD') __sub__ = make_numeric_binop('-', 'SUB') __mul__ = make_numeric_binop('*', 'MUL') __truediv__ = make_numeric_binop('/', 'DIV') __floordiv__ = make_numeric_binop('//', 'FLOORDIV') __mod__ = make_numeric_binop('%', 'MOD') def __pow__(monad, monad2): if not isinstance(monad2, NumericMixin): throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), '**')) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 return NumericExprMonad(float, [ 'POW', left_sql[0], right_sql[0] ], nullable=monad.nullable or monad2.nullable) def __neg__(monad): sql = monad.getsql()[0] return NumericExprMonad(monad.type, [ 'NEG', sql ], nullable=monad.nullable) def abs(monad): sql = monad.getsql()[0] return NumericExprMonad(monad.type, [ 'ABS', sql ], nullable=monad.nullable) def nonzero(monad): translator = monad.translator sql = monad.getsql()[0] if not (translator.dialect == 'PostgreSQL' and monad.type is bool): sql = [ 'NE', sql, [ 'VALUE', 0 ] ] return BoolExprMonad(sql, nullable=False) def negate(monad): sql = monad.getsql()[0] translator = monad.translator pg_bool = translator.dialect == 'PostgreSQL' and monad.type is bool result_sql = [ 'NOT', sql ] if pg_bool else [ 'EQ', sql, [ 'VALUE', 0 ] ] if monad.nullable: if isinstance(monad, AttrMonad): result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] elif pg_bool: result_sql = [ 'NOT', [ 'COALESCE', sql, [ 'VALUE', True ] ] ] else: result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', 0 ] ], [ 'VALUE', 0 ] ] return BoolExprMonad(result_sql, nullable=False) def numeric_attr_factory(name): def attr_func(monad): sql = [ name, monad.getsql()[0] ] return NumericExprMonad(int, sql, nullable=monad.nullable) attr_func.__name__ = name.lower() return attr_func def make_datetime_binop(op, sqlop): def datetime_binop(monad, monad2): if monad2.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad return expr_monad_cls(monad.type, [ sqlop, monad.getsql()[0], monad2.getsql()[0] ], nullable=monad.nullable or monad2.nullable) datetime_binop.__name__ = sqlop return datetime_binop class DateMixin(MonadMixin): def mixin_init(monad): assert monad.type is date attr_year = numeric_attr_factory('YEAR') attr_month = numeric_attr_factory('MONTH') attr_day = numeric_attr_factory('DAY') def __add__(monad, other): if other.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) return DateExprMonad(monad.type, [ 'DATE_ADD', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) def __sub__(monad, other): if other.type == timedelta: return DateExprMonad(monad.type, [ 'DATE_SUB', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) elif other.type == date: return TimedeltaExprMonad(timedelta, [ 'DATE_DIFF', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) class TimeMixin(MonadMixin): def mixin_init(monad): assert monad.type is time attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') class TimedeltaMixin(MonadMixin): def mixin_init(monad): assert monad.type is timedelta class DatetimeMixin(DateMixin): def mixin_init(monad): assert monad.type is datetime def call_date(monad): sql = [ 'DATE', monad.getsql()[0] ] return ExprMonad.new(date, sql, nullable=monad.nullable) attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') def __add__(monad, other): if other.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) return DatetimeExprMonad(monad.type, [ 'DATETIME_ADD', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) def __sub__(monad, other): if other.type == timedelta: return DatetimeExprMonad(monad.type, [ 'DATETIME_SUB', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) elif other.type == datetime: return TimedeltaExprMonad(timedelta, [ 'DATETIME_DIFF', monad.getsql()[0], other.getsql()[0] ], nullable=monad.nullable or other.nullable) throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) def make_string_binop(op, sqlop): def string_binop(monad, monad2): if not are_comparable_types(monad.type, monad2.type, sqlop): if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 return StringExprMonad(monad.type, [ sqlop, left_sql[0], right_sql[0] ], nullable=monad.nullable or monad2.nullable) string_binop.__name__ = sqlop return string_binop def make_string_func(sqlop): def func(monad): sql = monad.getsql() assert len(sql) == 1 return StringExprMonad(monad.type, [ sqlop, sql[0] ], nullable=monad.nullable) func.__name__ = sqlop return func class StringMixin(MonadMixin): def mixin_init(monad): assert issubclass(monad.type, basestring), monad.type __add__ = make_string_binop('+', 'CONCAT') def __getitem__(monad, index): if isinstance(index, ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start, stop = index.start, index.stop if start is None and stop is None: return monad if isinstance(monad, StringConstMonad) \ and (start is None or isinstance(start, NumericConstMonad)) \ and (stop is None or isinstance(stop, NumericConstMonad)): if start is not None: start = start.value if stop is not None: stop = stop.value return ConstMonad.new(monad.value[start:stop]) if start is not None and start.type is not int: throw(TypeError, "Invalid type of start index (expected 'int', got %r) in string slice {EXPR}" % type2str(start.type)) if stop is not None and stop.type is not int: throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in string slice {EXPR}" % type2str(stop.type)) expr_sql = monad.getsql()[0] if start is None: start = ConstMonad.new(0) if isinstance(start, NumericConstMonad): if start.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') start_sql = [ 'VALUE', start.value + 1 ] else: start_sql = start.getsql()[0] start_sql = [ 'ADD', start_sql, [ 'VALUE', 1 ] ] if stop is None: len_sql = None elif isinstance(stop, NumericConstMonad): if stop.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') if isinstance(start, NumericConstMonad): len_sql = [ 'VALUE', stop.value - start.value ] else: len_sql = [ 'SUB', [ 'VALUE', stop.value ], start.getsql()[0] ] else: stop_sql = stop.getsql()[0] if isinstance(start, NumericConstMonad): len_sql = [ 'SUB', stop_sql, [ 'VALUE', start.value ] ] else: len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] return StringExprMonad(monad.type, sql, nullable=monad.nullable or start.nullable or stop is not None and stop.nullable) if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): return ConstMonad.new(monad.value[index.value]) if index.type is not int: throw(TypeError, 'String indices must be integers. Got %r in expression {EXPR}' % type2str(index.type)) expr_sql = monad.getsql()[0] if isinstance(index, NumericConstMonad): value = index.value if value >= 0: value += 1 index_sql = [ 'VALUE', value ] else: inner_sql = index.getsql()[0] index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] return StringExprMonad(monad.type, sql, nullable=monad.nullable) def negate(monad): sql = monad.getsql()[0] translator = monad.translator if translator.dialect == 'Oracle': result_sql = [ 'IS_NULL', sql ] else: result_sql = [ 'EQ', sql, [ 'VALUE', '' ] ] if monad.nullable: if isinstance(monad, AttrMonad): result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] else: result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', '' ] ], [ 'VALUE', '' ]] result = BoolExprMonad(result_sql, nullable=False) result.aggregated = monad.aggregated return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator if translator.dialect == 'Oracle': result_sql = [ 'IS_NOT_NULL', sql ] else: result_sql = [ 'NE', sql, [ 'VALUE', '' ] ] result = BoolExprMonad(result_sql, nullable=False) result.aggregated = monad.aggregated return result def len(monad): sql = monad.getsql()[0] return NumericExprMonad(int, [ 'LENGTH', sql ]) def contains(monad, item, not_in=False): check_comparable(item, monad, 'LIKE') return monad._like(item, before='%', after='%', not_like=not_in) call_upper = make_string_func('UPPER') call_lower = make_string_func('LOWER') def call_startswith(monad, arg): if not are_comparable_types(monad.type, arg.type, None): if arg.type == 'METHOD': raise_forgot_parentheses(arg) throw(TypeError, 'Expected %r argument but got %r in expression {EXPR}' % (type2str(monad.type), type2str(arg.type))) return monad._like(arg, after='%') def call_endswith(monad, arg): if not are_comparable_types(monad.type, arg.type, None): if arg.type == 'METHOD': raise_forgot_parentheses(arg) throw(TypeError, 'Expected %r argument but got %r in expression {EXPR}' % (type2str(monad.type), type2str(arg.type))) return monad._like(arg, before='%') def _like(monad, item, before=None, after=None, not_like=False): escape = False translator = monad.translator if isinstance(item, StringConstMonad): value = item.value if '%' in value or '_' in value: escape = True value = value.replace('!', '!!').replace('%', '!%').replace('_', '!_') if before: value = before + value if after: value = value + after item_sql = [ 'VALUE', value ] else: escape = True item_sql = item.getsql()[0] item_sql = [ 'REPLACE', item_sql, [ 'VALUE', '!' ], [ 'VALUE', '!!' ] ] item_sql = [ 'REPLACE', item_sql, [ 'VALUE', '%' ], [ 'VALUE', '!%' ] ] item_sql = [ 'REPLACE', item_sql, [ 'VALUE', '_' ], [ 'VALUE', '!_' ] ] if before and after: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql, [ 'VALUE', after ] ] elif before: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql ] elif after: item_sql = [ 'CONCAT', item_sql, [ 'VALUE', after ] ] sql = monad.getsql()[0] if not_like and monad.nullable and not isinstance(monad, AttrMonad) and translator.dialect != 'Oracle': sql = [ 'COALESCE', sql, [ 'VALUE', '' ] ] result_sql = [ 'NOT_LIKE' if not_like else 'LIKE', sql, item_sql ] if escape: result_sql.append([ 'VALUE', '!' ]) if not_like and monad.nullable and (isinstance(monad, AttrMonad) or translator.dialect == 'Oracle'): result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] return BoolExprMonad(result_sql, nullable=not_like) def strip(monad, chars, strip_type): if chars is not None and not are_comparable_types(monad.type, chars.type, None): if chars.type == 'METHOD': raise_forgot_parentheses(chars) throw(TypeError, "'chars' argument must be of %r type in {EXPR}, got: %r" % (type2str(monad.type), type2str(chars.type))) parent_sql = monad.getsql()[0] sql = [ strip_type, parent_sql ] if chars is not None: sql.append(chars.getsql()[0]) return StringExprMonad(monad.type, sql, nullable=monad.nullable) def call_strip(monad, chars=None): return monad.strip(chars, 'TRIM') def call_lstrip(monad, chars=None): return monad.strip(chars, 'LTRIM') def call_rstrip(monad, chars=None): return monad.strip(chars, 'RTRIM') class JsonMixin(object): disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column def mixin_init(monad): assert monad.type is Json, monad.type def get_path(monad): return monad, [] def __getitem__(monad, key): return JsonItemMonad(monad, key) def contains(monad, key, not_in=False): translator = monad.translator if isinstance(key, ParamMonad): if translator.dialect == 'Oracle': throw(TypeError, 'For `key in JSON` operation %s supports literal key values only, ' 'parameters are not allowed: {EXPR}' % translator.dialect) elif not isinstance(key, StringConstMonad): raise NotImplementedError base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] key_sql = key.getsql()[0] sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] if not_in: sql = [ 'NOT', sql ] return BoolExprMonad(sql) def __or__(monad, other): if not isinstance(other, JsonMixin): raise TypeError('Should be JSON: %s' % ast2src(other.node)) left_sql = monad.getsql()[0] right_sql = other.getsql()[0] sql = [ 'JSON_CONCAT', left_sql, right_sql ] return JsonExprMonad(Json, sql) def len(monad): sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] return NumericExprMonad(int, sql) def cast_from_json(monad, type): if type in (Json, NoneType): return monad throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') def nonzero(monad): return BoolExprMonad([ 'JSON_NONZERO', monad.getsql()[0] ]) class ArrayMixin(MonadMixin): def contains(monad, key, not_in=False): if key.type is monad.type.item_type: sql = 'ARRAY_CONTAINS', key.getsql()[0], not_in, monad.getsql()[0] return BoolExprMonad(sql) if isinstance(key, ListMonad): if not key.items: if not_in: return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) else: return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) sql = [ 'MAKE_ARRAY' ] sql.extend(item.getsql()[0] for item in key.items) sql = 'ARRAY_SUBSET', sql, not_in, monad.getsql()[0] return BoolExprMonad(sql) elif isinstance(key, ArrayParamMonad): sql = 'ARRAY_SUBSET', key.getsql()[0], not_in, monad.getsql()[0] return BoolExprMonad(sql) throw(TypeError, 'Cannot search for %s in %s: {EXPR}' % (type2str(key.type), type2str(monad.type))) def len(monad): sql = ['ARRAY_LENGTH', monad.getsql()[0]] return NumericExprMonad(int, sql) def nonzero(monad): return BoolExprMonad(['GT', ['ARRAY_LENGTH', monad.getsql()[0]], ['VALUE', 0]]) def _index(monad, index, from_one, plus_one): if isinstance(index, NumericConstMonad): expr_sql = monad.getsql()[0] index_sql = index.getsql()[0] value = index_sql[1] if from_one and plus_one: if value >= 0: index_sql = ['VALUE', value + 1] else: index_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value) + 1]] return index_sql elif isinstance(index, NumericMixin): expr_sql = monad.getsql()[0] index0 = index.getsql()[0] index1 = ['ADD', index0, ['VALUE', 1]] if from_one and plus_one else index0 index_sql = ['CASE', None, [[['GE', index0, ['VALUE', 0]], index1]], ['ADD', ['ARRAY_LENGTH', expr_sql], index1]] return index_sql def __getitem__(monad, index): dialect = monad.translator.database.provider.dialect expr_sql = monad.getsql()[0] from_one = dialect != 'SQLite' if isinstance(index, NumericMixin): index_sql = monad._index(index, from_one, plus_one=True) sql = ['ARRAY_INDEX', expr_sql, index_sql] return ExprMonad.new(monad.type.item_type, sql) elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start_sql = monad._index(index.start, from_one, plus_one=True) stop_sql = monad._index(index.stop, from_one, plus_one=False) sql = ['ARRAY_SLICE', expr_sql, start_sql, stop_sql] return ExprMonad.new(monad.type, sql) class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) def negate(monad): return CmpMonad('is', monad, NoneMonad()) def nonzero(monad): return CmpMonad('is not', monad, NoneMonad()) def getattr(monad, attrname): entity = monad.type attr = entity._adict_.get(attrname) or entity._subclass_adict_.get(attrname) if attr is None: if hasattr(entity, attrname): attr = getattr(entity, attrname, None) if isinstance(attr, property): new_monad = HybridMethodMonad(monad, attrname, attr.fget) return new_monad() if callable(attr): func = getattr(attr, '__func__') if PY2 else attr if func is not None: return HybridMethodMonad(monad, attrname, func) throw(NotImplementedError, '{EXPR} cannot be translated to SQL') throw(AttributeError, 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, attrname)) if hasattr(monad, 'tableref'): monad.tableref.used_attrs.add(attr) if not attr.is_collection: return AttrMonad.new(monad, attr) else: return AttrSetMonad(monad, attr) def requires_distinct(monad, joined=False): return monad.attr.reverse.is_collection or monad.parent.requires_distinct(joined) # parent ??? class ObjectIterMonad(ObjectMixin, Monad): def __init__(monad, tableref, entity): Monad.__init__(monad, entity) monad.tableref = tableref def getsql(monad, sqlquery=None): entity = monad.type alias, pk_columns = monad.tableref.make_join(pk_only=True) return [ [ 'COLUMN', alias, column ] for column in pk_columns ] def requires_distinct(monad, joined=False): return monad.tableref.name_path != monad.translator.tree.quals[-1].assign.name class AttrMonad(Monad): @staticmethod def new(parent, attr, *args, **kwargs): t = normalize_type(attr.py_type) if t in numeric_types: cls = NumericAttrMonad elif t is unicode: cls = StringAttrMonad elif t is date: cls = DateAttrMonad elif t is time: cls = TimeAttrMonad elif t is timedelta: cls = TimedeltaAttrMonad elif t is datetime: cls = DatetimeAttrMonad elif t is buffer: cls = BufferAttrMonad elif t is UUID: cls = UuidAttrMonad elif t is Json: cls = JsonAttrMonad elif isinstance(t, EntityMeta): cls = ObjectAttrMonad elif isinstance(t, type) and issubclass(t, Array): cls = ArrayAttrMonad else: throw(NotImplementedError, t) # pragma: no cover return cls(parent, attr, *args, **kwargs) def __new__(cls, *args): if cls is AttrMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, parent, attr): assert monad.__class__ is not AttrMonad attr_type = normalize_type(attr.py_type) Monad.__init__(monad, attr_type) monad.parent = parent monad.attr = attr monad.nullable = attr.nullable def getsql(monad, sqlquery=None): parent = monad.parent attr = monad.attr entity = attr.entity pk_only = attr.pk_offset is not None alias, parent_columns = monad.parent.tableref.make_join(pk_only) if pk_only: if entity._pk_is_composite_: offset = attr.pk_columns_offset columns = parent_columns[offset:offset+len(attr.columns)] else: columns = parent_columns elif not attr.columns: assert isinstance(monad, ObjectAttrMonad) sqlquery = monad.translator.sqlquery monad.translator.left_join = sqlquery.left_join = True sqlquery.from_ast[0] = 'LEFT_JOIN' alias, columns = monad.tableref.make_join() else: columns = attr.columns return [ [ 'COLUMN', alias, column ] for column in columns ] class ObjectAttrMonad(ObjectMixin, AttrMonad): def __init__(monad, parent, attr): AttrMonad.__init__(monad, parent, attr) translator = monad.translator parent_monad = monad.parent entity = monad.type name_path = '-'.join((parent_monad.tableref.name_path, attr.name)) monad.tableref = translator.sqlquery.get_tableref(name_path) if monad.tableref is None: parent_sqlquery = parent_monad.tableref.sqlquery monad.tableref = parent_sqlquery.add_tableref(name_path, parent_monad.tableref, attr) class StringAttrMonad(StringMixin, AttrMonad): pass class NumericAttrMonad(NumericMixin, AttrMonad): pass class DateAttrMonad(DateMixin, AttrMonad): pass class TimeAttrMonad(TimeMixin, AttrMonad): pass class TimedeltaAttrMonad(TimedeltaMixin, AttrMonad): pass class DatetimeAttrMonad(DatetimeMixin, AttrMonad): pass class BufferAttrMonad(BufferMixin, AttrMonad): pass class UuidAttrMonad(UuidMixin, AttrMonad): pass class JsonAttrMonad(JsonMixin, AttrMonad): pass class ArrayAttrMonad(ArrayMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod def new(t, paramkey): t = normalize_type(t) if t in numeric_types: cls = NumericParamMonad elif t is unicode: cls = StringParamMonad elif t is date: cls = DateParamMonad elif t is time: cls = TimeParamMonad elif t is timedelta: cls = TimedeltaParamMonad elif t is datetime: cls = DatetimeParamMonad elif t is buffer: cls = BufferParamMonad elif t is UUID: cls = UuidParamMonad elif t is Json: cls = JsonParamMonad elif isinstance(t, type) and issubclass(t, Array): cls = ArrayParamMonad elif isinstance(t, EntityMeta): cls = ObjectParamMonad else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (t,)) result = cls(t, paramkey) result.aggregated = False return result def __new__(cls, *args, **kwargs): if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, t, paramkey): t = normalize_type(t) Monad.__init__(monad, t, nullable=False) monad.paramkey = paramkey if not isinstance(t, EntityMeta): provider = monad.translator.database.provider monad.converter = provider.get_converter_by_py_type(t) else: monad.converter = None def getsql(monad, sqlquery=None): return [ [ 'PARAM', monad.paramkey, monad.converter ] ] class ObjectParamMonad(ObjectMixin, ParamMonad): def __init__(monad, entity, paramkey): ParamMonad.__init__(monad, entity, paramkey) if monad.translator.database is not entity._database_: assert monad.translator.database is entity._database_, (paramkey, monad.translator.database, entity._database_) varkey, i, j = paramkey assert j is None monad.params = tuple((varkey, i, j) for j in xrange(len(entity._pk_converters_))) def getsql(monad, sqlquery=None): entity = monad.type assert len(monad.params) == len(entity._pk_converters_) return [ [ 'PARAM', param, converter ] for param, converter in izip(monad.params, entity._pk_converters_) ] def requires_distinct(monad, joined=False): assert False # pragma: no cover class StringParamMonad(StringMixin, ParamMonad): pass class NumericParamMonad(NumericMixin, ParamMonad): pass class DateParamMonad(DateMixin, ParamMonad): pass class TimeParamMonad(TimeMixin, ParamMonad): pass class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass class ArrayParamMonad(ArrayMixin, ParamMonad): def __init__(monad, t, paramkey, list_monad=None): ParamMonad.__init__(monad, t, paramkey) monad.list_monad = list_monad def contains(monad, key, not_in=False): if key.type is monad.type.item_type: return monad.list_monad.contains(key, not_in) return ArrayMixin.contains(monad, key, not_in) class JsonParamMonad(JsonMixin, ParamMonad): def getsql(monad, sqlquery=None): return [ [ 'JSON_PARAM', ParamMonad.getsql(monad)[0] ] ] class ExprMonad(Monad): @staticmethod def new(t, sql, nullable=True): if t in numeric_types: cls = NumericExprMonad elif t is unicode: cls = StringExprMonad elif t is date: cls = DateExprMonad elif t is time: cls = TimeExprMonad elif t is timedelta: cls = TimedeltaExprMonad elif t is datetime: cls = DatetimeExprMonad elif t is Json: cls = JsonExprMonad elif isinstance(t, EntityMeta): cls = ObjectExprMonad elif isinstance(t, type) and issubclass(t, Array): cls = ArrayExprMonad else: throw(NotImplementedError, t) # pragma: no cover return cls(t, sql, nullable=nullable) def __new__(cls, *args, **kwargs): if cls is ExprMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, type, sql, nullable=True): Monad.__init__(monad, type, nullable=nullable) monad.sql = sql def getsql(monad, sqlquery=None): return [ monad.sql ] class ObjectExprMonad(ObjectMixin, ExprMonad): def getsql(monad, sqlquery=None): return monad.sql class StringExprMonad(StringMixin, ExprMonad): pass class NumericExprMonad(NumericMixin, ExprMonad): pass class DateExprMonad(DateMixin, ExprMonad): pass class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass class JsonExprMonad(JsonMixin, ExprMonad): pass class ArrayExprMonad(ArrayMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): assert isinstance(parent, JsonMixin), parent Monad.__init__(monad, Json) monad.parent = parent if isinstance(key, slice): if key != slice(None, None, None): throw(NotImplementedError) monad.key_ast = [ 'VALUE', key ] elif isinstance(key, (ParamMonad, StringConstMonad, NumericConstMonad, EllipsisMonad)): monad.key_ast = key.getsql()[0] else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) translator = monad.translator if isinstance(key, (slice, EllipsisMonad)) and not translator.json_path_wildcard_syntax: throw(TranslationError, '%s does not support wildcards in JSON path: {EXPR}' % translator.dialect) def get_path(monad): path = [] while isinstance(monad, JsonItemMonad): path.append(monad.key_ast) monad = monad.parent path.reverse() return monad, path def to_int(monad): return monad.cast_from_json(int) def to_str(monad): return monad.cast_from_json(unicode) def to_real(monad): return monad.cast_from_json(float) def cast_from_json(monad, type): translator = monad.translator if issubclass(type, Json): if not translator.json_values_are_comparable: throw(TranslationError, '%s does not support comparison of json structures: {EXPR}' % translator.dialect) return monad base_monad, path = monad.get_path() sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] return ExprMonad.new(Json if type is NoneType else type, sql) def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] translator = monad.translator if translator.inside_order_by and translator.dialect == 'SQLite': return [ [ 'JSON_VALUE', base_sql, path, None ] ] return [ [ 'JSON_QUERY', base_sql, path ] ] class ConstMonad(Monad): @staticmethod def new(value): value_type, value = normalize(value) if value_type in numeric_types: cls = NumericConstMonad elif value_type is unicode: cls = StringConstMonad elif value_type is date: cls = DateConstMonad elif value_type is time: cls = TimeConstMonad elif value_type is timedelta: cls = TimedeltaConstMonad elif value_type is datetime: cls = DatetimeConstMonad elif value_type is NoneType: cls = NoneMonad elif value_type is buffer: cls = BufferConstMonad elif value_type is Json: cls = JsonConstMonad elif issubclass(value_type, type(Ellipsis)): cls = EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover result = cls(value) result.aggregated = False return result def __new__(cls, *args): if cls is ConstMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, value): value_type, value = normalize(value) Monad.__init__(monad, value_type, nullable=value_type is NoneType) monad.value = value def getsql(monad, sqlquery=None): return [ [ 'VALUE', monad.value ] ] class NoneMonad(ConstMonad): type = NoneType def __init__(monad, value=None): assert value is None ConstMonad.__init__(monad, value) class EllipsisMonad(ConstMonad): pass class StringConstMonad(StringMixin, ConstMonad): def len(monad): return ConstMonad.new(len(monad.value)) class JsonConstMonad(JsonMixin, ConstMonad): pass class BufferConstMonad(BufferMixin, ConstMonad): pass class NumericConstMonad(NumericMixin, ConstMonad): pass class DateConstMonad(DateMixin, ConstMonad): pass class TimeConstMonad(TimeMixin, ConstMonad): pass class TimedeltaConstMonad(TimedeltaMixin, ConstMonad): pass class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): def __init__(monad, nullable=True): Monad.__init__(monad, bool, nullable=nullable) def nonzero(monad): return monad sql_negation = { 'IN' : 'NOT_IN', 'EXISTS' : 'NOT_EXISTS', 'LIKE' : 'NOT_LIKE', 'BETWEEN' : 'NOT_BETWEEN', 'IS_NULL' : 'IS_NOT_NULL' } sql_negation.update((value, key) for key, value in items_list(sql_negation)) class BoolExprMonad(BoolMonad): def __init__(monad, sql, nullable=True): BoolMonad.__init__(monad, nullable=nullable) monad.sql = sql def getsql(monad, sqlquery=None): return [ monad.sql ] def negate(monad): sql = monad.sql sqlop = sql[0] negated_op = sql_negation.get(sqlop) if negated_op is not None: negated_sql = [ negated_op ] + sql[1:] elif negated_op == 'NOT': assert len(sql) == 2 negated_sql = sql[1] else: return NotMonad(monad) return BoolExprMonad(negated_sql, nullable=monad.nullable) cmp_ops = { '>=' : 'GE', '>' : 'GT', '<=' : 'LE', '<' : 'LT' } cmp_negate = { '<' : '>=', '<=' : '>', '==' : '!=', 'is' : 'is not' } cmp_negate.update((b, a) for a, b in items_list(cmp_negate)) class CmpMonad(BoolMonad): EQ = 'EQ' NE = 'NE' def __init__(monad, op, left, right): if op == '<>': op = '!=' if left.type is NoneType: assert right.type is not NoneType left, right = right, left if right.type is NoneType: if op == '==': op = 'is' elif op == '!=': op = 'is not' elif op == 'is': op = '==' elif op == 'is not': op = '!=' check_comparable(left, right, op) result_type, left, right = coerce_monads(left, right, for_comparison=True) BoolMonad.__init__(monad, nullable=left.nullable or right.nullable) monad.op = op monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) if isinstance(left, JsonMixin): left = left.cast_from_json(right.type) if isinstance(right, JsonMixin): right = right.cast_from_json(left.type) monad.left = left monad.right = right def negate(monad): return CmpMonad(cmp_negate[monad.op], monad.left, monad.right) def getsql(monad, sqlquery=None): op = monad.op left_sql = monad.left.getsql() if op == 'is': return [ sqland([ [ 'IS_NULL', item ] for item in left_sql ]) ] if op == 'is not': return [ sqland([ [ 'IS_NOT_NULL', item ] for item in left_sql ]) ] right_sql = monad.right.getsql() if len(left_sql) == 1 and left_sql[0][0] == 'ROW': left_sql = left_sql[0][1:] if len(right_sql) == 1 and right_sql[0][0] == 'ROW': right_sql = right_sql[0][1:] assert len(left_sql) == len(right_sql) size = len(left_sql) if op in ('<', '<=', '>', '>='): if size == 1: return [ [ cmp_ops[op], left_sql[0], right_sql[0] ] ] if monad.translator.row_value_syntax: return [ [ cmp_ops[op], [ 'ROW' ] + left_sql, [ 'ROW' ] + right_sql ] ] clauses = [] for i in xrange(size): clause = [ [ monad.EQ, left_sql[j], right_sql[j] ] for j in range(i) ] clause.append([ cmp_ops[op], left_sql[i], right_sql[i] ]) clauses.append(sqland(clause)) return [ sqlor(clauses) ] if op == '==': return [ sqland([ [ monad.EQ, a, b ] for a, b in izip(left_sql, right_sql) ]) ] if op == '!=': return [ sqlor([ [ monad.NE, a, b ] for a, b in izip(left_sql, right_sql) ]) ] assert False, op # pragma: no cover class LogicalBinOpMonad(BoolMonad): def __init__(monad, operands): assert len(operands) >= 2 items = [] for operand in operands: if operand.type is not bool: items.append(operand.nonzero()) elif isinstance(operand, LogicalBinOpMonad) and monad.binop == operand.binop: items.extend(operand.operands) else: items.append(operand) nullable = any(item.nullable for item in items) BoolMonad.__init__(monad, nullable=nullable) monad.operands = items def getsql(monad, sqlquery=None): result = [ monad.binop ] for operand in monad.operands: operand_sql = operand.getsql() assert len(operand_sql) == 1 result.extend(operand_sql) return [ result ] class AndMonad(LogicalBinOpMonad): binop = 'AND' class OrMonad(LogicalBinOpMonad): binop = 'OR' class NotMonad(BoolMonad): def __init__(monad, operand): if operand.type is not bool: operand = operand.nonzero() BoolMonad.__init__(monad, nullable=operand.nullable) monad.operand = operand def negate(monad): return monad.operand def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] class HybridFuncMonad(Monad): def __init__(monad, func_type, func_name, *params): Monad.__init__(monad, func_type) monad.func = func_type.func monad.func_name = func_name monad.params = params def __call__(monad, *args, **kwargs): translator = monad.translator name_mapping = inspect.getcallargs(monad.func, *(monad.params + args), **kwargs) func = monad.func if PY2 and isinstance(func, types.UnboundMethodType): func = func.im_func func_id = id(func) try: func_ast, external_names, cells = decompile(func) except DecompileError: throw(TranslationError, '%s(...) is too complex to decompile' % ast2src(monad.node)) func_ast, func_extractors = create_extractors( func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) root_translator = translator.root_translator if func not in root_translator.func_extractors_map: func_vars, func_vartypes = extract_vars(func_id, translator.filter_num, func_extractors, func.__globals__, {}, cells) translator.database.provider.normalize_vars(func_vars, func_vartypes) if func.__closure__: translator.can_be_cached = False if func_extractors: root_translator.func_extractors_map[func] = func_extractors root_translator.func_vartypes.update(func_vartypes) root_translator.vartypes.update(func_vartypes) root_translator.vars.update(func_vars) stack = translator.namespace_stack stack.append(name_mapping) func_ast = copy_ast(func_ast) try: prev_code_key = translator.code_key translator.code_key = func_id try: translator.dispatch(func_ast) finally: translator.code_key = prev_code_key except Exception as e: if len(e.args) == 1 and isinstance(e.args[0], basestring): msg = e.args[0] + ' (inside %s)' % (monad.func_name) e.args = (msg,) raise stack.pop() return func_ast.monad class HybridMethodMonad(HybridFuncMonad): def __init__(monad, parent, attrname, func): entity = parent.type assert isinstance(entity, EntityMeta) func_name = '%s.%s' % (entity.__name__, attrname) HybridFuncMonad.__init__(monad, FuncType(func), func_name, parent) registered_functions = SQLTranslator.registered_functions = {} class FuncMonadMeta(MonadMeta): def __new__(meta, cls_name, bases, cls_dict): func = cls_dict.get('func') monad_cls = super(FuncMonadMeta, meta).__new__(meta, cls_name, bases, cls_dict) if func: if type(func) is tuple: functions = func else: functions = (func,) for func in functions: registered_functions[func] = monad_cls return monad_cls class FuncMonad(with_metaclass(FuncMonadMeta, Monad)): def __call__(monad, *args, **kwargs): for arg in args: assert isinstance(arg, Monad) for value in kwargs.values(): assert isinstance(value, Monad) try: return monad.call(*args, **kwargs) except TypeError as exc: reraise_improved_typeerror(exc, 'call', monad.type.__name__) def get_classes(classinfo): if isinstance(classinfo, EntityMonad): yield classinfo.type.item_type elif isinstance(classinfo, ListMonad): for item in classinfo.items: for type in get_classes(item): yield type else: throw(TypeError, ast2src(classinfo.node)) class FuncIsinstanceMonad(FuncMonad): func = isinstance def call(monad, obj, classinfo): if not isinstance(obj, ObjectMixin): throw(ValueError, 'Inside a query, isinstance first argument should be of entity type. Got: %s' % ast2src(obj.node)) entity = obj.type classes = list(get_classes(classinfo)) subclasses = set() for cls in classes: if entity._root_ is cls._root_: subclasses.add(cls) subclasses.update(cls._subclasses_) if entity in subclasses: return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) subclasses.intersection_update(entity._subclasses_) if not subclasses: return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) discr_attr = entity._discriminator_attr_ assert discr_attr is not None discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in subclasses ] alias, pk_columns = obj.tableref.make_join(pk_only=True) sql = [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] return BoolExprMonad(sql, nullable=False) class FuncBufferMonad(FuncMonad): func = buffer def call(monad, source, encoding=None, errors=None): if not isinstance(source, StringConstMonad): throw(TypeError) source = source.value if encoding is not None: if not isinstance(encoding, StringConstMonad): throw(TypeError) encoding = encoding.value if errors is not None: if not isinstance(errors, StringConstMonad): throw(TypeError) errors = errors.value if PY2: if encoding and errors: source = source.encode(encoding, errors) elif encoding: source = source.encode(encoding) return ConstMonad.new(buffer(source)) else: if encoding and errors: value = buffer(source, encoding, errors) elif encoding: value = buffer(source, encoding) else: value = buffer(source) return ConstMonad.new(value) class FuncBoolMonad(FuncMonad): func = bool def call(monad, x): return x.nonzero() class FuncIntMonad(FuncMonad): func = int def call(monad, x): return x.to_int() class FuncStrMonad(FuncMonad): func = str def call(monad, x): return x.to_str() class FuncFloatMonad(FuncMonad): func = float def call(monad, x): return x.to_real() class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): if not isinstance(x, StringConstMonad): throw(TypeError) return ConstMonad.new(Decimal(x.value)) class FuncDateMonad(FuncMonad): func = date def call(monad, year, month, day): for arg, name in izip((year, month, day), ('year', 'month', 'day')): if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of date(year, month, day) function must be of 'int' type. " "Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) return ConstMonad.new(date(year.value, month.value, day.value)) def call_today(monad): return DateExprMonad(date, [ 'TODAY' ], nullable=monad.nullable) class FuncTimeMonad(FuncMonad): func = time def call(monad, *args): for arg, name in izip(args, ('hour', 'minute', 'second', 'microsecond')): if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of time(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) return ConstMonad.new(time(*tuple(arg.value for arg in args))) class FuncTimedeltaMonad(FuncMonad): func = timedelta def call(monad, days=None, seconds=None, microseconds=None, milliseconds=None, minutes=None, hours=None, weeks=None): args = days, seconds, microseconds, milliseconds, minutes, hours, weeks for arg, name in izip(args, ('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks')): if arg is None: continue if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of timedelta(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = timedelta(*(arg.value if arg is not None else 0 for arg in args)) return ConstMonad.new(value) class FuncDatetimeMonad(FuncDateMonad): func = datetime def call(monad, year, month, day, hour=None, minute=None, second=None, microsecond=None): args = year, month, day, hour, minute, second, microsecond for arg, name in izip(args, ('year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond')): if arg is None: continue if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of datetime(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = datetime(*(arg.value if arg is not None else 0 for arg in args)) return ConstMonad.new(value) def call_now(monad): return DatetimeExprMonad(datetime, [ 'NOW' ], nullable=monad.nullable) class FuncBetweenMonad(FuncMonad): func = between def call(monad, x, a, b): check_comparable(x, a, '<') check_comparable(x, b, '<') if isinstance(x.type, EntityMeta): throw(TypeError, '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] return BoolExprMonad(sql, nullable=x.nullable or a.nullable or b.nullable) class FuncConcatMonad(FuncMonad): func = concat def call(monad, *args): if len(args) < 2: throw(TranslationError, 'concat() function requires at least two arguments') result_ast = [ 'CONCAT' ] for arg in args: t = arg.type if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) result_ast.extend(arg.getsql()) return ExprMonad.new(unicode, result_ast, nullable=any(arg.nullable for arg in args)) class FuncLenMonad(FuncMonad): func = len def call(monad, x): return x.len() class FuncGetattrMonad(FuncMonad): func = getattr def call(monad, obj_monad, name_monad): if isinstance(name_monad, ConstMonad): attrname = name_monad.value elif isinstance(name_monad, ParamMonad): translator = monad.translator.root_translator key = name_monad.paramkey[0] if key in translator.getattr_values: attrname = translator.getattr_values[key] else: attrname = translator.vars[key] translator.getattr_values[key] = attrname else: throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' 'because %s will be different for each row' % ast2src(name_monad.node)) if not isinstance(attrname, basestring): throw(TypeError, 'In `{EXPR}` second argument should be a string. Got: %r' % attrname) return obj_monad.getattr(attrname) class FuncRawSQLMonad(FuncMonad): func = raw_sql def call(monad, *args): throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' 'because raw SQL fragment will be different for each row') class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count def call(monad, x=None, distinct=None): if isinstance(x, StringConstMonad) and x.value == '*': x = None if x is not None: return x.count(distinct) result = ExprMonad.new(int, [ 'COUNT', None ], nullable=False) result.aggregated = True return result class FuncAbsMonad(FuncMonad): func = abs def call(monad, x): return x.abs() class FuncSumMonad(FuncMonad): func = sum, core.sum def call(monad, x, distinct=None): return x.aggregate('SUM', distinct) class FuncAvgMonad(FuncMonad): func = utils.avg, core.avg def call(monad, x, distinct=None): return x.aggregate('AVG', distinct) class FuncGroupConcatMonad(FuncMonad): func = utils.group_concat, core.group_concat def call(monad, x, sep=None, distinct=None): if sep is not None: if distinct and monad.translator.database.provider.dialect == 'SQLite': throw(TypeError, 'SQLite does not allow to specify distinct and separator in group_concat at the same time: {EXPR}') if not(isinstance(sep, StringConstMonad) and isinstance(sep.value, basestring)): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % ast2src(sep.node)) sep = sep.value return x.aggregate('GROUP_CONCAT', distinct=distinct, sep=sep) class FuncCoalesceMonad(FuncMonad): func = coalesce def call(monad, *args): if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') arg = args[0] t = arg.type result = [ [ sql ] for sql in arg.getsql() ] for arg in args[1:]: if arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') for i, sql in enumerate(arg.getsql()): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] if not isinstance(t, EntityMeta): sql = sql[0] return ExprMonad.new(t, sql, nullable=all(arg.nullable for arg in args)) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct def call(monad, x): if isinstance(x, SetMixin): return x.call_distinct() if not isinstance(x, NumericMixin): throw(TypeError) result = object.__new__(x.__class__) result.__dict__.update(x.__dict__) result.forced_distinct = True return result class FuncMinMonad(FuncMonad): func = min, core.min def call(monad, *args): if not args: throw(TypeError, 'min() function expected at least one argument') if len(args) == 1: return args[0].aggregate('MIN') return minmax(monad, 'MIN', *args) class FuncMaxMonad(FuncMonad): func = max, core.max def call(monad, *args): if not args: throw(TypeError, 'max() function expected at least one argument') if len(args) == 1: return args[0].aggregate('MAX') return minmax(monad, 'MAX', *args) def minmax(monad, sqlop, *args): assert len(args) > 1 translator = monad.translator t = args[0].type if t == 'METHOD': raise_forgot_parentheses(args[0]) if t not in comparable_types: throw(TypeError, "Value of type %r is not valid as argument of %r function in expression {EXPR}" % (type2str(t), sqlop.lower())) for arg in args[1:]: t2 = arg.type if t2 == 'METHOD': raise_forgot_parentheses(arg) t3 = coerce_types(t, t2) if t3 is None: throw(IncomparableTypesError, t, t2) t = t3 if t3 in numeric_types and translator.dialect == 'PostgreSQL': args = list(args) for i, arg in enumerate(args): if arg.type is bool: args[i] = NumericExprMonad(int, [ 'TO_INT', arg.getsql() ], nullable=arg.nullable) sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] return ExprMonad.new(t, sql, nullable=any(arg.nullable for arg in args)) class FuncSelectMonad(FuncMonad): func = core.select def call(monad, queryset): if not isinstance(queryset, QuerySetMonad): throw(TypeError, "'select' function expects generator expression, got: {EXPR}") return queryset class FuncExistsMonad(FuncMonad): func = core.exists def call(monad, arg): if not isinstance(arg, SetMixin): throw(TypeError, "'exists' function expects generator expression or collection, got: {EXPR}") return arg.nonzero() class FuncDescMonad(FuncMonad): func = core.desc def call(monad, expr): return DescMonad(expr) class DescMonad(Monad): def __init__(monad, expr): Monad.__init__(monad, expr.type, nullable=expr.nullable) monad.expr = expr def getsql(monad): return [ [ 'DESC', item ] for item in monad.expr.getsql() ] class JoinMonad(Monad): def __init__(monad, type): Monad.__init__(monad, type) translator = monad.translator monad.hint_join_prev = translator.hint_join translator.hint_join = True def __call__(monad, x): monad.translator.hint_join = monad.hint_join_prev return x registered_functions[JOIN] = JoinMonad class FuncRandomMonad(FuncMonad): func = random def __init__(monad, type): FuncMonad.__init__(monad, type) monad.translator.query_result_is_cacheable = False def __call__(monad): return NumericExprMonad(float, [ 'RANDOM' ], nullable=False) class SetMixin(MonadMixin): forced_distinct = False def call_distinct(monad): new_monad = object.__new__(monad.__class__) new_monad.__dict__.update(monad.__dict__) new_monad.forced_distinct = True return new_monad def make_attrset_binop(op, sqlop): def attrset_binop(monad, monad2): return NumericSetExprMonad(op, sqlop, monad, monad2) return attrset_binop class AttrSetMonad(SetMixin, Monad): def __init__(monad, parent, attr): item_type = normalize_type(attr.py_type) Monad.__init__(monad, SetType(item_type)) monad.parent = parent monad.attr = attr monad.sqlquery = None monad.tableref = None def cmp(monad, op, monad2): if type(monad2.type) is SetType \ and are_comparable_types(monad.type.item_type, monad2.type.item_type): pass elif monad.type != monad2.type: check_comparable(monad, monad2) throw(NotImplementedError) def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') if not translator.hint_join: sqlop = 'NOT_IN' if not_in else 'IN' sqlquery = monad._subselect() expr_list = sqlquery.expr_list from_ast = sqlquery.from_ast conditions = sqlquery.outer_conditions + sqlquery.conditions if len(expr_list) == 1: subquery_ast = [ 'SELECT', [ 'ALL' ] + expr_list, from_ast, [ 'WHERE' ] + conditions ] sql_ast = [ sqlop, item.getsql()[0], subquery_ast ] elif translator.row_value_syntax: subquery_ast = [ 'SELECT', [ 'ALL' ] + expr_list, from_ast, [ 'WHERE' ] + conditions ] sql_ast = [ sqlop, [ 'ROW' ] + item.getsql(), subquery_ast ] else: conditions += [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item.getsql(), expr_list) ] sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS', from_ast, [ 'WHERE' ] + conditions ] result = BoolExprMonad(sql_ast, nullable=False) result.nogroup = True return result elif not not_in: translator.distinct = True tableref = monad.make_tableref(translator.sqlquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) return BoolExprMonad(expr_ast, nullable=False) else: sqlquery = SqlQuery(translator, translator.sqlquery) tableref = monad.make_tableref(sqlquery) attr = monad.attr alias, columns = tableref.make_join(pk_only=attr.reverse) expr_list = monad.make_expr_list() if not attr.reverse: columns = attr.columns from_ast = translator.sqlquery.from_ast from_ast[0] = 'LEFT_JOIN' from_ast.extend(sqlquery.from_ast[1:]) conditions = [ [ 'EQ', [ 'COLUMN', alias, column ], expr ] for column, expr in izip(columns, item.getsql()) ] conditions.extend(sqlquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) return BoolExprMonad(expr_ast, nullable=False) def getattr(monad, name): try: return Monad.getattr(monad, name) except AttributeError: pass entity = monad.type.item_type if not isinstance(entity, EntityMeta): throw(AttributeError) attr = entity._adict_.get(name) if attr is None: throw(AttributeError) return AttrSetMonad(monad, attr) def call_select(monad): # calling with lambda argument processed in preCallFunc return monad call_filter = call_select def call_exists(monad): return monad def requires_distinct(monad, joined=False, for_count=False): if monad.parent.requires_distinct(joined): return True reverse = monad.attr.reverse if not reverse: return True if reverse.is_collection: translator = monad.translator if not for_count and not translator.hint_join: return True if isinstance(monad.parent, AttrSetMonad): return True return False def count(monad, distinct=None): translator = monad.translator distinct = distinct_from_monad(distinct, monad.requires_distinct(joined=translator.hint_join, for_count=True)) sqlquery = monad._subselect() expr_list = sqlquery.expr_list from_ast = sqlquery.from_ast inner_conditions = sqlquery.conditions outer_conditions = sqlquery.outer_conditions sql_ast = make_aggr = None extra_grouping = False if not distinct and monad.tableref.name_path != translator.optimize: make_aggr = lambda expr_list: [ 'COUNT', None ] elif len(expr_list) == 1: make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'Oracle': if monad.tableref.name_path == translator.optimize: alias, pk_columns = monad.tableref.make_join(pk_only=True) make_aggr = lambda expr_list: [ 'COUNT', distinct, [ 'COLUMN', alias, 'ROWID' ] ] else: extra_grouping = True if translator.hint_join: make_aggr = lambda expr_list: [ 'COUNT', None ] else: make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COUNT', None ] ] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr_list expr = [ 'CASE', None, [ [ [ 'IS_NULL', row ], [ 'VALUE', None ] ] ], row ] make_aggr = lambda expr_list: [ 'COUNT', True, expr ] elif translator.row_value_syntax: make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'SQLite': if not distinct: alias, pk_columns = monad.tableref.make_join(pk_only=True) make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COLUMN', alias, 'ROWID' ] ] elif translator.hint_join: # Same join as in Oracle extra_grouping = True make_aggr = lambda expr_list: [ 'COUNT', None ] elif translator.sqlite_version < (3, 6, 21): alias, pk_columns = monad.tableref.make_join(pk_only=False) make_aggr = lambda expr_list: [ 'COUNT', True, [ 'COLUMN', alias, 'ROWID' ] ] else: sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ 't', 'SELECT', [ [ 'DISTINCT' ] + expr_list, from_ast, [ 'WHERE' ] + outer_conditions + inner_conditions ] ] ] ] else: throw(NotImplementedError) # pragma: no cover if sql_ast: optimized = False elif translator.hint_join: sql_ast, optimized = monad._joined_subselect(make_aggr, extra_grouping, coalesce_to_zero=True) else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr, extra_grouping) translator.aggregated_subquery_paths.add(monad.tableref.name_path) result = ExprMonad.new(int, sql_ast, nullable=False) if optimized: result.aggregated = True else: result.nogroup = True return result len = count def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator item_type = monad.type.item_type if func_name in ('SUM', 'AVG'): if item_type not in numeric_types: throw(TypeError, "Function %s() expects query or items of numeric type, got %r in {EXPR}" % (func_name.lower(), type2str(item_type))) elif func_name in ('MIN', 'MAX'): if item_type not in comparable_types: throw(TypeError, "Function %s() expects query or items of comparable type, got %r in {EXPR}" % (func_name.lower(), type2str(item_type))) elif func_name == 'GROUP_CONCAT': if isinstance(item_type, EntityMeta) and item_type._pk_is_composite_: throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover def make_aggr(expr_list): result = [ func_name, distinct ] + expr_list if sep is not None: assert func_name == 'GROUP_CONCAT' result.append(['VALUE', sep]) return result # make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list if translator.hint_join: sql_ast, optimized = monad._joined_subselect(make_aggr, coalesce_to_zero=(func_name=='SUM')) else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr) if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': result_type = unicode else: result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') if optimized: result.aggregated = True else: result.nogroup = True return result def nonzero(monad): sqlquery = monad._subselect() sql_ast = [ 'EXISTS', sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] return BoolExprMonad(sql_ast, nullable=False) def negate(monad): sqlquery = monad._subselect() sql_ast = [ 'NOT_EXISTS', sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] return BoolExprMonad(sql_ast, nullable=False) call_is_empty = negate def make_tableref(monad, sqlquery): parent = monad.parent attr = monad.attr if isinstance(parent, ObjectMixin): parent_tableref = parent.tableref elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(sqlquery) else: assert False # pragma: no cover if attr.reverse: name_path = parent_tableref.name_path + '-' + attr.name monad.tableref = sqlquery.get_tableref(name_path) \ or sqlquery.add_tableref(name_path, parent_tableref, attr) else: monad.tableref = parent_tableref monad.tableref.can_affect_distinct = True return monad.tableref def make_expr_list(monad): attr = monad.attr pk_only = attr.reverse or attr.pk_offset is not None alias, columns = monad.tableref.make_join(pk_only) if attr.reverse: pass elif pk_only: offset = attr.pk_columns_offset columns = columns[offset:offset+len(attr.columns)] else: columns = attr.columns return [ [ 'COLUMN', alias, column ] for column in columns ] def _aggregated_scalar_subselect(monad, make_aggr, extra_grouping=False): translator = monad.translator sqlquery = monad._subselect() optimized = False if translator.optimize == monad.tableref.name_path: sql_ast = make_aggr(sqlquery.expr_list) optimized = True if not translator.from_optimized: from_ast = monad.sqlquery.from_ast[1:] assert sqlquery.outer_conditions from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True else: sql_ast = [ 'SELECT', [ 'AGGREGATES', make_aggr(sqlquery.expr_list) ], sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] if extra_grouping: # This is for Oracle only, with COUNT(COUNT(*)) sql_ast.append([ 'GROUP_BY' ] + sqlquery.expr_list) return sql_ast, optimized def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=False): translator = monad.translator sqlquery = monad._subselect() expr_list = sqlquery.expr_list from_ast = sqlquery.from_ast inner_conditions = sqlquery.conditions outer_conditions = sqlquery.outer_conditions groupby_columns = [ inner_column[:] for cond, outer_column, inner_column in outer_conditions ] assert len({alias for _, alias, column in groupby_columns}) == 1 if extra_grouping: inner_alias = translator.sqlquery.make_alias('t') inner_columns = [ 'DISTINCT' ] col_mapping = {} col_names = set() for i, column_ast in enumerate(groupby_columns + expr_list): assert column_ast[0] == 'COLUMN' tname, cname = column_ast[1:] if cname not in col_names: col_mapping[tname, cname] = cname col_names.add(cname) expr = [ 'AS', column_ast, cname ] new_name = cname else: new_name = 'expr-%d' % next(translator.sqlquery.expr_counter) col_mapping[tname, cname] = new_name expr = [ 'AS', column_ast, new_name ] inner_columns.append(expr) if i < len(groupby_columns): groupby_columns[i] = [ 'COLUMN', inner_alias, new_name ] inner_select = [ inner_columns, from_ast ] if inner_conditions: inner_select.append([ 'WHERE' ] + inner_conditions) from_ast = [ 'FROM', [ inner_alias, 'SELECT', inner_select ] ] outer_conditions = outer_conditions[:] for i, (cond, outer_column, inner_column) in enumerate(outer_conditions): assert inner_column[0] == 'COLUMN' tname, cname = inner_column[1:] new_name = col_mapping[tname, cname] outer_conditions[i] = [ cond, outer_column, [ 'COLUMN', inner_alias, new_name ] ] subselect_columns = [ 'ALL' ] for column_ast in groupby_columns: assert column_ast[0] == 'COLUMN' subselect_columns.append([ 'AS', column_ast, column_ast[2] ]) expr_name = 'expr-%d' % next(translator.sqlquery.expr_counter) subselect_columns.append([ 'AS', make_aggr(expr_list), expr_name ]) subquery_ast = [ subselect_columns, from_ast ] if inner_conditions and not extra_grouping: subquery_ast.append([ 'WHERE' ] + inner_conditions) subquery_ast.append([ 'GROUP_BY' ] + groupby_columns) alias = translator.sqlquery.make_alias('t') for cond in outer_conditions: cond[2][1] = alias translator.sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) expr_ast = [ 'COLUMN', alias, expr_name ] if coalesce_to_zero: expr_ast = [ 'COALESCE', expr_ast, [ 'VALUE', 0 ] ] return expr_ast, False def _subselect(monad, sqlquery=None, extract_outer_conditions=True): if monad.sqlquery is not None: return monad.sqlquery attr = monad.attr translator = monad.translator if sqlquery is None: sqlquery = SqlQuery(translator, translator.sqlquery) monad.make_tableref(sqlquery) sqlquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: sqlquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in sqlquery.expr_list) if sqlquery is not translator.sqlquery and extract_outer_conditions: outer_cond = sqlquery.from_ast[1].pop() if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] else: sqlquery.outer_conditions = [ outer_cond ] monad.sqlquery = sqlquery return sqlquery def getsql(monad, sqlquery=None): if sqlquery is None: sqlquery = monad.translator.sqlquery monad.make_tableref(sqlquery) return monad.make_expr_list() __add__ = make_attrset_binop('+', 'ADD') __sub__ = make_attrset_binop('-', 'SUB') __mul__ = make_attrset_binop('*', 'MUL') __truediv__ = make_attrset_binop('/', 'DIV') __floordiv__ = make_attrset_binop('//', 'FLOORDIV') def make_numericset_binop(op, sqlop): def numericset_binop(monad, monad2): return NumericSetExprMonad(op, sqlop, monad, monad2) return numericset_binop class NumericSetExprMonad(SetMixin, Monad): def __init__(monad, op, sqlop, left, right): result_type, left, right = coerce_monads(left, right) assert type(result_type) is SetType if result_type.item_type not in numeric_types: throw(TypeError, _binop_errmsg % (type2str(left.type), type2str(right.type), op)) Monad.__init__(monad, result_type) monad.op = op monad.sqlop = sqlop monad.left = left monad.right = right def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator sqlquery = SqlQuery(translator, translator.sqlquery) expr = monad.getsql(sqlquery)[0] translator.aggregated_subquery_paths.add(monad.tableref.name_path) outer_cond = sqlquery.from_ast[1].pop() if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] else: sqlquery.outer_conditions = [ outer_cond ] if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': result_type = unicode else: result_type = monad.type.item_type aggr_ast = [ func_name, distinct, expr ] if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.nogroup = True else: if not translator.from_optimized: from_ast = sqlquery.from_ast[1:] assert sqlquery.outer_conditions from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.aggregated = True return result def getsql(monad, sqlquery=None): if sqlquery is None: sqlquery = monad.translator.sqlquery left, right = monad.left, monad.right left_expr = left.getsql(sqlquery)[0] right_expr = right.getsql(sqlquery)[0] if isinstance(left, NumericMixin): left_path = '' else: left_path = left.tableref.name_path + '-' if isinstance(right, NumericMixin): right_path = '' else: right_path = right.tableref.name_path + '-' if left_path.startswith(right_path): tableref = left.tableref elif right_path.startswith(left_path): tableref = right.tableref else: throw(TranslationError, 'Cartesian product detected in %s' % ast2src(monad.node)) monad.tableref = tableref return [ [ monad.sqlop, left_expr, right_expr ] ] __add__ = make_numericset_binop('+', 'ADD') __sub__ = make_numericset_binop('-', 'SUB') __mul__ = make_numericset_binop('*', 'MUL') __truediv__ = make_numericset_binop('/', 'DIV') __floordiv__ = make_numericset_binop('//', 'FLOORDIV') class QuerySetMonad(SetMixin, Monad): nogroup = True def __init__(monad, subtranslator): item_type = subtranslator.expr_type monad_type = SetType(item_type) Monad.__init__(monad, monad_type) monad.subtranslator = subtranslator monad.item_type = item_type monad.limit = monad.offset = None def requires_distinct(monad, joined=False): assert False def call_limit(monad, limit=None, offset=None): if limit is not None and not isinstance(limit, int_types): if not isinstance(limit, (NoneMonad, NumericConstMonad)): throw(TypeError, '`limit` parameter should be of int type') limit = limit.value if offset is not None and not isinstance(offset, int_types): if not isinstance(offset, (NoneMonad, NumericConstMonad)): throw(TypeError, '`offset` parameter should be of int type') offset = offset.value monad.limit = limit monad.offset = offset return monad def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') if isinstance(item, ListMonad): item_columns = [] for subitem in item.items: item_columns.extend(subitem.getsql()) else: item_columns = item.getsql() sub = monad.subtranslator if translator.hint_join and len(sub.sqlquery.from_ast[1]) == 3: subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) select_ast, from_ast, where_ast = subquery_ast[1:4] sqlquery = translator.sqlquery if not not_in: translator.distinct = True if sqlquery.from_ast[0] == 'FROM': sqlquery.from_ast[0] = 'INNER_JOIN' else: sqlquery.left_join = True sqlquery.from_ast[0] = 'LEFT_JOIN' col_names = set() new_names = [] exprs = [] for i, column_ast in enumerate(select_ast): if not i: continue # 'ALL' if column_ast[0] == 'COLUMN': tab_name, col_name = column_ast[1:] if col_name not in col_names: col_names.add(col_name) new_names.append(col_name) select_ast[i] = [ 'AS', column_ast, col_name ] continue new_name = 'expr-%d' % next(sqlquery.expr_counter) new_names.append(new_name) select_ast[i] = [ 'AS', column_ast, new_name ] alias = sqlquery.make_alias('t') outer_conditions = [ [ 'EQ', item_column, [ 'COLUMN', alias, new_name ] ] for item_column, new_name in izip(item_columns, new_names) ] sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) if not_in: sql_ast = sqland([ [ 'IS_NULL', [ 'COLUMN', alias, new_name ] ] for new_name in new_names ]) else: sql_ast = [ 'EQ', [ 'VALUE', 1 ], [ 'VALUE', 1 ] ] else: if len(item_columns) == 1: subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', item_columns[0], subquery_ast ] elif translator.row_value_syntax: subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: ambiguous_names = set() if sub.injected: for name in translator.sqlquery.tablerefs: if name in sub.sqlquery.tablerefs: ambiguous_names.add(name) subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) if ambiguous_names: select_ast = subquery_ast[1] expr_aliases = [] for i, expr_ast in enumerate(select_ast): if i > 0: if expr_ast[0] == 'AS': expr_ast = expr_ast[1] expr_alias = 'expr-%d' % i expr_aliases.append(expr_alias) expr_ast = [ 'AS', expr_ast, expr_alias ] select_ast[i] = expr_ast new_table_alias = translator.sqlquery.make_alias('t') new_select_ast = [ 'ALL' ] for expr_alias in expr_aliases: new_select_ast.append([ 'COLUMN', new_table_alias, expr_alias ]) new_from_ast = [ 'FROM', [ new_table_alias, 'SELECT', subquery_ast[1:] ] ] new_where_ast = [ 'WHERE' ] subquery_ast = [ 'SELECT', new_select_ast, new_from_ast, new_where_ast ] select_ast, from_ast, where_ast = subquery_ast[1:4] in_conditions = [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item_columns, select_ast[1:]) ] if not ambiguous_names and sub.aggregated: having_ast = find_or_create_having_ast(subquery_ast) having_ast += in_conditions else: where_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] return BoolExprMonad(sql_ast, nullable=False) def nonzero(monad): subquery_ast = monad.subtranslator.construct_subquery_ast(distinct=False) expr_monads = monad.subtranslator.expr_monads if len(expr_monads) > 1: throw(NotImplementedError) expr_monad = expr_monads[0] if not isinstance(expr_monad, ObjectIterMonad): sql = expr_monad.nonzero().getsql() assert subquery_ast[3][0] == 'WHERE' subquery_ast[3].append(sql[0]) subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] return BoolExprMonad(subquery_ast, nullable=False) def negate(monad): sql = monad.nonzero().sql assert sql[0] == 'EXISTS' return BoolExprMonad([ 'NOT_EXISTS' ] + sql[1:], nullable=False) def count(monad, distinct=None): distinct = distinct_from_monad(distinct) translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = None expr_type = sub.expr_type if isinstance(expr_type, (tuple, EntityMeta)): if not sub.distinct and not distinct: select_ast = [ 'AGGREGATES', [ 'COUNT', None ] ] elif len(sub.expr_columns) == 1: select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'Oracle': sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None, [ 'COUNT', None ] ] ], from_ast, where_ast, [ 'GROUP_BY' ] + sub.expr_columns ] elif translator.row_value_syntax: select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'SQLite': if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) alias, pk_columns = sub.tableref.make_join(pk_only=False) subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: alias = translator.sqlquery.make_alias('t') sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' if distinct is not False else 'ALL' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] else: assert False # pragma: no cover elif len(sub.expr_columns) == 1: select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, sub.expr_columns[0] ] ] else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] return ExprMonad.new(int, sql_ast, nullable=False) len = count def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] expr_type = sub.expr_type if func_name in ('SUM', 'AVG'): if expr_type not in numeric_types: throw(TypeError, "Function %s() expects query or items of numeric type, got %r in {EXPR}" % (func_name.lower(), type2str(expr_type))) elif func_name in ('MIN', 'MAX'): if expr_type not in comparable_types: throw(TypeError, "Function %s() cannot be applied to type %r in {EXPR}" % (func_name.lower(), type2str(expr_type))) elif func_name == 'GROUP_CONCAT': if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': result_type = unicode else: result_type = expr_type return ExprMonad.new(result_type, sql_ast, func_name != 'SUM') def call_count(monad, distinct=None): return monad.count(distinct=distinct) def call_sum(monad, distinct=None): return monad.aggregate('SUM', distinct) def call_min(monad): return monad.aggregate('MIN') def call_max(monad): return monad.aggregate('MAX') def call_avg(monad, distinct=None): return monad.aggregate('AVG', distinct) def call_group_concat(monad, sep=None, distinct=None): if sep is not None: if not isinstance(sep, basestring): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def getsql(monad): return monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) def find_or_create_having_ast(sections): groupby_offset = None for i, section in enumerate(sections): section_name = section[0] if section_name == 'GROUP_BY': groupby_offset = i elif section_name == 'HAVING': return section having_ast = [ 'HAVING' ] sections.insert(groupby_offset + 1, having_ast) return having_ast ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.9593694 pony-0.7.11/pony/orm/tests/0000777000000000000000000000000000000000000013644 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636028.0 pony-0.7.11/pony/orm/tests/__init__.py0000666000000000000000000000015000000000000015751 0ustar0000000000000000import pony.orm.core, pony.options pony.options.CUT_TRACEBACK = False pony.orm.core.sql_debug(False)././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/fixtures.py0000666000000000000000000003510600000000000016074 0ustar0000000000000000import sys import os import logging from pony.py23compat import PY2 from ponytest import with_cli_args, pony_fixtures, provider_validators, provider, Fixture, \ ValidationError from functools import wraps, partial import click from contextlib import contextmanager, closing from pony.utils import cached_property, class_property if not PY2: from contextlib import contextmanager, ContextDecorator else: from contextlib2 import contextmanager, ContextDecorator import unittest from pony.orm import db_session, Database, rollback, delete if not PY2: from io import StringIO else: from StringIO import StringIO from multiprocessing import Process import threading class DBContext(ContextDecorator): fixture = 'db' enabled = False def __init__(self, Test): if not isinstance(Test, type): # FIXME ? TestCls = type(Test) NewClass = type(TestCls.__name__, (TestCls,), {}) NewClass.__module__ = TestCls.__module__ NewClass.db = property(lambda t: self.db) Test.__class__ = NewClass else: Test.db = class_property(lambda cls: self.db) self.Test = Test @class_property def fixture_name(cls): return cls.db_provider @class_property def db_provider(cls): # is used in tests return cls.provider_key def init_db(self): raise NotImplementedError @cached_property def db(self): raise NotImplementedError def __enter__(self): self.init_db() try: self.Test.make_entities() except (AttributeError, TypeError): # No method make_entities with due signature pass else: self.db.generate_mapping(check_tables=True, create_tables=True) return self.db def __exit__(self, *exc_info): self.db.provider.disconnect() @classmethod def validate_fixtures(cls, fixtures, config): return any(f.fixture_key == 'db' for f in fixtures) db_name = 'testdb' @provider() class GenerateMapping(ContextDecorator): weight = 200 fixture = 'generate_mapping' def __init__(self, Test): self.Test = Test def __enter__(self): db = getattr(self.Test, 'db', None) if not db or not db.entities: return for entity in db.entities.values(): if entity._database_.schema is None: db.generate_mapping(check_tables=True, create_tables=True) break def __exit__(self, *exc_info): pass @provider() class MySqlContext(DBContext): provider_key = 'mysql' def drop_db(self, cursor): cursor.execute('use sys') cursor.execute('drop database %s' % self.db_name) def init_db(self): from pony.orm.dbproviders.mysql import mysql_module with closing(mysql_module.connect(**self.CONN).cursor()) as c: try: self.drop_db(c) except mysql_module.DatabaseError as exc: print('Failed to drop db: %s' % exc) c.execute('create database %s' % self.db_name) c.execute('use %s' % self.db_name) CONN = { 'host': "localhost", 'user': "ponytest", 'passwd': "ponytest", } @cached_property def db(self): CONN = dict(self.CONN, db=self.db_name) return Database('mysql', **CONN) @provider() class SqlServerContext(DBContext): provider_key = 'sqlserver' def get_conn_string(self, db=None): s = ( 'DSN=MSSQLdb;' 'SERVER=mssql;' 'UID=sa;' 'PWD=pass;' ) if db: s += 'DATABASE=%s' % db return s @cached_property def db(self): CONN = self.get_conn_string(self.db_name) return Database('mssqlserver', CONN) def init_db(self): import pyodbc cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor() with closing(cursor) as c: try: self.drop_db(c) except pyodbc.DatabaseError as exc: print('Failed to drop db: %s' % exc) c.execute('''CREATE DATABASE %s DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci''' % self.db_name ) c.execute('use %s' % self.db_name) def drop_db(self, cursor): cursor.execute('use master') cursor.execute('drop database %s' % self.db_name) class SqliteMixin(DBContext): def init_db(self): try: os.remove(self.db_path) except OSError as exc: print('Failed to drop db: %s' % exc) @cached_property def db_path(self): p = os.path.dirname(__file__) p = os.path.join(p, '%s.sqlite' % self.db_name) return os.path.abspath(p) @cached_property def db(self): return Database('sqlite', self.db_path, create_db=True) @provider() class SqliteNoJson1(SqliteMixin): provider_key = 'sqlite_no_json1' enabled = True def __init__(self, cls): self.Test = cls cls.no_json1 = True return super(SqliteNoJson1, self).__init__(cls) def __enter__(self): resource = super(SqliteNoJson1, self).__enter__() self.json1_available = self.Test.db.provider.json1_available self.Test.db.provider.json1_available = False return resource def __exit__(self, *exc_info): self.Test.db.provider.json1_available = self.json1_available return super(SqliteNoJson1, self).__exit__(*exc_info) @provider() class SqliteJson1(SqliteMixin): provider_key = 'sqlite_json1' def __enter__(self): result = super(SqliteJson1, self).__enter__() if not self.db.provider.json1_available: raise unittest.SkipTest return result @provider() class PostgresContext(DBContext): provider_key = 'postgresql' def get_conn_dict(self, no_db=False): d = dict( user='ponytest', password='ponytest', host='localhost', database='postgres', ) if not no_db: d.update(database=self.db_name) return d def init_db(self): import psycopg2 conn = psycopg2.connect( **self.get_conn_dict(no_db=True) ) conn.set_isolation_level(0) with closing(conn.cursor()) as cursor: try: self.drop_db(cursor) except psycopg2.DatabaseError as exc: print('Failed to drop db: %s' % exc) cursor.execute('create database %s' % self.db_name) def drop_db(self, cursor): cursor.execute('drop database %s' % self.db_name) @cached_property def db(self): return Database('postgres', **self.get_conn_dict()) @provider() class OracleContext(DBContext): provider_key = 'oracle' def __enter__(self): os.environ.update(dict( ORACLE_BASE='/u01/app/oracle', ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1', ORACLE_OWNR='oracle', ORACLE_SID='orcl', )) return super(OracleContext, self).__enter__() def init_db(self): import cx_Oracle with closing(self.connect_sys()) as conn: with closing(conn.cursor()) as cursor: try: self._destroy_test_user(cursor) except cx_Oracle.DatabaseError as exc: print('Failed to drop user: %s' % exc) try: self._drop_tablespace(cursor) except cx_Oracle.DatabaseError as exc: print('Failed to drop db: %s' % exc) cursor.execute( """CREATE TABLESPACE %(tblspace)s DATAFILE '%(datafile)s' SIZE 20M REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s """ % self.parameters) cursor.execute( """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s TEMPFILE '%(datafile_tmp)s' SIZE 20M REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s """ % self.parameters) self._create_test_user(cursor) def _drop_tablespace(self, cursor): cursor.execute( 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' % self.parameters) cursor.execute( 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' % self.parameters) parameters = { 'tblspace': 'test_tblspace', 'tblspace_temp': 'test_tblspace_temp', 'datafile': 'test_datafile.dbf', 'datafile_tmp': 'test_datafile_tmp.dbf', 'user': 'ponytest', 'password': 'ponytest', 'maxsize': '100M', 'maxsize_tmp': '100M', } def connect_sys(self): import cx_Oracle return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA) def connect_test(self): import cx_Oracle return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL') @cached_property def db(self): return Database('oracle', 'ponytest/ponytest@localhost/ORCL') def _create_test_user(self, cursor): cursor.execute( """CREATE USER %(user)s IDENTIFIED BY %(password)s DEFAULT TABLESPACE %(tblspace)s TEMPORARY TABLESPACE %(tblspace_temp)s QUOTA UNLIMITED ON %(tblspace)s """ % self.parameters ) cursor.execute( """GRANT CREATE SESSION, CREATE TABLE, CREATE SEQUENCE, CREATE PROCEDURE, CREATE TRIGGER TO %(user)s """ % self.parameters ) def _destroy_test_user(self, cursor): cursor.execute(''' DROP USER %(user)s CASCADE ''' % self.parameters) @provider(fixture='log', weight=100, enabled=False) @contextmanager def logging_context(test): level = logging.getLogger().level from pony.orm.core import debug, sql_debug logging.getLogger().setLevel(logging.INFO) sql_debug(True) yield logging.getLogger().setLevel(level) sql_debug(debug) @provider(fixture='log_all', weight=-100, enabled=False) def log_all(Test): return logging_context(Test) # @with_cli_args # @click.option('--log', 'scope', flag_value='test') # @click.option('--log-all', 'scope', flag_value='all') # def use_logging(scope): # if scope == 'test': # yield logging_context # elif scope =='all': # yield log_all # @provider(enabled=False) # class DBSessionProvider(object): # # fixture= 'db_session' # # weight = 30 # # def __new__(cls, test): # return db_session @provider(fixture='rollback', weight=40) @contextmanager def do_rollback(test): try: yield finally: rollback() @provider() class SeparateProcess(object): # TODO read failures from sep process better fixture = 'separate_process' enabled = False def __init__(self, Test): self.Test = Test def __call__(self, func): def wrapper(Test): rnr = unittest.runner.TextTestRunner() TestCls = Test if isinstance(Test, type) else type(Test) def runTest(self): try: func(Test) finally: rnr.stream = unittest.runner._WritelnDecorator(StringIO()) name = getattr(func, '__name__', 'runTest') Case = type(TestCls.__name__, (TestCls,), {name: runTest}) Case.__module__ = TestCls.__module__ case = Case(name) suite = unittest.suite.TestSuite([case]) def run(): result = rnr.run(suite) if not result.wasSuccessful(): sys.exit(1) p = Process(target=run, args=()) p.start() p.join() case.assertEqual(p.exitcode, 0) return wrapper @classmethod def validate_chain(cls, fixtures, klass): for f in fixtures: if f.KEY in ('ipdb', 'ipdb_all'): return False for f in fixtures: if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): return True @provider() class ClearTables(ContextDecorator): fixture = 'clear_tables' def __init__(self, test): self.test = test def __enter__(self): pass @db_session def __exit__(self, *exc_info): db = self.test.db for entity in db.entities.values(): if entity._database_.schema is None: break delete(i for i in entity) import signal @provider() class Timeout(object): fixture = 'timeout' @with_cli_args @click.option('--timeout', type=int) def __init__(self, Test, timeout): self.Test = Test self.timeout = timeout if timeout else Test.TIMEOUT enabled = False class Exception(Exception): pass class FailedInSubprocess(Exception): pass def __call__(self, func): def wrapper(test): p = Process(target=func, args=(test,)) p.start() def on_expired(): p.terminate() t = threading.Timer(self.timeout, on_expired) t.start() p.join() t.cancel() if p.exitcode == -signal.SIGTERM: raise self.Exception elif p.exitcode: raise self.FailedInSubprocess return wrapper @classmethod @with_cli_args @click.option('--timeout', type=int) def validate_chain(cls, fixtures, klass, timeout): if not getattr(klass, 'TIMEOUT', None) and not timeout: return False for f in fixtures: if f.KEY in ('ipdb', 'ipdb_all'): return False for f in fixtures: if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): return True pony_fixtures['test'].extend([ 'log', 'clear_tables', ]) pony_fixtures['class'].extend([ 'separate_process', 'timeout', 'db', 'log_all', 'generate_mapping', ]) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/model1.py0000666000000000000000000000362600000000000015406 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.orm.core import * db = Database('sqlite', ':memory:') class Student(db.Entity): _table_ = "Students" record = PrimaryKey(int) name = Required(unicode, column="fio") group = Required("Group") scholarship = Required(int, default=0) marks = Set("Mark") class Group(db.Entity): _table_ = "Groups" number = PrimaryKey(str) department = Required(int) students = Set("Student") subjects = Set("Subject") class Subject(db.Entity): _table_ = "Subjects" name = PrimaryKey(unicode) groups = Set("Group") marks = Set("Mark") class Mark(db.Entity): _table_ = "Exams" student = Required(Student, column="student") subject = Required(Subject, column="subject") value = Required(int) PrimaryKey(student, subject) db.generate_mapping(create_tables=True) @db_session def populate_db(): Physics = Subject(name='Physics') Chemistry = Subject(name='Chemistry') Math = Subject(name='Math') g3132 = Group(number='3132', department=33, subjects=[ Physics, Math ]) g4145 = Group(number='4145', department=44, subjects=[ Physics, Chemistry, Math ]) g4146 = Group(number='4146', department=44) s101 = Student(record=101, name='Bob', group=g4145, scholarship=0) s102 = Student(record=102, name='Joe', group=g4145, scholarship=800) s103 = Student(record=103, name='Alex', group=g4145, scholarship=0) s104 = Student(record=104, name='Brad', group=g3132, scholarship=500) s105 = Student(record=105, name='John', group=g3132, scholarship=1000) Mark(student=s101, subject=Physics, value=4) Mark(student=s101, subject=Math, value=3) Mark(student=s102, subject=Chemistry, value=5) Mark(student=s103, subject=Physics, value=2) Mark(student=s103, subject=Chemistry, value=4) populate_db() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/py36_test_f_strings.py0000666000000000000000000000457500000000000020147 0ustar0000000000000000import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Person(db.Entity): first_name = Required(str) last_name = Required(str) age = Optional(int) value = Required(float) db.generate_mapping(create_tables=True) with db_session: Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) class TestFString(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_1(self): x = 'Alexander' y = 'Tischenko' q = select(p.id for p in Person if p.first_name + ' ' + p.last_name == f'{x} {y}') self.assertEqual(set(q), {1}) def test_2(self): q = select(p.id for p in Person if f'{p.first_name} {p.last_name}' == 'Alexander Tischenko') self.assertEqual(set(q), {1}) def test_3(self): x = 'Great' q = select(f'{p.first_name} the {x}' for p in Person if p.id == 1) self.assertEqual(set(q), {'Alexander the Great'}) def test_4(self): q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) self.assertEqual(set(q), {'Alexander 23'}) def test_5(self): q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) self.assertEqual(set(q), {'Alexander 23'}) @raises_exception(NotImplementedError, 'You cannot set width and precision markers in query') def test_6(self): width = 3 precision = 4 q = select(p.id for p in Person if f'{p.value:{width}.{precision}}')[:] self.assertEqual({2,}, set(q)) def test_7(self): x = 'Tischenko' q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) self.assertEqual(set(q), {'Alexander Tischenko'}) def test_8(self): q = select(p for p in Person if not p.age)[:] ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/queries.txt0000666000000000000000000007626100000000000016076 0ustar0000000000000000# Testing queries from documentation (CRUD): Schema: pony.orm.examples.estore >>> Product.get(id=1) SELECT "id", "name", "description", "picture", "price", "quantity" FROM "Product" WHERE "id" = ? >>> Product.get(name="Product1") SELECT "id", "name", "description", "picture", "price", "quantity" FROM "Product" WHERE "name" = ? LIMIT 2 >>> Product.select(lambda p: p.price > 100) SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > 100 >>> x = 100 >>> Product.select(lambda p: p.price > x) SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > ? >>> Product.select(lambda p: p.price > 100).order_by(lambda p: desc(p.price)) SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > 100 ORDER BY "p"."price" DESC >>> Product.select(lambda p: p.price > 100).order_by(desc(Product.price)) SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > 100 ORDER BY "p"."price" DESC >>> Product.select(lambda p: p.price > 100).order_by(desc(Product.price), Product.name) SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > 100 ORDER BY "p"."price" DESC, "p"."name" >>> Product.select(lambda p: p.price > 100).order_by("desc(p.price), p.name") SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" WHERE "p"."price" > 100 ORDER BY "p"."price" DESC, "p"."name" >>> Product.select().order_by(lambda p: desc(p.price))[:10] SELECT "p"."id", "p"."name", "p"."description", "p"."picture", "p"."price", "p"."quantity" FROM "Product" "p" ORDER BY "p"."price" DESC LIMIT 10 # Testing declarative queries from documentation: # TODO # Testing aggregation queries from documentation: Schema: pony.orm.examples.university1 >>> sum(s.gpa for s in Student if s.group.number == 101) SELECT coalesce(SUM("s"."gpa"), 0) FROM "Student" "s" WHERE "s"."group" = 101 >>> count(s for s in Student if s.gpa > 3) SELECT COUNT(*) FROM "Student" "s" WHERE "s"."gpa" > 3 >>> min(s.name for s in Student if "Philosophy" in s.courses.name) SELECT MIN("s"."name") FROM "Student" "s" WHERE 'Philosophy' IN ( SELECT "t-1"."course_name" FROM "Course_Student" "t-1" WHERE "s"."id" = "t-1"."student" ) >>> max(s.dob for s in Student if s.group.number == 101) SELECT MAX("s"."dob") FROM "Student" "s" WHERE "s"."group" = 101 >>> avg(s.gpa for s in Student if s.group.dept.number == 44) SELECT AVG("s"."gpa") FROM "Student" "s", "Group" "group" WHERE "group"."dept" = 44 AND "s"."group" = "group"."number" >>> select(s for s in Student if s.group.number == 101 and s.dob == max(s.dob for s in Student if s.group.number == 101)) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."group" = 101 AND "s"."dob" = ( SELECT MAX("s-2"."dob") FROM "Student" "s-2" WHERE "s-2"."group" = 101 ) >>> select(g for g in Group if avg(s.gpa for s in g.students) > 4.5) SELECT "g"."number", "g"."major", "g"."dept" FROM "Group" "g" WHERE ( SELECT AVG("s"."gpa") FROM "Student" "s" WHERE "g"."number" = "s"."group" ) > 4.5 >>> select(g for g in Group if avg(g.students.gpa) > 4.5) SELECT "g"."number" FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" GROUP BY "g"."number" HAVING AVG("student"."gpa") > 4.5 >>> select((s.group, min(s.gpa), max(s.gpa)) for s in Student) SELECT "s"."group", MIN("s"."gpa"), MAX("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" >>> select((g, min(g.students.gpa), max(g.students.gpa)) for g in Group) SELECT "g"."number", MIN("student"."gpa"), MAX("student"."gpa") FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" GROUP BY "g"."number" >>> select((g, g.students.name, min(g.students.gpa), max(g.students.gpa)) for g in Group) SELECT "g"."number", "student"."name", MIN("student"."gpa"), MAX("student"."gpa") FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" GROUP BY "g"."number", "student"."name" >>> count(s for s in Student if s.group.number == 101) SELECT COUNT(*) FROM "Student" "s" WHERE "s"."group" = 101 >>> select((g, count(g.students)) for g in Group if g.dept.number == 44) SELECT "g"."number", COUNT(DISTINCT "student"."id") FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" WHERE "g"."dept" = 44 GROUP BY "g"."number" >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44) SELECT "s"."group", COUNT(DISTINCT "s"."id") FROM "Student" "s", "Group" "group" WHERE "group"."dept" = 44 AND "s"."group" = "group"."number" GROUP BY "s"."group" >>> select((g, count(s for s in g.students if s.gpa <= 3), count(s for s in g.students if s.gpa > 3 and s.gpa <= 4), count(s for s in g.students if s.gpa > 4)) for g in Group) SELECT "g"."number", ( SELECT COUNT(DISTINCT "s"."id") FROM "Student" "s" WHERE "g"."number" = "s"."group" AND "s"."gpa" <= 3 ), ( SELECT COUNT(DISTINCT "s"."id") FROM "Student" "s" WHERE "g"."number" = "s"."group" AND "s"."gpa" > 3 AND "s"."gpa" <= 4 ), ( SELECT COUNT(DISTINCT "s"."id") FROM "Student" "s" WHERE "g"."number" = "s"."group" AND "s"."gpa" > 4 ) FROM "Group" "g" >>> select((s.group, count(s.gpa <= 3), count(s.gpa > 3 and s.gpa <= 4), count(s.gpa > 4)) for s in Student) SELECT "s"."group", COUNT(case when "s"."gpa" <= 3 then 1 else null end), COUNT(case when "s"."gpa" > 3 AND "s"."gpa" <= 4 then 1 else null end), COUNT(case when "s"."gpa" > 4 then 1 else null end) FROM "Student" "s" GROUP BY "s"."group" >>> left_join((g, count(s.gpa <= 3), count(s.gpa > 3 and s.gpa <= 4), count(s.gpa > 4)) for g in Group for s in g.students) SELECT "g"."number", COUNT(case when "s"."gpa" <= 3 then 1 else null end), COUNT(case when "s"."gpa" > 3 AND "s"."gpa" <= 4 then 1 else null end), COUNT(case when "s"."gpa" > 4 then 1 else null end) FROM "Group" "g" LEFT JOIN "Student" "s" ON "g"."number" = "s"."group" GROUP BY "g"."number" >>> select((s.dob.year, avg(s.gpa)) for s in Student) SELECT cast(substr("s"."dob", 1, 4) as integer), AVG("s"."gpa") FROM "Student" "s" GROUP BY cast(substr("s"."dob", 1, 4) as integer) Schema: pony.orm.examples.estore >>> select((item.order, sum(item.price * item.quantity)) for item in OrderItem if item.order.id == 123) SELECT "item"."order", coalesce(SUM(("item"."price" * "item"."quantity")), 0) FROM "OrderItem" "item" WHERE "item"."order" = 123 GROUP BY "item"."order" >>> select((order, sum(order.items.price * order.items.quantity)) for order in Order if order.id == 123) SELECT "order"."id", coalesce(SUM(("orderitem"."price" * "orderitem"."quantity")), 0) FROM "Order" "order" LEFT JOIN "OrderItem" "orderitem" ON "order"."id" = "orderitem"."order" WHERE "order"."id" = 123 GROUP BY "order"."id" >>> select((item.order, item.order.total_price, sum(item.price * item.quantity)) for item in OrderItem if item.order.total_price < sum(item.price * item.quantity)) SELECT "item"."order", "order"."total_price", coalesce(SUM(("item"."price" * "item"."quantity")), 0) FROM "OrderItem" "item", "Order" "order" WHERE "item"."order" = "order"."id" GROUP BY "item"."order", "order"."total_price" HAVING "order"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity")), 0) >>> select(c for c in Customer for p in c.orders.items.product if 'Tablets' in p.categories.name and count(p) > 1) SELECT "c"."id" FROM "Customer" "c", "Order" "order", "OrderItem" "orderitem" WHERE 'Tablets' IN ( SELECT "category"."name" FROM "Category_Product" "t-1", "Category" "category" WHERE "orderitem"."product" = "t-1"."product" AND "t-1"."category" = "category"."id" ) AND "c"."id" = "order"."customer" AND "order"."id" = "orderitem"."order" GROUP BY "c"."id" HAVING COUNT(DISTINCT "orderitem"."product") > 1 Schema: pony.orm.examples.university1 >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44 and avg(s.gpa) > 4) SELECT "s"."group", COUNT(DISTINCT "s"."id") FROM "Student" "s", "Group" "group" WHERE "group"."dept" = 44 AND "s"."group" = "group"."number" GROUP BY "s"."group" HAVING AVG("s"."gpa") > 4 >>> select(g for g in Group if max(g.students.gpa) < 4) SELECT "g"."number" FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" GROUP BY "g"."number" HAVING MAX("student"."gpa") < 4 >>> select(g for g in Group if JOIN(max(g.students.gpa) < 4)) SELECT "g"."number" FROM "Group" "g" LEFT JOIN ( SELECT "student"."group" AS "group", MAX("student"."gpa") AS "expr-1" FROM "Student" "student" GROUP BY "student"."group" ) "t-1" ON "g"."number" = "t-1"."group" WHERE "t-1"."expr-1" < 4 GROUP BY "g"."number" >>> select(s.group for s in Student if max(s.gpa) < 4) SELECT "s"."group" FROM "Student" "s" GROUP BY "s"."group" HAVING MAX("s"."gpa") < 4 >>> select((s.group, avg(s.gpa)) for s in Student).order_by(lambda: avg(s.gpa)) SELECT "s"."group", AVG("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" ORDER BY AVG("s"."gpa") >>> select((s.group, avg(s.gpa)) for s in Student).order_by(lambda: desc(avg(s.gpa))) SELECT "s"."group", AVG("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" ORDER BY AVG("s"."gpa") DESC >>> select((s.group, avg(s.gpa)) for s in Student).order_by(2) SELECT "s"."group", AVG("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" ORDER BY 2 >>> select((s.group, avg(s.gpa)) for s in Student).order_by(-2) SELECT "s"."group", AVG("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" ORDER BY 2 DESC >>> select(sum(s.gpa) for s in Student) SELECT coalesce(SUM("s"."gpa"), 0) FROM "Student" "s" >>> select(s.gpa for s in Student).sum() SELECT coalesce(SUM("s"."gpa"), 0) FROM "Student" "s" # Other tests Schema: pony.orm.examples.university1 >>> select(s for s in Student if s.gpa > 3) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."gpa" > 3 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."GPA" > 3 >>> select(s for s in Student if s.group.number == 1) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."group" = 1 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."GROUP" = 1 >>> select(c for c in Course if count(c.students) > 3) SELECT "c"."name", "c"."semester" FROM "Course" "c" LEFT JOIN "Course_Student" "t-1" ON "c"."name" = "t-1"."course_name" AND "c"."semester" = "t-1"."course_semester" GROUP BY "c"."name", "c"."semester" HAVING COUNT(DISTINCT "t-1"."student") > 3 Oracle: SELECT "c"."NAME", "c"."SEMESTER" FROM "COURSE" "c" LEFT JOIN "COURSE_STUDENT" "t-1" ON "c"."NAME" = "t-1"."COURSE_NAME" AND "c"."SEMESTER" = "t-1"."COURSE_SEMESTER" GROUP BY "c"."NAME", "c"."SEMESTER" HAVING COUNT(DISTINCT "t-1"."STUDENT") > 3 >>> select(s for s in Student if count(s.courses) > 3) SELECT "s"."id" FROM "Student" "s" LEFT JOIN "Course_Student" "t-1" ON "s"."id" = "t-1"."student" GROUP BY "s"."id" HAVING COUNT("t-1"."ROWID") > 3 Oracle: SELECT "s"."ID" FROM "STUDENT" "s" LEFT JOIN "COURSE_STUDENT" "t-1" ON "s"."ID" = "t-1"."STUDENT" GROUP BY "s"."ID" HAVING COUNT("t-1"."ROWID") > 3 >>> select(s for s in Student).for_update()[:3] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" LIMIT 3 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."ROWID" IN ( SELECT * FROM ( SELECT "s"."ROWID" AS "row-id" FROM "STUDENT" "s" ) WHERE ROWNUM <= 3 ) FOR UPDATE >>> select(s for s in Student).order_by(Student.name).for_update()[:3] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" ORDER BY "s"."name" LIMIT 3 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."ROWID" IN ( SELECT * FROM ( SELECT "s"."ROWID" AS "row-id" FROM "STUDENT" "s" ORDER BY "s"."NAME" ) WHERE ROWNUM <= 3 ) ORDER BY "s"."NAME" FOR UPDATE >>> select(s for s in Student).order_by(Student.name).for_update()[3:6] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" ORDER BY "s"."name" LIMIT 3 OFFSET 3 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."ROWID" IN ( SELECT t."row-id" FROM ( SELECT t.*, ROWNUM "row-num" FROM ( SELECT "s"."ROWID" AS "row-id" FROM "STUDENT" "s" ORDER BY "s"."NAME" ) t WHERE ROWNUM <= 6 ) t WHERE "row-num" > 3 ) ORDER BY "s"."NAME" FOR UPDATE >>> select(s for s in Student).for_update()[3:6] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" LIMIT 3 OFFSET 3 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."ROWID" IN ( SELECT t."row-id" FROM ( SELECT t.*, ROWNUM "row-num" FROM ( SELECT "s"."ROWID" AS "row-id" FROM "STUDENT" "s" ) t WHERE ROWNUM <= 6 ) t WHERE "row-num" > 3 ) FOR UPDATE >>> select((g, count(s)) for g in Group for s in g.students) SELECT "g"."number", COUNT(DISTINCT "s"."id") FROM "Group" "g", "Student" "s" WHERE "g"."number" = "s"."group" GROUP BY "g"."number" Oracle: SELECT "g"."NUMBER", COUNT(DISTINCT "s"."ID") FROM "GROUP" "g", "STUDENT" "s" WHERE "g"."NUMBER" = "s"."GROUP" GROUP BY "g"."NUMBER" PostgreSQL: SELECT "g"."number", COUNT(DISTINCT "s"."id") FROM "group" "g", "student" "s" WHERE "g"."number" = "s"."group" GROUP BY "g"."number" >>> select((s, count(c)) for s in Student for c in s.courses) SELECT "s"."id", COUNT(DISTINCT "c"."ROWID") FROM "Student" "s", "Course_Student" "t-1", "Course" "c" WHERE "s"."id" = "t-1"."student" AND "t-1"."course_name" = "c"."name" AND "t-1"."course_semester" = "c"."semester" GROUP BY "s"."id" Oracle: SELECT "s"."ID", COUNT(DISTINCT "c"."ROWID") FROM "STUDENT" "s", "COURSE_STUDENT" "t-1", "COURSE" "c" WHERE "s"."ID" = "t-1"."STUDENT" AND "t-1"."COURSE_NAME" = "c"."NAME" AND "t-1"."COURSE_SEMESTER" = "c"."SEMESTER" GROUP BY "s"."ID" PostgreSQL: SELECT "s"."id", COUNT(DISTINCT case when ("t-1"."course_name", "t-1"."course_semester") IS NULL then null else ("t-1"."course_name", "t-1"."course_semester") end) FROM "student" "s", "course_student" "t-1" WHERE "s"."id" = "t-1"."student" GROUP BY "s"."id" >>> x = '123' >>> select(s for s in Student if s.tel == x)[:] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."tel" = ? Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."TEL" = :p1 >>> x = '' >>> select(s for s in Student if s.tel == x)[:] Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."TEL" IS NULL >>> x = '' >>> select(s for s in Student).filter(lambda s: s.tel == x)[:] Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE "s"."TEL" IS NULL >>> select(s.name for s in Student if s.name in (s.name for s in Student if count() > 1))[:] SELECT DISTINCT "s"."name" FROM "Student" "s" WHERE "s"."name" IN ( SELECT "s-2"."name" FROM "Student" "s-2" GROUP BY "s-2"."name" HAVING COUNT(*) > 1 ) >>> empty_list = [] >>> select(s.name for s in Student if s.id in empty_list) SELECT DISTINCT "s"."name" FROM "Student" "s" WHERE 0 = 1 >>> empty_list = [] >>> select(s.name for s in Student if s.id not in empty_list) SELECT DISTINCT "s"."name" FROM "Student" "s" WHERE 1 = 1 >>> select(s.name for s in Student)[:] SELECT DISTINCT "s"."name" FROM "Student" "s" >>> select(s.name for s in Student).without_distinct()[:] SELECT "s"."name" FROM "Student" "s" >>> select(s.name for s in Student).without_distinct().distinct()[:] SELECT DISTINCT "s"."name" FROM "Student" "s" >>> select(s.name for s in Student).without_distinct().distinct().without_distinct()[:] SELECT "s"."name" FROM "Student" "s" >>> select(s.name for s in Student).first() SELECT "s"."name" FROM "Student" "s" ORDER BY 1 LIMIT 1 >>> select(s.name for s in Student).without_distinct().first() SELECT "s"."name" FROM "Student" "s" ORDER BY 1 LIMIT 1 >>> select(s.name for s in Student)[4:5] SELECT DISTINCT "s"."name" FROM "Student" "s" LIMIT 1 OFFSET 4 >>> select(s for s in Student if (s.group, s.gpa) in select((s2.group, max(s2.gpa)) for s2 in Student)) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE EXISTS ( SELECT 1 FROM "Student" "s2" GROUP BY "s2"."group" HAVING "s"."group" = "s2"."group" AND "s"."gpa" = MAX("s2"."gpa") ) PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."group", "s"."gpa") IN ( SELECT "s2"."group", MAX("s2"."gpa") FROM "student" "s2" GROUP BY "s2"."group" ) >>> select(s for s in Student if (s.group, s.gpa) not in select((s2.group, max(s2.gpa)) for s2 in Student)) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE NOT EXISTS ( SELECT 1 FROM "Student" "s2" GROUP BY "s2"."group" HAVING "s"."group" = "s2"."group" AND "s"."gpa" = MAX("s2"."gpa") ) PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."group", "s"."gpa") NOT IN ( SELECT "s2"."group", MAX("s2"."gpa") FROM "student" "s2" GROUP BY "s2"."group" HAVING MAX("s2"."gpa") IS NOT NULL ) # Complex aggregations: Schema: pony.orm.examples.university1 >>> count(g for g in Group if count(g.students) > 0 and g.number > 101) SELECT COUNT(*) FROM ( SELECT "g"."number" FROM "Group" "g" LEFT JOIN "Student" "student" ON "g"."number" = "student"."group" WHERE "g"."number" > 101 GROUP BY "g"."number" HAVING COUNT(DISTINCT "student"."id") > 0 ) "t" >>> count(g for g in Group if count(s for s in g.students) > 0 and g.number > 101) SELECT COUNT(*) FROM "Group" "g" WHERE ( SELECT COUNT(DISTINCT "s"."id") FROM "Student" "s" WHERE "g"."number" = "s"."group" ) > 0 AND "g"."number" > 101 >>> sum(s.gpa for s in Student if count(s.courses) > 3 and s.gpa > 3.5) SELECT coalesce(SUM("s"."gpa"), 0) FROM ( SELECT "s"."gpa" FROM "Student" "s" LEFT JOIN "Course_Student" "t-1" ON "s"."id" = "t-1"."student" WHERE "s"."gpa" > 3.5 GROUP BY "s"."gpa" HAVING COUNT("t-1"."ROWID") > 3 ) "s" >>> sum(s.dob.year for s in Student if count(s.courses) > 3 and s.gpa > 3.5) SELECT coalesce(SUM("t"."expr"), 0) FROM ( SELECT cast(substr("s"."dob", 1, 4) as integer) AS "expr" FROM "Student" "s" LEFT JOIN "Course_Student" "t-1" ON "s"."id" = "t-1"."student" WHERE "s"."gpa" > 3.5 GROUP BY cast(substr("s"."dob", 1, 4) as integer) HAVING COUNT("t-1"."ROWID") > 3 ) "t" # Bulk delete: >>> select(s for s in Student if s.gpa > 3).delete(bulk=True) DELETE FROM "Student" WHERE "gpa" > 3 >>> select(s for s in Student if s.group.dept.number == 1).delete(bulk=True) DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" FROM "Student" "s", "Group" "group" WHERE "group"."dept" = 1 AND "s"."group" = "group"."number" ) MySQL: DELETE s FROM `student` `s` INNER JOIN `group` `group` ON `s`.`group` = `group`.`number` WHERE `group`.`dept` = 1 PostgreSQL: DELETE FROM "student" WHERE "id" IN ( SELECT "s"."id" FROM "student" "s", "group" "group" WHERE "group"."dept" = 1 AND "s"."group" = "group"."number" ) Oracle: DELETE FROM "STUDENT" WHERE "ID" IN ( SELECT "s"."ID" FROM "STUDENT" "s", "GROUP" "group" WHERE "group"."DEPT" = 1 AND "s"."GROUP" = "group"."NUMBER" ) >>> select(c for c in Course if c.dept.name.startswith('D')).delete(bulk=True) DELETE FROM "Course" WHERE "ROWID" IN ( SELECT "c"."ROWID" FROM "Course" "c", "Department" "department" WHERE "department"."name" LIKE 'D%' AND "c"."dept" = "department"."number" ) MySQL: DELETE c FROM `course` `c` INNER JOIN `department` `department` ON `c`.`dept` = `department`.`number` WHERE `department`.`name` LIKE 'D%%' PostgreSQL: DELETE FROM "course" WHERE ("name", "semester") IN ( SELECT "c"."name", "c"."semester" FROM "course" "c", "department" "department" WHERE "department"."name" LIKE 'D%%' AND "c"."dept" = "department"."number" ) Oracle: DELETE FROM "COURSE" WHERE "ROWID" IN ( SELECT "c"."ROWID" FROM "COURSE" "c", "DEPARTMENT" "department" WHERE "department"."NAME" LIKE 'D%' AND "c"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if s.gpa > 3 and s not in (s2 for s2 in Student if s2.group.dept.name.startswith('A'))).delete(bulk=True) DELETE FROM "Student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" FROM "Student" "s2", "Group" "group", "Department" "department" WHERE "department"."name" LIKE 'A%' AND "s2"."group" = "group"."number" AND "group"."dept" = "department"."number" ) # MySQL does not support such queries PostgreSQL: DELETE FROM "student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" FROM "student" "s2", "group" "group", "department" "department" WHERE "department"."name" LIKE 'A%%' AND "s2"."group" = "group"."number" AND "group"."dept" = "department"."number" ) Oracle: DELETE FROM "STUDENT" WHERE "GPA" > 3 AND "ID" NOT IN ( SELECT "s2"."ID" FROM "STUDENT" "s2", "GROUP" "group", "DEPARTMENT" "department" WHERE "department"."NAME" LIKE 'A%' AND "s2"."GROUP" = "group"."NUMBER" AND "group"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if exists(s2 for s2 in Student if s.gpa > s2.gpa)).delete(bulk=True) DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" FROM "Student" "s" WHERE EXISTS ( SELECT 1 FROM "Student" "s2" WHERE "s"."gpa" > "s2"."gpa" ) ) PostgreSQL: DELETE FROM "student" WHERE "id" IN ( SELECT "s"."id" FROM "student" "s" WHERE EXISTS ( SELECT 1 FROM "student" "s2" WHERE "s"."gpa" > "s2"."gpa" ) ) Oracle: DELETE FROM "STUDENT" WHERE "ID" IN ( SELECT "s"."ID" FROM "STUDENT" "s" WHERE EXISTS ( SELECT 1 FROM "STUDENT" "s2" WHERE "s"."GPA" > "s2"."GPA" ) ) >>> select(s for s in Student if count(g for g in s.group.dept.groups) > 2).delete(bulk=True) DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" FROM "Student" "s" WHERE ( SELECT COUNT(DISTINCT "g"."number") FROM "Group" "group", "Group" "g" WHERE "s"."group" = "group"."number" AND "group"."dept" = "g"."dept" ) > 2 ) >>> Student.select(lambda s: count(s.group.students) == 2).delete(bulk=True) DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" FROM "Student" "s" LEFT JOIN "Student" "student" ON "s"."group" = "student"."group" ) # Test UPPER/LOWER functions: >>> select(s.name.upper() for s in Student) SELECT DISTINCT py_upper("s"."name") FROM "Student" "s" PostgreSQL: SELECT DISTINCT upper("s"."name") FROM "student" "s" # Test modulo division operator >>> select(s for s in Student if s.id % 2 == 0) SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE ("s"."id" % 2) = 0 PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."id" %% 2) = 0 MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` WHERE (`s`.`id` %% 2) = 0 Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE MOD("s"."ID", 2) = 0 # Test group_concat: >>> select((g, group_concat(s.name, '+')) for g in Group for s in g.students) SELECT "g"."number", GROUP_CONCAT("s"."name", '+') FROM "Group" "g", "Student" "s" WHERE "g"."number" = "s"."group" GROUP BY "g"."number" PostgreSQL: SELECT "g"."number", string_agg("s"."name"::text, '+') FROM "group" "g", "student" "s" WHERE "g"."number" = "s"."group" GROUP BY "g"."number" MySQL: SELECT `g`.`number`, GROUP_CONCAT(`s`.`name` SEPARATOR '+') FROM `group` `g`, `student` `s` WHERE `g`.`number` = `s`.`group` GROUP BY `g`.`number` Oracle: SELECT "g"."NUMBER", LISTAGG("s"."NAME", '+') WITHIN GROUP(ORDER BY 1) FROM "GROUP" "g", "STUDENT" "s" WHERE "g"."NUMBER" = "s"."GROUP" GROUP BY "g"."NUMBER" # Test offset without limit >>> select(s for s in Student)[3:] SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" LIMIT -1 OFFSET 3 PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" LIMIT null OFFSET 3 MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` LIMIT 18446744073709551615 OFFSET 3 Oracle: SELECT t.* FROM ( SELECT t.*, ROWNUM "row-num" FROM ( SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" ) t ) t WHERE "row-num" > 3 # Test row comparison: >>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) < (s2.name, s2.gpa, s2.tel)) SELECT DISTINCT "s1"."id", "s2"."id" FROM "Student" "s1", "Student" "s2" WHERE ("s1"."name" < "s2"."name" OR "s1"."name" = "s2"."name" AND "s1"."gpa" < "s2"."gpa" OR "s1"."name" = "s2"."name" AND "s1"."gpa" = "s2"."gpa" AND "s1"."tel" < "s2"."tel") PostgreSQL: SELECT DISTINCT "s1"."id", "s2"."id" FROM "student" "s1", "student" "s2" WHERE ("s1"."name", "s1"."gpa", "s1"."tel") < ("s2"."name", "s2"."gpa", "s2"."tel") MySQL: SELECT DISTINCT `s1`.`id`, `s2`.`id` FROM `student` `s1`, `student` `s2` WHERE (`s1`.`name`, `s1`.`gpa`, `s1`.`tel`) < (`s2`.`name`, `s2`.`gpa`, `s2`.`tel`) Oracle: SELECT DISTINCT "s1"."ID", "s2"."ID" FROM "STUDENT" "s1", "STUDENT" "s2" WHERE ("s1"."NAME", "s1"."GPA", "s1"."TEL") < ("s2"."NAME", "s2"."GPA", "s2"."TEL") >>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) == (s2.name, s2.gpa, s2.tel)) SELECT DISTINCT "s1"."id", "s2"."id" FROM "Student" "s1", "Student" "s2" WHERE "s1"."name" = "s2"."name" AND "s1"."gpa" = "s2"."gpa" AND "s1"."tel" = "s2"."tel" PostgreSQL: SELECT DISTINCT "s1"."id", "s2"."id" FROM "student" "s1", "student" "s2" WHERE "s1"."name" = "s2"."name" AND "s1"."gpa" = "s2"."gpa" AND "s1"."tel" = "s2"."tel" MySQL: SELECT DISTINCT `s1`.`id`, `s2`.`id` FROM `student` `s1`, `student` `s2` WHERE `s1`.`name` = `s2`.`name` AND `s1`.`gpa` = `s2`.`gpa` AND `s1`.`tel` = `s2`.`tel` Oracle: SELECT DISTINCT "s1"."ID", "s2"."ID" FROM "STUDENT" "s1", "STUDENT" "s2" WHERE "s1"."NAME" = "s2"."NAME" AND "s1"."GPA" = "s2"."GPA" AND "s1"."TEL" = "s2"."TEL" # Test date operations: >>> select(s for s in Student if s.dob + timedelta(days=100) < date(2010, 1, 1)) SQLite: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE date("s"."dob", '+100 days') < '2010-01-01' PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."dob" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE ("s"."DOB" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` WHERE ADDDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' >>> td = timedelta(days=100) >>> select(s for s in Student if s.dob + td < date(2010, 1, 1)) SQLite: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE datetime(julianday("s"."dob") + ?) < '2010-01-01' PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."dob" + %(p1)s) < DATE '2010-01-01' Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE ("s"."DOB" + :p1) < DATE '2010-01-01' MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` WHERE ADDDATE(`s`.`dob`, %s) < DATE '2010-01-01' >>> select(s for s in Student if s.dob - timedelta(days=100) < date(2010, 1, 1)) SQLite: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE date("s"."dob", '-100 days') < '2010-01-01' PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."dob" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE ("s"."DOB" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` WHERE SUBDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' >>> td = timedelta(days=100) >>> select(s for s in Student if s.dob - td < date(2010, 1, 1)) SQLite: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE datetime(julianday("s"."dob") - ?) < '2010-01-01' PostgreSQL: SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "student" "s" WHERE ("s"."dob" - %(p1)s) < DATE '2010-01-01' Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE ("s"."DOB" - :p1) < DATE '2010-01-01' MySQL: SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` FROM `student` `s` WHERE SUBDATE(`s`.`dob`, %s) < DATE '2010-01-01' ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/sql_tests.py0000666000000000000000000000756400000000000016253 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 import re, os, os.path, sys from datetime import datetime, timedelta from pony import orm from pony.orm import core from pony.orm.tests import testutils core.suppress_debug_change = True directive_re = re.compile(r'(\w+)(\s+[0-9\.]+)?:') directive = module_name = None statements = [] lines = [] def Schema(param): if not statement_used: print() print('Statement not used:') print() print('\n'.join(statements)) print() sys.exit() assert len(lines) == 1 global module_name module_name = lines[0].strip() def SQLite(server_version): do_test('sqlite', server_version) def MySQL(server_version): do_test('mysql', server_version) def PostgreSQL(server_version): do_test('postgres', server_version) def Oracle(server_version): do_test('oracle', server_version) unavailable_providers = set() def do_test(provider_name, raw_server_version): if provider_name in unavailable_providers: return testutils.TestDatabase.real_provider_name = provider_name testutils.TestDatabase.raw_server_version = raw_server_version core.Database = orm.Database = testutils.TestDatabase sys.modules.pop(module_name, None) try: __import__(module_name) except ImportError as e: print() print('ImportError for database provider %s:\n%s' % (provider_name, e)) print() unavailable_providers.add(provider_name) return module = sys.modules[module_name] globals = vars(module).copy() globals.update(datetime=datetime, timedelta=timedelta) with orm.db_session: for statement in statements[:-1]: code = compile(statement, '', 'exec') if PY2: exec('exec code in globals') else: exec(code, globals) statement = statements[-1] try: last_code = compile(statement, '', 'eval') except SyntaxError: last_code = compile(statement, '', 'exec') if PY2: exec('exec last_code in globals') else: exec(last_code, globals) else: result = eval(last_code, globals) if isinstance(result, core.Query): result = list(result) sql = module.db.sql expected_sql = '\n'.join(lines) if sql == expected_sql: print('.', end='') else: print() print(provider_name, statements[-1]) print() print('Expected:') print(expected_sql) print() print('Got:') print(sql) print() global statement_used statement_used = True dirname, fname = os.path.split(__file__) queries_fname = os.path.join(dirname, 'queries.txt') def orphan_lines(lines): SQLite(None) lines[:] = [] statement_used = True for raw_line in open(queries_fname): line = raw_line.strip() if not line: continue if line.startswith('#'): continue match = directive_re.match(line) if match: if directive: directive(directive_param) lines[:] = [] elif lines: orphan_lines(lines) directive = eval(match.group(1)) if match.group(2): directive_param = match.group(2) else: directive_param = None elif line.startswith('>>> '): if directive: directive(directive_param) lines[:] = [] statements[:] = [] elif lines: orphan_lines(lines) directive = None directive_param = None statements.append(line[4:]) statement_used = False else: lines.append(raw_line.rstrip()) if directive: directive(directive_param) elif lines: orphan_lines(lines) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_array.py0000666000000000000000000001635300000000000016403 0ustar0000000000000000from pony.py23compat import PY2 import unittest from pony.orm.tests.testutils import * from pony.orm import * db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(int) a = Required(int) b = Required(int) c = Required(int) array1 = Required(IntArray, index=True) array2 = Required(FloatArray) array3 = Required(StrArray) array4 = Optional(IntArray) array5 = Optional(IntArray, nullable=True) db.generate_mapping(create_tables=True) with db_session: Foo(id=1, a=1, b=3, c=-2, array1=[10, 20, 30, 40, 50], array2=[1.1, 2.2, 3.3, 4.4, 5.5], array3=['foo', 'bar']) class Test(unittest.TestCase): @db_session def test_1(self): foo = select(f for f in Foo if 10 in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_2(self): foo = select(f for f in Foo if [10, 20, 50] in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_2a(self): foo = select(f for f in Foo if [] in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_3(self): x = [10, 20, 50] foo = select(f for f in Foo if x in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_4(self): foo = select(f for f in Foo if 1.1 in f.array2)[:] self.assertEqual([Foo[1]], foo) err_msg = "Cannot store 'int' item in array of " + ("'unicode'" if PY2 else "'str'") @raises_exception(TypeError, err_msg) @db_session def test_5(self): foo = Foo.select().first() foo.array3.append(123) @raises_exception(TypeError, err_msg) @db_session def test_6(self): foo = Foo.select().first() foo.array3[0] = 123 @raises_exception(TypeError, err_msg) @db_session def test_7(self): foo = Foo.select().first() foo.array3.extend(['str', 123, 'str']) @db_session def test_8(self): foo = Foo.select().first() foo.array3.extend(['str1', 'str2']) @db_session def test_9(self): foos = select(f.array2[0] for f in Foo)[:] self.assertEqual([1.1], foos) @db_session def test_10(self): foos = select(f.array1[1:-1] for f in Foo)[:] self.assertEqual([20, 30, 40], foos[0]) @db_session def test_11(self): foo = Foo.select().first() foo.array4.append(1) self.assertEqual([1], foo.array4) @raises_exception(AttributeError, "'NoneType' object has no attribute 'append'") @db_session def test_12(self): foo = Foo.select().first() foo.array5.append(1) @db_session def test_13(self): x = [10, 20, 30, 40, 50] ids = select(f.id for f in Foo if x == f.array1)[:] self.assertEqual(ids, [1]) @db_session def test_14(self): val = select(f.array1[0] for f in Foo).first() self.assertEqual(val, 10) @db_session def test_15(self): val = select(f.array1[2] for f in Foo).first() self.assertEqual(val, 30) @db_session def test_16(self): val = select(f.array1[-1] for f in Foo).first() self.assertEqual(val, 50) @db_session def test_17(self): val = select(f.array1[-2] for f in Foo).first() self.assertEqual(val, 40) @db_session def test_18(self): x = 2 val = select(f.array1[x] for f in Foo).first() self.assertEqual(val, 30) @db_session def test_19(self): val = select(f.array1[f.a] for f in Foo).first() self.assertEqual(val, 20) @db_session def test_20(self): val = select(f.array1[f.c] for f in Foo).first() self.assertEqual(val, 40) @db_session def test_21(self): array = select(f.array1[2:4] for f in Foo).first() self.assertEqual(array, [30, 40]) @db_session def test_22(self): array = select(f.array1[1:-2] for f in Foo).first() self.assertEqual(array, [20, 30]) @db_session def test_23(self): array = select(f.array1[10:-10] for f in Foo).first() self.assertEqual(array, []) @db_session def test_24(self): x = 2 array = select(f.array1[x:4] for f in Foo).first() self.assertEqual(array, [30, 40]) @db_session def test_25(self): y = 4 array = select(f.array1[2:y] for f in Foo).first() self.assertEqual(array, [30, 40]) @db_session def test_26(self): x, y = 2, 4 array = select(f.array1[x:y] for f in Foo).first() self.assertEqual(array, [30, 40]) @db_session def test_27(self): x, y = 1, -2 array = select(f.array1[x:y] for f in Foo).first() self.assertEqual(array, [20, 30]) @db_session def test_28(self): x = 1 array = select(f.array1[x:f.b] for f in Foo).first() self.assertEqual(array, [20, 30]) @db_session def test_29(self): array = select(f.array1[f.a:f.c] for f in Foo).first() self.assertEqual(array, [20, 30]) @db_session def test_30(self): array = select(f.array1[:3] for f in Foo).first() self.assertEqual(array, [10, 20, 30]) @db_session def test_31(self): array = select(f.array1[2:] for f in Foo).first() self.assertEqual(array, [30, 40, 50]) @db_session def test_32(self): array = select(f.array1[:f.b] for f in Foo).first() self.assertEqual(array, [10, 20, 30]) @db_session def test_33(self): array = select(f.array1[:f.c] for f in Foo).first() self.assertEqual(array, [10, 20, 30]) @db_session def test_34(self): array = select(f.array1[f.c:] for f in Foo).first() self.assertEqual(array, [40, 50]) @db_session def test_35(self): foo = Foo.select().first() self.assertTrue(10 in foo.array1) self.assertTrue(1000 not in foo.array1) self.assertTrue([10, 20] in foo.array1) self.assertTrue([20, 10] in foo.array1) self.assertTrue([10, 1000] not in foo.array1) self.assertTrue([] in foo.array1) self.assertTrue('bar' in foo.array3) self.assertTrue('baz' not in foo.array3) self.assertTrue(['foo', 'bar'] in foo.array3) self.assertTrue(['bar', 'foo'] in foo.array3) self.assertTrue(['baz', 'bar'] not in foo.array3) self.assertTrue([] in foo.array3) @db_session def test_36(self): items = [] result = select(foo for foo in Foo if foo in items)[:] self.assertEqual(result, []) @db_session def test_37(self): f1 = Foo[1] items = [f1] result = select(foo for foo in Foo if foo in items)[:] self.assertEqual(result, [f1]) @db_session def test_38(self): items = [] result = select(foo for foo in Foo if foo.id in items)[:] self.assertEqual(result, []) @db_session def test_39(self): items = [1] result = select(foo.id for foo in Foo if foo.id in items)[:] self.assertEqual(result, [1]) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_attribute_options.py0000666000000000000000000001024000000000000021030 0ustar0000000000000000import unittest from decimal import Decimal from datetime import datetime, time from random import randint from pony import orm from pony.orm.core import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): name = orm.Required(str, 40) lastName = orm.Required(str, max_len=40, unique=True) age = orm.Optional(int, max=60, min=10) nickName = orm.Optional(str, autostrip=False) middleName = orm.Optional(str, nullable=True) rate = orm.Optional(Decimal, precision=11) salaryRate = orm.Optional(Decimal, precision=13, scale=8) timeStmp = orm.Optional(datetime, precision=6) gpa = orm.Optional(float, py_check=lambda val: val >= 0 and val <= 5) vehicle = orm.Optional(str, column='car') db.generate_mapping(create_tables=True) with orm.db_session: p1 = Person(name='Andrew', lastName='Bodroue', age=40, rate=0.980000000001, salaryRate=0.98000001) p2 = Person(name='Vladimir', lastName='Andrew ', nickName='vlad ') p3 = Person(name='Nick', lastName='Craig', middleName=None, timeStmp='2010-12-10 14:12:09.019473', vehicle='dodge') class TestAttributeOptions(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_optionalStringEmpty(self): queryResult = select(p.id for p in Person if p.nickName==None).first() self.assertIsNone(queryResult) def test_optionalStringNone(self): queryResult = select(p.id for p in Person if p.middleName==None).first() self.assertIsNotNone(queryResult) def test_stringAutoStrip(self): self.assertEqual(p2.lastName, 'Andrew') def test_stringAutoStripFalse(self): self.assertEqual(p2.nickName, 'vlad ') def test_intNone(self): queryResult = select(p.id for p in Person if p.age==None).first() self.assertIsNotNone(queryResult) def test_columnName(self): self.assertEqual(getattr(Person.vehicle, 'column'), 'car') def test_decimalPrecisionTwo(self): queryResult = select(p.rate for p in Person if p.age==40).first() self.assertAlmostEqual(float(queryResult), 0.98, 12) def test_decimalPrecisionEight(self): queryResult = select(p.salaryRate for p in Person if p.age==40).first() self.assertAlmostEqual(float(queryResult), 0.98000001, 8) def test_fractionalSeconds(self): queryResult = select(p.timeStmp for p in Person if p.name=='Nick').first() self.assertEqual(queryResult.microsecond, 19473) def test_intMax(self): p4 = Person(name='Denis', lastName='Blanc', age=60) def test_intMin(self): p4 = Person(name='Denis', lastName='Blanc', age=10) @raises_exception(ValueError, "Value 61 of attr Person.age is greater than the maximum allowed value 60") def test_intMaxException(self): p4 = Person(name='Denis', lastName='Blanc', age=61) @raises_exception(ValueError, "Value 9 of attr Person.age is less than the minimum allowed value 10") def test_intMinException(self): p4 = Person(name='Denis', lastName='Blanc', age=9) def test_py_check(self): p4 = Person(name='Denis', lastName='Blanc', gpa=5) p5 = Person(name='Mario', lastName='Gon', gpa=1) flush() @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: 6.0") def test_py_checkMoreException(self): p6 = Person(name='Daniel', lastName='Craig', gpa=6) @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: -1.0") def test_py_checkLessException(self): p6 = Person(name='Daniel', lastName='Craig', gpa=-1) @raises_exception(TransactionIntegrityError, 'Object Person[new:...] cannot be stored in the database.' ' IntegrityError: UNIQUE constraint failed: Person.lastName') def test_unique(self): p6 = Person(name='Boris', lastName='Bodroue') flush()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_autostrip.py0000666000000000000000000000125200000000000017307 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(str) tel = Optional(str) db.generate_mapping(create_tables=True) class TestAutostrip(unittest.TestCase): @db_session def test_1(self): p = Person(name=' John ', tel=' ') p.flush() self.assertEqual(p.name, 'John') self.assertEqual(p.tel, '') @raises_exception(ValueError, 'Attribute Person.name is required') @db_session def test_2(self): p = Person(name=' ') if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_buffer.py0000666000000000000000000000164400000000000016533 0ustar0000000000000000import unittest from pony import orm from pony.py23compat import buffer db = orm.Database('sqlite', ':memory:') class Foo(db.Entity): id = orm.PrimaryKey(int) b = orm.Optional(orm.buffer) class Bar(db.Entity): b = orm.PrimaryKey(orm.buffer) class Baz(db.Entity): id = orm.PrimaryKey(int) b = orm.Optional(orm.buffer, unique=True) db.generate_mapping(create_tables=True) buf = orm.buffer(b'123') with orm.db_session: Foo(id=1, b=buf) Bar(b=buf) Baz(id=1, b=buf) class Test(unittest.TestCase): def test_1(self): # Bug #355 with orm.db_session: Bar[buf] def test_2(self): # Regression after #355 fix with orm.db_session: result = orm.select(bar.b for bar in Foo)[:] self.assertEqual(result, [buf]) def test_3(self): # Bug #390 with orm.db_session: Baz.get(b=buf) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571702850.0 pony-0.7.11/pony/orm/tests/test_bug_170.py0000666000000000000000000000115500000000000016423 0ustar0000000000000000import unittest from pony import orm class Test(unittest.TestCase): def test_1(self): db = orm.Database('sqlite', ':memory:') class Person(db.Entity): id = orm.PrimaryKey(int, auto=True) name = orm.Required(str) orm.composite_key(id, name) db.generate_mapping(create_tables=True) table = db.schema.tables[Person._table_] pk_column = table.column_dict[Person.id.column] self.assertTrue(pk_column.is_pk) with orm.db_session: p1 = Person(name='John') p2 = Person(name='Mike') ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1570625601.0 pony-0.7.11/pony/orm/tests/test_bug_182.py0000666000000000000000000000207500000000000016430 0ustar0000000000000000 import unittest from pony.orm import * from pony import orm class Test(unittest.TestCase): def setUp(self): db = self.db = Database('sqlite', ':memory:') class User(db.Entity): name = Required(str) servers = Set("Server") class Worker(db.User): pass class Admin(db.Worker): pass # And M:1 relationship with another entity class Server(db.Entity): name = Required(str) user = Optional(User) db.generate_mapping(check_tables=True, create_tables=True) with orm.db_session: Server(name='s1.example.com', user=User(name="Alex")) Server(name='s2.example.com', user=Worker(name="John")) Server(name='free.example.com', user=None) @db_session def test(self): qu = left_join((s.name, s.user.name) for s in self.db.Server)[:] for server, user in qu: if user is None: break else: self.fail() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_bug_331.py0000666000000000000000000000223600000000000016423 0ustar0000000000000000import unittest from pony import orm class Test(unittest.TestCase): def test_1(self): db = orm.Database('sqlite', ':memory:') class Person(db.Entity): name = orm.Required(str) group = orm.Optional(lambda: Group) class Group(db.Entity): title = orm.PrimaryKey(str) persons = orm.Set(Person) def __len__(self): return len(self.persons) db.generate_mapping(create_tables=True) with orm.db_session: p1 = Person(name="Alex") p2 = Person(name="Brad") p3 = Person(name="Chad") p4 = Person(name="Dylan") p5 = Person(name="Ethan") g1 = Group(title="Foxes") g2 = Group(title="Gorillas") g1.persons.add(p1) g1.persons.add(p2) g1.persons.add(p3) g2.persons.add(p4) g2.persons.add(p5) orm.commit() foxes = Group['Foxes'] gorillas = Group['Gorillas'] self.assertEqual(len(foxes), 3) self.assertEqual(len(gorillas), 2) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_bug_386.py0000666000000000000000000000062500000000000016435 0ustar0000000000000000import unittest from pony import orm class Test(unittest.TestCase): def test_1(self): db = orm.Database('sqlite', ':memory:') class Person(db.Entity): name = orm.Required(str) db.generate_mapping(create_tables=True) with orm.db_session: a = Person(name='John') a.delete() Person.exists(name='Mike') ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_cascade.py0000666000000000000000000000555400000000000016651 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import * class TestCascade(unittest.TestCase): def test_1(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Required('Group') class Group(self.db.Entity): persons = Set(Person) db.generate_mapping(create_tables=True) self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) def test_2(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Required('Group') class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) db.generate_mapping(create_tables=True) self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) def test_3(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Optional('Group') class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) db.generate_mapping(create_tables=True) self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) @raises_exception(TypeError, "'cascade_delete' option cannot be set for attribute Group.persons, because reverse attribute Person.group is collection") def test_4(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Set('Group') class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "'cascade_delete' option cannot be set for both sides of relationship (Person.group and Group.persons) simultaneously") def test_5(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Set('Group', cascade_delete=True) class Group(self.db.Entity): persons = Required(Person, cascade_delete=True) db.generate_mapping(create_tables=True) def test_6(self): db = self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(str) group = Set('Group') class Group(self.db.Entity): persons = Optional(Person) db.generate_mapping(create_tables=True) self.assertTrue('ON DELETE SET NULL' in self.db.schema.tables['Group'].get_create_command()) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_cascade_delete.py0000666000000000000000000000327200000000000020166 0ustar0000000000000000import unittest from pony.orm import * db = Database('sqlite', ':memory:') class X(db.Entity): id = PrimaryKey(int) parent = Optional('X', reverse='children') children = Set('X', reverse='parent', cascade_delete=True) class Y(db.Entity): parent = Optional('Y', reverse='children') children = Set('Y', reverse='parent', cascade_delete=True, lazy=True) db.generate_mapping(create_tables=True) with db_session: x1 = X(id=1) x2 = X(id=2, parent=x1) x3 = X(id=3, parent=x1) x4 = X(id=4, parent=x3) x5 = X(id=5, parent=x3) x6 = X(id=6, parent=x5) x7 = X(id=7, parent=x3) x8 = X(id=8, parent=x7) x9 = X(id=9, parent=x7) x10 = X(id=10) x11 = X(id=11, parent=x10) x12 = X(id=12, parent=x10) y1 = Y(id=1) y2 = Y(id=2, parent=y1) y3 = Y(id=3, parent=y1) y4 = Y(id=4, parent=y3) y5 = Y(id=5, parent=y3) y6 = Y(id=6, parent=y5) y7 = Y(id=7, parent=y3) y8 = Y(id=8, parent=y7) y9 = Y(id=9, parent=y7) y10 = Y(id=10) y11 = Y(id=11, parent=y10) y12 = Y(id=12, parent=y10) class TestCascade(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_1(self): db.merge_local_stats() X[1].delete() stats = db.local_stats[None] self.assertEqual(5, stats.db_count) def test_2(self): db.merge_local_stats() Y[1].delete() stats = db.local_stats[None] self.assertEqual(10, stats.db_count) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_collections.py0000666000000000000000000000610400000000000017574 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 import unittest from pony.orm.tests.testutils import raises_exception from pony.orm.tests.model1 import * class TestCollections(unittest.TestCase): @db_session def test_setwrapper_len(self): g = Group.get(number='4145') self.assertTrue(len(g.students) == 3) @db_session def test_setwrapper_nonzero(self): g = Group.get(number='4145') self.assertTrue(bool(g.students) == True) self.assertTrue(len(g.students) == 3) @db_session @raises_exception(TypeError, 'Collection attribute Group.students cannot be specified as search criteria') def test_get_by_collection_error(self): Group.get(students=[]) @db_session def test_collection_create_one2many_1(self): g = Group['3132'] g.students.create(record=106, name='Mike', scholarship=200) flush() self.assertEqual(len(g.students), 3) rollback() @raises_exception(TypeError, "When using Group.students.create(), " "'group' attribute should not be passed explicitly") @db_session def test_collection_create_one2many_2(self): g = Group['3132'] g.students.create(record=106, name='Mike', scholarship=200, group=g) @raises_exception(TransactionIntegrityError, "Object Student[105] cannot be stored in the database...") @db_session def test_collection_create_one2many_3(self): g = Group['3132'] g.students.create(record=105, name='Mike', scholarship=200) @db_session def test_collection_create_many2many_1(self): g = Group['3132'] g.subjects.create(name='Biology') flush() self.assertEqual(len(g.subjects), 3) rollback() @raises_exception(TypeError, "When using Group.subjects.create(), " "'groups' attribute should not be passed explicitly") @db_session def test_collection_create_many2many_2(self): g = Group['3132'] g.subjects.create(name='Biology', groups=[g]) @raises_exception(TransactionIntegrityError, "Object Subject[u'Math'] cannot be stored in the database..." if PY2 else "Object Subject['Math'] cannot be stored in the database...") @db_session def test_collection_create_many2many_3(self): g = Group['3132'] g.subjects.create(name='Math') # replace collection items when the old ones are not fully loaded ##>>> from pony.examples.orm.students01.model import * ##>>> s1 = Student[101] ##>>> g = s1.group ##>>> g.__dict__[Group.students].is_fully_loaded ##False ##>>> s2 = Student[104] ##>>> g.students = [s2] ##>>> # replace collection items when the old ones are not loaded ##>>> from pony.examples.orm.students01.model import * ##>>> g = Group[4145] ##>>> Group.students not in g.__dict__ ##True ##>>> s2 = Student[104] ##>>> g.students = [s2] if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_core_find_in_cache.py0000666000000000000000000001373100000000000021023 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.tests.testutils import raises_exception from pony.orm import * db = Database('sqlite', ':memory:') class AbstractUser(db.Entity): username = PrimaryKey(unicode) class User(AbstractUser): diagrams = Set('Diagram') email = Optional(unicode) class SubUser1(User): attr1 = Optional(unicode) class SubUser2(User): attr2 = Optional(unicode) class Organization(AbstractUser): address = Optional(unicode) class SubOrg1(Organization): attr3 = Optional(unicode) class SubOrg2(Organization): attr4 = Optional(unicode) class Diagram(db.Entity): name = Required(unicode) owner = Required(User) db.generate_mapping(create_tables=True) with db_session: u1 = User(username='user1') u2 = SubUser1(username='subuser1', attr1='some attr') u3 = SubUser2(username='subuser2', attr2='some attr') o1 = Organization(username='org1') o2 = SubOrg1(username='suborg1', attr3='some attr') o3 = SubOrg2(username='suborg2', attr4='some attr') au = AbstractUser(username='abstractUser') Diagram(name='diagram1', owner=u1) Diagram(name='diagram2', owner=u2) Diagram(name='diagram3', owner=u3) def is_seed(entity, pk): cache = entity._database_._get_cache() return pk in [ obj._pk_ for obj in cache.seeds[entity._pk_attrs_] ] class TestFindInCache(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): u = User.get(username='org1') org = Organization.get(username='org1') u1 = User.get(username='org1') self.assertEqual(u, None) self.assertEqual(org, Organization['org1']) self.assertEqual(u1, None) def test_user_1(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql self.assertTrue(is_seed(User, 'user1')) u = AbstractUser['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) def test_user_2(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql self.assertTrue(is_seed(User, 'user1')) u = User['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) @raises_exception(ObjectNotFound) def test_user_3(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql self.assertTrue(is_seed(User, 'user1')) try: SubUser1['user1'] finally: self.assertNotEqual(last_sql, db.last_sql) @raises_exception(ObjectNotFound) def test_user_4(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql self.assertTrue(is_seed(User, 'user1')) try: Organization['user1'] finally: self.assertEqual(last_sql, db.last_sql) @raises_exception(ObjectNotFound) def test_user_5(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql self.assertTrue(is_seed(User, 'user1')) try: SubOrg1['user1'] finally: self.assertEqual(last_sql, db.last_sql) def test_subuser_1(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) u = AbstractUser['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) def test_subuser_2(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) u = User['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) def test_subuser_3(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) u = SubUser1['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) @raises_exception(ObjectNotFound) def test_subuser_4(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) try: Organization['subuser1'] finally: self.assertEqual(last_sql, db.last_sql) @raises_exception(ObjectNotFound) def test_subuser_5(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) try: SubUser2['subuser1'] finally: self.assertNotEqual(last_sql, db.last_sql) @raises_exception(ObjectNotFound) def test_subuser_6(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql self.assertTrue(is_seed(User, 'subuser1')) try: SubOrg2['subuser1'] finally: self.assertEqual(last_sql, db.last_sql) def test_user_6(self): u1 = SubUser1['subuser1'] last_sql = db.last_sql u2 = SubUser1['subuser1'] self.assertEqual(last_sql, db.last_sql) self.assertEqual(u1, u2) def test_user_7(self): u1 = SubUser1['subuser1'] u1.delete() last_sql = db.last_sql u2 = SubUser1.get(username='subuser1') self.assertEqual(last_sql, db.last_sql) self.assertEqual(u2, None) def test_user_8(self): u1 = SubUser1['subuser1'] last_sql = db.last_sql u2 = SubUser1.get(username='subuser1', attr1='wrong val') self.assertEqual(last_sql, db.last_sql) self.assertEqual(u2, None) if __name__ == '__main__': unittest.main()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1570625601.0 pony-0.7.11/pony/orm/tests/test_core_multiset.py0000666000000000000000000001072000000000000020133 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import pickle import unittest from pony.orm.core import * db = Database('sqlite', ':memory:') class Department(db.Entity): number = PrimaryKey(int) groups = Set('Group') courses = Set('Course') class Group(db.Entity): number = PrimaryKey(int) department = Required(Department) students = Set('Student') class Student(db.Entity): name = Required(str) group = Required(Group) courses = Set('Course') class Course(db.Entity): name = PrimaryKey(str) department = Required(Department) students = Set('Student') db.generate_mapping(create_tables=True) with db_session: d1 = Department(number=1) d2 = Department(number=2) d3 = Department(number=3) g1 = Group(number=101, department=d1) g2 = Group(number=102, department=d1) g3 = Group(number=201, department=d2) c1 = Course(name='C1', department=d1) c2 = Course(name='C2', department=d1) c3 = Course(name='C3', department=d2) c4 = Course(name='C4', department=d2) c5 = Course(name='C5', department=d3) s1 = Student(name='S1', group=g1, courses=[c1, c2]) s2 = Student(name='S2', group=g1, courses=[c1, c3]) s3 = Student(name='S3', group=g1, courses=[c2, c3]) s4 = Student(name='S4', group=g2, courses=[c1, c2]) s5 = Student(name='S5', group=g2, courses=[c1, c2]) s6 = Student(name='A', group=g3, courses=[c5]) class TestMultiset(unittest.TestCase): @db_session def test_multiset_repr_1(self): d = Department[1] multiset = d.groups.students self.assertEqual(repr(multiset), "") @db_session def test_multiset_repr_2(self): g = Group[101] multiset = g.students.courses self.assertEqual(repr(multiset), "") @db_session def test_multiset_repr_3(self): g = Group[201] multiset = g.students.courses self.assertEqual(repr(multiset), "") def test_multiset_repr_4(self): with db_session: g = Group[101] multiset = g.students.courses self.assertIsNone(multiset._obj_._session_cache_) self.assertEqual(repr(multiset), "") @db_session def test_multiset_str(self): g = Group[101] multiset = g.students.courses self.assertEqual(str(multiset), "CourseMultiset({Course[%r]: 2, Course[%r]: 2, Course[%r]: 2})" % (u'C1', u'C2', u'C3')) @db_session def test_multiset_distinct(self): d = Department[1] multiset = d.groups.students.courses self.assertEqual(multiset.distinct(), {Course['C1']: 4, Course['C2']: 4, Course['C3']: 2}) @db_session def test_multiset_nonzero(self): d = Department[1] multiset = d.groups.students self.assertEqual(bool(multiset), True) @db_session def test_multiset_len(self): d = Department[1] multiset = d.groups.students.courses self.assertEqual(len(multiset), 10) @db_session def test_multiset_eq(self): d = Department[1] multiset = d.groups.students.courses c1, c2, c3 = Course['C1'], Course['C2'], Course['C3'] self.assertEqual(multiset, multiset) self.assertEqual(multiset, {c1: 4, c2: 4, c3: 2}) self.assertEqual(multiset, [ c1, c1, c1, c2, c2, c2, c2, c3, c3, c1 ]) @db_session def test_multiset_ne(self): d = Department[1] multiset = d.groups.students.courses self.assertFalse(multiset != multiset) @db_session def test_multiset_contains(self): d = Department[1] multiset = d.groups.students.courses self.assertTrue(Course['C1'] in multiset) self.assertFalse(Course['C5'] in multiset) def test_multiset_reduce(self): with db_session: d = Department[1] multiset = d.groups.students s = pickle.dumps(multiset) with db_session: d = Department[1] multiset_2 = d.groups.students multiset_1 = pickle.loads(s) self.assertEqual(multiset_1, multiset_2) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_crud.py0000666000000000000000000001026100000000000016212 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from decimal import Decimal from datetime import date from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Group(db.Entity): id = PrimaryKey(int) major = Required(unicode) students = Set('Student') class Student(db.Entity): name = Required(unicode) age = Optional(int) scholarship = Required(Decimal, default=0) picture = Optional(buffer, lazy=True) email = Required(unicode, unique=True) phone = Optional(unicode, unique=True) courses = Set('Course') group = Optional('Group') class Course(db.Entity): name = Required(unicode) semester = Required(int) students = Set(Student) composite_key(name, semester) db.generate_mapping(create_tables=True) with db_session: g1 = Group(id=1, major='Math') g2 = Group(id=2, major='Physics') s1 = Student(id=1, name='S1', age=19, email='s1@example.com', group=g1) s2 = Student(id=2, name='S2', age=21, email='s2@example.com', group=g1) s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) c1 = Course(name='Math', semester=1) c2 = Course(name='Math', semester=2) c3 = Course(name='Physics', semester=1) class TestCRUD(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_getitem_1(self): g1 = Group[1] self.assertEqual(g1.id, 1) @raises_exception(ObjectNotFound, 'Group[333]') def test_getitem_2(self): g333 = Group[333] def test_exists_1(self): x = Group.exists(id=1) self.assertEqual(x, True) def test_exists_2(self): x = Group.exists(id=333) self.assertEqual(x, False) def test_exists_3(self): g1 = Group[1] x = Student.exists(group=g1) self.assertEqual(x, True) def test_numeric_nonzero(self): result = select(s.id for s in Student if s.age)[:] self.assertEqual(set(result), {1, 2}) def test_numeric_negate_1(self): result = select(s.id for s in Student if not s.age)[:] self.assertEqual(set(result), {3}) self.assertTrue('is null' in db.last_sql.lower()) def test_numeric_negate_2(self): result = select(c.id for c in Course if not c.semester)[:] self.assertEqual(result, []) self.assertTrue('is null' not in db.last_sql.lower()) def test_set1(self): s1 = Student[1] s1.set(name='New name', scholarship=100) self.assertEqual(s1.name, 'New name') self.assertEqual(s1.scholarship, 100) def test_set2(self): g1 = Group[1] s3 = Student[3] g1.set(students=[s3]) self.assertEqual(s3.group, Group[1]) def test_set3(self): c1 = Course[1] c1.set(name='Algebra', semester=3) def test_set4(self): s1 = Student[1] s1.set(name='New name', email='new_email@example.com') def test_validate_1(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=1) def test_validate_2(self): s4 = Student(id=3, name='S4', email='s4@example.com', group='1') @raises_exception(TransactionIntegrityError) def test_validate_3(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=100) flush() @raises_exception(ValueError, "Value type for attribute Group.id must be int. Got string 'not a number'") def test_validate_5(self): s4 = Student(id=3, name='S4', email='s4@example.com', group='not a number') @raises_exception(TypeError, "Attribute Student.group must be of Group type. Got: datetime.date(2011, 1, 1)") def test_validate_6(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=date(2011, 1, 1)) @raises_exception(TypeError, 'Invalid number of columns were specified for attribute Student.group. Expected: 1, got: 2') def test_validate_7(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=(1, 2)) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_crud_raw_sql.py0000666000000000000000000000550200000000000017744 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) age = Optional(int) friends = Set("Student", reverse='friends') group = Required("Group") bio = Optional("Bio") class Group(db.Entity): dept = Required(int) grad_year = Required(int) students = Set(Student) PrimaryKey(dept, grad_year) class Bio(db.Entity): picture = Optional(buffer) desc = Required(unicode) Student = Required(Student) db.generate_mapping(create_tables=True) class TestCrudRawSQL(unittest.TestCase): def setUp(self): with db_session: db.execute('delete from Student') db.execute('delete from "Group"') db.insert(Group, dept=44, grad_year=1999) db.insert(Student, id=1, name='A', age=30, group_dept=44, group_grad_year=1999) db.insert(Student, id=2, name='B', age=25, group_dept=44, group_grad_year=1999) db.insert(Student, id=3, name='C', age=20, group_dept=44, group_grad_year=1999) rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): students = set(Student.select_by_sql("select id, name, age, group_dept, group_grad_year from Student order by age")) self.assertEqual(students, {Student[3], Student[2], Student[1]}) def test2(self): students = set(Student.select_by_sql("select id, age, group_dept from Student order by age")) self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(NameError, "Column x does not belong to entity Student") def test3(self): students = set(Student.select_by_sql("select id, age, age*2 as x from Student order by age")) self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(TypeError, 'The first positional argument must be lambda function or its text source. Got: 123') def test4(self): students = Student.select(123) def test5(self): x = 1 y = 30 cursor = db.execute("select name from Student where id = $x and age = $y") self.assertEqual(cursor.fetchone()[0], 'A') def test6(self): x = 1 y = 30 cursor = db.execute("select name, 'abc$$def%' from Student where id = $x and age = $y") self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) def test7(self): cursor = db.execute("select name, 'abc$$def%' from Student where id = 1") self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_datetime.py0000666000000000000000000001200400000000000017046 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 import unittest from datetime import date, datetime, timedelta from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) d = Required(date) dt = Required(datetime) db.generate_mapping(create_tables=True) with db_session: Entity1(id=1, d=date(2009, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) Entity1(id=2, d=date(2010, 10, 21), dt=datetime(2010, 10, 21, 10, 21, 31)) Entity1(id=3, d=date(2011, 11, 22), dt=datetime(2011, 11, 22, 10, 20, 32)) class TestDate(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_create(self): e1 = Entity1(id=4, d=date(2011, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) def test_date_year(self): result = select(e for e in Entity1 if e.d.year > 2009) self.assertEqual(len(result), 2) def test_date_month(self): result = select(e for e in Entity1 if e.d.month == 10) self.assertEqual(len(result), 2) def test_date_day(self): result = select(e for e in Entity1 if e.d.day == 22) self.assertEqual(len(result), 1) def test_datetime_year(self): result = select(e for e in Entity1 if e.dt.year > 2009) self.assertEqual(len(result), 2) def test_datetime_month(self): result = select(e for e in Entity1 if e.dt.month == 10) self.assertEqual(len(result), 2) def test_datetime_day(self): result = select(e for e in Entity1 if e.dt.day == 22) self.assertEqual(len(result), 1) def test_datetime_hour(self): result = select(e for e in Entity1 if e.dt.hour == 10) self.assertEqual(len(result), 3) def test_datetime_minute(self): result = select(e for e in Entity1 if e.dt.minute == 20) self.assertEqual(len(result), 2) def test_datetime_second(self): result = select(e for e in Entity1 if e.dt.second == 30) self.assertEqual(len(result), 1) def test_date_sub_date(self): dt = date(2012, 1, 1) result = select(e.id for e in Entity1 if dt - e.d > timedelta(days=500)) self.assertEqual(set(result), {1}) def test_datetime_sub_datetime(self): dt = datetime(2012, 1, 1, 10, 20, 30) result = select(e.id for e in Entity1 if dt - e.dt > timedelta(days=500)) self.assertEqual(set(result), {1}) def test_date_sub_timedelta_param(self): td = timedelta(days=500) result = select(e.id for e in Entity1 if e.d - td < date(2009, 1, 1)) self.assertEqual(set(result), {1}) def test_date_sub_const_timedelta(self): result = select(e.id for e in Entity1 if e.d - timedelta(days=500) < date(2009, 1, 1)) self.assertEqual(set(result), {1}) def test_datetime_sub_timedelta_param(self): td = timedelta(days=500) result = select(e.id for e in Entity1 if e.dt - td < datetime(2009, 1, 1, 10, 20, 30)) self.assertEqual(set(result), {1}) def test_datetime_sub_const_timedelta(self): result = select(e.id for e in Entity1 if e.dt - timedelta(days=500) < datetime(2009, 1, 1, 10, 20, 30)) self.assertEqual(set(result), {1}) def test_date_add_timedelta_param(self): td = timedelta(days=500) result = select(e.id for e in Entity1 if e.d + td > date(2013, 1, 1)) self.assertEqual(set(result), {3}) def test_date_add_const_timedelta(self): result = select(e.id for e in Entity1 if e.d + timedelta(days=500) > date(2013, 1, 1)) self.assertEqual(set(result), {3}) def test_datetime_add_timedelta_param(self): td = timedelta(days=500) result = select(e.id for e in Entity1 if e.dt + td > date(2013, 1, 1)) self.assertEqual(set(result), {3}) def test_datetime_add_const_timedelta(self): result = select(e.id for e in Entity1 if e.dt + timedelta(days=500) > date(2013, 1, 1)) self.assertEqual(set(result), {3}) @raises_exception(TypeError, "Unsupported operand types 'date' and '%s' " "for operation '-' in expression: e.d - s" % ('unicode' if PY2 else 'str')) def test_date_sub_error(self): s = 'hello' result = select(e.id for e in Entity1 if e.d - s > timedelta(days=500)) self.assertEqual(set(result), {1}) @raises_exception(TypeError, "Unsupported operand types 'datetime' and '%s' " "for operation '-' in expression: e.dt - s" % ('unicode' if PY2 else 'str')) def test_datetime_sub_error(self): s = 'hello' result = select(e.id for e in Entity1 if e.dt - s > timedelta(days=500)) self.assertEqual(set(result), {1}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_db_session.py0000666000000000000000000003771400000000000017421 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest, warnings from datetime import date from decimal import Decimal from itertools import count from pony.orm.core import * from pony.orm.tests.testutils import * class TestDBSession(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class X(self.db.Entity): a = PrimaryKey(int) b = Optional(int) self.X = X self.db.generate_mapping(create_tables=True) with db_session: x1 = X(a=1, b=1) x2 = X(a=2, b=2) @raises_exception(TypeError, "Pass only keyword arguments to db_session or use db_session as decorator") def test_db_session_1(self): db_session(1, 2, 3) @raises_exception(TypeError, "Pass only keyword arguments to db_session or use db_session as decorator") def test_db_session_2(self): db_session(1, 2, 3, a=10, b=20) def test_db_session_3(self): self.assertTrue(db_session is db_session()) def test_db_session_4(self): # Nested db_sessions are ignored with db_session: with db_session: self.X(a=3, b=3) with db_session: self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_1(self): # Should commit changes on exit from db_session @db_session def test(): self.X(a=3, b=3) test() with db_session: self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_2(self): # Should rollback changes if an exception is occurred @db_session def test(): self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_db_session_decorator_3(self): # Should rollback changes if the exception is not in the list of allowed exceptions @db_session(allowed_exceptions=[TypeError]) def test(): self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_db_session_decorator_4(self): # Should commit changes if the exception is in the list of allowed exceptions @db_session(allowed_exceptions=[ZeroDivisionError]) def test(): self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 3) else: self.fail() def test_allowed_exceptions_1(self): # allowed_exceptions may be callable, should commit if nonzero @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) def test(): self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 3) else: self.fail() def test_allowed_exceptions_2(self): # allowed_exceptions may be callable, should rollback if not nonzero @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) def test(): self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() @raises_exception(TypeError, "'retry' parameter of db_session must be of integer type. Got: %r" % str) def test_retry_1(self): @db_session(retry='foobar') def test(): pass @raises_exception(TypeError, "'retry' parameter of db_session must not be negative. Got: -1") def test_retry_2(self): @db_session(retry=-1) def test(): pass def test_retry_3(self): # Should not to do retry until retry count is specified counter = count() @db_session(retry_exceptions=[ZeroDivisionError]) def test(): next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: self.assertEqual(next(counter), 1) with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_retry_4(self): # Should rollback & retry 1 time if retry=1 counter = count() @db_session(retry=1, retry_exceptions=[ZeroDivisionError]) def test(): next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: self.assertEqual(next(counter), 2) with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_retry_5(self): # Should rollback & retry N time if retry=N counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: self.assertEqual(next(counter), 6) with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_retry_6(self): # Should not retry if the exception not in the list of retry_exceptions counter = count() @db_session(retry=3, retry_exceptions=[TypeError]) def test(): next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: self.assertEqual(next(counter), 1) with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_retry_7(self): # Should commit after successful retrying counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): i = next(counter) self.X(a=3, b=3) if i < 2: 1/0 try: test() except ZeroDivisionError: self.fail() else: self.assertEqual(next(counter), 3) with db_session: self.assertEqual(count(x for x in self.X), 3) @raises_exception(TypeError, "The same exception ZeroDivisionError cannot be specified " "in both allowed and retry exception lists simultaneously") def test_retry_8(self): @db_session(retry=3, retry_exceptions=[ZeroDivisionError], allowed_exceptions=[ZeroDivisionError]) def test(): pass def test_retry_9(self): # retry_exceptions may be callable, should retry if nonzero counter = count() @db_session(retry=3, retry_exceptions=lambda e: isinstance(e, ZeroDivisionError)) def test(): i = next(counter) self.X(a=3, b=3) 1/0 try: test() except ZeroDivisionError: self.assertEqual(next(counter), 4) with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_retry_10(self): # Issue 313: retry on exception raised during db_session.__exit__ retries = count() @db_session(retry=3) def test(): next(retries) self.X(a=1, b=1) try: test() except TransactionIntegrityError: self.assertEqual(next(retries), 4) else: self.fail() @raises_exception(PonyRuntimeWarning, '@db_session decorator with `retry=3` option is ignored for test() function ' 'because it is called inside another db_session') def test_retry_11(self): @db_session(retry=3) def test(): pass with warnings.catch_warnings(): warnings.simplefilter('error', PonyRuntimeWarning) with db_session: test() def test_db_session_manager_1(self): with db_session: self.X(a=3, b=3) with db_session: self.assertEqual(count(x for x in self.X), 3) @raises_exception(TypeError, "@db_session can accept 'retry' parameter " "only when used as decorator and not as context manager") def test_db_session_manager_2(self): with db_session(retry=3): self.X(a=3, b=3) def test_db_session_manager_3(self): # Should rollback if the exception is not in the list of allowed_exceptions try: with db_session(allowed_exceptions=[TypeError]): self.X(a=3, b=3) 1/0 except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 2) else: self.fail() def test_db_session_manager_4(self): # Should commit if the exception is in the list of allowed_exceptions try: with db_session(allowed_exceptions=[ZeroDivisionError]): self.X(a=3, b=3) 1/0 except ZeroDivisionError: with db_session: self.assertEqual(count(x for x in self.X), 3) else: self.fail() # restriction removed in 0.7.3: # @raises_exception(TypeError, "@db_session can accept 'ddl' parameter " # "only when used as decorator and not as context manager") def test_db_session_ddl_1(self): with db_session(ddl=True): pass def test_db_session_ddl_1a(self): with db_session(ddl=True): with db_session(ddl=True): pass def test_db_session_ddl_1b(self): with db_session(ddl=True): with db_session: pass @raises_exception(TransactionError, 'Cannot start ddl transaction inside non-ddl transaction') def test_db_session_ddl_1c(self): with db_session: with db_session(ddl=True): pass @raises_exception(TransactionError, "@db_session-decorated test() function with `ddl` option " "cannot be called inside of another db_session") def test_db_session_ddl_2(self): @db_session(ddl=True) def test(): pass with db_session: test() def test_db_session_ddl_3(self): @db_session(ddl=True) def test(): pass test() @raises_exception(ZeroDivisionError) def test_db_session_exceptions_1(self): def before_insert(self): 1/0 self.X.before_insert = before_insert with db_session: self.X(a=3, b=3) # Should raise ZeroDivisionError and not CommitException @raises_exception(ZeroDivisionError) def test_db_session_exceptions_2(self): def before_insert(self): 1 / 0 self.X.before_insert = before_insert with db_session: self.X(a=3, b=3) commit() # Should raise ZeroDivisionError and not CommitException @raises_exception(ZeroDivisionError) def test_db_session_exceptions_3(self): def before_insert(self): 1 / 0 self.X.before_insert = before_insert with db_session: self.X(a=3, b=3) db.commit() # Should raise ZeroDivisionError and not CommitException @raises_exception(ZeroDivisionError) def test_db_session_exceptions_4(self): with db_session: connection = self.db.get_connection() connection.close() 1/0 db = Database('sqlite', ':memory:') class Group(db.Entity): id = PrimaryKey(int) major = Required(unicode) students = Set('Student') class Student(db.Entity): name = Required(unicode) picture = Optional(buffer, lazy=True) group = Required('Group') db.generate_mapping(create_tables=True) with db_session: g1 = Group(id=1, major='Math') g2 = Group(id=2, major='Physics') s1 = Student(id=1, name='S1', group=g1) s2 = Student(id=2, name='S2', group=g1) s3 = Student(id=3, name='S3', group=g2) class TestDBSessionScope(unittest.TestCase): def setUp(self): rollback() def tearDown(self): rollback() def test1(self): with db_session: s1 = Student[1] name = s1.name @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].picture: the database session is over') def test2(self): with db_session: s1 = Student[1] picture = s1.picture @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over') def test3(self): with db_session: s1 = Student[1] group_id = s1.group.id major = s1.group.major @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to Student[1].name: the database session is over') def test4(self): with db_session: s1 = Student[1] s1.name = 'New name' def test5(self): with db_session: g1 = Group[1] self.assertEqual(str(g1.students), 'StudentSet([...])') @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Group[1].students: the database session is over') def test6(self): with db_session: g1 = Group[1] l = len(g1.students) @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test7(self): with db_session: s1 = Student[1] g1 = Group[1] g1.students.remove(s1) @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test8(self): with db_session: g2_students = Group[2].students g1 = Group[1] g1.students = g2_students @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test9(self): with db_session: s3 = Student[3] g1 = Group[1] g1.students.add(s3) @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test10(self): with db_session: g1 = Group[1] g1.students.clear() @raises_exception(DatabaseSessionIsOver, 'Cannot delete object Student[1]: the database session is over') def test11(self): with db_session: s1 = Student[1] s1.delete() @raises_exception(DatabaseSessionIsOver, 'Cannot change object Student[1]: the database session is over') def test12(self): with db_session: s1 = Student[1] s1.set(name='New name') def test_db_session_strict_1(self): with db_session(strict=True): s1 = Student[1] @raises_exception(DatabaseSessionIsOver, 'Cannot read value of Student[1].name: the database session is over') def test_db_session_strict_2(self): with db_session(strict=True): s1 = Student[1] name = s1.name if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_declarative_attr_set_monad.py0000666000000000000000000002134400000000000022627 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) group = Required("Group") marks = Set("Mark") class Group(db.Entity): number = PrimaryKey(int) department = Required(int) students = Set(Student) subjects = Set("Subject") class Subject(db.Entity): name = PrimaryKey(unicode) groups = Set(Group) marks = Set("Mark") class Mark(db.Entity): value = Required(int) student = Required(Student) subject = Required(Subject) PrimaryKey(student, subject) db.generate_mapping(create_tables=True) with db_session: g41 = Group(number=41, department=101) g42 = Group(number=42, department=102) g43 = Group(number=43, department=102) g44 = Group(number=44, department=102) s1 = Student(id=1, name="Joe", scholarship=None, group=g41) s2 = Student(id=2, name="Bob", scholarship=100, group=g41) s3 = Student(id=3, name="Beth", scholarship=500, group=g41) s4 = Student(id=4, name="Jon", scholarship=500, group=g42) s5 = Student(id=5, name="Pete", scholarship=700, group=g42) s6 = Student(id=6, name="Mary", scholarship=300, group=g44) Math = Subject(name="Math") Physics = Subject(name="Physics") History = Subject(name="History") g41.subjects = [ Math, Physics, History ] g42.subjects = [ Math, Physics ] g43.subjects = [ Physics ] Mark(value=5, student=s1, subject=Math) Mark(value=4, student=s2, subject=Physics) Mark(value=3, student=s2, subject=Math) Mark(value=2, student=s2, subject=History) Mark(value=1, student=s3, subject=History) Mark(value=2, student=s3, subject=Math) Mark(value=2, student=s4, subject=Math) class TestAttrSetMonad(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): groups = select(g for g in Group if len(g.students) > 2)[:] self.assertEqual(groups, [Group[41]]) def test2(self): groups = set(select(g for g in Group if len(g.students.name) >= 2)) self.assertEqual(groups, {Group[41], Group[42]}) def test3(self): groups = select(g for g in Group if len(g.students.marks) > 2)[:] self.assertEqual(groups, [Group[41]]) def test3a(self): groups = select(g for g in Group if len(g.students.marks) < 2)[:] self.assertEqual(groups, [Group[42], Group[43], Group[44]]) def test4(self): groups = select(g for g in Group if max(g.students.marks.value) <= 2)[:] self.assertEqual(groups, [Group[42]]) def test5(self): students = select(s for s in Student if len(s.marks.subject.name) > 5)[:] self.assertEqual(students, []) def test6(self): students = set(select(s for s in Student if len(s.marks.subject) >= 2)) self.assertEqual(students, {Student[2], Student[3]}) def test8(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test9(self): students = set(select(s for s in Student if s.group not in (g for g in Group if g.department == 101))) self.assertEqual(students, {Student[4], Student[5], Student[6]}) def test10(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test11(self): students = set(select(g for g in Group if len(g.subjects.groups.subjects) > 1)) self.assertEqual(students, {Group[41], Group[42], Group[43]}) def test12(self): groups = set(select(g for g in Group if len(g.subjects) >= 2)) self.assertEqual(groups, {Group[41], Group[42]}) def test13(self): groups = set(select(g for g in Group if g.students)) self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test14(self): groups = set(select(g for g in Group if not g.students)) self.assertEqual(groups, {Group[43]}) def test15(self): groups = set(select(g for g in Group if exists(g.students))) self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test15a(self): groups = set(select(g for g in Group if not not exists(g.students))) self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test16(self): groups = select(g for g in Group if not exists(g.students))[:] self.assertEqual(groups, [Group[43]]) def test17(self): groups = set(select(g for g in Group if 100 in g.students.scholarship)) self.assertEqual(groups, {Group[41]}) def test18(self): groups = set(select(g for g in Group if 100 not in g.students.scholarship)) self.assertEqual(groups, {Group[42], Group[43], Group[44]}) def test19(self): groups = set(select(g for g in Group if not not not 100 not in g.students.scholarship)) self.assertEqual(groups, {Group[41]}) def test20(self): groups = set(select(g for g in Group if exists(s for s in Student if s.group == g and s.scholarship == 500))) self.assertEqual(groups, {Group[41], Group[42]}) def test21(self): groups = set(select(g for g in Group if g.department is not None)) self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test21a(self): groups = set(select(g for g in Group if not g.department is not None)) self.assertEqual(groups, set()) def test21b(self): groups = set(select(g for g in Group if not not not g.department is None)) self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test22(self): groups = set(select(g for g in Group if 700 in (s.scholarship for s in Student if s.group == g))) self.assertEqual(groups, {Group[42]}) def test23a(self): groups = set(select(g for g in Group if 700 not in g.students.scholarship)) self.assertEqual(groups, {Group[41], Group[43], Group[44]}) def test23b(self): groups = set(select(g for g in Group if 700 not in (s.scholarship for s in Student if s.group == g))) self.assertEqual(groups, {Group[41], Group[43], Group[44]}) @raises_exception(NotImplementedError) def test24(self): groups = set(select(g for g in Group for g2 in Group if g.students == g2.students)) def test25(self): m1 = Mark[Student[1], Subject["Math"]] students = set(select(s for s in Student if m1 in s.marks)) self.assertEqual(students, {Student[1]}) def test26(self): s1 = Student[1] groups = set(select(g for g in Group if s1 in g.students)) self.assertEqual(groups, {Group[41]}) @raises_exception(AttributeError, 'g.students.name.foo') def test27(self): select(g for g in Group if g.students.name.foo == 1) def test28(self): groups = set(select(g for g in Group if not g.students.is_empty())) self.assertEqual(groups, {Group[41], Group[42], Group[44]}) @raises_exception(NotImplementedError) def test29(self): students = select(g.students.select(lambda s: s.scholarship > 0) for g in Group if g.department == 101)[:] def test30a(self): s = Student[2] groups = select(g for g in Group if g.department == 101 and s in g.students.select(lambda s: s.scholarship > 0))[:] self.assertEqual(set(groups), {Group[41]}) def test30b(self): s = Student[2] groups = select(g for g in Group if g.department == 101 and s in g.students.filter(lambda s: s.scholarship > 0))[:] self.assertEqual(set(groups), {Group[41]}) def test30c(self): s = Student[2] groups = select(g for g in Group if g.department == 101 and s in g.students.select())[:] self.assertEqual(set(groups), {Group[41]}) def test30d(self): s = Student[2] groups = select(g for g in Group if g.department == 101 and s in g.students.filter())[:] self.assertEqual(set(groups), {Group[41]}) def test31(self): s = Student[2] groups = select(g for g in Group if g.department == 101 and g.students.exists(lambda s: s.scholarship > 0))[:] self.assertEqual(set(groups), {Group[41]}) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_declarative_exceptions.py0000666000000000000000000003313000000000000022001 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PYPY, PYPY2 import sys, unittest from datetime import date from decimal import Decimal from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) dob = Optional(date) gpa = Optional(float) scholarship = Optional(Decimal, 7, 2) group = Required('Group') courses = Set('Course') class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) dept = Required('Department') class Department(db.Entity): number = PrimaryKey(int) groups = Set(Group) class Course(db.Entity): name = Required(unicode) semester = Required(int) PrimaryKey(name, semester) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: d1 = Department(number=44) g1 = Group(number=101, dept=d1) Student(name='S1', group=g1) Student(name='S2', group=g1) Student(name='S3', group=g1) class TestSQLTranslatorExceptions(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() @raises_exception(NotImplementedError, 'for x in s.name') def test1(self): x = 10 select(s for s in Student for x in s.name) @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for i in x") def test2(self): x = [1, 2, 3] select(s for s in Student for i in x) @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for s2 in g.students") def test3(self): g = Group[101] select(s for s in Student for s2 in g.students) @raises_exception(NotImplementedError, "*args is not supported") def test4(self): args = 'abc' select(s for s in Student if s.name.upper(*args)) if sys.version_info[:2] < (3, 5): # TODO @raises_exception(NotImplementedError) # "**{'a': 'b', 'c': 'd'} is not supported def test5(self): select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'})) @raises_exception(ExprEvalError, "`1 in 2` raises TypeError: argument of type 'int' is not iterable" if not PYPY else "`1 in 2` raises TypeError: 'int' object is not iterable") def test6(self): select(s for s in Student if 1 in 2) @raises_exception(NotImplementedError, 'Group[s.group.number]') def test7(self): select(s for s in Student if Group[s.group.number].dept.number == 44) @raises_exception(ExprEvalError, "`Group[123, 456].dept.number == 44` raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)") def test8(self): select(s for s in Student if Group[123, 456].dept.number == 44) @raises_exception(ExprEvalError, "`Course[123]` raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)") def test9(self): select(s for s in Student if Course[123] in s.courses) @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name < s.gpa" % unicode.__name__) def test10(self): select(s for s in Student if s.name < s.gpa) @raises_exception(ExprEvalError, "`Group(101)` raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument") def test11(self): select(s for s in Student if s.group == Group(101)) @raises_exception(ExprEvalError, "`Group[date(2011, 1, 2)]` raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date) def test12(self): select(s for s in Student if s.group == Group[date(2011, 1, 2)]) @raises_exception(TypeError, "Unsupported operand types 'int' and '%s' for operation '+' in expression: s.group.number + s.name" % unicode.__name__) def test13(self): select(s for s in Student if s.group.number + s.name < 0) @raises_exception(TypeError, "Unsupported operand types 'Decimal' and 'float' for operation '+' in expression: s.scholarship + 1.1") def test14(self): select(s for s in Student if s.scholarship + 1.1 > 10) @raises_exception(TypeError, "Unsupported operand types 'Decimal' and '%s' for operation '**' " "in expression: s.scholarship ** 'abc'" % unicode.__name__) def test15(self): select(s for s in Student if s.scholarship ** 'abc' > 10) @raises_exception(TypeError, "Unsupported operand types '%s' and 'int' for operation '+' in expression: s.name + 2" % unicode.__name__) def test16(self): select(s for s in Student if s.name + 2 > 10) @raises_exception(TypeError, "Step is not supported in s.name[1:3:5]") def test17(self): select(s for s in Student if s.name[1:3:5] == 'A') @raises_exception(TypeError, "Invalid type of start index (expected 'int', got '%s') in string slice s.name['a':1]" % unicode.__name__) def test18(self): select(s for s in Student if s.name['a':1] == 'A') @raises_exception(TypeError, "Invalid type of stop index (expected 'int', got '%s') in string slice s.name[1:'a']" % unicode.__name__) def test19(self): select(s for s in Student if s.name[1:'a'] == 'A') @raises_exception(NotImplementedError, "Negative indices are not supported in string slice s.name[-1:1]") def test20(self): select(s for s in Student if s.name[-1:1] == 'A') @raises_exception(TypeError, "String indices must be integers. Got '%s' in expression s.name['a']" % unicode.__name__) def test21(self): select(s.name for s in Student if s.name['a'] == 'h') @raises_exception(TypeError, "Incomparable types 'int' and '%s' in expression: 1 in s.name" % unicode.__name__) def test22(self): select(s.name for s in Student if 1 in s.name) @raises_exception(TypeError, "Expected '%s' argument but got 'int' in expression s.name.startswith(1)" % unicode.__name__) def test23(self): select(s.name for s in Student if s.name.startswith(1)) @raises_exception(TypeError, "Expected '%s' argument but got 'int' in expression s.name.endswith(1)" % unicode.__name__) def test24(self): select(s.name for s in Student if s.name.endswith(1)) @raises_exception(TypeError, "'chars' argument must be of '%s' type in s.name.strip(1), got: 'int'" % unicode.__name__) def test25(self): select(s.name for s in Student if s.name.strip(1)) @raises_exception(AttributeError, "'%s' object has no attribute 'unknown': s.name.unknown" % unicode.__name__) def test26(self): result = set(select(s for s in Student if s.name.unknown() == "joe")) @raises_exception(AttributeError, "Entity Group does not have attribute foo: s.group.foo") def test27(self): select(s.name for s in Student if s.group.foo.bar == 10) @raises_exception(ExprEvalError, "`g.dept.foo.bar` raises AttributeError: 'Department' object has no attribute 'foo'") def test28(self): g = Group[101] select(s for s in Student if s.name == g.dept.foo.bar) @raises_exception(TypeError, "'year' argument of date(year, month, day) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) def test29(self): select(s for s in Student if s.dob < date('2011', 1, 1)) @raises_exception(NotImplementedError, "date(s.id, 1, 1)") def test30(self): select(s for s in Student if s.dob < date(s.id, 1, 1)) @raises_exception(ExprEvalError, "`max()` raises TypeError: max() expects at least one argument" if PYPY else "`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else "`max()` raises TypeError: max expected 1 argument, got 0") def test31(self): select(s for s in Student if s.id < max()) @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses") def test32(self): select(s for s in Student if s in s.courses) @raises_exception(AttributeError, "s.courses.name.foo") def test33(self): select(s for s in Student if 'x' in s.courses.name.foo.bar) @raises_exception(AttributeError, "s.courses.foo") def test34(self): select(s for s in Student if 'x' in s.courses.foo.bar) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(s.courses.name)" % unicode.__name__) def test35(self): select(s for s in Student if sum(s.courses.name) > 10) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(c.name for c in s.courses)" % unicode.__name__) def test36(self): select(s for s in Student if sum(c.name for c in s.courses) > 10) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got '%s' in sum(c.name for c in s.courses)" % unicode.__name__) def test37(self): select(s for s in Student if sum(c.name for c in s.courses) > 10) @raises_exception(TypeError, "Function avg() expects query or items of numeric type, got '%s' in avg(c.name for c in s.courses)" % unicode.__name__) def test38(self): select(s for s in Student if avg(c.name for c in s.courses) > 10 and len(s.courses) > 1) @raises_exception(TypeError, "strip() takes at most 1 argument (3 given)") def test39(self): select(s for s in Student if s.name.strip(1, 2, 3)) @raises_exception(ExprEvalError, "`len(1, 2) == 3` raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else "`len(1, 2) == 3` raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else "`len(1, 2) == 3` raises TypeError: len() takes exactly one argument (2 given)") def test40(self): select(s for s in Student if len(1, 2) == 3) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got 'Student' in sum(s for s in Student if s.group == g)") def test41(self): select(g for g in Group if sum(s for s in Student if s.group == g) > 1) @raises_exception(TypeError, "Function avg() expects query or items of numeric type, got 'Student' in avg(s for s in Student if s.group == g)") def test42(self): select(g for g in Group if avg(s for s in Student if s.group == g) > 1) @raises_exception(TypeError, "Function min() cannot be applied to type 'Student' in min(s for s in Student if s.group == g)") def test43(self): select(g for g in Group if min(s for s in Student if s.group == g) > 1) @raises_exception(TypeError, "Function max() cannot be applied to type 'Student' in max(s for s in Student if s.group == g)") def test44(self): select(g for g in Group if max(s for s in Student if s.group == g) > 1) @raises_exception(TypeError, "Attribute should be specified for 'max' aggregate function") def test45(self): max(s for s in Student) @raises_exception(TypeError, "Single attribute should be specified for 'max' aggregate function") def test46(self): max((s.name, s.gpa) for s in Student) @raises_exception(TypeError, "Attribute should be specified for 'sum' aggregate function") def test47(self): sum(s for s in Student) @raises_exception(TypeError, "Single attribute should be specified for 'sum' aggregate function") def test48(self): sum((s.name, s.gpa) for s in Student) @raises_exception(TypeError, "'sum' is valid for numeric attributes only") def test49(self): sum(s.name for s in Student) @raises_exception(TypeError, "Cannot compare whole JSON value, you need to select specific sub-item: s.name == {'a':'b'}") def test50(self): # cannot compare JSON value to dynamic string, # because a database does not provide json.dumps(s.name) functionality select(s for s in Student if s.name == {'a': 'b'}) @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__) def test51(self): a = 1 select(s for s in Student if s.name > a & 2) @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name > 1 / a - 3" % unicode.__name__) def test52(self): a = 1 select(s for s in Student if s.name > 1 / a - 3) @raises_exception(TypeError, "Incomparable types '%s' and 'int' in expression: s.name > 1 // a - 3" % unicode.__name__) def test53(self): a = 1 select(s for s in Student if s.name > 1 // a - 3) @raises_exception(TypeError, "Incomparable types '%s' and 'int' in expression: s.name > -a" % unicode.__name__) def test54(self): a = 1 select(s for s in Student if s.name > -a) @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == [1, (2,)]" % unicode.__name__) def test55(self): select(s for s in Student if s.name == [1, (2,)]) @raises_exception(TypeError, "Delete query should be applied to a single entity. Got: (s, g)") def test56(self): delete((s, g) for g in Group for s in g.students if s.gpa > 3) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_declarative_func_monad.py0000666000000000000000000001564300000000000021742 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, PYPY, PYPY2 import sys, unittest from datetime import date, datetime from decimal import Decimal from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): id = PrimaryKey(int) name = Required(unicode) dob = Required(date) last_visit = Required(datetime) scholarship = Required(Decimal, 6, 2) phd = Required(bool) group = Required('Group') class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=1) g2 = Group(number=2) Student(id=1, name="AA", dob=date(1981, 1, 1), last_visit=datetime(2011, 1, 1, 11, 11, 11), scholarship=Decimal("0"), phd=True, group=g1) Student(id=2, name="BB", dob=date(1982, 2, 2), last_visit=datetime(2011, 2, 2, 12, 12, 12), scholarship=Decimal("202.2"), phd=True, group=g1) Student(id=3, name="CC", dob=date(1983, 3, 3), last_visit=datetime(2011, 3, 3, 13, 13, 13), scholarship=Decimal("303.3"), phd=False, group=g1) Student(id=4, name="DD", dob=date(1984, 4, 4), last_visit=datetime(2011, 4, 4, 14, 14, 14), scholarship=Decimal("404.4"), phd=False, group=g2) Student(id=5, name="EE", dob=date(1985, 5, 5), last_visit=datetime(2011, 5, 5, 15, 15, 15), scholarship=Decimal("505.5"), phd=False, group=g2) class TestFuncMonad(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_minmax1(self): result = set(select(s for s in Student if max(s.id, 3) == 3 )) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax2(self): result = set(select(s for s in Student if min(s.id, 3) == 3 )) self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax3(self): result = set(select(s for s in Student if max(s.name, "CC") == "CC" )) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax4(self): result = set(select(s for s in Student if min(s.name, "CC") == "CC" )) self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax5(self): x = chr(128) try: result = set(select(s for s in Student if min(s.name, x) == "CC" )) except TypeError as e: self.assertTrue(PY2 and e.args[0] == "The bytestring '\\x80' contains non-ascii symbols. Try to pass unicode string instead") else: self.assertFalse(PY2) def test_minmax6(self): x = chr(128) try: result = set(select(s for s in Student if min(s.name, x, "CC") == "CC" )) except TypeError as e: self.assertTrue(PY2 and e.args[0] == "The bytestring '\\x80' contains non-ascii symbols. Try to pass unicode string instead") else: self.assertFalse(PY2) def test_minmax7(self): result = set(select(s for s in Student if min(s.phd, 2) == 2 )) def test_date_func1(self): result = set(select(s for s in Student if s.dob >= date(1983, 3, 3))) self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "date(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of date(year, month, day) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) def test_date_func2(self): result = set(select(s for s in Student if s.dob >= date(1983, 'three', 3))) # @raises_exception(NotImplementedError) # def test_date_func3(self): # d = 3 # result = set(select(s for s in Student if s.dob >= date(1983, d, 3))) def test_datetime_func1(self): result = set(select(s for s in Student if s.last_visit >= date(2011, 3, 3))) self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func2(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3))) self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func3(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3, 13, 13, 13))) self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "datetime(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of datetime(...) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) def test_datetime_func4(self): result = set(select(s for s in Student if s.last_visit >= datetime(1983, 'three', 3))) # @raises_exception(NotImplementedError) # def test_datetime_func5(self): # d = 3 # result = set(select(s for s in Student if s.last_visit >= date(1983, d, 3))) def test_datetime_now1(self): result = set(select(s for s in Student if s.dob < date.today())) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) @raises_exception(ExprEvalError, "`1 < datetime.now()` raises TypeError: " + ( "can't compare 'datetime' to 'int'" if PYPY2 else "'<' not supported between instances of 'int' and 'datetime'" if PYPY and sys.version_info >= (3, 6) else "unorderable types: int < datetime" if PYPY else "can't compare datetime.datetime to int" if PY2 else "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else "'<' not supported between instances of 'int' and 'datetime.datetime'")) def test_datetime_now2(self): select(s for s in Student if 1 < datetime.now()) def test_datetime_now3(self): result = set(select(s for s in Student if s.dob < datetime.today())) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test_decimal_func(self): result = set(select(s for s in Student if s.scholarship >= Decimal("303.3"))) self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_concat_1(self): result = set(select(concat(s.name, ':', s.dob.year, ':', s.scholarship) for s in Student)) self.assertEqual(result, {'AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'}) @raises_exception(TranslationError, 'Invalid argument of concat() function: g.students') def test_concat_2(self): result = set(select(concat(g.number, g.students) for g in Group)) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_declarative_join_optimization.py0000666000000000000000000000547500000000000023400 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from datetime import date from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Department(db.Entity): name = Required(str) groups = Set('Group') courses = Set('Course') class Group(db.Entity): number = PrimaryKey(int) dept = Required(Department) major = Required(unicode) students = Set("Student") class Course(db.Entity): name = Required(unicode) dept = Required(Department) semester = Required(int) credits = Required(int) students = Set("Student") PrimaryKey(name, semester) class Student(db.Entity): id = PrimaryKey(int, auto=True) name = Required(unicode) dob = Required(date) picture = Optional(buffer) gpa = Required(float, default=0) group = Required(Group) courses = Set(Course) db.generate_mapping(create_tables=True) class TestM2MOptimization(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): q = select(s for s in Student if len(s.courses) > 2) self.assertEqual(Course._table_ not in flatten(q._translator.conditions), True) def test2(self): q = select(s for s in Student if max(s.courses.semester) > 2) self.assertEqual(Course._table_ not in flatten(q._translator.conditions), True) # def test3(self): # q = select(s for s in Student if max(s.courses.credits) > 2) # self.assertEqual(Course._table_ in flatten(q._translator.conditions), True) # self.assertEqual(Course.students.table in flatten(q._translator.conditions), True) def test4(self): q = select(g for g in Group if sum(g.students.gpa) > 5) self.assertEqual(Group._table_ not in flatten(q._translator.conditions), True) def test5(self): q = select(s for s in Student if s.group.number == 1 or s.group.major == '1') self.assertEqual(Group._table_ in flatten(q._translator.sqlquery.from_ast), True) # def test6(self): ### Broken with ExprEvalError: Group[101] raises ObjectNotFound: Group[101] # q = select(s for s in Student if s.group == Group[101]) # self.assertEqual(Group._table_ not in flatten(q._translator.sqlquery.from_ast), True) def test7(self): q = select(s for s in Student if sum(c.credits for c in Course if s.group.dept == c.dept) > 10) objects = q[:] self.assertEqual(str(q._translator.sqlquery.from_ast), "['FROM', ['s', 'TABLE', 'Student'], ['group', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']]]]") if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_declarative_object_flat_monad.py0000666000000000000000000000424700000000000023261 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) group = Required("Group") marks = Set("Mark") class Group(db.Entity): number = PrimaryKey(int) department = Required(int) students = Set(Student) subjects = Set("Subject") class Subject(db.Entity): name = PrimaryKey(unicode) groups = Set(Group) marks = Set("Mark") class Mark(db.Entity): value = Required(int) student = Required(Student) subject = Required(Subject) PrimaryKey(student, subject) db.generate_mapping(create_tables=True) with db_session: Math = Subject(name="Math") Physics = Subject(name="Physics") History = Subject(name="History") g41 = Group(number=41, department=101, subjects=[ Math, Physics, History ]) g42 = Group(number=42, department=102, subjects=[ Math, Physics ]) g43 = Group(number=43, department=102, subjects=[ Physics ]) s1 = Student(id=1, name="Joe", scholarship=None, group=g41) s2 = Student(id=2, name="Bob", scholarship=100, group=g41) s3 = Student(id=3, name="Beth", scholarship=500, group=g41) s4 = Student(id=4, name="Jon", scholarship=500, group=g42) s5 = Student(id=5, name="Pete", scholarship=700, group=g42) Mark(value=5, student=s1, subject=Math) Mark(value=4, student=s2, subject=Physics) Mark(value=3, student=s2, subject=Math) Mark(value=2, student=s2, subject=History) Mark(value=1, student=s3, subject=History) Mark(value=2, student=s3, subject=Math) Mark(value=2, student=s4, subject=Math) class TestObjectFlatMonad(unittest.TestCase): @db_session def test1(self): result = set(select(s.groups for s in Subject if len(s.name) == 4)) self.assertEqual(result, {Group[41], Group[42]}) @db_session def test2(self): result = set(select(g.students for g in Group if g.department == 102)) self.assertEqual(result, {Student[5], Student[4]}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_declarative_orderby_limit.py0000666000000000000000000001346000000000000022470 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) group = Required(int) db.generate_mapping(create_tables=True) with db_session: Student(id=1, name="B", scholarship=None, group=41) Student(id=2, name="C", scholarship=700, group=41) Student(id=3, name="A", scholarship=500, group=42) Student(id=4, name="D", scholarship=500, group=43) Student(id=5, name="E", scholarship=700, group=42) class TestOrderbyLimit(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): students = set(select(s for s in Student).order_by(Student.name)) self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test2(self): students = set(select(s for s in Student).order_by(Student.name.asc)) self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test3(self): students = set(select(s for s in Student).order_by(Student.id.desc)) self.assertEqual(students, {Student[5], Student[4], Student[3], Student[2], Student[1]}) def test4(self): students = set(select(s for s in Student).order_by(Student.scholarship.asc, Student.group.desc)) self.assertEqual(students, {Student[1], Student[4], Student[3], Student[5], Student[2]}) def test5(self): students = set(select(s for s in Student).order_by(Student.name).limit(3)) self.assertEqual(students, {Student[3], Student[1], Student[2]}) def test6(self): students = set(select(s for s in Student).order_by(Student.name).limit(3, 1)) self.assertEqual(students, {Student[1], Student[2], Student[4]}) def test7(self): q = select(s for s in Student).order_by(Student.name).limit(3, 1) students = set(q) self.assertEqual(students, {Student[1], Student[2], Student[4]}) students = set(q) self.assertEqual(students, {Student[1], Student[2], Student[4]}) # @raises_exception(TypeError, "query.order_by() arguments must be attributes. Got: 'name'") # now generate: ExprEvalError: name raises NameError: name 'name' is not defined # def test8(self): # students = select(s for s in Student).order_by("name") def test9(self): students = set(select(s for s in Student).order_by(Student.id)[1:4]) self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test10(self): students = set(select(s for s in Student).order_by(Student.id)[:4]) self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4]}) # @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") # def test11(self): # students = select(s for s in Student).order_by(Student.id)[4:] @raises_exception(TypeError, "Parameter 'start' of slice object cannot be negative") def test12(self): students = select(s for s in Student).order_by(Student.id)[-3:2] @raises_exception(TypeError, 'If you want apply index to a query, convert it to list first') def test13(self): students = select(s for s in Student).order_by(Student.id)[3] self.assertEqual(students, Student[4]) # @raises_exception(TypeError, 'If you want apply index to query, convert it to list first') # def test14(self): # students = select(s for s in Student).order_by(Student.id)["a"] def test15(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:3]) self.assertEqual(students, {Student[2], Student[3]}) def test16(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:]) self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test17(self): students = set(select(s for s in Student).order_by(Student.id)[:4][1:]) self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test18(self): students = set(select(s for s in Student).order_by(Student.id)[:]) self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test19(self): q = select(s for s in Student).order_by(Student.id) students = q[1:3] self.assertEqual(students, [Student[2], Student[3]]) students = q[2:4] self.assertEqual(students, [Student[3], Student[4]]) students = q[:] self.assertEqual(students, [Student[1], Student[2], Student[3], Student[4], Student[5]]) def test20(self): q = select(s for s in Student).limit(offset=2) self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) self.assertTrue('LIMIT -1 OFFSET 2' in db.last_sql) def test21(self): q = select(s for s in Student).limit(0, offset=2) self.assertEqual(set(q), set()) def test22(self): q = select(s for s in Student).order_by(Student.id).limit(offset=1) self.assertEqual(set(q), {Student[2], Student[3], Student[4], Student[5]}) def test23(self): q = select(s for s in Student)[2:2] self.assertEqual(set(q), set()) self.assertTrue('LIMIT 0' in db.last_sql) def test24(self): q = select(s for s in Student)[2:] self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) def test25(self): q = select(s for s in Student)[:2] self.assertEqual(set(q), {Student[2], Student[1]}) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_declarative_query_set_monad.py0000666000000000000000000003515100000000000023023 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Group(db.Entity): id = PrimaryKey(int) students = Set('Student') class Student(db.Entity): name = Required(unicode) age = Required(int) group = Required('Group') scholarship = Required(int, default=0) courses = Set('Course') class Course(db.Entity): name = Required(unicode) semester = Required(int) PrimaryKey(name, semester) students = Set('Student') db.generate_mapping(create_tables=True) with db_session: g1 = Group(id=1) g2 = Group(id=2) s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) s2 = Student(id=2, name='S2', age=23, group=g1, scholarship=100) s3 = Student(id=3, name='S3', age=23, group=g2, scholarship=500) c1 = Course(name='C1', semester=1, students=[s1, s2]) c2 = Course(name='C2', semester=1, students=[s2, s3]) c3 = Course(name='C3', semester=2, students=[s3]) class TestQuerySetMonad(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_len(self): result = set(select(g for g in Group if len(g.students) > 1)) self.assertEqual(result, {Group[1]}) def test_len_2(self): result = set(select(g for g in Group if len(s for s in Student if s.group == g) > 1)) self.assertEqual(result, {Group[1]}) def test_len_3(self): result = set(select(g for g in Group if len(s.name for s in Student if s.group == g) > 1)) self.assertEqual(result, {Group[1]}) def test_count_1(self): result = set(select(g for g in Group if count(s.name for s in g.students) > 1)) self.assertEqual(result, {Group[1]}) def test_count_2(self): result = set(select(g for g in Group if select(s.name for s in g.students).count() > 1)) self.assertEqual(result, {Group[1]}) def test_count_3(self): result = set(select(s for s in Student if count(c for c in s.courses) > 1)) self.assertEqual(result, {Student[2], Student[3]}) def test_count_3a(self): result = set(select(s for s in Student if select(c for c in s.courses).count() > 1)) self.assertEqual(result, {Student[2], Student[3]}) self.assertTrue('DISTINCT' in db.last_sql) def test_count_3b(self): result = set(select(s for s in Student if select(c for c in s.courses).count(distinct=False) > 1)) self.assertEqual(result, {Student[2], Student[3]}) self.assertTrue('DISTINCT' not in db.last_sql) def test_count_4(self): result = set(select(c for c in Course if count(s for s in c.students) > 1)) self.assertEqual(result, {Course['C1', 1], Course['C2', 1]}) def test_count_5(self): result = select(c.semester for c in Course).count(distinct=True) self.assertEqual(result, 2) def test_count_6(self): result = select(c for c in Course).count() self.assertEqual(result, 3) self.assertTrue('DISTINCT' not in db.last_sql) def test_count_7(self): result = select(c for c in Course).count(distinct=True) self.assertEqual(result, 3) self.assertTrue('DISTINCT' in db.last_sql) def test_count_8(self): select(count(c.semester, distinct=False) for c in Course)[:] self.assertTrue('DISTINCT' not in db.last_sql) @raises_exception(TypeError, "`distinct` value should be True or False. Got: s.name.startswith('P')") def test_count_9(self): select(count(s, distinct=s.name.startswith('P')) for s in Student) def test_count_10(self): select(count('*', distinct=True) for s in Student)[:] self.assertTrue('DISTINCT' not in db.last_sql) @raises_exception(TypeError) def test_sum_1(self): result = set(select(g for g in Group if sum(s for s in Student if s.group == g) > 1)) @raises_exception(TypeError) def test_sum_2(self): select(g for g in Group if sum(s.name for s in Student if s.group == g) > 1) def test_sum_3(self): result = sum(s.scholarship for s in Student) self.assertEqual(result, 600) def test_sum_4(self): result = sum(s.scholarship for s in Student if s.name == 'Unnamed') self.assertEqual(result, 0) def test_sum_5(self): result = select(c.semester for c in Course).sum() self.assertEqual(result, 4) def test_sum_6(self): result = select(c.semester for c in Course).sum(distinct=True) self.assertEqual(result, 3) def test_sum_7(self): result = set(select(g for g in Group if sum(s.scholarship for s in Student if s.group == g) > 500)) self.assertEqual(result, set()) def test_sum_8(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum() > 200)) self.assertEqual(result, {Group[2]}) self.assertTrue('DISTINCT' not in db.last_sql) def test_sum_9(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum(distinct=True) > 200)) self.assertEqual(result, {Group[2]}) self.assertTrue('DISTINCT' in db.last_sql) def test_sum_10(self): select(sum(s.scholarship, distinct=True) for s in Student)[:] self.assertTrue('SUM(DISTINCT' in db.last_sql) def test_min_1(self): result = set(select(g for g in Group if min(s.name for s in Student if s.group == g) == 'S1')) self.assertEqual(result, {Group[1]}) @raises_exception(TypeError) def test_min_2(self): select(g for g in Group if min(s for s in Student if s.group == g) == None) def test_min_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).min() == 0)) self.assertEqual(result, {Group[1]}) def test_min_4(self): result = select(s.scholarship for s in Student).min() self.assertEqual(0, result) def test_max_1(self): result = set(select(g for g in Group if max(s.scholarship for s in Student if s.group == g) > 100)) self.assertEqual(result, {Group[2]}) @raises_exception(TypeError) def test_max_2(self): select(g for g in Group if max(s for s in Student if s.group == g) == None) def test_max_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).max() == 100)) self.assertEqual(result, {Group[1]}) def test_max_4(self): result = select(s.scholarship for s in Student).max() self.assertEqual(result, 500) def test_avg_1(self): result = select(g for g in Group if avg(s.scholarship for s in Student if s.group == g) == 50)[:] self.assertEqual(result, [Group[1]]) def test_avg_2(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg() == 50)) self.assertEqual(result, {Group[1]}) def test_avg_3(self): result = select(c.semester for c in Course).avg() self.assertAlmostEqual(1.33, result, places=2) def test_avg_4(self): result = select(c.semester for c in Course).avg(distinct=True) self.assertAlmostEqual(1.5, result) def test_avg_5(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg(distinct=True) == 50)) self.assertEqual(result, {Group[1]}) self.assertTrue('AVG(DISTINCT' in db.last_sql) def test_avg_6(self): select(avg(s.scholarship, distinct=True) for s in Student)[:] self.assertTrue('AVG(DISTINCT' in db.last_sql) def test_exists_1(self): result = set(select(g for g in Group if exists(s for s in g.students if s.age < 23))) self.assertEqual(result, {Group[1]}) def test_exists_2(self): result = set(select(g for g in Group if exists(s.age < 23 for s in g.students))) self.assertEqual(result, {Group[1]}) def test_exists_3(self): result = set(select(g for g in Group if (s.age < 23 for s in g.students))) self.assertEqual(result, {Group[1]}) def test_negate(self): result = set(select(g for g in Group if not(s.scholarship for s in Student if s.group == g))) self.assertEqual(result, set()) def test_no_conditions(self): students = set(select(s for s in Student if s.group in (g for g in Group))) self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test_no_conditions_2(self): students = set(select(s for s in Student if s.scholarship == max(s.scholarship for s in Student))) self.assertEqual(students, {Student[3]}) def test_hint_join_1(self): result = set(select(s for s in Student if JOIN(s.group in select(g for g in Group if g.id < 2)))) self.assertEqual(result, {Student[1], Student[2]}) def test_hint_join_2(self): result = set(select(s for s in Student if JOIN(s.group not in select(g for g in Group if g.id < 2)))) self.assertEqual(result, {Student[3]}) def test_hint_join_3(self): result = set(select(s for s in Student if JOIN(s.scholarship in select(s.scholarship + 100 for s in Student if s.name != 'S2')))) self.assertEqual(result, {Student[2]}) def test_hint_join_4(self): result = set(select(g for g in Group if JOIN(g in select(s.group for s in g.students)))) self.assertEqual(result, {Group[1], Group[2]}) def test_group_concat_1(self): result = select(s.name for s in Student).group_concat() self.assertEqual(result, 'S1,S2,S3') def test_group_concat_2(self): result = select(s.name for s in Student).group_concat('-') self.assertEqual(result, 'S1-S2-S3') def test_group_concat_3(self): result = select(s for s in Student if s.name in group_concat(s.name for s in Student))[:] self.assertEqual(set(result), {Student[1], Student[2], Student[3]}) def test_group_concat_4(self): result = Student.select().group_concat() self.assertEqual(result, '1,2,3') def test_group_concat_5(self): result = Student.select().group_concat('.') self.assertEqual(result, '1.2.3') @raises_exception(TypeError, '`group_concat` cannot be used with entity with composite primary key') def test_group_concat_6(self): select(group_concat(s.courses, '-') for s in Student) def test_group_concat_7(self): result = select(group_concat(c.semester) for c in Course)[:] self.assertEqual(result[0], '1,1,2') def test_group_concat_8(self): result = select(group_concat(c.semester, '-') for c in Course)[:] self.assertEqual(result[0], '1-1-2') def test_group_concat_9(self): result = select(group_concat(c.semester, distinct=True) for c in Course)[:] self.assertEqual(result[0], '1,2') def test_group_concat_10(self): result = group_concat((s.name for s in Student if int(s.name[1]) > 1), sep='-') self.assertEqual(result, 'S2-S3') def test_group_concat_11(self): result = group_concat((c.semester for c in Course), distinct=True) self.assertEqual(result, '1,2') @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') def test_select_from_select_1(self): query = select(s for s in Student if s.scholarship > 0)[:] result = set(select(x for x in query)) self.assertEqual(result, {}) def test_select_from_select_2(self): p, q = 50, 400 query = select(s for s in Student if s.scholarship > p) result = select(x.id for x in query if x.scholarship < q)[:] self.assertEqual(set(result), {2}) def test_select_from_select_3(self): p, q = 50, 400 g = (s for s in Student if s.scholarship > p) result = select(x.id for x in g if x.scholarship < q)[:] self.assertEqual(set(result), {2}) def test_select_from_select_4(self): p, q = 50, 400 result = select(x.id for x in (s for s in Student if s.scholarship > p) if x.scholarship < q)[:] self.assertEqual(set(result), {2}) def test_select_from_select_5(self): p, q = 50, 400 result = select(x.id for x in select(s for s in Student if s.scholarship > 0) if x.scholarship < 400)[:] self.assertEqual(set(result), {2}) def test_select_from_select_6(self): query = select(s.name for s in Student if s.scholarship > 0) result = select(x for x in query if not x.endswith('3')) self.assertEqual(set(result), {'S2'}) @raises_exception(TranslationError, 'Too many values to unpack "for a, b in select(s for ...)" (expected 2, got 1)') def test_select_from_select_7(self): query = select(s for s in Student if s.scholarship > 0) result = select(a for a, b in query) @raises_exception(NotImplementedError, 'Please unpack a tuple of (s.name, s.group) in for-loop ' 'to individual variables (like: "for x, y in ...")') def test_select_from_select_8(self): query = select((s.name, s.group) for s in Student if s.scholarship > 0) result = select(x for x in query) @raises_exception(TranslationError, 'Not enough values to unpack "for x, y in ' 'select(s.name, s.group, s.scholarship for ...)" (expected 2, got 3)') def test_select_from_select_9(self): query = select((s.name, s.group, s.scholarship) for s in Student if s.scholarship > 0) result = select(x for x, y in query) def test_select_from_select_10(self): query = select((s.name, s.age) for s in Student if s.scholarship > 0) result = select(n for n, a in query if n.endswith('2') and a > 20) self.assertEqual(set(x for x in result), {'S2'}) def test_aggregations_1(self): query = select((min(s.age), max(s.scholarship)) for s in Student) result = query[:] self.assertEqual(result, [(20, 500)]) def test_aggregations_2(self): query = select((min(s.age), max(s.scholarship)) for s in Student for g in Group) result = query[:] self.assertEqual(result, [(20, 500)]) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_declarative_sqltranslator.py0000666000000000000000000005072400000000000022541 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 import unittest from datetime import date from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Department(db.Entity): number = PrimaryKey(int) groups = Set('Group') courses = Set('Course') class Student(db.Entity): name = Required(unicode) group = Required('Group') scholarship = Required(int, default=0) picture = Optional(buffer) courses = Set('Course') grades = Set('Grade') class Group(db.Entity): id = PrimaryKey(int) students = Set(Student) dept = Required(Department) rooms = Set('Room') class Course(db.Entity): dept = Required(Department) name = Required(unicode) credits = Optional(int) semester = Required(int) PrimaryKey(name, semester) grades = Set('Grade') students = Set(Student) class Grade(db.Entity): student = Required(Student) course = Required(Course) PrimaryKey(student, course) value = Required(str) date = Optional(date) teacher = Required('Teacher') class Teacher(db.Entity): name = Required(unicode) grades = Set(Grade) class Room(db.Entity): name = PrimaryKey(unicode) groups = Set(Group) db.generate_mapping(create_tables=True) with db_session: d1 = Department(number=44) d2 = Department(number=43) g1 = Group(id=1, dept=d1) g2 = Group(id=2, dept=d2) s1 = Student(id=1, name='S1', group=g1, scholarship=0) s2 = Student(id=2, name='S2', group=g1, scholarship=100) s3 = Student(id=3, name='S3', group=g2, scholarship=500) c1 = Course(name='Math', semester=1, dept=d1) c2 = Course(name='Economics', semester=1, dept=d1, credits=3) c3 = Course(name='Physics', semester=2, dept=d2) t1 = Teacher(id=101, name="T1") t2 = Teacher(id=102, name="T2") Grade(student=s1, course=c1, value='C', teacher=t2, date=date(2011, 1, 1)) Grade(student=s1, course=c3, value='A', teacher=t1, date=date(2011, 2, 1)) Grade(student=s2, course=c2, value='B', teacher=t1) r1 = Room(name='Room1') r2 = Room(name='Room2') r3 = Room(name='Room3') g1.rooms = [ r1, r2 ] g2.rooms = [ r2, r3 ] c1.students.add(s1) c1.students.add(s2) c2.students.add(s2) db2 = Database('sqlite', ':memory:') class Room2(db2.Entity): name = PrimaryKey(unicode) db2.generate_mapping(create_tables=True) name1 = 'S1' class TestSQLTranslator(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_select1(self): result = set(select(s for s in Student)) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_select_param(self): result = select(s for s in Student if s.name == name1)[:] self.assertEqual(result, [Student[1]]) def test_select_object_param(self): stud1 = Student[1] result = set(select(s for s in Student if s != stud1)) self.assertEqual(result, {Student[2], Student[3]}) def test_select_deref(self): x = 'S1' result = select(s for s in Student if s.name == x)[:] self.assertEqual(result, [Student[1]]) def test_select_composite_key(self): grade1 = Grade[Student[1], Course['Physics', 2]] result = select(g for g in Grade if g != grade1) grades = [ grade.value for grade in result ] grades.sort() self.assertEqual(grades, ['B', 'C']) def test_function_max1(self): result = select(s for s in Student if max(s.grades.value) == 'C')[:] self.assertEqual(result, [Student[1]]) @raises_exception(TypeError) def test_function_max2(self): grade1 = Grade[Student[1], Course['Physics', 2]] select(s for s in Student if max(s.grades) == grade1) def test_function_min(self): result = select(s for s in Student if min(s.grades.value) == 'B')[:] self.assertEqual(result, [Student[2]]) @raises_exception(TypeError) def test_function_min2(self): grade1 = Grade[Student[1], Course['Physics', 2]] select(s for s in Student if min(s.grades) == grade1) def test_min3(self): d = date(2011, 1, 1) result = set(select(g for g in Grade if min(g.date, d) == d and g.date is not None)) self.assertEqual(result, {Grade[Student[1], Course[u'Math', 1]], Grade[Student[1], Course[u'Physics', 2]]}) def test_function_len1(self): result = select(s for s in Student if len(s.grades) == 1)[:] self.assertEqual(result, [Student[2]]) def test_function_len2(self): result = select(s for s in Student if max(s.grades.value) == 'C')[:] self.assertEqual(result, [Student[1]]) def test_function_sum1(self): result = select(g for g in Group if sum(g.students.scholarship) == 100)[:] self.assertEqual(result, [Group[1]]) def test_function_avg1(self): result = select(g for g in Group if avg(g.students.scholarship) == 50)[:] self.assertEqual(result, [Group[1]]) @raises_exception(TypeError) def test_function_sum2(self): select(g for g in Group if sum(g.students) == 100) @raises_exception(TypeError) def test_function_sum3(self): select(g for g in Group if sum(g.students.name) == 100) def test_function_abs(self): result = select(s for s in Student if abs(s.scholarship) == 100)[:] self.assertEqual(result, [Student[2]]) def test_builtin_in_locals(self): x = max gen = (s.group for s in Student if x(s.grades.value) == 'C') result = select(gen)[:] self.assertEqual(result, [Group[1]]) x = min result = select(gen)[:] self.assertEqual(result, []) # @raises_exception(TranslationError, "Name 'g' must be defined in query") # def test_name(self): # select(s for s in Student for g in g.subjects) def test_chain1(self): result = set(select(g for g in Group for s in g.students if s.name.endswith('3'))) self.assertEqual(result, {Group[2]}) def test_chain2(self): result = set(select(s for g in Group if g.dept.number == 44 for s in g.students if s.name.startswith('S'))) self.assertEqual(result, {Student[1], Student[2]}) def test_chain_m2m(self): result = set(select(g for g in Group for r in g.rooms if r.name == 'Room2')) self.assertEqual(result, {Group[1], Group[2]}) @raises_exception(TranslationError, 'All entities in a query must belong to the same database') def test_two_diagrams(self): select(g for g in Group for r in Room2 if r.name == 'Room2') def test_add_sub_mul_etc(self): result = select(s for s in Student if ((-s.scholarship + 200) * 10 / 5 - 100) ** 2 == 10000 or 5 == 2)[:] self.assertEqual(result, [Student[2]]) def test_subscript(self): result = set(select(s for s in Student if s.name[1] == '2')) self.assertEqual(result, {Student[2]}) def test_slice(self): result = set(select(s for s in Student if s.name[:1] == 'S')) self.assertEqual(result, {Student[3], Student[2], Student[1]}) def test_attr_chain(self): s1 = Student[1] result = select(s for s in Student if s == s1)[:] self.assertEqual(result, [Student[1]]) result = select(s for s in Student if not s == s1)[:] self.assertEqual(result, [Student[2], Student[3]]) result = select(s for s in Student if s.group == s1.group)[:] self.assertEqual(result, [Student[1], Student[2]]) result = select(s for s in Student if s.group.dept == s1.group.dept)[:] self.assertEqual(result, [Student[1], Student[2]]) def test_list_monad1(self): result = select(s for s in Student if s.name in ['S1'])[:] self.assertEqual(result, [Student[1]]) def test_list_monad2(self): result = select(s for s in Student if s.name not in ['S1', 'S2'])[:] self.assertEqual(result, [Student[3]]) def test_list_monad3(self): grade1 = Grade[Student[1], Course['Physics', 2]] grade2 = Grade[Student[1], Course['Math', 1]] result = set(select(g for g in Grade if g in [grade1, grade2])) self.assertEqual(result, {grade1, grade2}) result = set(select(g for g in Grade if g not in [grade1, grade2])) self.assertEqual(result, {Grade[Student[2], Course['Economics', 1]]}) def test_tuple_monad1(self): n1 = 'S1' n2 = 'S2' result = select(s for s in Student if s.name in (n1, n2))[:] self.assertEqual(result, [Student[1], Student[2]]) def test_None_value(self): result = select(s for s in Student if s.name is None)[:] self.assertEqual(result, []) def test_None_value2(self): result = select(s for s in Student if None == s.name)[:] self.assertEqual(result, []) def test_None_value3(self): n = None result = select(s for s in Student if s.name == n)[:] self.assertEqual(result, []) def test_None_value4(self): n = None result = select(s for s in Student if n == s.name)[:] self.assertEqual(result, []) @raises_exception(TranslationError, "External parameter 'a' cannot be used as query result") def test_expr1(self): a = 100 result = select(a for s in Student) def test_expr2(self): result = set(select(s.group for s in Student)) self.assertEqual(result, {Group[1], Group[2]}) def test_numeric_binop(self): i = 100 f = 2.0 result = select(s for s in Student if s.scholarship > i + f)[:] self.assertEqual(result, [Student[3]]) def test_string_const_monad(self): result = select(s for s in Student if len(s.name) > len('ABC'))[:] self.assertEqual(result, []) def test_numeric_to_bool1(self): result = set(select(s for s in Student if s.name != 'John' or s.scholarship)) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_numeric_to_bool2(self): result = set(select(s for s in Student if not s.scholarship)) self.assertEqual(result, {Student[1]}) def test_not_monad1(self): result = set(select(s for s in Student if not (s.scholarship > 0 and s.name != 'S1'))) self.assertEqual(result, {Student[1]}) def test_not_monad2(self): result = set(select(s for s in Student if not not (s.scholarship > 0 and s.name != 'S1'))) self.assertEqual(result, {Student[2], Student[3]}) def test_subquery_with_attr(self): result = set(select(s for s in Student if max(g.value for g in s.grades) == 'C')) self.assertEqual(result, {Student[1]}) def test_query_reuse(self): q = select(s for s in Student if s.scholarship > 0) q.count() self.assertTrue("ORDER BY" not in db.last_sql.upper()) objects = q[:] # should not throw exception, query can be reused def test_select(self): result = Student.select(lambda s: s.scholarship > 0)[:] self.assertEqual(result, [Student[2], Student[3]]) def test_get(self): result = Student.get(lambda s: s.scholarship == 500) self.assertEqual(result, Student[3]) def test_order_by(self): result = list(Student.select().order_by(Student.name)) self.assertEqual(result, [Student[1], Student[2], Student[3]]) def test_read_inside_query(self): result = set(select(s for s in Student if Group[1].dept.number == 44)) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_crud_attr_chain(self): result = set(select(s for s in Student if Group[1].dept.number == s.group.dept.number)) self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key1(self): result = set(select(t for t in Teacher if Grade[Student[1], Course['Physics', 2]] in t.grades)) self.assertEqual(result, {Teacher.get(name='T1')}) def test_composite_key2(self): result = set(select(s for s in Student if Course['Math', 1] in s.courses)) self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key3(self): result = set(select(s for s in Student if Course['Math', 1] not in s.courses)) self.assertEqual(result, {Student[3]}) def test_composite_key4(self): result = set(select(s for s in Student if len(c for c in Course if c not in s.courses) == 2)) self.assertEqual(result, {Student[1]}) def test_composite_key5(self): result = set(select(s for s in Student if not (c for c in Course if c not in s.courses))) self.assertEqual(result, set()) def test_composite_key6(self): result = set(select(c for c in Course if c not in (c2 for s in Student for c2 in s.courses))) self.assertEqual(result, {Course['Physics', 2]}) def test_composite_key7(self): result = set(select(c for s in Student for c in s.courses)) self.assertEqual(result, {Course['Math', 1], Course['Economics', 1]}) def test_contains1(self): s1 = Student[1] result = set(select(g for g in Group if s1 in g.students)) self.assertEqual(result, {Group[1]}) def test_contains2(self): s1 = Student[1] result = set(select(g for g in Group if s1.name in g.students.name)) self.assertEqual(result, {Group[1]}) def test_contains3(self): s1 = Student[1] result = set(select(g for g in Group if s1 not in g.students)) self.assertEqual(result, {Group[2]}) def test_contains4(self): s1 = Student[1] result = set(select(g for g in Group if s1.name not in g.students.name)) self.assertEqual(result, {Group[2]}) def test_buffer_monad1(self): try: select(s for s in Student if s.picture == buffer('abc')) except TypeError as e: self.assertTrue(not PY2 and str(e) == 'string argument without an encoding') else: self.assertTrue(PY2) def test_buffer_monad2(self): select(s for s in Student if s.picture == buffer('abc', 'ascii')) def test_database_monad(self): result = set(select(s for s in db.Student if db.Student[1] == s)) self.assertEqual(result, {Student[1]}) def test_duplicate_name(self): result = set(select(x for x in Student if x.group in (x for x in Group))) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_hint_join1(self): result = set(select(s for s in Student if JOIN(max(s.courses.credits) == 3))) self.assertEqual(result, {Student[2]}) def test_hint_join2(self): result = set(select(c for c in Course if JOIN(len(c.students) == 1))) self.assertEqual(result, {Course['Economics', 1]}) def test_tuple_param(self): x = Student[1], Student[2] result = set(select(s for s in Student if s not in x)) self.assertEqual(result, {Student[3]}) @raises_exception(TypeError, "Expression `x` should not contain None values") def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) self.assertEqual(result, {Student[3]}) def test_method_monad(self): result = set(select(s for s in Student if s not in Student.select(lambda s: s.scholarship > 0))) self.assertEqual(result, {Student[1]}) def test_lambda_1(self): q = select(s for s in Student) q = q.filter(lambda s: s.name == 'S1') self.assertEqual(list(q), [Student[1]]) def test_lambda_2(self): q = select(s for s in Student) q = q.filter(lambda stud: stud.name == 'S1') self.assertEqual(list(q), [Student[1]]) def test_lambda_3(self): q = select(s for s in Student) q = q.filter(lambda stud: exists(x for x in Student if stud.name < x.name)) self.assertEqual(set(q), {Student[1], Student[2]}) def test_lambda_4(self): q = select(s for s in Student) q = q.filter(lambda stud: exists(s for s in Student if stud.name < s.name)) self.assertEqual(set(q), {Student[1], Student[2]}) def test_optimized_1(self): q = select((g, count(g.students)) for g in Group if count(g.students) > 1) self.assertEqual(set(q), {(Group[1], 2)}) def test_optimized_2(self): q = select((s, count(s.courses)) for s in Student if count(s.courses) > 1) self.assertEqual(set(q), {(Student[2], 2)}) def test_delete(self): q = select(g for g in Grade if g.teacher.id == 101).delete() q2 = select(g for g in Grade)[:] self.assertEqual([g.value for g in q2], ['C']) def test_delete_2(self): delete(g for g in Grade if g.teacher.id == 101) q2 = select(g for g in Grade)[:] self.assertEqual([g.value for g in q2], ['C']) def test_delete_3(self): select(g for g in Grade if g.teacher.id == 101).delete(bulk=True) q2 = select(g for g in Grade)[:] self.assertEqual([g.value for g in q2], ['C']) def test_delete_4(self): select(g for g in Grade if exists(g2 for g2 in Grade if g2.value > g.value)).delete(bulk=True) q2 = select(g for g in Grade)[:] self.assertEqual([g.value for g in q2], ['C']) def test_select_2(self): result = select(s for s in Student)[:] self.assertEqual(result, [Student[1], Student[2], Student[3]]) def test_select_add(self): result = [None] + select(s for s in Student)[:] self.assertEqual(result, [None, Student[1], Student[2], Student[3]]) def test_query_result_radd(self): result = select(s for s in Student)[:] + [None] self.assertEqual(result, [Student[1], Student[2], Student[3], None]) def test_query_result_sort(self): result = select(s for s in Student)[:] result.sort() self.assertEqual(result, [Student[1], Student[2], Student[3]]) def test_query_result_reverse(self): result = select(s for s in Student)[:] items = list(result) result.reverse() self.assertEqual(items, list(reversed(result))) def test_query_result_shuffle(self): result = select(s for s in Student)[:] items = set(result) result.shuffle() self.assertEqual(items, set(result)) def test_query_result_to_list(self): result = select(s for s in Student)[:] items = result.to_list() self.assertTrue(type(items) is list) @raises_exception(TypeError, 'In order to do item assignment, cast QueryResult to list first') def test_query_result_setitem(self): result = select(s for s in Student)[:] result[0] = None @raises_exception(TypeError, 'In order to do item deletion, cast QueryResult to list first') def test_query_result_delitem(self): result = select(s for s in Student)[:] del result[0] @raises_exception(TypeError, 'In order to do +=, cast QueryResult to list first') def test_query_result_iadd(self): result = select(s for s in Student)[:] result += None @raises_exception(TypeError, 'In order to do append, cast QueryResult to list first') def test_query_result_append(self): result = select(s for s in Student)[:] result.append(None) @raises_exception(TypeError, 'In order to do clear, cast QueryResult to list first') def test_query_result_clear(self): result = select(s for s in Student)[:] result.clear() @raises_exception(TypeError, 'In order to do extend, cast QueryResult to list first') def test_query_result_extend(self): result = select(s for s in Student)[:] result.extend([]) @raises_exception(TypeError, 'In order to do insert, cast QueryResult to list first') def test_query_result_insert(self): result = select(s for s in Student)[:] result.insert(0, None) @raises_exception(TypeError, 'In order to do pop, cast QueryResult to list first') def test_query_result_pop(self): result = select(s for s in Student)[:] result.pop() @raises_exception(TypeError, 'In order to do remove, cast QueryResult to list first') def test_query_result_remove(self): result = select(s for s in Student)[:] result.remove(None) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_declarative_sqltranslator2.py0000666000000000000000000002650200000000000022620 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from datetime import date from decimal import Decimal from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Department(db.Entity): number = PrimaryKey(int, auto=True) name = Required(unicode, unique=True) groups = Set("Group") courses = Set("Course") class Group(db.Entity): number = PrimaryKey(int) major = Required(unicode) dept = Required("Department") students = Set("Student") class Course(db.Entity): name = Required(unicode) semester = Required(int) lect_hours = Required(int) lab_hours = Required(int) credits = Required(int) dept = Required(Department) students = Set("Student") PrimaryKey(name, semester) class Student(db.Entity): id = PrimaryKey(int, auto=True) name = Required(unicode) dob = Required(date) tel = Optional(str) picture = Optional(buffer, lazy=True) gpa = Required(float, default=0) phd = Optional(bool) group = Required(Group) courses = Set(Course) db.generate_mapping(create_tables=True) with db_session: d1 = Department(name="Department of Computer Science") d2 = Department(name="Department of Mathematical Sciences") d3 = Department(name="Department of Applied Physics") c1 = Course(name="Web Design", semester=1, dept=d1, lect_hours=30, lab_hours=30, credits=3) c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, lect_hours=40, lab_hours=20, credits=4) c3 = Course(name="Linear Algebra", semester=1, dept=d2, lect_hours=30, lab_hours=30, credits=4) c4 = Course(name="Statistical Methods", semester=2, dept=d2, lect_hours=50, lab_hours=25, credits=5) c5 = Course(name="Thermodynamics", semester=2, dept=d3, lect_hours=25, lab_hours=40, credits=4) c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, lect_hours=40, lab_hours=30, credits=5) g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d2) g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) g105 = Group(number=105, major='B.E in Electronics', dept=d3) g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) Student(name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, phd=True, courses=[c1, c2, c4, c6]) Student(name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, phd=True, courses=[c1, c3, c4, c5]) Student(name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, courses=[c3, c5, c6]) Student(name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, courses=[c1, c4, c5, c6]) Student(name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, courses=[c1, c2, c4, c6]) Student(name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, courses=[c1, c2, c5]) Student(name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, courses=[c1, c3, c5, c6]) class TestSQLTranslator2(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_distinct1(self): q = select(c.students for c in Course) self.assertEqual(q._translator.distinct, True) self.assertEqual(q.count(), 7) def test_distinct3(self): q = select(d for d in Department if len(s for c in d.courses for s in c.students) > len(s for s in Student)) self.assertEqual(q[:], []) self.assertTrue('DISTINCT' in db.last_sql) def test_distinct4(self): q = select(d for d in Department if len(d.groups.students) > 3) self.assertEqual(q[:], [Department[2]]) self.assertTrue("DISTINCT" not in db.last_sql) def test_distinct5(self): result = set(select(s for s in Student)) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_distinct6(self): result = set(select(s for s in Student).distinct()) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_not_null1(self): q = select(g for g in Group if '123-45-67' not in g.students.tel and g.dept == Department[1]) not_null = "IS_NOT_NULL COLUMN student tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, True) self.assertEqual(q[:], [Group[101]]) def test_not_null2(self): q = select(g for g in Group if 'John' not in g.students.name and g.dept == Department[1]) not_null = "IS_NOT_NULL COLUMN student name" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, False) self.assertEqual(q[:], [Group[101]]) def test_chain_of_attrs_inside_for1(self): result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) def test_chain_of_attrs_inside_for2(self): pony.options.SIMPLE_ALIASES = False result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) pony.options.SIMPLE_ALIASES = True def test_non_entity_result1(self): result = select((s.name, s.group.number) for s in Student if s.name.startswith("J"))[:] self.assertEqual(sorted(result), [(u'Jing Xia', 102), (u'John Smith', 101)]) def test_non_entity_result2(self): result = select((s.dob.year, s.group.number) for s in Student)[:] self.assertEqual(sorted(result), [(1988, 102), (1989, 101), (1990, 101), (1990, 102), (1991, 101), (1991, 102)]) def test_non_entity_result3(self): result = select(s.dob.year for s in Student).without_distinct() self.assertEqual(sorted(result), [1988, 1989, 1990, 1990, 1990, 1991, 1991]) result = select(s.dob.year for s in Student)[:] # test the last query didn't override the cached one self.assertEqual(sorted(result), [1988, 1989, 1990, 1991]) def test_non_entity_result3a(self): result = select(s.dob.year for s in Student)[:] self.assertEqual(sorted(result), [1988, 1989, 1990, 1991]) def test_non_entity_result4(self): result = set(select(s.name for s in Student if s.name.startswith('M'))) self.assertEqual(result, {u'Matthew Reed', u'Maria Ionescu'}) def test_non_entity_result5(self): result = select((s.group, s.dob) for s in Student if s.group == Group[101])[:] self.assertEqual(sorted(result), [(Group[101], date(1989, 2, 5)), (Group[101], date(1990, 11, 26)), (Group[101], date(1991, 3, 20))]) def test_non_entity_result6(self): result = select((c, s) for s in Student for c in Course if c.semester == 1 and s.id < 3)[:] self.assertEqual(sorted(result), sorted([(Course[u'Linear Algebra',1], Student[1]), (Course[u'Linear Algebra',1], Student[2]), (Course[u'Web Design',1], Student[1]), (Course[u'Web Design',1], Student[2])])) def test_non_entity7(self): result = set(select(s for s in Student if (s.name, s.dob) not in (((s2.name, s2.dob) for s2 in Student if s.group.number == 101)))) self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) @raises_exception(IncomparableTypesError, "Incomparable types 'int' and 'Set of Student' in expression: g.number == g.students") def test_incompartible_types(self): select(g for g in Group if g.number == g.students) @raises_exception(TranslationError, "External parameter 'x' cannot be used as query result") def test_external_param1(self): x = Student[1] select(x for s in Student) def test_external_param2(self): x = Student[1] result = set(select(s for s in Student if s.name != x.name)) self.assertEqual(result, {Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) @raises_exception(TypeError, "Use select(...) function or Group.select(...) method for iteration") def test_exception1(self): for g in Group: pass @raises_exception(MultipleObjectsFoundError, "Multiple objects were found. Use select(...) to retrieve them") def test_exception2(self): get(s for s in Student) def test_exists(self): result = exists(s for s in Student) @raises_exception(ExprEvalError, "`db.FooBar` raises AttributeError: 'Database' object has no attribute 'FooBar'") def test_entity_not_found(self): select(s for s in db.Student for g in db.FooBar) def test_keyargs1(self): result = set(select(s for s in Student if s.dob < date(year=1990, month=10, day=20))) self.assertEqual(result, {Student[3], Student[4], Student[6], Student[7]}) def test_query_as_string1(self): result = set(select('s for s in Student if 3 <= s.gpa < 4')) self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_query_as_string2(self): result = set(select('s for s in db.Student if 3 <= s.gpa < 4')) self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_str_subclasses(self): result = select(d for d in Department for g in d.groups for c in d.courses if g.number == 106 and c.name.startswith('T'))[:] self.assertEqual(result, [Department[3]]) def test_unicode_subclass(self): class Unicode2(unicode): pass u2 = Unicode2(u'\xf0') select(s for s in Student if len(u2) == 1) def test_bool(self): result = set(select(s for s in Student if s.phd == True)) self.assertEqual(result, {Student[1], Student[2]}) def test_bool2(self): result = list(select(s for s in Student if s.phd + 1 == True)) self.assertEqual(result, []) def test_bool3(self): result = list(select(s for s in Student if s.phd + 1.1 == True)) self.assertEqual(result, []) def test_bool4(self): result = list(select(s for s in Student if s.phd + Decimal('1.1') == True)) self.assertEqual(result, []) def test_bool5(self): x = True result = set(select(s for s in Student if s.phd == True and (False or (True and x)))) self.assertEqual(result, {Student[1], Student[2]}) def test_bool6(self): x = False result = list(select(s for s in Student if s.phd == (False or (True and x)) and s.phd is True)) self.assertEqual(result, []) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_declarative_strings.py0000666000000000000000000001703700000000000021321 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode, autostrip=False) foo = Optional(unicode) bar = Optional(unicode) db.generate_mapping(create_tables=True) with db_session: Student(id=1, name="Jon", foo='Abcdef', bar='b%d') Student(id=2, name=" Bob ", foo='Ab%def', bar='b%d') Student(id=3, name=" Beth ", foo='Ab_def', bar='b%d') Student(id=4, name="Jonathan") Student(id=5, name="Pete") class TestStringMethods(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_nonzero(self): result = set(select(s for s in Student if s.foo)) self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_add(self): name = 'Jonny' result = set(select(s for s in Student if s.name + "ny" == name)) self.assertEqual(result, {Student[1]}) def test_slice_1(self): result = set(select(s for s in Student if s.name[0:3] == "Jon")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_2(self): result = set(select(s for s in Student if s.name[:3] == "Jon")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_3(self): x = 3 result = set(select(s for s in Student if s.name[:x] == "Jon")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_4(self): x = 3 result = set(select(s for s in Student if s.name[0:x] == "Jon")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_5(self): result = set(select(s for s in Student if s.name[0:10] == "Jon")) self.assertEqual(result, {Student[1]}) def test_slice_6(self): result = set(select(s for s in Student if s.name[0:] == "Jon")) self.assertEqual(result, {Student[1]}) def test_slice_7(self): result = set(select(s for s in Student if s.name[:] == "Jon")) self.assertEqual(result, {Student[1]}) def test_slice_8(self): result = set(select(s for s in Student if s.name[1:] == "on")) self.assertEqual(result, {Student[1]}) def test_slice_9(self): x = 1 result = set(select(s for s in Student if s.name[x:] == "on")) self.assertEqual(result, {Student[1]}) def test_slice_10(self): x = 0 result = set(select(s for s in Student if s.name[x:3] == "Jon")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_11(self): x = 1 y = 3 result = set(select(s for s in Student if s.name[x:y] == "on")) self.assertEqual(result, {Student[1], Student[4]}) def test_slice_12(self): x = 10 y = 20 result = set(select(s for s in Student if s.name[x:y] == '')) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test_getitem_1(self): result = set(select(s for s in Student if s.name[1] == 'o')) self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_2(self): x = 1 result = set(select(s for s in Student if s.name[x] == 'o')) self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_3(self): result = set(select(s for s in Student if s.name[-1] == 'n')) self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_4(self): x = -1 result = set(select(s for s in Student if s.name[x] == 'n')) self.assertEqual(result, {Student[1], Student[4]}) def test_contains_1(self): result = set(select(s for s in Student if 'o' in s.name)) self.assertEqual(result, {Student[1], Student[2], Student[4]}) def test_contains_2(self): result = set(select(s for s in Student if 'on' in s.name)) self.assertEqual(result, {Student[1], Student[4]}) def test_contains_3(self): x = 'on' result = set(select(s for s in Student if x in s.name)) self.assertEqual(result, {Student[1], Student[4]}) def test_contains_4(self): x = 'on' result = set(select(s for s in Student if x not in s.name)) self.assertEqual(result, {Student[2], Student[3], Student[5]}) def test_contains_5(self): result = set(select(s for s in Student if '%' in s.foo)) self.assertEqual(result, {Student[2]}) def test_contains_6(self): x = '%' result = set(select(s for s in Student if x in s.foo)) self.assertEqual(result, {Student[2]}) def test_contains_7(self): result = set(select(s for s in Student if '_' in s.foo)) self.assertEqual(result, {Student[3]}) def test_contains_8(self): x = '_' result = set(select(s for s in Student if x in s.foo)) self.assertEqual(result, {Student[3]}) def test_contains_9(self): result = set(select(s for s in Student if s.foo in 'Abcdef')) self.assertEqual(result, {Student[1], Student[4], Student[5]}) def test_contains_10(self): result = set(select(s for s in Student if s.bar in s.foo)) self.assertEqual(result, {Student[2], Student[4], Student[5]}) def test_startswith_1(self): students = set(select(s for s in Student if s.name.startswith('J'))) self.assertEqual(students, {Student[1], Student[4]}) def test_startswith_2(self): students = set(select(s for s in Student if not s.name.startswith('J'))) self.assertEqual(students, {Student[2], Student[3], Student[5]}) def test_startswith_3(self): students = set(select(s for s in Student if not not s.name.startswith('J'))) self.assertEqual(students, {Student[1], Student[4]}) def test_startswith_4(self): students = set(select(s for s in Student if not not not s.name.startswith('J'))) self.assertEqual(students, {Student[2], Student[3], Student[5]}) def test_startswith_5(self): x = "Pe" students = select(s for s in Student if s.name.startswith(x))[:] self.assertEqual(students, [Student[5]]) def test_endswith_1(self): students = set(select(s for s in Student if s.name.endswith('n'))) self.assertEqual(students, {Student[1], Student[4]}) def test_endswith_2(self): x = "te" students = select(s for s in Student if s.name.endswith(x))[:] self.assertEqual(students, [Student[5]]) def test_strip_1(self): students = select(s for s in Student if s.name.strip() == 'Beth')[:] self.assertEqual(students, [Student[3]]) def test_rstrip(self): students = select(s for s in Student if s.name.rstrip('n') == 'Jo')[:] self.assertEqual(students, [Student[1]]) def test_lstrip(self): students = select(s for s in Student if s.name.lstrip('P') == 'ete')[:] self.assertEqual(students, [Student[5]]) def test_upper(self): result = select(s for s in Student if s.name.upper() == "JON")[:] self.assertEqual(result, [Student[1]]) def test_lower(self): result = select(s for s in Student if s.name.lower() == "jon")[:] self.assertEqual(result, [Student[1]]) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_decompiler.py0000666000000000000000000000703400000000000017404 0ustar0000000000000000import unittest from pony.thirdparty.compiler.transformer import parse from pony.orm.decompiling import Decompiler from pony.orm.asttranslation import ast2src def generate_gens(): patterns = [ '(x * y) * [z * j)', '([x * y) * z) * j', '(x * [y * z)) * j', 'x * ([y * z) * j)', 'x * (y * [z * j))' ] ops = ('and', 'or') nots = (True, False) result = [] for pat in patterns: cur = pat for op1 in ops: for op2 in ops: for op3 in ops: res = cur.replace('*', op1, 1) res = res.replace('*', op2, 1) res = res.replace('*', op3, 1) result.append(res) final = [] for res in result: for par1 in nots: for par2 in nots: for a in nots: for b in nots: for c in nots: for d in nots: cur = res.replace('(', 'not(') if not par1 else res if not par2: cur = cur.replace('[', 'not(') else: cur = cur.replace('[', '(') if not a: cur = cur.replace('x', 'not x') if not b: cur = cur.replace('y', 'not y') if not c: cur = cur.replace('z', 'not z') if not d: cur = cur.replace('j', 'not j') final.append(cur) return final def create_test(gen): def wrapped_test(self): def get_condition_values(cond): result = [] vals = (True, False) for x in vals: for y in vals: for z in vals: for j in vals: result.append(eval(cond, {'x': x, 'y': y, 'z': z, 'j': j})) return result src1 = '(a for a in [] if %s)' % gen src2 = 'lambda x, y, z, j: (%s)' % gen src3 = '(m for m in [] if %s for n in [] if %s)' % (gen, gen) code1 = compile(src1, '', 'eval').co_consts[0] ast1 = Decompiler(code1).ast res1 = ast2src(ast1).replace('.0', '[]') res1 = res1[res1.find('if')+2:-1] code2 = compile(src2, '', 'eval').co_consts[0] ast2 = Decompiler(code2).ast res2 = ast2src(ast2).replace('.0', '[]') res2 = res2[res2.find(':')+1:] code3 = compile(src3, '', 'eval').co_consts[0] ast3 = Decompiler(code3).ast res3 = ast2src(ast3).replace('.0', '[]') res3 = res3[res3.find('if')+2: res3.rfind('for')-1] if get_condition_values(gen) != get_condition_values(res1): self.fail("Incorrect generator decompilation: %s -> %s" % (gen, res1)) if get_condition_values(gen) != get_condition_values(res2): self.fail("Incorrect lambda decompilation: %s -> %s" % (gen, res2)) if get_condition_values(gen) != get_condition_values(res3): self.fail("Incorrect multi-for generator decompilation: %s -> %s" % (gen, res3)) return wrapped_test class TestDecompiler(unittest.TestCase): pass for i, gen in enumerate(generate_gens()): test_method = create_test(gen) test_method.__name__ = 'test_decompiler_%d' % i setattr(TestDecompiler, test_method.__name__, test_method) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_deduplication.py0000666000000000000000000000214600000000000020104 0ustar0000000000000000from pony.py23compat import StringIO import unittest from pony import orm db = orm.Database('sqlite', ':memory:') class A(db.Entity): id = orm.PrimaryKey(int) x = orm.Required(bool) y = orm.Required(float) db.generate_mapping(create_tables=True) with orm.db_session: a1 = A(id=1, x=False, y=3.0) a2 = A(id=2, x=True, y=4.0) a3 = A(id=3, x=False, y=1.0) class TestDeduplication(unittest.TestCase): @orm.db_session def test_1(self): a2 = A.get(id=2) a1 = A.get(id=1) self.assertIs(a1.id, 1) @orm.db_session def test_2(self): a3 = A.get(id=3) a1 = A.get(id=1) self.assertIs(a1.id, 1) @orm.db_session def test_3(self): q = A.select().order_by(-1) stream = StringIO() q.show(stream=stream) s = stream.getvalue() self.assertEqual(s, 'id|x |y \n' '--+-----+---\n' '3 |False|1.0\n' '2 |True |4.0\n' '1 |False|3.0\n') ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_diagram.py0000666000000000000000000001342400000000000016665 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.core import Entity from pony.orm.tests.testutils import * class TestDiag(unittest.TestCase): @raises_exception(ERDiagramError, 'Entity Entity1 already exists') def test_entity_duplicate(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity1(db.Entity): id = PrimaryKey(int) @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database.' ' Entities Entity2 and Entity1 belongs to different databases') def test_diagram1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') db = Database('sqlite', ':memory:') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping() @raises_exception(ERDiagramError, 'Entity definition Entity2 was not found') def test_diagram2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') db.generate_mapping() @raises_exception(TypeError, 'Entity1._table_ property must be a string. Got: 123') def test_diagram3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _table_ = 123 id = PrimaryKey(int) db.generate_mapping() def test_diagram4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') db.generate_mapping(create_tables=True) def test_diagram5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) db.generate_mapping(create_tables=True) @raises_exception(MappingError, "Parameter 'table' for Entity1.attr1 and Entity2.attr2 do not match") def test_diagram6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table2') db.generate_mapping() @raises_exception(MappingError, 'Table name "Table1" is already in use') def test_diagram7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _table_ = 'Table1' id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') db.generate_mapping() def test_diagram8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) db.generate_mapping(create_tables=True) m2m_table = db.schema.tables['Entity1_Entity2'] col_names = {col.name for col in m2m_table.column_list} self.assertEqual(col_names, {'entity1', 'entity2'}) self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) def test_diagram9(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(str) PrimaryKey(a, b) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) db.generate_mapping(create_tables=True) m2m_table = db.schema.tables['Entity1_Entity2'] col_names = {col.name for col in m2m_table.column_list} self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2'}) def test_diagram10(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(str) PrimaryKey(a, b) attr1 = Set('Entity2', column='z') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x', 'y']) db.generate_mapping(create_tables=True) @raises_exception(MappingError, 'Invalid number of columns for Entity2.attr2') def test_diagram11(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(str) PrimaryKey(a, b) attr1 = Set('Entity2', column='z') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x']) db.generate_mapping() @raises_exception(ERDiagramError, 'Base Entity does not belong to any database') def test_diagram12(self): class Test(Entity): name = Required(unicode) @raises_exception(ERDiagramError, 'Entity class name should start with a capital letter. Got: entity1') def test_diagram13(self): db = Database('sqlite', ':memory:') class entity1(db.Entity): a = Required(int) db.generate_mapping() if __name__ == '__main__': unittest.main()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_diagram_attribute.py0000666000000000000000000006611100000000000020751 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 from datetime import date import unittest from pony.orm.core import * from pony.orm.core import Attribute from pony.orm.tests.testutils import * class TestAttribute(unittest.TestCase): @raises_exception(TypeError, "Attribute Entity1.id has unknown option 'another_option'") def test_attribute1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, another_option=3) db.generate_mapping(create_tables=True) @raises_exception(TypeError, 'Cannot link attribute Entity1.b to abstract Entity class. Use specific Entity subclass instead') def test_attribute2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) b = Required(db.Entity) db.generate_mapping() @raises_exception(TypeError, 'Default value for required attribute Entity1.b cannot be None') def test_attribute3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) b = Required(int, default=None) def test_attribute4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse='attr2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) self.assertEqual(Entity1.attr1.reverse, Entity2.attr2) def test_attribute5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1, reverse=Entity1.attr1) self.assertEqual(Entity2.attr2.reverse, Entity1.attr1) @raises_exception(TypeError, "Value of 'reverse' option must be name of reverse attribute). Got: 123") def test_attribute6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse=123) @raises_exception(TypeError, "Reverse option cannot be set for this type: %r" % str) def test_attribute7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str, reverse='attr1') @raises_exception(TypeError, "'Attribute' is abstract type") def test_attribute8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Attribute(str) @raises_exception(ERDiagramError, "Attribute name cannot both start and end with underscore. Got: _attr1_") def test_attribute9(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) _attr1_ = Required(str) @raises_exception(ERDiagramError, "Duplicate use of attribute Entity1.attr1 in entity Entity2") def test_attribute10(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Entity1.attr1 @raises_exception(ERDiagramError, "Invalid use of attribute Entity1.a in entity Entity2") def test_attribute11(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(str) class Entity2(db.Entity): b = Required(str) composite_key(Entity1.a, b) @raises_exception(ERDiagramError, "Cannot create default primary key attribute for Entity1 because name 'id' is already in use." " Please create a PrimaryKey attribute for entity Entity1 or rename the 'id' attribute") def test_attribute12(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = Optional(str) @raises_exception(ERDiagramError, "Reverse attribute for Entity1.attr1 not found") def test_attribute13(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) db.generate_mapping() @raises_exception(ERDiagramError, "Reverse attribute Entity1.attr1 not found") def test_attribute14(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1, reverse='attr1') db.generate_mapping() @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute15(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse='attr2') db.generate_mapping() @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute16(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse=Entity2.attr2) db.generate_mapping() @raises_exception(ERDiagramError, 'Reverse attribute for Entity2.attr2 not found') def test_attribute18(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required('Entity1') db.generate_mapping() @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.a. Use the 'reverse' parameter for pointing to right attribute") def test_attribute19(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2') b = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) c = Set(Entity1) d = Set(Entity1) db.generate_mapping() @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.c. Use the 'reverse' parameter for pointing to right attribute") def test_attribute20(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) c = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) a = Required(Entity1, reverse='c') b = Optional(Entity1, reverse='c') db.generate_mapping() def test_attribute21(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') b = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) c = Set(Entity1) d = Set(Entity1) def test_attribute22(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') b = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) c = Set(Entity1, reverse='a') d = Set(Entity1) @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.b') def test_attribute23(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required('Entity2', reverse='b') class Entity2(db.Entity): b = Optional('Entity3') class Entity3(db.Entity): c = Required('Entity2') db.generate_mapping() @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.c') def test_attribute23(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required('Entity2', reverse='c') b = Required('Entity2', reverse='d') class Entity2(db.Entity): c = Optional('Entity1', reverse='b') d = Optional('Entity1', reverse='a') db.generate_mapping() def test_attribute24(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(str, auto=True) db.generate_mapping(create_tables=True) self.assertTrue('AUTOINCREMENT' not in db.schema.tables['Entity1'].get_create_command()) @raises_exception(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") def test_columns1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a', columns=['b', 'c']) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) def test_columns2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, column='a') self.assertEqual(Entity1.id.columns, ['a']) def test_columns3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, columns=['a']) self.assertEqual(Entity1.id.column, 'a') @raises_exception(MappingError, "Too many columns were specified for Entity1.id") def test_columns5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, columns=['a', 'b']) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % {'a'}) def test_columns6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, columns={'a'}) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'column' must be a string. Got: 4") def test_columns7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int, column=4) db.generate_mapping(create_tables=True) def test_columns8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int) attr1 = Optional('Entity2') PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y']) self.assertEqual(Entity2.attr2.column, None) self.assertEqual(Entity2.attr2.columns, ['x', 'y']) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns9(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int) attr1 = Optional('Entity2') PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y', 'z']) db.generate_mapping(create_tables=True) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns10(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int) attr1 = Optional('Entity2') PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, column='x') db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Items of parameter 'columns' must be strings. Got: [1, 2]") def test_columns11(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int) attr1 = Optional('Entity2') PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=[1, 2]) def test_columns12(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column2']) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameters 'reverse_column' and 'reverse_columns' cannot be specified simultaneously") def test_columns13(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column3']) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'reverse_column' must be a string. Got: 5") def test_columns14(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column=5) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list. Got: 'column3'") def test_columns15(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns='column3') db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list of strings. Got: [5]") def test_columns16(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=[5]) db.generate_mapping(create_tables=True) def test_columns17(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=['column2']) db.generate_mapping(create_tables=True) def test_columns18(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table='T1') db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'table' must be a string. Got: 5") def test_columns19(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=5) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Each part of table name must be a string. Got: 1") def test_columns20(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=[1, 'T1']) db.generate_mapping(create_tables=True) def test_columns_21(self): db = Database('sqlite', ':memory:') class Stat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('Stat') db.generate_mapping(create_tables=True) self.assertEqual(Stat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') def test_columns_22(self): db = Database('sqlite', ':memory:') class ZStat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('ZStat') db.generate_mapping(create_tables=True) self.assertEqual(ZStat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') def test_nullable1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(unicode, unique=True) db.generate_mapping(create_tables=True) self.assertEqual(Entity1.a.nullable, True) def test_nullable2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(unicode, unique=True) db.generate_mapping(create_tables=True) with db_session: Entity1() commit() Entity1() commit() def test_lambda_1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(lambda: db.Entity2) class Entity2(db.Entity): b = Set(lambda: db.Entity1) db.generate_mapping(create_tables=True) self.assertEqual(Entity1.a.py_type, Entity2) self.assertEqual(Entity2.b.py_type, Entity1) @raises_exception(TypeError, "Invalid type of attribute Entity1.a: expected entity class, got 'Entity2'") def test_lambda_2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(lambda: 'Entity2') class Entity2(db.Entity): b = Set(lambda: db.Entity1) db.generate_mapping(create_tables=True) @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database. ' 'Entities Entity1 and Entity2 belongs to different databases') def test_lambda_3(self): db1 = Database('sqlite', ':memory:') class Entity1(db1.Entity): a = Required(lambda: db2.Entity2) db2 = Database('sqlite', ':memory:') class Entity2(db2.Entity): b = Set(lambda: db1.Entity1) db1.generate_mapping(create_tables=True) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: 1') def test_py_check_1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=1) def test_py_check_2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=7) def test_py_check_3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=None) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: datetime.date(1999, 1, 1)') def test_py_check_4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=date(1999, 1, 1)) def test_py_check_5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=date(2010, 1, 1)) @raises_exception(ValueError, 'Should be positive number') def test_py_check_6(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(int, py_check=positive_number) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=-1) def test_py_check_7(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') return True db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Optional(int, py_check=positive_number) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a=1) @raises_exception(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def test_py_check_8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required('Entity2') class Entity2(db.Entity): a = Set('Entity1', py_check=lambda val: True) db.generate_mapping(create_tables=True) def test_py_check_truncate(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(str, py_check=lambda val: False) db.generate_mapping(create_tables=True) with db_session: try: obj = Entity1(a='1234567890' * 1000) except ValueError as e: error_message = "Check for attribute Entity1.a failed. Value: " + ( "u'12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345..." if PY2 else "'123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456...") self.assertEqual(str(e), error_message) else: self.assert_(False) @raises_exception(ValueError, 'Value for attribute Entity1.a is too long. Max length is 10, value length is 10000') def test_str_max_len(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(str, 10) db.generate_mapping(create_tables=True) with db_session: obj = Entity1(a='1234567890' * 1000) def test_foreign_key_sql_type_1(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) db.generate_mapping(create_tables=True) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'SOME_TYPE') def test_foreign_key_sql_type_2(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') db.generate_mapping(create_tables=True) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') def test_foreign_key_sql_type_3(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') db.generate_mapping(create_tables=True) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') def test_foreign_key_sql_type_4(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) db.generate_mapping(create_tables=True) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'INTEGER') def test_foreign_key_sql_type_5(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='serial') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) db.generate_mapping(create_tables=True) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'integer') def test_self_referenced_m2m_1(self): db = Database('sqlite', ':memory:') class Node(db.Entity): id = PrimaryKey(int) prev_nodes = Set("Node") next_nodes = Set("Node") db.generate_mapping(create_tables=True) def test_implicit_1(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): name = Required(str) bar = Required("Bar") class Bar(db.Entity): id = PrimaryKey(int) name = Optional(str) foos = Set("Foo") db.generate_mapping(create_tables=True) self.assertTrue(Foo.id.is_implicit) self.assertFalse(Foo.name.is_implicit) self.assertFalse(Foo.bar.is_implicit) self.assertFalse(Bar.id.is_implicit) self.assertFalse(Bar.name.is_implicit) self.assertFalse(Bar.foos.is_implicit) def test_implicit_2(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): x = Required(str) class Bar(Foo): y = Required(str) db.generate_mapping(create_tables=True) self.assertTrue(Foo.id.is_implicit) self.assertTrue(Foo.classtype.is_implicit) self.assertFalse(Foo.x.is_implicit) self.assertTrue(Bar.id.is_implicit) self.assertTrue(Bar.classtype.is_implicit) self.assertFalse(Bar.x.is_implicit) self.assertFalse(Bar.y.is_implicit) @raises_exception(TypeError, 'Attribute Foo.x has invalid type NoneType') def test_none_type(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): x = Required(type(None)) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "'sql_default' option value cannot be empty string, " "because it should be valid SQL literal or expression. " "Try to use \"''\", or just specify default='' instead.") def test_none_type(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): x = Required(str, sql_default='') if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_diagram_keys.py0000666000000000000000000001704700000000000017725 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestKeys(unittest.TestCase): def test_keys1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Required(str) self.assertEqual(Entity1._pk_attrs_, (Entity1.a,)) self.assertEqual(Entity1._pk_is_composite_, False) self.assertEqual(Entity1._pk_, Entity1.a) self.assertEqual(Entity1._keys_, []) self.assertEqual(Entity1._simple_keys_, []) self.assertEqual(Entity1._composite_keys_, []) def test_keys2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(str) PrimaryKey(a, b) self.assertEqual(Entity1._pk_attrs_, (Entity1.a, Entity1.b)) self.assertEqual(Entity1._pk_is_composite_, True) self.assertEqual(Entity1._pk_, (Entity1.a, Entity1.b)) self.assertEqual(Entity1._keys_, []) self.assertEqual(Entity1._simple_keys_, []) self.assertEqual(Entity1._composite_keys_, []) @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = PrimaryKey(int) @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int) c = Required(int) PrimaryKey(b, c) def test_unique1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int, unique=True) self.assertEqual(Entity1._keys_, [(Entity1.b,)]) self.assertEqual(Entity1._simple_keys_, [Entity1.b]) self.assertEqual(Entity1._composite_keys_, []) def test_unique2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int, unique=True) self.assertEqual(Entity1._keys_, [(Entity1.b,)]) self.assertEqual(Entity1._simple_keys_, [Entity1.b]) self.assertEqual(Entity1._composite_keys_, []) def test_unique2_1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int) c = Required(int) composite_key(b, c) self.assertEqual(Entity1._keys_, [(Entity1.b, Entity1.c)]) self.assertEqual(Entity1._simple_keys_, []) self.assertEqual(Entity1._composite_keys_, [(Entity1.b, Entity1.c)]) @raises_exception(TypeError, 'composite_key() must receive at least two attributes as arguments') def test_unique3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) composite_key() @raises_exception(TypeError, 'composite_key() arguments must be attributes. Got: 123') def test_unique4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) composite_key(123, 456) @raises_exception(TypeError, "composite_key() arguments must be attributes. Got: %r" % int) def test_unique5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) composite_key(int, a) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be part of unique index') def test_unique6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Set('Entity2') composite_key(a, b) @raises_exception(TypeError, "'unique' option cannot be set for attribute Entity1.b because it is collection") def test_unique7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', unique=True) @raises_exception(TypeError, 'Optional attribute Entity1.b cannot be part of primary key') def test_unique8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Optional(int) PrimaryKey(a, b) @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be of type float') def test_float_pk(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(float) @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of primary key') def test_float_composite_pk(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(float) PrimaryKey(a, b) @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of unique index') def test_float_composite_key(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(float) composite_key(a, b) @raises_exception(TypeError, 'Unique attribute Entity1.a cannot be of type float') def test_float_unique(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(float, unique=True) @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be volatile') def test_volatile_pk(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int, volatile=True) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be volatile') def test_volatile_set(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', volatile=True) @raises_exception(TypeError, 'Volatile attribute Entity1.b cannot be part of primary key') def test_volatile_composite_pk(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int, volatile=True) PrimaryKey(a, b) def test_composite_key_update(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): s = Set('Entity3') class Entity2(db.Entity): s = Set('Entity3') class Entity3(db.Entity): a = Required(Entity1) b = Required(Entity2) composite_key(a, b) db.generate_mapping(create_tables=True) with db_session: x = Entity1(id=1) y = Entity2(id=1) z = Entity3(id=1, a=x, b=y) with db_session: z = Entity3[1] self.assertEqual(z.a, Entity1[1]) self.assertEqual(z.b, Entity2[1]) with db_session: z = Entity3[1] w = Entity1(id=2) z.a = w with db_session: z = Entity3[1] self.assertEqual(z.a, Entity1[2]) self.assertEqual(z.b, Entity2[1]) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_distinct.py0000666000000000000000000000602500000000000017101 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Department(db.Entity): number = PrimaryKey(int) groups = Set('Group') class Group(db.Entity): id = PrimaryKey(int) dept = Required('Department') students = Set('Student') class Student(db.Entity): name = Required(unicode) age = Required(int) group = Required('Group') scholarship = Required(int, default=0) courses = Set('Course') class Course(db.Entity): name = Required(unicode) semester = Required(int) PrimaryKey(name, semester) students = Set('Student') db.generate_mapping(create_tables=True) with db_session: d1 = Department(number=1) d2 = Department(number=2) g1 = Group(id=1, dept=d1) g2 = Group(id=2, dept=d2) s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) s2 = Student(id=2, name='S2', age=21, group=g1, scholarship=100) s3 = Student(id=3, name='S3', age=23, group=g1, scholarship=200) s4 = Student(id=4, name='S4', age=21, group=g1, scholarship=100) s5 = Student(id=5, name='S5', age=23, group=g2, scholarship=0) s6 = Student(id=6, name='S6', age=23, group=g2, scholarship=200) c1 = Course(name='C1', semester=1, students=[s1, s2, s3]) c2 = Course(name='C2', semester=1, students=[s2, s3, s5, s6]) c3 = Course(name='C3', semester=2, students=[s4, s5, s6]) class TestDistinct(unittest.TestCase): def setUp(self): db_session.__enter__() def tearDown(self): db_session.__exit__() def test_group_by(self): result = set(select((s.age, sum(s.scholarship)) for s in Student if s.scholarship > 0)) self.assertEqual(result, {(21, 200), (23, 400)}) self.assertNotIn('distinct', db.last_sql.lower()) def test_group_by_having(self): result = set(select((s.age, sum(s.scholarship)) for s in Student if sum(s.scholarship) < 300)) self.assertEqual(result, {(20, 0), (21, 200)}) self.assertNotIn('distinct', db.last_sql.lower()) def test_aggregation_no_group_by_1(self): result = set(select(sum(s.scholarship) for s in Student if s.age < 23)) self.assertEqual(result, {200}) self.assertNotIn('distinct', db.last_sql.lower()) def test_aggregation_no_group_by_2(self): result = set(select((sum(s.scholarship), min(s.scholarship)) for s in Student if s.age < 23)) self.assertEqual(result, {(200, 0)}) self.assertNotIn('distinct', db.last_sql.lower()) def test_aggregation_no_group_by_3(self): result = set(select((sum(s.scholarship), min(s.scholarship)) for s in Student for g in Group if s.group == g and g.dept.number == 1)) self.assertEqual(result, {(400, 0)}) self.assertNotIn('distinct', db.last_sql.lower()) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_entity_init.py0000666000000000000000000000347700000000000017627 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from datetime import date, datetime from hashlib import md5 from pony.orm.tests.testutils import raises_exception from pony.orm import * class TestCustomInit(unittest.TestCase): def test1(self): db = Database('sqlite', ':memory:') class User(db.Entity): name = Required(str) password = Required(str) created_at = Required(datetime) def __init__(self, name, password): password = md5(password.encode('utf8')).hexdigest() super(User, self).__init__(name=name, password=password, created_at=datetime.now()) self.uppercase_name = name.upper() db.generate_mapping(create_tables=True) with db_session: u1 = User('John', '123') u2 = User('Mike', '456') commit() self.assertEqual(u1.name, 'John') self.assertEqual(u1.uppercase_name, 'JOHN') self.assertEqual(u1.password, md5(b'123').hexdigest()) self.assertEqual(u2.name, 'Mike') self.assertEqual(u2.uppercase_name, 'MIKE') self.assertEqual(u2.password, md5(b'456').hexdigest()) with db_session: users = select(u for u in User).order_by(User.id) self.assertEqual(len(users), 2) u1, u2 = users self.assertEqual(u1.name, 'John') self.assertTrue(not hasattr(u1, 'uppercase_name')) self.assertEqual(u1.password, md5(b'123').hexdigest()) self.assertEqual(u2.name, 'Mike') self.assertTrue(not hasattr(u2, 'uppercase_name')) self.assertEqual(u2.password, md5(b'456').hexdigest()) if __name__ == '__main__': unittest.main()././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_entity_instances.py0000666000000000000000000000730600000000000020646 0ustar0000000000000000import unittest from pony import orm from pony.orm.core import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): id = orm.PrimaryKey(int, auto=True) name = orm.Required(str, 40) lastName = orm.Required(str, max_len=40, unique=True) age = orm.Optional(int) groupName = orm.Optional('Group') chiefOfGroup = orm.Optional('Group') class Group(db.Entity): name = orm.Required(str) persons = orm.Set(Person) chief = orm.Optional(Person, reverse='chiefOfGroup') db.generate_mapping(create_tables=True) class TestEntityInstances(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_create_instance(self): with orm.db_session: Person(id=1, name='Philip', lastName='Croissan') Person(id=2, name='Philip', lastName='Parlee', age=40) Person(id=3, name='Philip', lastName='Illinois', age=50) commit() def test_getObjectByPK(self): self.assertEqual(Person[1].lastName, "Croissan") @raises_exception(ObjectNotFound , "Person[666]") def test_getObjectByPKexception(self): p = Person[666] def test_getObjectByGet(self): p = Person.get(age=40) self.assertEqual(p.lastName, "Parlee") def test_getObjectByGetNone(self): self.assertIsNone(Person.get(age=41)) @raises_exception(MultipleObjectsFoundError , 'Multiple objects were found.' ' Use Person.select(...) to retrieve them') def test_getObjectByGetException(self): p = Person.get(name="Philip") def test_updateObject(self): with db_session: Person[2].age=42 self.assertEqual(Person[2].age, 42) commit() @raises_exception(ObjectNotFound, 'Person[2]') def test_deleteObject(self): with db_session: Person[2].delete() p = Person[2] def test_bulkDelete(self): with orm.db_session: Person(id=4, name='Klaus', lastName='Mem', age=12) Person(id=5, name='Abraham', lastName='Wrangler', age=13) Person(id=6, name='Kira', lastName='Phito', age=20) delete(p for p in Person if p.age <= 20) self.assertEqual(select(p for p in Person if p.age <= 20).count(), 0) def test_bulkDeleteV2(self): with orm.db_session: Person(id=4, name='Klaus', lastName='Mem', age=12) Person(id=5, name='Abraham', lastName='Wrangler', age=13) Person(id=6, name='Kira', lastName='Phito', age=20) Person.select(lambda p: p.id >= 4).delete(bulk=True) self.assertEqual(select(p for p in Person if p.id >= 4).count(), 0) @raises_exception(UnresolvableCyclicDependency, 'Cannot save cyclic chain: Person -> Group') def test_saveChainsException(self): with orm.db_session: claire = Person(name='Claire', lastName='Forlani') annabel = Person(name='Annabel', lastName='Fiji') Group(name='Aspen', persons=[claire, annabel], chief=claire) print('group1=', Group[1]) def test_saveChainsWithFlush(self): with orm.db_session: claire = Person(name='Claire', lastName='Forlani') annabel = Person(name='Annabel', lastName='Fiji') flush() Group(name='Aspen', persons=[claire, annabel], chief=claire) self.assertEqual(Group[1].name, 'Aspen') self.assertEqual(Group[1].chief.lastName, 'Forlani')././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_entity_proxy.py0000666000000000000000000001151500000000000020035 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import * class TestProxy(unittest.TestCase): def setUp(self): db = self.db = Database('sqlite', ':memory:') class Country(db.Entity): id = PrimaryKey(int) name = Required(str) persons = Set("Person") class Person(db.Entity): id = PrimaryKey(int) name = Required(str) country = Required(Country) db.generate_mapping(create_tables=True) with db_session: c1 = Country(id=1, name='Russia') c2 = Country(id=2, name='Japan') Person(id=1, name='Alexander Nevskiy', country=c1) Person(id=2, name='Raikou Minamoto', country=c2) Person(id=3, name='Ibaraki Douji', country=c2) def test_1(self): db = self.db with db_session: p = make_proxy(db.Person[2]) with db_session: x1 = db.local_stats[None].db_count # number of queries # it is possible to access p attributes in a new db_session name = p.name country = p.country x2 = db.local_stats[None].db_count # p.name and p.country are loaded with a single query self.assertEqual(x1, x2-1) def test_2(self): db = self.db with db_session: p = make_proxy(db.Person[2]) name = p.name country = p.country with db_session: x1 = db.local_stats[None].db_count name = p.name country = p.country x2 = db.local_stats[None].db_count # attribute values from the first db_session should be ignored and loaded again self.assertEqual(x1, x2-1) def test_3(self): db = self.db with db_session: p = db.Person[2] proxy = make_proxy(p) with db_session: p2 = db.Person[2] name1 = 'Tamamo no Mae' # It is possible to assign new attribute values to a proxy object p2.name = name1 name2 = proxy.name self.assertEqual(name1, name2) def test_4(self): db = self.db with db_session: p = db.Person[2] proxy = make_proxy(p) with db_session: p2 = db.Person[2] name1 = 'Tamamo no Mae' p2.name = name1 with db_session: # new attribute value was successfully stored in the database name2 = proxy.name self.assertEqual(name1, name2) def test_5(self): db = self.db with db_session: p = db.Person[2] r = repr(p) self.assertEqual(r, 'Person[2]') proxy = make_proxy(p) r = repr(proxy) # proxy object has specific repr self.assertEqual(r, '') r = repr(proxy) # repr of proxy object can be used outside of db_session self.assertEqual(r, '') del p r = repr(proxy) # repr works even if the original object was deleted self.assertEqual(r, '') def test_6(self): db = self.db with db_session: p = db.Person[2] proxy = make_proxy(p) proxy.name = 'Okita Souji' # after assignment, the attribute value is the same for the proxy and for the original object self.assertEqual(proxy.name, 'Okita Souji') self.assertEqual(p.name, 'Okita Souji') def test_7(self): db = self.db with db_session: p = db.Person[2] proxy = make_proxy(p) proxy.name = 'Okita Souji' # after assignment, the attribute value is the same for the proxy and for the original object self.assertEqual(proxy.name, 'Okita Souji') self.assertEqual(p.name, 'Okita Souji') def test_8(self): db = self.db with db_session: c1 = db.Country[1] c1_proxy = make_proxy(c1) p2 = db.Person[2] self.assertNotEqual(p2.country, c1) self.assertNotEqual(p2.country, c1_proxy) # proxy can be used in attribute assignment p2.country = c1_proxy self.assertEqual(p2.country, c1_proxy) self.assertIs(p2.country, c1) def test_9(self): db = self.db with db_session: c2 = db.Country[2] c2_proxy = make_proxy(c2) persons = select(p for p in db.Person if p.country == c2_proxy) self.assertEqual({p.id for p in persons}, {2, 3}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_exists.py0000666000000000000000000000462300000000000016601 0ustar0000000000000000import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Group(db.Entity): students = Set('Student') class Student(db.Entity): first_name = Required(str) last_name = Required(str) login = Optional(str, nullable=True) graduated = Optional(bool, default=False) group = Required(Group) passport = Optional('Passport', column='passport') class Passport(db.Entity): student = Optional(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group() g2 = Group() p = Passport() Student(first_name='Mashu', last_name='Kyrielight', login='Shielder', group=g1) Student(first_name='Okita', last_name='Souji', login='Sakura', group=g1) Student(first_name='Francis', last_name='Drake', group=g2, graduated=True) Student(first_name='Oda', last_name='Nobunaga', group=g2, graduated=True) Student(first_name='William', last_name='Shakespeare', group=g2, graduated=True, passport=p) class TestExists(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_1(self): q = select(g for g in Group if exists(s.login for s in g.students))[:] self.assertEqual(q[0], Group[1]) def test_2(self): q = select(g for g in Group if exists(s.graduated for s in g.students))[:] self.assertEqual(q[0], Group[2]) def test_3(self): q = select(s for s in Student if exists(len(s2.first_name) == len(s.first_name) and s != s2 for s2 in Student))[:] self.assertEqual(set(q), {Student[1], Student[2], Student[3], Student[5]}) def test_4(self): q = select(g for g in Group if not exists(not s.graduated for s in g.students))[:] self.assertEqual(q[0], Group[2]) def test_5(self): q = select(g for g in Group if exists(s for s in g.students))[:] self.assertEqual(set(q), {Group[1], Group[2]}) def test_6(self): q = select(g for g in Group if exists(s.login for s in g.students if s.first_name != 'Okita') and g.id != 10)[:] self.assertEqual(q[0], Group[1]) def test_7(self): q = select(g for g in Group if exists(s.passport for s in g.students))[:] self.assertEqual(q[0], Group[2])././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_f_strings.py0000666000000000000000000000016400000000000017254 0ustar0000000000000000from sys import version_info if version_info[:2] >= (3, 6): from pony.orm.tests.py36_test_f_strings import *././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_filter.py0000666000000000000000000000451500000000000016547 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.tests.model1 import * class TestFilter(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_filter_1(self): q = select(s for s in Student) result = set(q.filter(scholarship=0)) self.assertEqual(result, {Student[101], Student[103]}) def test_filter_2(self): q = select(s for s in Student) q2 = q.filter(scholarship=500) result = set(q2.filter(group=Group['3132'])) self.assertEqual(result, {Student[104]}) def test_filter_3(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship > 500) result = set(q2.filter(lambda s: count(s.marks) > 0)) self.assertEqual(result, {Student[102]}) def test_filter_4(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(1) result = list(q3.filter(lambda s: count(s.marks) > 1)) self.assertEqual(result, [Student[101], Student[103]]) def test_filter_5(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(Student.name) result = list(q3.filter(lambda s: count(s.marks) > 1)) self.assertEqual(result, [Student[103], Student[101]]) def test_filter_6(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(lambda s: s.name) result = list(q3.filter(lambda s: count(s.marks) > 1)) self.assertEqual(result, [Student[103], Student[101]]) def test_filter_7(self): q = select(s for s in Student) q2 = q.filter(scholarship=0) result = set(q2.filter(lambda s: count(s.marks) > 1)) self.assertEqual(result, {Student[103], Student[101]}) def test_filter_8(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(lambda s: s.name) q4 = q3.order_by(None) result = set(q4.filter(lambda s: count(s.marks) > 1)) self.assertEqual(result, {Student[103], Student[101]}) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_flush.py0000666000000000000000000000203400000000000016375 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestFlush(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(unicode) self.db.generate_mapping(create_tables=True) def tearDown(self): self.db = None def test1(self): Person = self.db.Person with db_session: a = Person(name='A') b = Person(name='B') c = Person(name='C') self.assertEqual(a.id, None) self.assertEqual(b.id, None) self.assertEqual(c.id, None) b.flush() self.assertEqual(a.id, None) self.assertEqual(b.id, 1) self.assertEqual(c.id, None) flush() self.assertEqual(a.id, 2) self.assertEqual(b.id, 1) self.assertEqual(c.id, 3) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_frames.py0000666000000000000000000001277600000000000016547 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * import pony.orm.decompiling from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) age = Required(int) db.generate_mapping(create_tables=True) with db_session: p1 = Person(name='John', age=22) p2 = Person(name='Mary', age=18) p3 = Person(name='Mike', age=25) class TestFrames(unittest.TestCase): @db_session def test_select(self): x = 20 result = select(p.id for p in Person if p.age > x)[:] self.assertEqual(set(result), {1, 3}) @db_session def test_select_str(self): x = 20 result = select('p.id for p in Person if p.age > x')[:] self.assertEqual(set(result), {1, 3}) @db_session def test_left_join(self): x = 20 result = left_join(p.id for p in Person if p.age > x)[:] self.assertEqual(set(result), {1, 3}) @db_session def test_left_join_str(self): x = 20 result = left_join('p.id for p in Person if p.age > x')[:] self.assertEqual(set(result), {1, 3}) @db_session def test_get(self): x = 23 result = get(p.id for p in Person if p.age > x) self.assertEqual(result, 3) @db_session def test_get_str(self): x = 23 result = get('p.id for p in Person if p.age > x') self.assertEqual(result, 3) @db_session def test_exists(self): x = 23 result = exists(p for p in Person if p.age > x) self.assertEqual(result, True) @db_session def test_exists_str(self): x = 23 result = exists('p for p in Person if p.age > x') self.assertEqual(result, True) @db_session def test_entity_get(self): x = 23 p = Person.get(lambda p: p.age > x) self.assertEqual(p, Person[3]) @db_session def test_entity_get_str(self): x = 23 p = Person.get('lambda p: p.age > x') self.assertEqual(p, Person[3]) @db_session def test_entity_get_by_sql(self): x = 25 p = Person.get_by_sql('select * from Person where age = $x') self.assertEqual(p, Person[3]) @db_session def test_entity_select_by_sql(self): x = 25 p = Person.select_by_sql('select * from Person where age = $x') self.assertEqual(p, [ Person[3] ]) @db_session def test_entity_exists(self): x = 23 result = Person.exists(lambda p: p.age > x) self.assertTrue(result) @db_session def test_entity_exists_str(self): x = 23 result = Person.exists('lambda p: p.age > x') self.assertTrue(result) @db_session def test_entity_select(self): x = 20 result = Person.select(lambda p: p.age > x)[:] self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_entity_select_str(self): x = 20 result = Person.select('lambda p: p.age > x')[:] self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_order_by(self): x = 20 y = -1 result = Person.select(lambda p: p.age > x).order_by(lambda p: p.age * y)[:] self.assertEqual(result, [Person[3], Person[1]]) @db_session def test_order_by_str(self): x = 20 y = -1 result = Person.select('lambda p: p.age > x').order_by('p.age * y')[:] self.assertEqual(result, [Person[3], Person[1]]) @db_session def test_filter(self): x = 20 y = 'M' result = Person.select(lambda p: p.age > x).filter(lambda p: p.name.startswith(y))[:] self.assertEqual(result, [Person[3]]) @db_session def test_filter_str(self): x = 20 y = 'M' result = Person.select('lambda p: p.age > x').filter('p.name.startswith(y)')[:] self.assertEqual(result, [Person[3]]) @db_session def test_db_select(self): x = 20 result = db.select('name from Person where age > $x order by name') self.assertEqual(result, ['John', 'Mike']) @db_session def test_db_get(self): x = 18 result = db.get('name from Person where age = $x') self.assertEqual(result, 'Mary') @db_session def test_db_execute(self): x = 18 result = db.execute('select name from Person where age = $x').fetchone() self.assertEqual(result, ('Mary',)) @db_session def test_db_exists(self): x = 18 result = db.exists('name from Person where age = $x') self.assertEqual(result, True) @raises_exception(pony.orm.decompiling.InvalidQuery, 'Use generator expression (... for ... in ...) ' 'instead of list comprehension [... for ... in ...] inside query') @db_session def test_inner_list_comprehension(self): result = select(p.id for p in Person if p.age not in [ p2.age for p2 in Person if p2.name.startswith('M')])[:] @db_session def test_outer_list_comprehension(self): names = ['John', 'Mary', 'Mike'] persons = [ Person.select(lambda p: p.name == name).first() for name in names ] self.assertEqual(set(p.name for p in persons), {'John', 'Mary', 'Mike'}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_generator_db_session.py0000666000000000000000000001137400000000000021461 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.core import local from pony.orm.tests.testutils import * class TestGeneratorDbSession(unittest.TestCase): def setUp(self): db = Database('sqlite', ':memory:') class Account(db.Entity): id = PrimaryKey(int) amount = Required(int) db.generate_mapping(create_tables=True) self.db = db self.Account = Account with db_session: a1 = Account(id=1, amount=1000) a2 = Account(id=2, amount=2000) a3 = Account(id=3, amount=3000) def tearDown(self): assert local.db_session is None self.db = self.Account = None @raises_exception(TypeError, 'db_session with `retry` option cannot be applied to generator function') def test1(self): @db_session(retry=3) def f(): yield @raises_exception(TypeError, 'db_session with `ddl` option cannot be applied to generator function') def test2(self): @db_session(ddl=True) def f(): yield @raises_exception(TypeError, 'db_session with `serializable` option cannot be applied to generator function') def test3(self): @db_session(serializable=True) def f(): yield def test4(self): @db_session(immediate=True) def f(): yield @raises_exception(TransactionError, '@db_session-wrapped generator cannot be used inside another db_session') def test5(self): @db_session def f(): yield with db_session: next(f()) def test6(self): @db_session def f(): x = local.db_session self.assertTrue(x is not None) yield self.db._get_cache() self.assertEqual(local.db_session, x) a1 = self.Account[1] yield a1.amount self.assertEqual(local.db_session, x) a2 = self.Account[2] yield a2.amount gen = f() cache = next(gen) self.assertTrue(cache.is_alive) self.assertEqual(local.db_session, None) amount = next(gen) self.assertEqual(amount, 1000) self.assertEqual(local.db_session, None) amount = next(gen) self.assertEqual(amount, 2000) self.assertEqual(local.db_session, None) try: next(gen) except StopIteration: self.assertFalse(cache.is_alive) else: self.fail() def test7(self): @db_session def f(id1): a1 = self.Account[id1] id2 = yield a1.amount a2 = self.Account[id2] amount = yield a2.amount a1.amount -= amount a2.amount += amount commit() gen = f(1) amount1 = next(gen) self.assertEqual(amount1, 1000) amount2 = gen.send(2) self.assertEqual(amount2, 2000) try: gen.send(100) except StopIteration: pass else: self.fail() with db_session: a1 = self.Account[1] self.assertEqual(a1.amount, 900) a2 = self.Account[2] self.assertEqual(a2.amount, 2100) @raises_exception(TransactionError, 'You need to manually commit() changes before suspending the generator') def test8(self): @db_session def f(id1): a1 = self.Account[id1] a1.amount += 100 yield a1.amount for amount in f(1): pass def test9(self): @db_session def f(id1): a1 = self.Account[id1] a1.amount += 100 commit() yield a1.amount for amount in f(1): pass def test10(self): @db_session def f(id1): a1 = self.Account[id1] yield a1.amount a1.amount += 100 with db_session: a = self.Account[1].amount for amount in f(1): pass with db_session: b = self.Account[1].amount self.assertEqual(b, a + 100) def test12(self): @db_session def f(id1): a1 = self.Account[id1] yield a1.amount gen = f(1) next(gen) gen.close() @raises_exception(TypeError, 'error message') def test13(self): @db_session def f(id1): a1 = self.Account[id1] yield a1.amount gen = f(1) next(gen) gen.throw(TypeError('error message')) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_get_pk.py0000666000000000000000000000340600000000000016531 0ustar0000000000000000from pony.py23compat import basestring import unittest from pony.orm import * from pony import orm from pony.utils import cached_property from datetime import date class Test(unittest.TestCase): @cached_property def db(self): return orm.Database('sqlite', ':memory:') def setUp(self): db = self.db self.day = date.today() class A(db.Entity): b = Required("B") c = Required("C") PrimaryKey(b, c) class B(db.Entity): id = PrimaryKey(date) a_set = Set(A) class C(db.Entity): x = Required("X") y = Required("Y") a_set = Set(A) PrimaryKey(x, y) class X(db.Entity): id = PrimaryKey(int) c_set = Set(C) class Y(db.Entity): id = PrimaryKey(int) c_set = Set(C) db.generate_mapping(check_tables=True, create_tables=True) with orm.db_session: x1 = X(id=123) y1 = Y(id=456) b1 = B(id=self.day) c1 = C(x=x1, y=y1) A(b=b1, c=c1) @db_session def test_1(self): a1 = self.db.A.select().first() a2 = self.db.A[a1.get_pk()] self.assertEqual(a1, a2) @db_session def test2(self): a = self.db.A.select().first() b = self.db.B.select().first() c = self.db.C.select().first() pk = (b.get_pk(), c._get_raw_pkval_()) self.assertTrue(a is self.db.A[pk]) @db_session def test3(self): a = self.db.A.select().first() c = self.db.C.select().first() pk = (self.day, c.get_pk()) self.assertTrue(a is self.db.A[pk])././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/orm/tests/test_getattr.py0000666000000000000000000000711200000000000016730 0ustar0000000000000000from pony.py23compat import basestring import unittest from pony.orm import * from pony import orm from pony.utils import cached_property from pony.orm.tests.testutils import raises_exception class Test(unittest.TestCase): @cached_property def db(self): return orm.Database('sqlite', ':memory:') def setUp(self): db = self.db class Genre(db.Entity): name = orm.Required(str) artists = orm.Set('Artist') class Hobby(db.Entity): name = orm.Required(str) artists = orm.Set('Artist') class Artist(db.Entity): name = orm.Required(str) age = orm.Optional(int) hobbies = orm.Set(Hobby) genres = orm.Set(Genre) db.generate_mapping(check_tables=True, create_tables=True) with orm.db_session: pop = Genre(name='Pop') Artist(name='Sia', age=40, genres=[pop]) Hobby(name='Swimming') pony.options.INNER_JOIN_SYNTAX = True @db_session def test_no_caching(self): for attr_name, attr_type in zip(['name', 'age'], [basestring, int]): val = select(getattr(x, attr_name) for x in self.db.Artist).first() self.assertIsInstance(val, attr_type) @db_session def test_simple(self): val = select(getattr(x, 'age') for x in self.db.Artist).first() self.assertIsInstance(val, int) @db_session def test_expr(self): val = select(getattr(x, ''.join(['ag', 'e'])) for x in self.db.Artist).first() self.assertIsInstance(val, int) @db_session def test_external(self): class data: id = 1 val = select(x.id for x in self.db.Artist if x.id >= getattr(data, 'id')).first() self.assertIsNotNone(val) @db_session def test_related(self): val = select(getattr(x.genres, 'name') for x in self.db.Artist).first() self.assertIsNotNone(val) @db_session def test_not_instance_iter(self): val = select(getattr(x.name, 'startswith')('S') for x in self.db.Artist).first() self.assertTrue(val) @raises_exception(TranslationError, 'Expression `getattr(x, x.name)` cannot be translated into SQL ' 'because x.name will be different for each row') @db_session def test_not_external(self): select(getattr(x, x.name) for x in self.db.Artist) @raises_exception(TypeError, 'In `getattr(x, 1)` second argument should be a string. Got: 1') @db_session def test_not_string(self): select(getattr(x, 1) for x in self.db.Artist) @raises_exception(TypeError, 'In `getattr(x, name)` second argument should be a string. Got: 1') @db_session def test_not_string(self): name = 1 select(getattr(x, name) for x in self.db.Artist) @db_session def test_lambda_1(self): for name, value in [('name', 'Sia'), ('age', 40), ('name', 'Sia')]: result = self.db.Artist.select(lambda a: getattr(a, name) == value) self.assertEqual(set(obj.name for obj in result), {'Sia'}) @db_session def test_lambda_2(self): for entity, name, value in [ (self.db.Genre, 'name', 'Pop'), (self.db.Artist, 'age', 40), (self.db.Hobby, 'name', 'Swimming'), ]: result = entity.select(lambda a: getattr(a, name) == value) self.assertEqual(set(result[:]), {entity.select().first()}) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_hooks.py0000666000000000000000000001023600000000000016402 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * logged_events = [] db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) age = Required(int) def before_insert(self): logged_events.append('BI_' + self.name) do_before_insert(self) def before_update(self): logged_events.append('BU_' + self.name) do_before_update(self) def before_delete(self): logged_events.append('BD_' + self.name) do_before_delete(self) def after_insert(self): logged_events.append('AI_' + self.name) do_after_insert(self) def after_update(self): logged_events.append('AU_' + self.name) do_after_update(self) def after_delete(self): logged_events.append('AD_' + self.name) do_after_delete(self) def do_nothing(person): pass def set_hooks_to_do_nothing(): global do_before_insert, do_before_update, do_before_delete global do_after_insert, do_after_update, do_after_delete do_before_insert = do_before_update = do_before_delete = do_nothing do_after_insert = do_after_update = do_after_delete = do_nothing set_hooks_to_do_nothing() db.generate_mapping(create_tables=True) class TestHooks(unittest.TestCase): def setUp(self): set_hooks_to_do_nothing() with db_session: db.execute('delete from Person') p1 = Person(id=1, name='John', age=22) p2 = Person(id=2, name='Mary', age=18) p3 = Person(id=3, name='Mike', age=25) logged_events[:] = [] def tearDown(self): pass @db_session def test_1(self): p4 = Person(id=4, name='Bob', age=16) p5 = Person(id=5, name='Lucy', age=23) self.assertEqual(logged_events, []) db.flush() self.assertEqual(logged_events, ['BI_Bob', 'BI_Lucy', 'AI_Bob', 'AI_Lucy']) @db_session def test_2(self): p4 = Person(id=4, name='Bob', age=16) p1 = Person[1] # auto-flush here p2 = Person[2] self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob']) p2.age += 1 p5 = Person(id=5, name='Lucy', age=23) db.flush() self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BU_Mary', 'BI_Lucy', 'AU_Mary', 'AI_Lucy']) @db_session def test_3(self): global do_before_insert def do_before_insert(person): some_person = Person.select().first() # should not cause infinite recursion p4 = Person(id=4, name='Bob', age=16) db.flush() def flush_for(*objects): for obj in objects: obj.flush() class ObjectFlushTest(unittest.TestCase): def setUp(self): set_hooks_to_do_nothing() with db_session: db.execute('delete from Person') p1 = Person(id=1, name='John', age=22) p2 = Person(id=2, name='Mary', age=18) p3 = Person(id=3, name='Mike', age=25) logged_events[:] = [] def tearDown(self): pass @db_session def test_1(self): p4 = Person(id=4, name='Bob', age=16) p5 = Person(id=5, name='Lucy', age=23) self.assertEqual(logged_events, []) flush_for(p4, p5) self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BI_Lucy', 'AI_Lucy']) @db_session def test_2(self): p4 = Person(id=4, name='Bob', age=16) p1 = Person[1] # auto-flush here p2 = Person[2] self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob']) p2.age += 1 p5 = Person(id=5, name='Lucy', age=23) flush_for(p4, p2, p5) self.assertEqual(logged_events, ['BI_Bob', 'AI_Bob', 'BU_Mary', 'AU_Mary', 'BI_Lucy', 'AI_Lucy']) @db_session def test_3(self): global do_before_insert def do_before_insert(person): some_person = Person.select().first() # should not cause infinite recursion p4 = Person(id=4, name='Bob', age=16) p4.flush() if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_hybrid_methods_and_properties.py0000666000000000000000000002046500000000000023366 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') sep = ' ' class Person(db.Entity): id = PrimaryKey(int) first_name = Required(str) last_name = Required(str) favorite_color = Optional(str) cars = Set(lambda: Car) @property def full_name(self): return self.first_name + sep + self.last_name @property def full_name_2(self): return concat(self.first_name, sep, self.last_name) # tests using of function `concat` from external scope @property def has_car(self): return not self.cars.is_empty() def cars_by_color1(self, color): return select(car for car in self.cars if car.color == color) def cars_by_color2(self, color): return self.cars.select(lambda car: car.color == color) @property def cars_price(self): return sum(c.price for c in self.cars) @property def incorrect_full_name(self): return self.first_name + ' ' + p.last_name # p is FakePerson instance here @classmethod def find_by_full_name(cls, full_name): return cls.select(lambda p: p.full_name_2 == full_name) def complex_method(self): result = '' for i in range(10): result += str(i) return result def simple_method(self): return self.complex_method() class FakePerson(object): pass p = FakePerson() p.last_name = '***' class Car(db.Entity): brand = Required(str) model = Required(str) owner = Optional(Person) year = Required(int) price = Required(int) color = Required(str) db.generate_mapping(create_tables=True) def simple_func(person): return person.full_name def complex_func(person): return person.complex_method() with db_session: p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') p3 = Person(id=3, first_name='Vitaliy', last_name='Abetkin') p4 = Person(id=4, first_name='Alexander', last_name='Tischenko', favorite_color='blue') c1 = Car(brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') c2 = Car(brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') c3 = Car(brand='Nissan', model='Skyline', owner=p2, year=2008, price=29900, color='black') c4 = Car(brand='Volkswagen', model='Passat', owner=p1, year=2012, price=9400, color='blue') c5 = Car(brand='Koenigsegg', model='CCXR', owner=p4, year=2016, price=4850000, color='white') c6 = Car(brand='Lada', model='Kalina', owner=p4, year=2015, price=5000, color='white') class TestHybridsAndProperties(unittest.TestCase): @db_session def test1(self): persons = select(p.full_name for p in Person if p.has_car)[:] self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexei Malashkevich', 'Alexander Tischenko'}) @db_session def test2(self): cars_prices = select(p.cars_price for p in Person)[:] self.assertEqual(set(cars_prices), {0, 29900, 37250, 4855000}) @db_session def test3(self): persons = select(p.full_name for p in Person if p.cars_price > 100000)[:] self.assertEqual(set(persons), {'Alexander Tischenko'}) @db_session def test4(self): persons = select(p.full_name for p in Person if not p.cars_price)[:] self.assertEqual(set(persons), {'Vitaliy Abetkin'}) @db_session def test5(self): persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color2('white') if c.price > 10000))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) @db_session def test6(self): persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color1('white') if c.price > 10000))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) @db_session def test7(self): c1 = Car[1] persons = select(p.full_name for p in Person if c1 in p.cars_by_color2('red'))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky'}) @db_session def test8(self): c1 = Car[1] persons = select(p.full_name for p in Person if c1 in p.cars_by_color1('red'))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky'}) @db_session def test9(self): persons = select(p.full_name for p in Person if p.cars_by_color1(p.favorite_color))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky'}) @db_session def test10(self): persons = select(p.full_name for p in Person if not p.cars_by_color1(p.favorite_color))[:] self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) @db_session def test11(self): persons = select(p.full_name for p in Person if p.cars_by_color2(p.favorite_color))[:] self.assertEqual(set(persons), {'Alexander Kozlovsky'}) @db_session def test12(self): persons = select(p.full_name for p in Person if not p.cars_by_color2(p.favorite_color))[:] self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) @db_session def test13(self): persons = select(p.full_name for p in Person if count(p.cars_by_color1('white')) > 1) self.assertEqual(set(persons), {'Alexander Tischenko'}) @db_session def test14(self): # This test checks if accessing function-specific globals works correctly persons = select(p.incorrect_full_name for p in Person if p.has_car)[:] self.assertEqual(set(persons), {'Alexander ***', 'Alexei ***', 'Alexander ***'}) @db_session def test15(self): # Test repeated use of the same generator with hybrid method/property that uses funciton from external scope result = Person.find_by_full_name('Alexander Kozlovsky') self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) result = Person.find_by_full_name('Alexander Kozlovsky') self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) result = Person.find_by_full_name('Alexander Tischenko') self.assertEqual(set(obj.last_name for obj in result), {'Tischenko'}) @db_session def test16(self): result = Person.select(lambda p: p.full_name == 'Alexander Kozlovsky') self.assertEqual(set(p.id for p in result), {1}) @db_session def test17(self): global sep sep = '.' try: result = Person.select(lambda p: p.full_name == 'Alexander.Kozlovsky') self.assertEqual(set(p.id for p in result), {1}) finally: sep = ' ' @db_session def test18(self): result = Person.select().filter(lambda p: p.full_name == 'Alexander Kozlovsky') self.assertEqual(set(p.id for p in result), {1}) @db_session def test19(self): global sep sep = '.' try: result = Person.select().filter(lambda p: p.full_name == 'Alexander.Kozlovsky') self.assertEqual(set(p.id for p in result), {1}) finally: sep = ' ' @db_session @raises_exception(TranslationError, 'p.complex_method(...) is too complex to decompile') def test_20(self): q = select(p.complex_method() for p in Person)[:] @db_session @raises_exception(TranslationError, 'p.to_dict(...) is too complex to decompile') def test_21(self): q = select(p.to_dict() for p in Person)[:] @db_session @raises_exception(TranslationError, 'self.complex_method(...) is too complex to decompile (inside Person.simple_method)') def test_22(self): q = select(p.simple_method() for p in Person)[:] @db_session def test_23(self): q = select(simple_func(p) for p in Person)[:] @db_session @raises_exception(TranslationError, 'person.complex_method(...) is too complex to decompile (inside complex_func)') def test_24(self): q = select(complex_func(p) for p in Person)[:] if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_indexes.py0000666000000000000000000000727500000000000016727 0ustar0000000000000000import sys, unittest from decimal import Decimal from datetime import date from pony.orm import * from pony.orm.tests.testutils import * class TestIndexes(unittest.TestCase): def test_1(self): db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(str) age = Required(int) composite_key(name, 'age') db.generate_mapping(create_tables=True) i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) self.assertEqual(i2.attrs, (Person.name, Person.age)) self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, True) table = db.schema.tables['Person'] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) db_index = table.indexes[name_column, age_column] self.assertEqual(db_index.is_pk, False) self.assertEqual(db_index.is_unique, True) def test_2(self): db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(str) age = Required(int) composite_index(name, 'age') db.generate_mapping(create_tables=True) i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) self.assertEqual(i2.attrs, (Person.name, Person.age)) self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, False) table = db.schema.tables['Person'] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) db_index = table.indexes[name_column, age_column] self.assertEqual(db_index.is_pk, False) self.assertEqual(db_index.is_unique, False) create_script = db.schema.generate_create_script() index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' self.assertTrue(index_sql in create_script) def test_3(self): db = Database('sqlite', ':memory:') class User(db.Entity): name = Required(str, unique=True) db.generate_mapping(create_tables=True) with db_session: u = User(id=1, name='A') with db_session: u = User[1] u.name = 'B' with db_session: u = User[1] self.assertEqual(u.name, 'B') def test_4(self): # issue 321 db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(str) age = Required(int) composite_key(name, age) db.generate_mapping(create_tables=True) with db_session: p1 = Person(id=1, name='John', age=19) with db_session: p1 = Person[1] p1.set(name='John', age=19) p1.delete() def test_5(self): db = Database('sqlite', ':memory:') class Table1(db.Entity): name = Required(str) table2s = Set('Table2') class Table2(db.Entity): height = Required(int) length = Required(int) table1 = Optional('Table1') composite_key(height, length, table1) db.generate_mapping(create_tables=True) with db_session: Table2(height=2, length=1) Table2.exists(height=2, length=1) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_inheritance.py0000666000000000000000000002606300000000000017555 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestInheritance(unittest.TestCase): def test_0(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) self.assertTrue(Entity1._root_ is Entity1) self.assertEqual(Entity1._discriminator_attr_, None) self.assertEqual(Entity1._discriminator_, None) def test_1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): a = Required(int) class Entity3(Entity1): b = Required(int) class Entity4(Entity2, Entity3): c = Required(int) self.assertTrue(Entity1._root_ is Entity1) self.assertTrue(Entity2._root_ is Entity1) self.assertTrue(Entity3._root_ is Entity1) self.assertTrue(Entity4._root_ is Entity1) self.assertTrue(Entity1._discriminator_attr_ is Entity1.classtype) self.assertTrue(Entity2._discriminator_attr_ is Entity1._discriminator_attr_) self.assertTrue(Entity3._discriminator_attr_ is Entity1._discriminator_attr_) self.assertTrue(Entity4._discriminator_attr_ is Entity1._discriminator_attr_) self.assertEqual(Entity1._discriminator_, 'Entity1') self.assertEqual(Entity2._discriminator_, 'Entity2') self.assertEqual(Entity3._discriminator_, 'Entity3') self.assertEqual(Entity4._discriminator_, 'Entity4') @raises_exception(ERDiagramError, "Multiple inheritance graph must be diamond-like. " "Entity Entity3 inherits from Entity1 and Entity2 entities which don't have common base class.") def test_2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(db.Entity): b = PrimaryKey(int) class Entity3(Entity1, Entity2): c = Required(int) @raises_exception(ERDiagramError, 'Attribute Entity3.a conflicts with attribute Entity2.a ' 'because both entities inherit from Entity1. ' 'To fix this, move attribute definition to base class') def test_3a(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): a = Required(int) class Entity3(Entity1): a = Required(int) def test3b(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): a = Set('Entity4') class Entity3(Entity1): a = Set('Entity4') class Entity4(db.Entity): x = Required(Entity2) y = Required(Entity3) db.generate_mapping(create_tables=True) self.assertTrue(Entity2.a not in Entity1._subclass_attrs_) self.assertTrue(Entity3.a not in Entity1._subclass_attrs_) @raises_exception(ERDiagramError, "Name 'a' hides base attribute Entity1.a") def test_4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) a = Required(int) class Entity2(Entity1): a = Required(int) @raises_exception(ERDiagramError, "Primary key cannot be redefined in derived classes") def test_5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(Entity1): b = PrimaryKey(int) def test_6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Discriminator(str) b = Required(int) class Entity2(Entity1): c = Required(int) self.assertTrue(Entity1._discriminator_attr_ is Entity1.a) self.assertTrue(Entity2._discriminator_attr_ is Entity1.a) @raises_exception(TypeError, "Discriminator value for entity Entity1 " "with custom discriminator column 'a' of 'int' type is not set") def test_7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Discriminator(int) b = Required(int) class Entity2(Entity1): c = Required(int) def test_8(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) b = Required(int) class Entity2(Entity1): _discriminator_ = 2 c = Required(int) db.generate_mapping(create_tables=True) with db_session: x = Entity1(b=10) y = Entity2(b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] self.assertTrue(isinstance(x, Entity1)) self.assertTrue(isinstance(y, Entity2)) self.assertEqual(x.a, 1) self.assertEqual(y.a, 2) def test_9(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _discriminator_ = '1' a = Discriminator(int) b = Required(int) class Entity2(Entity1): _discriminator_ = '2' c = Required(int) db.generate_mapping(create_tables=True) with db_session: x = Entity1(b=10) y = Entity2(b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] self.assertTrue(isinstance(x, Entity1)) self.assertTrue(isinstance(y, Entity2)) self.assertEqual(x.a, 1) self.assertEqual(y.a, 2) @raises_exception(TypeError, "Incorrect discriminator value is set for Entity2 attribute 'a' of 'int' type: 'zzz'") def test_10(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) b = Required(int) class Entity2(Entity1): _discriminator_ = 'zzz' c = Required(int) def test_11(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) b = Required(int) composite_index(a, b) @raises_exception(ERDiagramError, 'Invalid use of attribute Entity1.a in entity Entity2') def test_12(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) class Entity2(db.Entity): b = Required(int) composite_index(Entity1.a, b) def test_13(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): b = Required(int) composite_index(Entity1.a, b) self.assertEqual([ index.attrs for index in Entity2._indexes_ ], [ (Entity2.id,), (Entity2.a, Entity2.b) ]) def test_14(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): d = Discriminator(str) a = Required(int) class Entity2(Entity1): b = Required(int) composite_index(Entity1.d, Entity1.a, b) self.assertEqual([ index.attrs for index in Entity2._indexes_ ], [ (Entity2.id,), (Entity2.d, Entity2.a, Entity2.b) ]) def test_15(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): d = Discriminator(str) a = Required(int) class Entity2(Entity1): b = Required(int) composite_index('d', 'id', 'a', 'b') self.assertEqual([ index.attrs for index in Entity2._indexes_ ], [ (Entity2.id,), (Entity2.d, Entity2.id, Entity2.a, Entity2.b) ]) def test_16(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): b = Required(int) composite_index('classtype', 'id', 'a', b) self.assertEqual([ index.attrs for index in Entity2._indexes_ ], [ (Entity2.id,), (Entity2.classtype, Entity2.id, Entity2.a, Entity2.b) ]) def test_17(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): t = Discriminator(str) a = Required(int) b = Required(int) composite_index(t, a, b) class Entity2(Entity1): c = Required(int) self.assertEqual([ index.attrs for index in Entity1._indexes_ ], [ (Entity1.id,), (Entity1.t, Entity1.a, Entity1.b) ]) def test_18(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) class Entity2(db.Entity1): b = Required(int) class Entity3(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: x = Entity1(a=10) y = Entity2(a=20, b=30) z = Entity3(a=40, c=50) with db_session: result = select(e for e in Entity1 if e.b == 30 or e.c == 50) self.assertEqual([ e.id for e in result ], [ 2, 3 ]) def test_discriminator_1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Discriminator(str) b = Required(int) PrimaryKey(a, b) class Entity2(db.Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: x = Entity1(b=10) y = Entity2(b=20, c=30) with db_session: obj = Entity1.get(b=20) self.assertEqual(obj.a, 'Entity2') self.assertEqual(obj.b, 20) self.assertEqual(obj._pkval_, ('Entity2', 20)) with db_session: obj = Entity1['Entity2', 20] self.assertIsInstance(obj, Entity2) self.assertEqual(obj.a, 'Entity2') self.assertEqual(obj.b, 20) self.assertEqual(obj._pkval_, ('Entity2', 20)) @raises_exception(TypeError, "Invalid discriminator attribute value for Foo. Expected: 'Foo', got: 'Baz'") def test_discriminator_2(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): id = PrimaryKey(int) a = Discriminator(str) b = Required(int) class Bar(db.Entity): c = Required(int) db.generate_mapping(create_tables=True) with db_session: x = Foo(id=1, a='Baz', b=100) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_inner_join_syntax.py0000666000000000000000000000573700000000000021031 0ustar0000000000000000import unittest from pony.orm import * from pony import orm class TestJoin(unittest.TestCase): exclude_fixtures = {'test': ['clear_tables']} @classmethod def setUpClass(cls): db = cls.db = Database('sqlite', ':memory:') class Genre(db.Entity): name = orm.Optional(str) # TODO primary key artists = orm.Set('Artist') favorite = orm.Optional(bool) index = orm.Optional(int) class Hobby(db.Entity): name = orm.Required(str) artists = orm.Set('Artist') class Artist(db.Entity): name = orm.Required(str) age = orm.Optional(int) hobbies = orm.Set(Hobby) genres = orm.Set(Genre) db.generate_mapping(create_tables=True) with orm.db_session: pop = Genre(name='pop') rock = Genre(name='rock') Artist(name='Sia', age=40, genres=[pop, rock]) Artist(name='Lady GaGa', age=30, genres=[pop]) pony.options.INNER_JOIN_SYNTAX = True @db_session def test_join_1(self): result = select(g.id for g in self.db.Genre for a in g.artists if a.name.startswith('S'))[:] self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre" INNER JOIN "Artist" "a" ON "t-1"."artist" = "a"."id" WHERE "a"."name" LIKE 'S%'""") @db_session def test_join_2(self): result = select(g.id for g in self.db.Genre for a in self.db.Artist if JOIN(a in g.artists) and a.name.startswith('S'))[:] self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "a" WHERE "t-1"."artist" = "a"."id" AND "a"."name" LIKE 'S%'""") @db_session def test_join_3(self): result = select(g.id for g in self.db.Genre for x in self.db.Artist for a in self.db.Artist if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "x", "Artist" "a" WHERE "t-1"."artist" = "a"."id" AND "a"."name" LIKE 'S%' AND "g"."id" = "x"."id"''') @db_session def test_join_4(self): result = select(g.id for g in self.db.Genre for a in self.db.Artist for x in self.db.Artist if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "a", "Artist" "x" WHERE "t-1"."artist" = "a"."id" AND "a"."name" LIKE 'S%' AND "g"."id" = "x"."id"''') if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_isinstance.py0000666000000000000000000000644000000000000017421 0ustar0000000000000000from datetime import date from decimal import Decimal import unittest from pony.orm import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:', create_db=True) class Person(db.Entity): id = PrimaryKey(int, auto=True) name = Required(str) dob = Optional(date) ssn = Required(str, unique=True) class Student(Person): group = Required("Group") mentor = Optional("Teacher") attend_courses = Set("Course") class Teacher(Person): teach_courses = Set("Course") apprentices = Set("Student") salary = Required(Decimal) class Assistant(Student, Teacher): pass class Professor(Teacher): position = Required(str) class Group(db.Entity): number = PrimaryKey(int) students = Set("Student") class Course(db.Entity): name = Required(str) semester = Required(int) students = Set(Student) teachers = Set(Teacher) PrimaryKey(name, semester) db.generate_mapping(create_tables=True) with db_session: p = Person(name='Person1', ssn='SSN1') g = Group(number=123) prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) s2 = Student(name='Student2', group=g, ssn='SSN3') class TestVolatile(unittest.TestCase): @db_session def test_1(self): q = select(p.name for p in Person if isinstance(p, Student)) self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2'}) @db_session def test_2(self): q = select(p.name for p in Person if not isinstance(p, Student)) self.assertEqual(set(q), {'Person1', 'Professor1'}) @db_session def test_3(self): q = select(p.name for p in Student if isinstance(p, Professor)) self.assertEqual(set(q), set()) @db_session def test_4(self): q = select(p.name for p in Person if not isinstance(p, Person)) self.assertEqual(set(q), set()) @db_session def test_5(self): q = select(p.name for p in Person if isinstance(p, (Student, Teacher))) self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) @db_session def test_6(self): q = select(p.name for p in Person if isinstance(p, Student) and isinstance(p, Teacher)) self.assertEqual(set(q), {'Assistant1', 'Assistant2'}) @db_session def test_7(self): q = select(p.name for p in Person if (isinstance(p, Student) and p.ssn == 'SSN2') or (isinstance(p, Professor) and p.salary > 500)) self.assertEqual(set(q), {'Student1', 'Professor1'}) @db_session def test_8(self): q = select(p.name for p in Person if isinstance(p, Person)) self.assertEqual(set(q), {'Person1', 'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) @db_session def test_9(self): q = select(g.number for g in Group if isinstance(g, Group)) self.assertEqual(set(q), {123}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_json.py0000666000000000000000000006410400000000000016233 0ustar0000000000000000from pony.py23compat import basestring, pickle import unittest from pony.orm import * from pony.orm.tests.testutils import raises_exception, raises_if from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict db = Database('sqlite', ':memory:') class Product(db.Entity): name = Required(str) info = Optional(Json) tags = Optional(Json) db.generate_mapping(create_tables=True) class TestJson(unittest.TestCase): def setUp(self): with db_session: Product.select().delete(bulk=True) flush() Product( name='Apple iPad Air 2', info={ 'name': 'Apple iPad Air 2', 'display': { 'size': 9.7, 'resolution': [2048, 1536], 'matrix-type': 'IPS', 'multi-touch': True }, 'os': { 'type': 'iOS', 'version': '8' }, 'cpu': 'Apple A8X', 'ram': '8GB', 'colors': ['Gold', 'Silver', 'Space Gray'], 'models': [ { 'name': 'Wi-Fi', 'capacity': ['16GB', '64GB'], 'height': 240, 'width': 169.5, 'depth': 6.1, 'weight': 437, }, { 'name': 'Wi-Fi + Cellular', 'capacity': ['16GB', '64GB'], 'height': 240, 'width': 169.5, 'depth': 6.1, 'weight': 444, }, ], 'discontinued': False, 'videoUrl': None, 'non-ascii-attr': u'\u0442\u0435\u0441\u0442' }, tags=['Tablets', 'Apple', 'Retina']) def test(self): with db_session: result = select(p for p in Product)[:] self.assertEqual(len(result), 1) p = result[0] p.info['os']['version'] = '9' with db_session: result = select(p for p in Product)[:] self.assertEqual(len(result), 1) p = result[0] self.assertEqual(p.info['os']['version'], '9') @db_session def test_query_int(self): val = get(p.info['display']['resolution'][0] for p in Product) self.assertEqual(val, 2048) @db_session def test_query_float(self): val = get(p.info['display']['size'] for p in Product) self.assertAlmostEqual(val, 9.7) @db_session def test_query_true(self): val = get(p.info['display']['multi-touch'] for p in Product) self.assertIs(val, True) @db_session def test_query_false(self): val = get(p.info['discontinued'] for p in Product) self.assertIs(val, False) @db_session def test_query_null(self): val = get(p.info['videoUrl'] for p in Product) self.assertIs(val, None) @db_session def test_query_list(self): val = get(p.info['colors'] for p in Product) self.assertListEqual(val, ['Gold', 'Silver', 'Space Gray']) self.assertNotIsInstance(val, TrackedValue) @db_session def test_query_dict(self): val = get(p.info['display'] for p in Product) self.assertDictEqual(val, { 'size': 9.7, 'resolution': [2048, 1536], 'matrix-type': 'IPS', 'multi-touch': True }) self.assertNotIsInstance(val, TrackedValue) @db_session def test_query_json_field(self): val = get(p.info for p in Product) self.assertDictEqual(val['display'], { 'size': 9.7, 'resolution': [2048, 1536], 'matrix-type': 'IPS', 'multi-touch': True }) self.assertNotIsInstance(val['display'], TrackedDict) val = get(p.tags for p in Product) self.assertListEqual(val, ['Tablets', 'Apple', 'Retina']) self.assertNotIsInstance(val, TrackedList) @db_session def test_get_object(self): p = get(p for p in Product) self.assertDictEqual(p.info['display'], { 'size': 9.7, 'resolution': [2048, 1536], 'matrix-type': 'IPS', 'multi-touch': True }) self.assertEqual(p.info['discontinued'], False) self.assertEqual(p.info['videoUrl'], None) self.assertListEqual(p.tags, ['Tablets', 'Apple', 'Retina']) self.assertIsInstance(p.info, TrackedDict) self.assertIsInstance(p.info['display'], TrackedDict) self.assertIsInstance(p.info['colors'], TrackedList) self.assertIsInstance(p.tags, TrackedList) def test_set_str(self): with db_session: p = get(p for p in Product) p.info['os']['version'] = '9' with db_session: p = get(p for p in Product) self.assertEqual(p.info['os']['version'], '9') def test_set_int(self): with db_session: p = get(p for p in Product) p.info['display']['resolution'][0] += 1 with db_session: p = get(p for p in Product) self.assertEqual(p.info['display']['resolution'][0], 2049) def test_set_true(self): with db_session: p = get(p for p in Product) p.info['discontinued'] = True with db_session: p = get(p for p in Product) self.assertIs(p.info['discontinued'], True) def test_set_false(self): with db_session: p = get(p for p in Product) p.info['display']['multi-touch'] = False with db_session: p = get(p for p in Product) self.assertIs(p.info['display']['multi-touch'], False) def test_set_null(self): with db_session: p = get(p for p in Product) p.info['display'] = None with db_session: p = get(p for p in Product) self.assertIs(p.info['display'], None) def test_set_list(self): with db_session: p = get(p for p in Product) p.info['colors'] = ['Pink', 'Black'] with db_session: p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Pink', 'Black']) def test_list_del(self): with db_session: p = get(p for p in Product) del p.info['colors'][1] with db_session: p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'Space Gray']) def test_list_append(self): with db_session: p = get(p for p in Product) p.info['colors'].append('White') with db_session: p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'Silver', 'Space Gray', 'White']) def test_list_set_slice(self): with db_session: p = get(p for p in Product) p.info['colors'][1:] = ['White'] with db_session: p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'White']) def test_list_set_item(self): with db_session: p = get(p for p in Product) p.info['colors'][1] = 'White' with db_session: p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'White', 'Space Gray']) def test_set_dict(self): with db_session: p = get(p for p in Product) p.info['display']['resolution'] = {'width': 2048, 'height': 1536} with db_session: p = get(p for p in Product) self.assertDictEqual(p.info['display']['resolution'], {'width': 2048, 'height': 1536}) def test_dict_del(self): with db_session: p = get(p for p in Product) del p.info['os']['version'] with db_session: p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS'}) def test_dict_pop(self): with db_session: p = get(p for p in Product) p.info['os'].pop('version') with db_session: p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS'}) def test_dict_update(self): with db_session: p = get(p for p in Product) p.info['os'].update(version='9') with db_session: p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) def test_dict_set_item(self): with db_session: p = get(p for p in Product) p.info['os']['version'] = '9' with db_session: p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) @db_session def test_set_same_value(self): p = get(p for p in Product) p.info = p.info @db_session def test_len(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, 'Oracle does not provide `length` function for JSON arrays'): val = select(len(p.tags) for p in Product).first() self.assertEqual(val, 3) val = select(len(p.info['colors']) for p in Product).first() self.assertEqual(val, 3) @db_session def test_equal_str(self): p = get(p for p in Product if p.info['name'] == 'Apple iPad Air 2') self.assertTrue(p) @db_session def test_unicode_key(self): p = get(p for p in Product if p.info[u'name'] == 'Apple iPad Air 2') self.assertTrue(p) @db_session def test_equal_string_attr(self): p = get(p for p in Product if p.info['name'] == p.name) self.assertTrue(p) @db_session def test_equal_param(self): x = 'Apple iPad Air 2' p = get(p for p in Product if p.name == x) self.assertTrue(p) @db_session def test_composite_param(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle doesn't allow parameters in JSON paths"): key = 'models' index = 0 val = get(p.info[key][index]['name'] for p in Product) self.assertEqual(val, 'Wi-Fi') @db_session def test_composite_param_in_condition(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle doesn't allow parameters in JSON paths"): key = 'models' index = 0 p = get(p for p in Product if p.info[key][index]['name'] == 'Wi-Fi') self.assertIsNotNone(p) @db_session def test_equal_json_1(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: " "p.info['os'] == {'type':'iOS', 'version':'8'}"): p = get(p for p in Product if p.info['os'] == {'type': 'iOS', 'version': '8'}) self.assertTrue(p) @db_session def test_equal_json_2(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: " "p.info['os'] == Json({'type':'iOS', 'version':'8'})"): p = get(p for p in Product if p.info['os'] == Json({'type': 'iOS', 'version': '8'})) self.assertTrue(p) @db_session def test_ne_json_1(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['os'] != {}"): p = get(p for p in Product if p.info['os'] != {}) self.assertTrue(p) p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) self.assertFalse(p) @db_session def test_ne_json_2(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['os'] != Json({})"): p = get(p for p in Product if p.info['os'] != Json({})) self.assertTrue(p) p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) self.assertFalse(p) @db_session def test_equal_list_1(self): colors = ['Gold', 'Silver', 'Space Gray'] with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): p = get(p for p in Product if p.info['colors'] == Json(colors)) self.assertTrue(p) @db_session @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == ['Gold']") def test_equal_list_2(self): p = get(p for p in Product if p.info['colors'] == ['Gold']) @db_session def test_equal_list_3(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) self.assertIsNotNone(p) @db_session def test_equal_list_4(self): colors = ['Gold', 'Silver', 'Space Gray'] with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): p = get(p for p in Product if p.info['colors'] == Json(colors)) self.assertTrue(p) @db_session @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == []") def test_equal_empty_list_1(self): p = get(p for p in Product if p.info['colors'] == []) @db_session def test_equal_empty_list_2(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json([])"): p = get(p for p in Product if p.info['colors'] == Json([])) self.assertIsNone(p) @db_session def test_ne_list(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) self.assertTrue(p) @db_session def test_ne_empty_list(self): with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json([])"): p = get(p for p in Product if p.info['colors'] != Json([])) self.assertTrue(p) @db_session def test_dbval2val(self): p = select(p for p in Product)[:][0] attr = Product.info val = p._vals_[attr] dbval = p._dbvals_[attr] self.assertIsInstance(dbval, basestring) self.assertIsInstance(val, TrackedValue) p.info['os']['version'] = '9' self.assertIs(val, p._vals_[attr]) self.assertIs(dbval, p._dbvals_[attr]) p.flush() self.assertIs(val, p._vals_[attr]) self.assertNotEqual(dbval, p._dbvals_[attr]) @db_session def test_wildcard_path_1(self): with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): names = get(p.info['models'][:]['name'] for p in Product) self.assertSetEqual(set(names), {'Wi-Fi', 'Wi-Fi + Cellular'}) @db_session def test_wildcard_path_2(self): with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): values = get(p.info['os'][...] for p in Product) self.assertSetEqual(set(values), {'iOS', '8'}) @db_session def test_wildcard_path_3(self): with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): names = get(p.info[...][0]['name'] for p in Product) self.assertSetEqual(set(names), {'Wi-Fi'}) @db_session def test_wildcard_path_4(self): if db.provider.dialect == 'Oracle': raise unittest.SkipTest with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, '...does not support wildcards in JSON path...'): values = get(p.info[...][:][...][:] for p in Product)[:] self.assertSetEqual(set(values), {'16GB', '64GB'}) @db_session def test_wildcard_path_with_params(self): if db.provider.dialect != 'Oracle': exc_msg = '...does not support wildcards in JSON path...' else: exc_msg = "Oracle doesn't allow parameters in JSON paths" with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): key = 'models' index = 0 values = get(p.info[key][:]['capacity'][index] for p in Product) self.assertListEqual(values, ['16GB', '16GB']) @db_session def test_wildcard_path_with_params_as_string(self): if db.provider.dialect != 'Oracle': exc_msg = '...does not support wildcards in JSON path...' else: exc_msg = "Oracle doesn't allow parameters in JSON paths" with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): key = 'models' index = 0 values = get("p.info[key][:]['capacity'][index] for p in Product") self.assertListEqual(values, ['16GB', '16GB']) @db_session def test_wildcard_path_in_condition(self): errors = { 'MySQL': 'Wildcards are not allowed in json_contains()', 'SQLite': '...does not support wildcards in JSON path...', 'PostgreSQL': '...does not support wildcards in JSON path...' } dialect = db.provider.dialect with raises_if(self, dialect in errors, TranslationError, errors.get(dialect)): p = get(p for p in Product if '16GB' in p.info['models'][:]['capacity']) self.assertTrue(p) ##### 'key' in json @db_session def test_in_dict(self): obj = get(p for p in Product if 'resolution' in p.info['display']) self.assertTrue(obj) @db_session def test_not_in_dict(self): obj = get(p for p in Product if 'resolution' not in p.info['display']) self.assertIs(obj, None) obj = get(p for p in Product if 'xyz' not in p.info['display']) self.assertTrue(obj) @db_session def test_in_list(self): obj = get(p for p in Product if 'Gold' in p.info['colors']) self.assertTrue(obj) @db_session def test_not_in_list(self): obj = get(p for p in Product if 'White' not in p.info['colors']) self.assertTrue(obj) obj = get(p for p in Product if 'Gold' not in p.info['colors']) self.assertIs(obj, None) @db_session def test_var_in_json(self): with raises_if(self, db.provider.dialect == 'Oracle', TypeError, "For `key in JSON` operation Oracle supports literal key values only, " "parameters are not allowed: key in p.info['colors']"): key = 'Gold' obj = get(p for p in Product if key in p.info['colors']) self.assertTrue(obj) @db_session def test_select_first(self): # query should not contain ORDER BY obj = select(p.info for p in Product).first() self.assertNotIn('order by', db.last_sql.lower()) def test_sql_inject(self): # test quote in json is not causing error with db_session: p = select(p for p in Product).first() p.info['display']['size'] = "0' 9.7\"" with db_session: p = select(p for p in Product).first() self.assertEqual(p.info['display']['size'], "0' 9.7\"") @db_session def test_int_compare(self): p = get(p for p in Product if p.info['display']['resolution'][0] == 2048) self.assertTrue(p) p = get(p for p in Product if p.info['display']['resolution'][0] != 2048) self.assertIsNone(p) p = get(p for p in Product if p.info['display']['resolution'][0] < 2048) self.assertIs(p, None) p = get(p for p in Product if p.info['display']['resolution'][0] <= 2048) self.assertTrue(p) p = get(p for p in Product if p.info['display']['resolution'][0] > 2048) self.assertIs(p, None) p = get(p for p in Product if p.info['display']['resolution'][0] >= 2048) self.assertTrue(p) @db_session def test_float_compare(self): p = get(p for p in Product if p.info['display']['size'] > 9.5) self.assertTrue(p) p = get(p for p in Product if p.info['display']['size'] < 9.8) self.assertTrue(p) p = get(p for p in Product if p.info['display']['size'] < 9.5) self.assertIsNone(p) p = get(p for p in Product if p.info['display']['size'] > 9.8) self.assertIsNone(p) @db_session def test_str_compare(self): p = get(p for p in Product if p.info['ram'] == '8GB') self.assertTrue(p) p = get(p for p in Product if p.info['ram'] != '8GB') self.assertIsNone(p) p = get(p for p in Product if p.info['ram'] < '9GB') self.assertTrue(p) p = get(p for p in Product if p.info['ram'] > '7GB') self.assertTrue(p) p = get(p for p in Product if p.info['ram'] > '9GB') self.assertIsNone(p) p = get(p for p in Product if p.info['ram'] < '7GB') self.assertIsNone(p) @db_session def test_bool_compare(self): p = get(p for p in Product if p.info['display']['multi-touch'] == True) self.assertTrue(p) p = get(p for p in Product if p.info['display']['multi-touch'] is True) self.assertTrue(p) p = get(p for p in Product if p.info['display']['multi-touch'] == False) self.assertIsNone(p) p = get(p for p in Product if p.info['display']['multi-touch'] is False) self.assertIsNone(p) p = get(p for p in Product if p.info['discontinued'] == False) self.assertTrue(p) p = get(p for p in Product if p.info['discontinued'] == True) self.assertIsNone(p) @db_session def test_none_compare(self): p = get(p for p in Product if p.info['videoUrl'] is None) self.assertTrue(p) p = get(p for p in Product if p.info['videoUrl'] is not None) self.assertIsNone(p) @db_session def test_none_for_nonexistent_path(self): p = get(p for p in Product if p.info['some_attr'] is None) self.assertTrue(p) @db_session def test_str_cast(self): p = get(coalesce(str(p.name), 'empty') for p in Product) self.assertTrue('AS text' in db.last_sql) @db_session def test_int_cast(self): p = get(coalesce(int(p.info['os']['version']), 0) for p in Product) self.assertTrue('as integer' in db.last_sql) def test_nonzero(self): with db_session: delete(p for p in Product) Product(name='P1', info=dict(id=1, val=True)) Product(name='P2', info=dict(id=2, val=False)) Product(name='P3', info=dict(id=3, val=0)) Product(name='P4', info=dict(id=4, val=1)) Product(name='P5', info=dict(id=5, val='')) Product(name='P6', info=dict(id=6, val='x')) Product(name='P7', info=dict(id=7, val=[])) Product(name='P8', info=dict(id=8, val=[1, 2, 3])) Product(name='P9', info=dict(id=9, val={})) Product(name='P10', info=dict(id=10, val={'a': 'b'})) Product(name='P11', info=dict(id=11)) Product(name='P12', info=dict(id=12, val='True')) Product(name='P13', info=dict(id=13, val='False')) Product(name='P14', info=dict(id=14, val='0')) Product(name='P15', info=dict(id=15, val='1')) Product(name='P16', info=dict(id=16, val='""')) Product(name='P17', info=dict(id=17, val='[]')) Product(name='P18', info=dict(id=18, val='{}')) with db_session: val = select(p.info['id'] for p in Product if not p.info['val']) self.assertEqual(tuple(sorted(val)), (2, 3, 5, 7, 9, 11)) @db_session def test_optimistic_check(self): p1 = Product.select().first() p1.info['foo'] = 'bar' flush() p1.name = 'name2' flush() p1.name = 'name3' flush() @db_session def test_avg(self): result = select(avg(p.info['display']['size']) for p in Product).first() self.assertAlmostEqual(result, 9.7) def test_pickle(self): with db_session: p1 = Product.select().first() data = pickle.dumps(p1) with db_session: p1 = pickle.loads(data) p1.name = 'name2' flush() rollback() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_lazy.py0000666000000000000000000000405300000000000016236 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * class TestLazy(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class X(self.db.Entity): a = Required(int) b = Required(unicode, lazy=True) self.X = X self.db.generate_mapping(create_tables=True) with db_session: x1 = X(a=1, b='first') x2 = X(a=2, b='second') x3 = X(a=3, b='third') @db_session def test_lazy_1(self): X = self.X x1 = X[1] self.assertTrue(X.a in x1._vals_) self.assertTrue(X.b not in x1._vals_) b = x1.b self.assertEqual(b, 'first') @db_session def test_lazy_2(self): X = self.X x1 = X[1] x2 = X[2] x3 = X[3] self.assertTrue(X.b not in x1._vals_) self.assertTrue(X.b not in x2._vals_) self.assertTrue(X.b not in x3._vals_) b = x1.b self.assertTrue(X.b in x1._vals_) self.assertTrue(X.b not in x2._vals_) self.assertTrue(X.b not in x3._vals_) @db_session def test_lazy_3(self): # coverage of https://github.com/ponyorm/pony/issues/49 X = self.X x1 = X.get(b='first') self.assertTrue(X._bits_[X.b] & x1._rbits_) self.assertTrue(X.b, x1._vals_) @db_session def test_lazy_4(self): # coverage of https://github.com/ponyorm/pony/issues/49 X = self.X result = select(x for x in X if x.b == 'first')[:] for x in result: self.assertTrue(X._bits_[X.b] & x._rbits_) self.assertTrue(X.b in x._vals_) @db_session def test_lazy_5(self): # coverage of https://github.com/ponyorm/pony/issues/49 X = self.X result = select(x for x in X if x.b == 'first' if count() > 0)[:] for x in result: self.assertFalse(X._bits_[X.b] & x._rbits_) self.assertTrue(X.b not in x._vals_) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_mapping.py0000666000000000000000000002402300000000000016711 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.dbschema import DBSchemaError from pony.orm.tests.testutils import * class TestColumnsMapping(unittest.TestCase): # raise exception if mapping table by default is not found @raises_exception(OperationalError, 'no such table: Student') def test_table_check1(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = PrimaryKey(str) sql = "drop table if exists Student;" with db_session: db.get_connection().executescript(sql) db.generate_mapping() # no exception if table was specified def test_table_check2(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = PrimaryKey(str) sql = """ drop table if exists Student; create table Student( name varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() self.assertEqual(db.schema.tables['Student'].column_list[0].name, 'name') # raise exception if specified mapping table is not found @raises_exception(OperationalError, 'no such table: Table1') def test_table_check3(self): db = Database('sqlite', ':memory:') class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) db.generate_mapping() # no exception if table was specified def test_table_check4(self): db = Database('sqlite', ':memory:') class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) sql = """ drop table if exists Table1; create table Table1( name varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() self.assertEqual(db.schema.tables['Table1'].column_list[0].name, 'name') # 'id' field created if primary key is not defined @raises_exception(OperationalError, 'no such column: Student.id') def test_table_check5(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str) sql = """ drop table if exists Student; create table Student( name varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() # 'id' field created if primary key is not defined def test_table_check6(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str) sql = """ drop table if exists Student; create table Student( id integer primary key, name varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() self.assertEqual(db.schema.tables['Student'].column_list[0].name, 'id') @raises_exception(DBSchemaError, "Column 'name' already exists in table 'Student'") def test_table_check7(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str, column='name') record = Required(str, column='name') sql = """ drop table if exists Student; create table Student( id integer primary key, name varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() # user can specify column name for an attribute def test_custom_column_name(self): db = Database('sqlite', ':memory:') class Student(db.Entity): name = PrimaryKey(str, column='name1') sql = """ drop table if exists Student; create table Student( name1 varchar(30) ); """ with db_session: db.get_connection().executescript(sql) db.generate_mapping() self.assertEqual(db.schema.tables['Student'].column_list[0].name, 'name1') # Required-Required raises exception @raises_exception(ERDiagramError, 'At least one attribute of one-to-one relationship Entity1.attr1 - Entity2.attr2 must be optional') def test_relations1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) db.generate_mapping() # no exception Optional-Required def test_relations2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) db.generate_mapping(create_tables=True) # no exception Optional-Required(column) def test_relations3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2", column='a') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) def test_relations4(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1, column='a') db.generate_mapping(create_tables=True) self.assertEqual(Entity1.attr1.columns, ['attr1']) self.assertEqual(Entity2.attr2.columns, ['a']) # no exception Optional-Optional def test_relations5(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) # no exception Optional-Optional(column) def test_relations6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) def test_relations7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1, column='a1') db.generate_mapping(create_tables=True) self.assertEqual(Entity1.attr1.columns, ['a']) self.assertEqual(Entity2.attr2.columns, ['a1']) def test_columns1(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = PrimaryKey(int) attr1 = Set("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) column_list = db.schema.tables['Entity2'].column_list self.assertEqual(len(column_list), 2) self.assertEqual(column_list[0].name, 'id') self.assertEqual(column_list[1].name, 'attr2') def test_columns2(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(int) b = Required(int) PrimaryKey(a, b) attr1 = Set("Entity2") class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) column_list = db.schema.tables['Entity2'].column_list self.assertEqual(len(column_list), 3) self.assertEqual(column_list[0].name, 'id') self.assertEqual(column_list[1].name, 'attr2_a') self.assertEqual(column_list[2].name, 'attr2_b') def test_columns3(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) db.generate_mapping(create_tables=True) self.assertEqual(Entity1.attr1.columns, ['attr1']) self.assertEqual(Entity2.attr2.columns, []) def test_columns4(self): db = Database('sqlite', ':memory:') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional('Entity1') class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional(Entity2) db.generate_mapping(create_tables=True) self.assertEqual(Entity1.attr1.columns, ['attr1']) self.assertEqual(Entity2.attr2.columns, []) @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping1(self): db = Database('sqlite', ':memory:') class E1(db.Entity): a1 = Required(int) select(e for e in E1) @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping2(self): db = Database('sqlite', ':memory:') class E1(db.Entity): a1 = Required(int) e = E1(a1=1) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_objects_to_save_cleanup.py0000666000000000000000000000327400000000000022143 0ustar0000000000000000 import unittest from pony.orm import * db = Database() class TestPost(db.Entity): category = Optional('TestCategory') name = Optional(str, default='Noname') class TestCategory(db.Entity): posts = Set(TestPost) db.bind('sqlite', ':memory:') db.generate_mapping(create_tables=True) class EntityStatusTestCase(object): def make_flush(self, obj=None): raise NotImplementedError @db_session def test_delete_updated(self): p = TestPost() self.make_flush(p) p.name = 'Pony' self.assertEqual(p._status_, 'modified') self.make_flush(p) self.assertEqual(p._status_, 'updated') p.delete() self.assertEqual(p._status_, 'marked_to_delete') self.make_flush(p) self.assertEqual(p._status_, 'deleted') @db_session def test_delete_inserted(self): p = TestPost() self.assertEqual(p._status_, 'created') self.make_flush(p) self.assertEqual(p._status_, 'inserted') p.delete() @db_session def test_cancelled(self): p = TestPost() self.assertEqual(p._status_, 'created') p.delete() self.assertEqual(p._status_, 'cancelled') self.make_flush(p) self.assertEqual(p._status_, 'cancelled') class EntityStatusTestCase_ObjectFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): obj.flush() class EntityStatusTestCase_FullFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): flush() # full flush ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_prefetching.py0000666000000000000000000001733500000000000017564 0ustar0000000000000000import sys, unittest from decimal import Decimal from datetime import date from pony.orm import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str) scholarship = Optional(int) gpa = Optional(Decimal, 3, 1) dob = Optional(date) group = Required('Group') courses = Set('Course') mentor = Optional('Teacher') biography = Optional(LongStr) class Group(db.Entity): number = PrimaryKey(int) major = Required(str, lazy=True) students = Set(Student) class Course(db.Entity): name = Required(str, unique=True) students = Set(Student) class Teacher(db.Entity): name = Required(str) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=1, major='Math') g2 = Group(number=2, major='Computer Sciense') c1 = Course(name='Math') c2 = Course(name='Physics') c3 = Course(name='Computer Science') t1 = Teacher(name='T1') t2 = Teacher(name='T2') Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1) Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2) Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) class TestPrefetching(unittest.TestCase): def test_1(self): with db_session: s1 = Student.select().first() g = s1.group self.assertEqual(g.major, 'Math') @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over') def test_2(self): with db_session: s1 = Student.select().first() g = s1.group g.major def test_3(self): with db_session: s1 = Student.select().prefetch(Group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') def test_4(self): with db_session: s1 = Student.select().prefetch(Student.group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') @raises_exception(TypeError, 'Argument of prefetch() query method must be entity class or attribute. Got: 111') def test_5(self): with db_session: Student.select().prefetch(111) @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over') def test_6(self): with db_session: name, group = select((s.name, s.group) for s in Student).first() group.major def test_7(self): with db_session: name, group = select((s.name, s.group) for s in Student).prefetch(Group, Group.major).first() self.assertEqual(group.major, 'Math') @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over') def test_8(self): with db_session: s1 = Student.select().first() set(s1.courses) @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over') def test_9(self): with db_session: s1 = Student.select().prefetch(Course).first() set(s1.courses) def test_10(self): with db_session: s1 = Student.select().prefetch(Student.courses).first() self.assertEqual(set(s1.courses.name), {'Math', 'Physics'}) @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].biography: the database session is over') def test_11(self): with db_session: s1 = Student.select().prefetch(Course).first() s1.biography def test_12(self): with db_session: s1 = Student.select().prefetch(Student.biography).first() self.assertEqual(s1.biography, 'S1 bio') self.assertEqual(db.last_sql, '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography" FROM "Student" "s" ORDER BY 1 LIMIT 1''') def test_13(self): db.merge_local_stats() with db_session: q = select(g for g in Group) for g in q: # 1 query for s in g.students: # 2 query b = s.biography # 5 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 8) def test_14(self): db.merge_local_stats() with db_session: q = select(g for g in Group).prefetch(Group.students) for g in q: # 1 query for s in g.students: # 1 query b = s.biography # 5 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 7) def test_15(self): with db_session: q = select(g for g in Group).prefetch(Group.students) q[:] db.merge_local_stats() with db_session: q = select(g for g in Group).prefetch(Group.students, Student.biography) for g in q: # 1 query for s in g.students: # 1 query b = s.biography # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 2) def test_16(self): db.merge_local_stats() with db_session: q = select(c for c in Course).prefetch(Course.students, Student.biography) for c in q: # 1 query for s in c.students: # 2 queries (as it is many-to-many relationship) b = s.biography # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 3) def test_17(self): db.merge_local_stats() with db_session: q = select(c for c in Course).prefetch(Course.students, Student.biography, Group, Group.major) for c in q: # 1 query for s in c.students: # 2 queries (as it is many-to-many relationship) m = s.group.major # 1 query b = s.biography # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 4) def test_18(self): db.merge_local_stats() with db_session: q = Group.select().prefetch(Group.students, Student.biography) for g in q: # 2 queries for s in g.students: m = s.mentor # 0 queries b = s.biography # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 2) def test_19(self): db.merge_local_stats() with db_session: q = Group.select().prefetch(Group.students, Student.biography, Student.mentor) mentors = set() for g in q: # 3 queries for s in g.students: m = s.mentor # 0 queries if m is not None: mentors.add(m) b = s.biography # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 3) for m in mentors: n = m.name # 0 queries query_count = db.local_stats[None].db_count self.assertEqual(query_count, 3) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_query.py0000666000000000000000000001473200000000000016431 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PYPY2, pickle import unittest from datetime import date from decimal import Decimal from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) gpa = Optional(Decimal,3,1) group = Required('Group') dob = Optional(date) class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=1) Student(id=1, name='S1', group=g1, gpa=3.1) Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2)) class TestQuery(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() @raises_exception(TypeError, "Query can only iterate over entity or another query (not a list of objects)") def test1(self): select(s for s in []) @raises_exception(TypeError, "Cannot iterate over non-entity object X") def test2(self): X = [1, 2, 3] select('x for x in X') def test3(self): g = Group[1] students = select(s for s in g.students) self.assertEqual(set(g.students), set(students)) @raises_exception(ExprEvalError, "`a` raises NameError: global name 'a' is not defined" if PYPY2 else "`a` raises NameError: name 'a' is not defined") def test4(self): select(a for s in Student) @raises_exception(TypeError, "Incomparable types '%s' and 'StrArray' in expression: s.name == x" % unicode.__name__) def test5(self): x = ['A'] select(s for s in Student if s.name == x) def test6(self): def f1(x): return float(x) + 1 students = select(s for s in Student if f1(s.gpa) > 4.25)[:] self.assertEqual({s.id for s in students}, {3}) @raises_exception(NotImplementedError, "m1") def test7(self): class C1(object): def method1(self, a, b): return a + b c = C1() m1 = c.method1 select(s for s in Student if m1(s.gpa, 1) > 3) @raises_exception(TypeError, "Expression `x` has unsupported type 'complex'") def test8(self): x = 1j select(s for s in Student if s.gpa == x) def test9(self): select(g for g in Group for s in db.Student) def test10(self): avg_gpa = avg(s.gpa for s in Student) self.assertEqual(round(avg_gpa, 6), 3.2) def test11(self): avg_gpa = avg(s.gpa for s in Student if s.id < 0) self.assertEqual(avg_gpa, None) def test12(self): sum_ss = sum(s.scholarship for s in Student) self.assertEqual(sum_ss, 300) def test13(self): sum_ss = sum(s.scholarship for s in Student if s.id < 0) self.assertEqual(sum_ss, 0) @raises_exception(TypeError, "'avg' is valid for numeric attributes only") def test14(self): avg(s.name for s in Student) def wrapper(self): return count(s for s in Student if s.scholarship > 0) def test15(self): c = self.wrapper() c = self.wrapper() self.assertEqual(c, 2) def test16(self): c = count(s.scholarship for s in Student if s.scholarship > 0) self.assertEqual(c, 2) def test17(self): s = get(s.scholarship for s in Student if s.id == 3) self.assertEqual(s, 200) def test18(self): s = get(s.scholarship for s in Student if s.id == 4) self.assertEqual(s, None) def test19(self): s = select(s for s in Student if s.id == 4).exists() self.assertEqual(s, False) def test20(self): r = min(s.scholarship for s in Student) self.assertEqual(r, 100) def test21(self): r = min(s.scholarship for s in Student if s.id < 2) self.assertEqual(r, None) def test22(self): r = max(s.scholarship for s in Student) self.assertEqual(r, 200) def test23(self): r = max(s.dob.year for s in Student) self.assertEqual(r, 2001) def test_first1(self): q = select(s for s in Student).order_by(Student.gpa) self.assertEqual(q.first(), Student[1]) def test_first2(self): q = select((s.name, s.group) for s in Student) self.assertEqual(q.first(), ('S1', Group[1])) def test_first3(self): q = select(s for s in Student) self.assertEqual(q.first(), Student[1]) def test_closures_1(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa fn = find_by_gpa(Decimal('3.1')) students = list(Student.select(fn)) self.assertEqual(students, [ Student[2], Student[3] ]) def test_closures_2(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa fn = find_by_gpa(Decimal('3.1')) q = select(s for s in Student) q = q.filter(fn) self.assertEqual(list(q), [ Student[2], Student[3] ]) @raises_exception(NameError, 'Free variable `gpa` referenced before assignment in enclosing scope') def test_closures_3(self): def find_by_gpa(): if False: gpa = Decimal('3.1') return lambda s: s.gpa > gpa fn = find_by_gpa() students = list(Student.select(fn)) def test_pickle(self): objects = select(s for s in Student if s.scholarship > 0).order_by(desc(Student.id)) data = pickle.dumps(objects) rollback() objects = pickle.loads(data) self.assertEqual([obj.id for obj in objects], [3, 2]) def test_bulk_delete_clear_query_cache(self): students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] self.assertEqual([s.id for s in students1], [2, 3]) Student.select(lambda s: s.id < 3).delete(bulk=True) students2 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] self.assertEqual([s.id for s in students2], [3]) rollback() students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] self.assertEqual([s.id for s in students1], [2, 3]) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_random.py0000666000000000000000000000210700000000000016535 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) db.generate_mapping(create_tables=True) with db_session: Person(id=1, name='John') Person(id=2, name='Mary') Person(id=3, name='Bob') Person(id=4, name='Mike') Person(id=5, name='Ann') class TestRandom(unittest.TestCase): @db_session def test_1(self): persons = Person.select().random(2) self.assertEqual(len(persons), 2) p1, p2 = persons self.assertNotEqual(p1.id, p2.id) self.assertTrue(p1.id in range(1, 6)) self.assertTrue(p2.id in range(1, 6)) @db_session def test_2(self): persons = Person.select_random(2) self.assertEqual(len(persons), 2) p1, p2 = persons self.assertNotEqual(p1.id, p2.id) self.assertTrue(p1.id in range(1, 6)) self.assertTrue(p2.id in range(1, 6)) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_raw_sql.py0000666000000000000000000001472200000000000016733 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import PYPY2 import unittest from datetime import date from pony.orm import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(str) age = Required(int) dob = Required(date) db.generate_mapping(create_tables=True) with db_session: Person(id=1, name='John', age=30, dob=date(1985, 1, 1)) Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20)) Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15)) class TestRawSQL(unittest.TestCase): @db_session def test_1(self): # raw_sql result can be treated as a logical expression persons = select(p for p in Person if raw_sql('abs("p"."age") > 25'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_2(self): # raw_sql result can be used for comparison persons = select(p for p in Person if raw_sql('abs("p"."age")') > 25)[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_3(self): # raw_sql can accept $parameters x = 25 persons = select(p for p in Person if raw_sql('abs("p"."age") > $x'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_4(self): # dynamic raw_sql content (1) x = 1 s = 'p.id > $x' persons = select(p for p in Person if raw_sql(s))[:] self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_5(self): # dynamic raw_sql content (2) x = 1 cond = raw_sql('p.id > $x') persons = select(p for p in Person if cond)[:] self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_6(self): # correct converter should be applied to raw_sql parameter type x = date(1990, 1, 1) persons = select(p for p in Person if raw_sql('p.dob < $x'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_7(self): # raw_sql argument may be complex expression (1) x = 10 y = 15 persons = select(p for p in Person if raw_sql('p.age > $(x + y)'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_8(self): # raw_sql argument may be complex expression (2) persons = select(p for p in Person if raw_sql('p.dob < $date.today()'))[:] self.assertEqual(set(persons), {Person[1], Person[2], Person[3]}) @db_session def test_9(self): # using raw_sql in the expression part of the generator names = select(raw_sql('UPPER(p.name)') for p in Person)[:] self.assertEqual(set(names), {'JOHN', 'MIKE', 'MARY'}) @db_session def test_10(self): # raw_sql does not know result type and cannot appy correct type converter automatically dates = select(raw_sql('(p.dob)') for p in Person).order_by(lambda: p.id)[:] self.assertEqual(dates, ['1985-01-01', '1983-05-20', '1995-02-15']) @db_session def test_11(self): # it is possible to specify raw_sql type manually dates = select(raw_sql('(p.dob)', result_type=date) for p in Person).order_by(lambda: p.id)[:] self.assertEqual(dates, [date(1985, 1, 1), date(1983, 5, 20), date(1995, 2, 15)]) @db_session def test_12(self): # raw_sql can be used in lambdas x = 25 persons = Person.select(lambda p: p.age > raw_sql('$x'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_13(self): # raw_sql in filter() x = 25 persons = select(p for p in Person).filter(lambda p: p.age > raw_sql('$x'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_14(self): # raw_sql in filter() without using lambda x = 25 persons = Person.select().filter(raw_sql('p.age > $x'))[:] self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_15(self): # several raw_sql expressions in a single query x = '123' y = 'John' persons = Person.select(lambda p: raw_sql("UPPER(p.name) || $x") == raw_sql("UPPER($y || '123')"))[:] self.assertEqual(set(persons), {Person[1]}) @db_session def test_16(self): # the same param name can be used several times with different types & values x = 10 y = 31 q = select(p for p in Person if p.age > x and p.age < raw_sql('$y')) x = date(1980, 1, 1) y = 'j' q = q.filter(lambda p: p.dob > x and p.name.startswith(raw_sql('UPPER($y)'))) persons = q[:] self.assertEqual(set(persons), {Person[1]}) @db_session def test_17(self): # raw_sql in order_by() section x = 9 persons = Person.select().order_by(lambda p: raw_sql('SUBSTR(p.dob, $x)'))[:] self.assertEqual(persons, [Person[1], Person[3], Person[2]]) @db_session def test_18(self): # raw_sql in order_by() section without using lambda x = 9 persons = Person.select().order_by(raw_sql('SUBSTR(p.dob, $x)'))[:] self.assertEqual(persons, [Person[1], Person[3], Person[2]]) @db_session @raises_exception(TranslationError, "Expression `raw_sql(p.name)` cannot be translated into SQL " "because raw SQL fragment will be different for each row") def test_19(self): # raw_sql argument cannot depend on iterator variables select(p for p in Person if raw_sql(p.name))[:] @db_session @raises_exception(ExprEvalError, "`raw_sql('p.dob < $x')` raises NameError: global name 'x' is not defined" if PYPY2 else "`raw_sql('p.dob < $x')` raises NameError: name 'x' is not defined") def test_20(self): # testing for situation where parameter variable is missing select(p for p in Person if raw_sql('p.dob < $x'))[:] @db_session def test_21(self): x = None persons = select(p for p in Person if p.id == 1 and raw_sql('$x') is None)[:] self.assertEqual(persons, [Person[1]]) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_relations_m2m.py0000666000000000000000000002535600000000000020043 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * class TestManyToManyNonComposite(unittest.TestCase): def setUp(self): db = Database('sqlite', ':memory:') class Group(db.Entity): number = PrimaryKey(int) subjects = Set("Subject") class Subject(db.Entity): name = PrimaryKey(str) groups = Set(Group) self.db = db self.Group = Group self.Subject = Subject self.db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=101) g2 = Group(number=102) s1 = Subject(name='Subj1') s2 = Subject(name='Subj2') s3 = Subject(name='Subj3') s4 = Subject(name='Subj4') g1.subjects = [ s1, s2 ] def test_1(self): schema = self.db.schema m2m_table_name = 'Group_Subject' self.assertIn(m2m_table_name, schema.tables) m2m_table = schema.tables[m2m_table_name] fkeys = list(m2m_table.foreign_keys.values()) self.assertEqual(len(fkeys), 2) for fk in fkeys: self.assertEqual(fk.on_delete, 'CASCADE') def test_2(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') g.subjects.add(s) with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_3(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') g.subjects.add(s) with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2', 'Subj3']) def test_4(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') g.subjects.remove(s) with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_5(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj2') g.subjects.remove(s) with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1']) def test_5(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] g.subjects.remove([s1, s2]) g.subjects.add([s3, s4]) with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj3', 'Subj4']) self.assertEqual(Group[101].subjects, {Subject['Subj3'], Subject['Subj4']}) def test_7(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') g.subjects.add(s) g.subjects.remove(s) last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_8(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') g.subjects.remove(s) g.subjects.add(s) last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_9(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s1 = Subject.get(name='Subj1') s2 = Subject.get(name='Subj2') g.subjects.clear() g.subjects.add([s1, s2]) last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_10(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g2 = Group.get(number=102) s1 = Subject.get(name='Subj1') g2.subjects.add(s1) g2.subjects.clear() last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 102') self.assertEqual(db_subjects , []) def test_11(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] g.subjects = [ s2, s3 ] with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj2', 'Subj3']) def test_12(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] g.subjects.remove(s2) g.subjects = [ s1, s2 ] last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) def test_13(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] g.subjects.add(s3) g.subjects = [ s1, s2 ] last_sql = db.last_sql with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) @db_session def test_14(self): db, Group, Subject = self.db, self.Group, self.Subject g1 = Group[101] s1 = Subject['Subj1'] self.assertTrue(s1 in g1.subjects) group_setdata = g1._vals_[Group.subjects] self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, None) self.assertEqual(group_setdata.removed, None) subj_setdata = s1._vals_[Subject.groups] self.assertTrue(g1 in subj_setdata) self.assertEqual(subj_setdata.added, None) self.assertEqual(subj_setdata.removed, None) g1.subjects.remove(s1) self.assertTrue(s1 not in group_setdata) self.assertEqual(group_setdata.added, None) self.assertEqual(group_setdata.removed, {s1}) self.assertTrue(g1 not in subj_setdata) self.assertEqual(subj_setdata.added, None) self.assertEqual(subj_setdata.removed, {g1}) g1.subjects.add(s1) self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, set()) self.assertEqual(group_setdata.removed, set()) self.assertTrue(g1 in subj_setdata) self.assertEqual(subj_setdata.added, set()) self.assertEqual(subj_setdata.removed, set()) @db_session def test_15(self): db, Group, Subject = self.db, self.Group, self.Subject g = Group[101] e = g.subjects.is_empty() self.assertEqual(e, False) db._dblocal.last_sql = None e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) g = Group[102] e = g.subjects.is_empty() # should take SQL from the SQL cache self.assertEqual(e, True) db._dblocal.last_sql = None e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, True) self.assertEqual(db.last_sql, None) @db_session def test_16(self): db, Group = self.db, self.Group g = Group[101] c = len(g.subjects) self.assertEqual(c, 2) db._dblocal.last_sql = None e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) g = Group[102] c = len(g.subjects) self.assertEqual(c, 0) db._dblocal.last_sql = None e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, True) self.assertEqual(db.last_sql, None) @db_session def test_17(self): db, Group, Subject = self.db, self.Group, self.Subject g = Group[101] s1 = Subject['Subj1'] s3 = Subject['Subj3'] c = g.subjects.count() self.assertEqual(c, 2) db._dblocal.last_sql = None c = g.subjects.count() # should take count from the cache self.assertEqual(c, 2) self.assertEqual(db.last_sql, None) g.subjects.add(s3) db._dblocal.last_sql = None c = g.subjects.count() # should take modified count from the cache self.assertEqual(c, 3) self.assertEqual(db.last_sql, None) g.subjects.remove(s1) db._dblocal.last_sql = None c = g.subjects.count() # should take modified count from the cache self.assertEqual(c, 2) self.assertEqual(db.last_sql, None) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_relations_one2many.py0000666000000000000000000002153700000000000021075 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestOneToManyRequired(unittest.TestCase): def setUp(self): db = Database('sqlite', ':memory:', create_db=True) class Student(db.Entity): id = PrimaryKey(int) name = Required(unicode) group = Required('Group') class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) self.db = db self.Group = Group self.Student = Student db.generate_mapping(create_tables=True) with db_session: g101 = Group(number=101) g102 = Group(number=102) g103 = Group(number=103) s1 = Student(id=1, name='Student1', group=g101) s2 = Student(id=2, name='Student2', group=g101) s3 = Student(id=3, name='Student3', group=g102) s4 = Student(id=4, name='Student3', group=g102) db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() @raises_exception(ValueError, 'Attribute Student[1].group is required') def test_1(self): self.Student[1].group = None def test_2(self): Student, Group = self.Student, self.Group s1 = Student[1] g = Group[101] g.students.remove(s1) self.assertEqual(s1._status_, 'marked_to_delete') def test_3(self): Student, Group = self.Student, self.Group s1, s2 = Student[1], Student[2] g = Group[101] g.students.clear() self.assertEqual(s1._status_, 'marked_to_delete') self.assertEqual(s2._status_, 'marked_to_delete') self.assertEqual(set(g.students), set()) def test_4(self): Student, Group = self.Student, self.Group s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1._status_, 'marked_to_delete') self.assertEqual(s2._status_, 'marked_to_delete') @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student g = Group[101] g.students.remove(None) @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_7(self): Group = self.Group g104 = Group(number=104, students=None) def test_8(self): db, Group, Student = self.db, self.Group, self.Student g = Group[101] s3 = Student[3] # s3 is loaded now db._dblocal.last_sql = None g.students.add(s3) # Group.students.load should not attempt to load s3 from db self.assertEqual(db.last_sql, None) def test_9(self): db, Group, Student = self.db, self.Group, self.Student g = Group[101] e = g.students.is_empty() self.assertEqual(e, False) db._dblocal.last_sql = None e = g.students.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) g = Group[103] e = g.students.is_empty() # should take SQL from the SQL cache self.assertEqual(e, True) db._dblocal.last_sql = None e = g.students.is_empty() # should take result from the cache self.assertEqual(e, True) self.assertEqual(db.last_sql, None) def test_10(self): db, Group = self.db, self.Group g = Group[101] c = len(g.students) self.assertEqual(c, 2) db._dblocal.last_sql = None e = g.students.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) g = Group[102] c = g.students.count() self.assertEqual(c, 2) db._dblocal.last_sql = None e = g.students.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) g = Group[103] c = len(g.students) self.assertEqual(c, 0) db._dblocal.last_sql = None e = g.students.is_empty() # should take result from the cache self.assertEqual(e, True) self.assertEqual(db.last_sql, None) def test_11(self): db, Group, Student = self.db, self.Group, self.Student g = Group[101] s3 = Student[3] c = g.students.count() self.assertEqual(c, 2) db._dblocal.last_sql = None c = g.students.count() # should take count from the cache self.assertEqual(c, 2) self.assertEqual(db.last_sql, None) g.students.add(s3) c = g.students.count() # should take modified count from the cache self.assertEqual(c, 3) self.assertEqual(db.last_sql, None) g2 = Group[102] c = g2.students.count() # should send query to the database self.assertEqual(c, 1) self.assertTrue(db.last_sql is not None) def test_12(self): Group, Student = self.Group, self.Student g = Group[101] s1 = Student[1] self.assertEqual(s1._rbits_, 0) self.assertTrue(s1 in g.students) self.assertEqual(s1._rbits_, Student._bits_[Student.group]) s3 = Student[3] self.assertEqual(s3._rbits_, 0) self.assertTrue(s3 not in g.students) self.assertEqual(s3._rbits_, Student._bits_[Student.group]) s5 = Student(id=5, name='Student5', group=g) self.assertEqual(s5._rbits_, None) self.assertTrue(s5 in g.students) self.assertEqual(s5._rbits_, None) class TestOneToManyOptional(unittest.TestCase): def setUp(self): db = Database('sqlite', ':memory:', create_db=True) class Student(db.Entity): id = PrimaryKey(int) name = Required(unicode) group = Optional('Group') class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) self.db = db self.Group = Group self.Student = Student db.generate_mapping(create_tables=True) with db_session: g101 = Group(number=101) g102 = Group(number=102) g103 = Group(number=103) s1 = Student(id=1, name='Student1', group=g101) s2 = Student(id=2, name='Student2', group=g101) s3 = Student(id=3, name='Student3', group=g102) s4 = Student(id=4, name='Student3', group=g102) db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test_1(self): self.Student[1].group = None self.assertEqual(set(self.Group[101].students), {self.Student[2]}) def test_2(self): Student, Group = self.Student, self.Group s1 = Student[1] g = Group[101] g.students.remove(s1) self.assertEqual(s1.group, None) def test_3(self): Student, Group = self.Student, self.Group s1, s2 = Student[1], Student[2] g = Group[101] g.students.clear() self.assertEqual(s1.group, None) self.assertEqual(s2.group, None) self.assertEqual(set(g.students), set()) def test_4(self): Student, Group = self.Student, self.Group s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1.group, None) self.assertEqual(s2.group, None) @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student g = Group[101] g.students.remove(None) @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_7(self): Group = self.Group g104 = Group(number=104, students=None) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571702850.0 pony-0.7.11/pony/orm/tests/test_relations_one2one1.py0000666000000000000000000001177600000000000020777 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') class Female(db.Entity): name = Required(unicode) husband = Optional('Male') db.generate_mapping(create_tables=True) class TestOneToOne(unittest.TestCase): def setUp(self): with db_session: db.execute('delete from male') db.execute('delete from female') db.insert(Female, id=1, name='F1') db.insert(Female, id=2, name='F2') db.insert(Female, id=3, name='F3') db.insert(Male, id=1, name='M1', wife=1) db.insert(Male, id=2, name='M2', wife=2) db.insert(Male, id=3, name='M3', wife=None) @db_session def test_1(self): Male[3].wife = Female[3] self.assertEqual(Male[3]._vals_[Male.wife], Female[3]) self.assertEqual(Female[3]._vals_[Female.husband], Male[3]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([1, 2, 3], wives) @db_session def test_2(self): Female[3].husband = Male[3] self.assertEqual(Male[3]._vals_[Male.wife], Female[3]) self.assertEqual(Female[3]._vals_[Female.husband], Male[3]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([1, 2, 3], wives) @db_session def test_3(self): Male[1].wife = None self.assertEqual(Male[1]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([None, 2, None], wives) @db_session def test_4(self): Female[1].husband = None self.assertEqual(Male[1]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([None, 2, None], wives) @db_session def test_5(self): Male[1].wife = Female[3] self.assertEqual(Male[1]._vals_[Male.wife], Female[3]) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[3]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([3, 2, None], wives) @db_session def test_6(self): Female[3].husband = Male[1] self.assertEqual(Male[1]._vals_[Male.wife], Female[3]) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[3]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([3, 2, None], wives) @db_session def test_7(self): Male[1].wife = Female[2] self.assertEqual(Male[1]._vals_[Male.wife], Female[2]) self.assertEqual(Male[2]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[2]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([2, None, None], wives) @db_session def test_8(self): Female[2].husband = Male[1] self.assertEqual(Male[1]._vals_[Male.wife], Female[2]) self.assertEqual(Male[2]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[2]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from Male order by Male.id') self.assertEqual([2, None, None], wives) @db_session def test_9(self): f4 = Female(name='F4') m4 = Male(name='M4', wife=f4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') @db_session def test_10(self): m4 = Male(name='M4') f4 = Female(name='F4', husband=m4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') @db_session def test_to_dict_1(self): m = Male[1] d = m.to_dict() self.assertEqual(d, dict(id=1, name='M1', wife=1)) @db_session def test_to_dict_2(self): m = Male[3] d = m.to_dict() self.assertEqual(d, dict(id=3, name='M3', wife=None)) @db_session def test_to_dict_3(self): f = Female[1] d = f.to_dict() self.assertEqual(d, dict(id=1, name='F1', husband=1)) @db_session def test_to_dict_4(self): f = Female[3] d = f.to_dict() self.assertEqual(d, dict(id=3, name='F3', husband=None)) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1570625601.0 pony-0.7.11/pony/orm/tests/test_relations_one2one2.py0000666000000000000000000001326400000000000020772 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') class Female(db.Entity): name = Required(unicode) husband = Optional('Male', column='husband') db.generate_mapping(create_tables=True) class TestOneToOne2(unittest.TestCase): def setUp(self): with db_session: db.execute('update female set husband=null') db.execute('update male set wife=null') db.execute('delete from male') db.execute('delete from female') db.insert(Female, id=1, name='F1') db.insert(Female, id=2, name='F2') db.insert(Female, id=3, name='F3') db.insert(Male, id=1, name='M1', wife=1) db.insert(Male, id=2, name='M2', wife=2) db.insert(Male, id=3, name='M3', wife=None) db.execute('update female set husband=1 where id=1') db.execute('update female set husband=2 where id=2') db_session.__enter__() def tearDown(self): db_session.__exit__() def test_1(self): Male[3].wife = Female[3] self.assertEqual(Male[3]._vals_[Male.wife], Female[3]) self.assertEqual(Female[3]._vals_[Female.husband], Male[3]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([1, 2, 3], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([1, 2, 3], husbands) def test_2(self): Female[3].husband = Male[3] self.assertEqual(Male[3]._vals_[Male.wife], Female[3]) self.assertEqual(Female[3]._vals_[Female.husband], Male[3]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([1, 2, 3], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([1, 2, 3], husbands) def test_3(self): Male[1].wife = None self.assertEqual(Male[1]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) commit() wives = db.select('wife from male order by male.id') self.assertEqual([None, 2, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 2, None], husbands) def test_4(self): Female[1].husband = None self.assertEqual(Male[1]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) commit() wives = db.select('wife from male order by male.id') self.assertEqual([None, 2, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 2, None], husbands) def test_5(self): Male[1].wife = Female[3] self.assertEqual(Male[1]._vals_[Male.wife], Female[3]) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[3]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([3, 2, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 2, 1], husbands) def test_6(self): Female[3].husband = Male[1] self.assertEqual(Male[1]._vals_[Male.wife], Female[3]) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[3]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([3, 2, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 2, 1], husbands) def test_7(self): Male[1].wife = Female[2] self.assertEqual(Male[1]._vals_[Male.wife], Female[2]) self.assertEqual(Male[2]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[2]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([2, None, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 1, None], husbands) def test_8(self): Female[2].husband = Male[1] self.assertEqual(Male[1]._vals_[Male.wife], Female[2]) self.assertEqual(Male[2]._vals_[Male.wife], None) self.assertEqual(Female[1]._vals_[Female.husband], None) self.assertEqual(Female[2]._vals_[Female.husband], Male[1]) commit() wives = db.select('wife from male order by male.id') self.assertEqual([2, None, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 1, None], husbands) @raises_exception(UnrepeatableReadError, 'Multiple Male objects linked with the same Female[1] object. ' 'Maybe Female.husband attribute should be Set instead of Optional') def test_9(self): db.execute('update female set husband = 3 where id = 1') m1 = Male[1] f1 = m1.wife f1.name def test_10(self): db.execute('update female set husband = 3 where id = 1') m1 = Male[1] f1 = Female[1] f1.name self.assertTrue(Male.wife not in m1._vals_) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_relations_one2one3.py0000666000000000000000000000554200000000000020773 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestOneToOne3(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(unicode) passport = Optional("Passport", cascade_delete=True) class Passport(self.db.Entity): code = Required(unicode) person = Required("Person") self.db.generate_mapping(create_tables=True) with db_session: p1 = Person(name='John') Passport(code='123', person=p1) def tearDown(self): self.db = None @db_session def test_1(self): obj = select(p for p in self.db.Person if p.passport.id).first() self.assertEqual(obj.name, 'John') self.assertEqual(obj.passport.code, '123') @db_session def test_2(self): select(p for p in self.db.Person if p.passport is None)[:] sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" LEFT JOIN "Passport" "passport" ON "p"."id" = "passport"."person" WHERE "passport"."id" IS NULL''') @db_session def test_3(self): select(p for p in self.db.Person if not p.passport)[:] sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" LEFT JOIN "Passport" "passport" ON "p"."id" = "passport"."person" WHERE "passport"."id" IS NULL''') @db_session def test_4(self): select(p for p in self.db.Person if p.passport)[:] sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" LEFT JOIN "Passport" "passport" ON "p"."id" = "passport"."person" WHERE "passport"."id" IS NOT NULL''') @db_session def test_5(self): p = self.db.Person.get(name='John') p.delete() flush() sql = self.db.last_sql self.assertEqual(sql, '''DELETE FROM "Person" WHERE "id" = ? AND "name" = ?''') @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session def test_6(self): p = self.db.Person.get(name='John') self.db.Passport(code='456', person=p) @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session def test7(self): p2 = self.db.Person(name='Mike') pas2 = self.db.Passport(code='456', person=p2) commit() p1 = self.db.Person.get(name='John') pas2.person = p1 if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_relations_one2one4.py0000666000000000000000000000221200000000000020763 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import * class TestOneToOne4(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class Person(self.db.Entity): name = Required(unicode) passport = Optional("Passport") class Passport(self.db.Entity): code = Required(unicode) person = Required("Person") self.db.generate_mapping(create_tables=True) with db_session: p1 = Person(name='John') Passport(code='123', person=p1) def tearDown(self): self.db = None @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session def test1(self): p2 = self.db.Person(name='Mike') pas2 = self.db.Passport(code='456', person=p2) commit() p1 = self.db.Person.get(name='John') pas2.person = p1 if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571702850.0 pony-0.7.11/pony/orm/tests/test_relations_symmetric_m2m.py0000666000000000000000000000646700000000000022141 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) friends = Set('Person', reverse='friends') db.generate_mapping(create_tables=True) class TestSymmetricM2M(unittest.TestCase): def setUp(self): with db_session: for p in Person.select(): p.delete() with db_session: db.insert(Person, id=1, name='A') db.insert(Person, id=2, name='B') db.insert(Person, id=3, name='C') db.insert(Person, id=4, name='D') db.insert(Person, id=5, name='E') db.insert(Person.friends, person=1, person_2=2) db.insert(Person.friends, person=2, person_2=1) db.insert(Person.friends, person=1, person_2=3) db.insert(Person.friends, person=3, person_2=1) db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1a(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) self.assertEqual(set(p4.friends), {p1}) def test1b(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) self.assertEqual(set(p1.friends), {Person[2], Person[3], p4}) def test1c(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) commit() rows = db.select("* from person_friends order by person, person_2") self.assertEqual(rows, [(1,2), (1,3), (1,4), (2,1), (3,1), (4,1)]) def test2a(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) self.assertEqual(set(p1.friends), {Person[3]}) def test2b(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) self.assertEqual(set(Person[3].friends), {p1}) def test2c(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) self.assertEqual(set(p2.friends), set()) def test2d(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) commit() rows = db.select("* from person_friends order by person, person_2") self.assertEqual(rows, [(1,3), (3,1)]) def test3a(self): db.execute('delete from person_friends') db.insert(Person.friends, person=1, person_2=2) p1 = Person[1] p2 = Person[2] p2_friends = set(p2.friends) self.assertEqual(p2_friends, set()) try: p1_friends = set(p1.friends) except UnrepeatableReadError as e: self.assertEqual(e.args[0], "Phantom object Person[1] appeared in collection Person[2].friends") else: self.fail() def test3b(self): db.execute('delete from person_friends') db.insert(Person.friends, person=1, person_2=2) p1 = Person[1] p2 = Person[2] p1_friends = set(p1.friends) self.assertEqual(p1_friends, {p2}) try: p2_friends = set(p2.friends) except UnrepeatableReadError as e: self.assertEqual(e.args[0], "Phantom object Person[1] disappeared from collection Person[2].friends") else: self.fail() if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571702850.0 pony-0.7.11/pony/orm/tests/test_relations_symmetric_one2one.py0000666000000000000000000000617500000000000023007 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import * from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) spouse = Optional('Person', reverse='spouse') db.generate_mapping(create_tables=True) class TestSymmetricOne2One(unittest.TestCase): def setUp(self): with db_session: db.execute('update person set spouse=null') db.execute('delete from person') db.insert(Person, id=1, name='A') db.insert(Person, id=2, name='B', spouse=1) db.execute('update person set spouse=2 where id=1') db.insert(Person, id=3, name='C') db.insert(Person, id=4, name='D', spouse=3) db.execute('update person set spouse=4 where id=3') db.insert(Person, id=5, name='E', spouse=None) db_session.__enter__() def tearDown(self): db_session.__exit__() def test1(self): p1 = Person[1] p2 = Person[2] p5 = Person[5] p1.spouse = p5 commit() self.assertEqual(p1._vals_.get(Person.spouse), p5) self.assertEqual(p5._vals_.get(Person.spouse), p1) self.assertEqual(p2._vals_.get(Person.spouse), None) data = db.select('spouse from person order by id') self.assertEqual([5, None, 4, 3, 1], data) def test2(self): p1 = Person[1] p2 = Person[2] p1.spouse = None commit() self.assertEqual(p1._vals_.get(Person.spouse), None) self.assertEqual(p2._vals_.get(Person.spouse), None) data = db.select('spouse from person order by id') self.assertEqual([None, None, 4, 3, None], data) def test3(self): p1 = Person[1] p2 = Person[2] p3 = Person[3] p4 = Person[4] p1.spouse = p3 commit() self.assertEqual(p1._vals_.get(Person.spouse), p3) self.assertEqual(p2._vals_.get(Person.spouse), None) self.assertEqual(p3._vals_.get(Person.spouse), p1) self.assertEqual(p4._vals_.get(Person.spouse), None) data = db.select('spouse from person order by id') self.assertEqual([3, None, 1, None, None], data) def test4(self): persons = set(select(p for p in Person if p.spouse.name in ('B', 'D'))) self.assertEqual(persons, {Person[1], Person[3]}) @raises_exception(UnrepeatableReadError, 'Multiple Person objects linked with the same Person[2] object. ' 'Maybe Person.spouse attribute should be Set instead of Optional') def test5(self): db.execute('update person set spouse = 3 where id = 2') p1 = Person[1] p1.spouse p2 = Person[2] p2.name def test6(self): db.execute('update person set spouse = 3 where id = 2') p1 = Person[1] p2 = Person[2] p2.name p1.spouse self.assertEqual(p2._vals_.get(Person.spouse), p1) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_select_from_select_queries.py0000666000000000000000000004060000000000000022653 0ustar0000000000000000import unittest from pony.orm import * from pony.orm.tests.testutils import * from pony.py23compat import PYPY2 db = Database('sqlite', ':memory:') class Group(db.Entity): number = PrimaryKey(int) major = Required(str) students = Set('Student') class Student(db.Entity): first_name = Required(unicode) last_name = Required(unicode) age = Required(int) group = Required('Group') scholarship = Required(int, default=0) courses = Set('Course') @property def full_name(self): return self.first_name + ' ' + self.last_name class Course(db.Entity): name = Required(unicode) semester = Required(int) credits = Required(int) PrimaryKey(name, semester) students = Set('Student') db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=123, major='Computer Science') g2 = Group(number=456, major='Graphic Design') s1 = Student(id=1, first_name='John', last_name='Smith', age=20, group=g1, scholarship=0) s2 = Student(id=2, first_name='Alex', last_name='Green', age=24, group=g1, scholarship=100) s3 = Student(id=3, first_name='Mary', last_name='White', age=23, group=g1, scholarship=500) s4 = Student(id=4, first_name='John', last_name='Brown', age=20, group=g2, scholarship=400) s5 = Student(id=5, first_name='Bruce', last_name='Lee', age=22, group=g2, scholarship=300) c1 = Course(name='Math', semester=1, credits=10, students=[s1, s2, s4]) c2 = Course(name='Computer Science', semester=1, credits=20, students=[s2, s3]) c3 = Course(name='3D Modeling', semester=2, credits=15, students=[s3, s5]) class TestSelectFromSelect(unittest.TestCase): @db_session def test_1(self): # basic select from another query q = select(s for s in Student if s.scholarship > 0) q2 = select(s for s in q if s.scholarship < 500) self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) self.assertEqual(db.last_sql.count('SELECT'), 1) # single SELECT...FROM expression @db_session def test_2(self): # different variable name in the second query q = select(s for s in Student if s.scholarship > 0) q2 = select(x for x in q if x.scholarship < 500) self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_3(self): # selecting single column instead of entity in the second query q = select(s for s in Student if s.scholarship > 0) q2 = select(x.first_name for x in q if x.scholarship < 500) self.assertEqual(set(q2), {'Alex', 'Bruce', 'John'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_4(self): # selecting single column instead of entity in the first query q = select(s.first_name for s in Student if s.scholarship > 0) q2 = select(name for name in q if 'r' in name) self.assertEqual(set(q2), {'Bruce', 'Mary'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_5(self): # selecting hybrid property in the second query q = select(s for s in Student if s.scholarship > 0) q2 = select(x.full_name for x in q if x.scholarship < 500) self.assertEqual(set(q2), {'Alex Green', 'Bruce Lee', 'John Brown'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_6(self): # selecting hybrid property in the first query q = select(s.full_name for s in Student if s.scholarship < 500) q2 = select(x for x in q if x.startswith('J')) self.assertEqual(set(q2), {'John Smith', 'John Brown'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session @raises_exception(ExprEvalError, "`s.scholarship > 0` raises NameError: name 's' is not defined" if not PYPY2 else "`s.scholarship > 0` raises NameError: global name 's' is not defined") def test_7(self): # test access to original query var name from the new query q = select(s.first_name for s in Student if s.scholarship < 500) q2 = select(x for x in q if s.scholarship > 0) @db_session def test_8(self): # test using external name which is equal to original query var name class Dummy(object): scholarship = 1 s = Dummy() q = select(s.first_name for s in Student if s.scholarship < 500) q2 = select(x for x in q if s.scholarship > 0) self.assertEqual(set(q2), {'John', 'Alex', 'Bruce'}) @db_session def test_9(self): # test reusing variable name from the original query q = select(s for s in Student if s.scholarship > 0) q2 = select(x for x in q for s in Student if x.scholarship < s.scholarship) self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_10(self): # test .filter() q = select(s for s in Student if s.scholarship > 0) q2 = q.filter(lambda a: a.scholarship < 500) q3 = select(x for x in q2 if x.age > 20) q4 = q3.filter(lambda b: b.age < 24) self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_11(self): # test .where() q = select(s for s in Student if s.scholarship > 0) q2 = q.where(lambda s: s.scholarship < 500) q3 = select(x for x in q2 if x.age > 20) q4 = q3.where(lambda x: x.age < 24) # the name should be accessible in previous generator self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session @raises_exception(TypeError, 'Lambda argument `s` does not correspond to any variable in original query') def test_12(self): # test .where() q = select(s for s in Student if s.scholarship > 0) q2 = q.where(lambda s: s.scholarship < 500) q3 = select(x for x in q2 if x.age > 20) q4 = q3.where(lambda s: s.age < 24) @db_session def test_13(self): # select several expressions from the first query q = select((s.full_name, s.age) for s in Student if s.scholarship > 0) q2 = select(name for name, age in q if age < 24 and 'e' in name) self.assertEqual(set(q2), {'Mary White', 'Bruce Lee'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_14(self): # select from entity with composite key q = select(c for c in Course if c.semester == 1) q2 = select(x.name for x in q if x.name.startswith('M')) self.assertEqual(set(q2), {'Math'}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_15(self): # SELECT ... FROM (SELECT alias.* FROM ... q = left_join(s for g in Group for s in g.students if g.number == 123 and s.scholarship > 0) q2 = select(x.full_name for x in q if x.scholarship > 100) self.assertEqual(set(q2), {'Mary White'}) self.assertEqual(db.last_sql.count('SELECT'), 2) self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) self.assertTrue('*' in db.last_sql) @db_session def test_16(self): # SELECT ... FROM (grouped-query) q = select(g for g in Group if count(g.students) > 2) q2 = select(x.number for x in q) self.assertEqual(set(q2), {123}) self.assertEqual(db.last_sql.count('SELECT'), 2) self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) self.assertEqual(db.last_sql.count('GROUP BY'), 1) self.assertEqual(db.last_sql.count('HAVING'), 1) self.assertTrue('WHERE' not in db.last_sql) @db_session def test_17(self): # SELECT ... FROM (grouped-query), t1 WHERE ... q = select(g for g in Group if count(g.students) > 2) q2 = select(x.major for x in q) self.assertEqual(set(q2), {'Computer Science'}) self.assertEqual(db.last_sql.count('SELECT'), 2) self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) self.assertEqual(db.last_sql.count('GROUP BY'), 1) self.assertEqual(db.last_sql.count('HAVING'), 1) @db_session def test_18(self): # SELECT ... FROM (grouped-query returns composite keys), t1 WHERE ... q = select((c, count(c.students)) for c in Course if c.semester == 1 and count(c.students) > 1) q2 = select((x.name, x.credits, y) for x, y in q if x.credits > 10 and y < 3) self.assertEqual(set(q2), {('Computer Science', 20, 2)}) self.assertEqual(db.last_sql.count('SELECT'), 2) self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) self.assertEqual(db.last_sql.count('GROUP BY'), 1) self.assertEqual(db.last_sql.count('HAVING'), 1) self.assertEqual(db.last_sql.count('WHERE'), 2) @db_session def test_19(self): # multiple for loops in the inner query q = select((g, s.first_name.lower()) for g in Group for s in g.students) q2 = select((g.major, n) for g, n in q if g.number == 123 and n[0] == 'j') self.assertEqual(set(q2), {('Computer Science', 'john')}) @db_session def test_20(self): # additional for loop with inlined subquery q = select((g, x.first_name.upper()) for g in Group for x in select(s for s in Student if s.age < 22) if x.group == g and g.number == 123 and x.first_name[0] == 'J') q2 = select(name for g, name in q if g.number == 123) self.assertEqual(set(q2), {'JOHN'}) @db_session def test_21(self): objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result q = select(s.first_name for s in Student if s not in objects) self.assertEqual(set(q), {'John', 'Alex'}) @db_session @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') def test_22(self): objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result q = select(s.first_name for s in objects) @db_session def test_23(self): q = select(s for s in Student) q2 = q.filter(lambda x: x.scholarship > 450) q3 = q2.where(lambda s: s.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_24(self): q = select(s for s in Student) q2 = q.where(lambda s: s.scholarship > 450) q3 = q2.filter(lambda x: x.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_25(self): q = Student.select().filter(lambda x: x.scholarship > 450) q2 = select(s for s in q) q3 = q2.where(lambda s: s.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_26(self): q = Student.select().filter(lambda x: x.scholarship > 450) q2 = q.where(lambda s: s.scholarship < 520) q3 = select(s for s in q2) self.assertEqual(set(q3), {Student[3]}) @db_session def test_27(self): q = Student.select().where(lambda s: s.scholarship > 450) q2 = select(s for s in q) q3 = q2.filter(lambda x: x.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_28(self): q = Student.select().where(lambda s: s.scholarship > 450) q2 = q.filter(lambda x: x.scholarship < 520) q3 = select(s for s in q2) self.assertEqual(set(q3), {Student[3]}) @db_session def test_29(self): q = select(s for s in Student) q2 = q.where(lambda s: s.scholarship > 450) q3 = q2.where(lambda s: s.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_30(self): q = select(s for s in Student) q2 = q.filter(lambda x: x.scholarship > 450) q3 = q2.filter(lambda z: z.scholarship < 520) self.assertEqual(set(q3), {Student[3]}) @db_session def test_31(self): q = select(s for s in Student).order_by(lambda s: s.scholarship) q2 = q.where(lambda s: s.scholarship > 450) self.assertEqual(set(q2), {Student[3]}) @db_session def test_32(self): q = select(s for s in Student).order_by(lambda s: s.scholarship) q2 = q.filter(lambda z: z.scholarship > 450) self.assertEqual(set(q2), {Student[3]}) @db_session def test_33(self): q = select(s for s in Student).sort_by(lambda x: x.scholarship) q2 = q.where(lambda s: s.scholarship > 450) self.assertEqual(set(q2), {Student[3]}) @db_session def test_34(self): q = select(s for s in Student).sort_by(lambda x: x.scholarship) q2 = q.filter(lambda s: s.scholarship > 450) self.assertEqual(set(q2), {Student[3]}) @db_session def test_35(self): q = select(s for s in Student if s.scholarship > 0) q2 = select(s.id for s in Student if s not in q) self.assertEqual(set(q2), {1}) self.assertEqual(db.last_sql.count('SELECT'), 2) @db_session def test_36(self): q = select(s for s in Student if s.scholarship > 0) q2 = select(s.id for s in Student if s not in q[:]) self.assertEqual(set(q2), {1}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_37(self): q = select(s.last_name for s in Student if s.scholarship > 0) q2 = select(s.id for s in Student if s.last_name not in q) self.assertEqual(set(q2), {1}) self.assertEqual(db.last_sql.count('SELECT'), 2) @db_session def test_38(self): q = select(s.last_name for s in Student if s.scholarship > 0) q2 = select(s.id for s in Student if s.last_name not in q[:]) self.assertEqual(set(q2), {1}) self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session def test_39(self): q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q) self.assertEqual(set(q2), {1}) self.assertTrue(db.last_sql.count('SELECT') > 1) # @db_session # def test_40(self): # TODO # q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) # q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q[:]) # self.assertEqual(set(q2), {1}) # self.assertTrue(db.last_sql.count('SELECT'), 1) @db_session def test_41(self): def f1(): x = 21 return select(s for s in Student if s.age > x) def f2(q): x = 23 return select(s.last_name for s in Student if s.age < x and s in q) q = f1() q2 = f2(q) self.assertEqual(set(q2), {'Lee'}) @db_session def test_42(self): q = select(s for s in Student if s.scholarship > 0) q2 = select(g for g in Group if g.major == 'Computer Science')[:] q3 = select(s.first_name for s in q if s.group in q2) self.assertEqual(set(q3), {'Alex', 'Mary'}) @db_session def test_43(self): q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) q2 = select(s.first_name for s in Student if s in q) self.assertEqual(set(q2), {'John', 'Bruce'}) @db_session def test_44(self): q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) q2 = select(s.first_name for s in q) self.assertEqual(set(q2), {'Bruce', 'John', 'Mary'}) @db_session def test_45(self): q = select(s for s in Student).order_by(Student.first_name, Student.id).limit(3, offset=1) q2 = select(s for s in q if s.age > 18).limit(2, offset=1) q3 = select(s.last_name for s in q2).limit(2, offset=1) self.assertEqual(set(q3), {'Brown'}) @db_session def test_46(self): q = select((c, count(c.students)) for c in Course).order_by(-2).limit(2) q2 = select((c.name, c.credits, m) for c, m in q).limit(1, offset=1) self.assertEqual(set(q2), {('3D Modeling', 15, 2)}) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1568899785.0 pony-0.7.11/pony/orm/tests/test_show.py0000666000000000000000000000377000000000000016244 0ustar0000000000000000from pony.py23compat import StringIO import sys, unittest from decimal import Decimal from datetime import date from pony.orm import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) gpa = Optional(Decimal, 3, 1) dob = Optional(date) group = Required('Group') courses = Set('Course') biography = Optional(LongUnicode) class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) class Course(db.Entity): name = Required(unicode, unique=True) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=1) g2 = Group(number=2) c1 = Course(name='Math') c2 = Course(name='Physics') c3 = Course(name='Computer Science') Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) normal_stdout = sys.stdout class TestShow(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() sys.stdout = StringIO() def tearDown(self): sys.stdout = normal_stdout rollback() db_session.__exit__() def test_1(self): Student.select().show() self.assertEqual('\n'+sys.stdout.getvalue().replace(' ', '~'), ''' id|name|scholarship|gpa|dob~~~~~~~|group~~~ --+----+-----------+---+----------+-------- 1~|S1~~|None~~~~~~~|3.1|None~~~~~~|Group[1] 2~|S2~~|100~~~~~~~~|3.2|2000-01-01|Group[1] 3~|S3~~|200~~~~~~~~|3.3|2001-01-02|Group[1] ''') def test_2(self): Group.select().show() self.assertEqual('\n'+sys.stdout.getvalue().replace(' ', '~'), ''' number ------ 1~~~~~ 2~~~~~ ''') if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/test_sqlbuilding_formatstyles.py0000666000000000000000000000564500000000000022420 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.sqlsymbols import * from pony.orm.sqlbuilding import SQLBuilder from pony.orm.dbapiprovider import DBAPIProvider from pony.orm.tests.testutils import TestPool class TestFormatStyles(unittest.TestCase): def setUp(self): self.key1 = 'KEY1' self.key2 = 'KEY2' self.provider = DBAPIProvider(pony_pool_mockup=TestPool(None)) self.ast = [ SELECT, [ ALL, [COLUMN, None, 'A']], [ FROM, [None, TABLE, 'T1']], [ WHERE, [ EQ, [COLUMN, None, 'B'], [ PARAM, self.key1 ] ], [ EQ, [COLUMN, None, 'C'], [ PARAM, self.key2 ] ], [ EQ, [COLUMN, None, 'D'], [ PARAM, self.key2 ] ], [ EQ, [COLUMN, None, 'E'], [ PARAM, self.key1 ] ] ] ] def test_qmark(self): self.provider.paramstyle = 'qmark' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = ?\n AND "C" = ?\n AND "D" = ?\n AND "E" = ?') self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_numeric(self): self.provider.paramstyle = 'numeric' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :1\n AND "C" = :2\n AND "D" = :2\n AND "E" = :1') self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_named(self): self.provider.paramstyle = 'named' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :p1\n AND "C" = :p2\n AND "D" = :p2\n AND "E" = :p1') self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_format(self): self.provider.paramstyle = 'format' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %s\n AND "C" = %s\n AND "D" = %s\n AND "E" = %s') self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_pyformat(self): self.provider.paramstyle = 'pyformat' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %(p1)s\n AND "C" = %(p2)s\n AND "D" = %(p2)s\n AND "E" = %(p1)s') self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_sqlbuilding_sqlast.py0000666000000000000000000000221100000000000021155 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from pony.orm.core import Database, db_session from pony.orm.sqlsymbols import * class TestSQLAST(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') with db_session: conn = self.db.get_connection() conn.executescript(""" create table if not exists T1( a integer primary key, b varchar(20) not null ); insert or ignore into T1 values(1, 'abc'); """) @db_session def test_alias(self): sql_ast = [SELECT, [ALL, [COLUMN, "Group", "a"]], [FROM, ["Group", TABLE, "T1" ]]] sql, adapter = self.db._ast2sql(sql_ast) cursor = self.db._exec_sql(sql) @db_session def test_alias2(self): sql_ast = [SELECT, [ALL, [COLUMN, None, "a"]], [FROM, [None, TABLE, "T1"]]] sql, adapter = self.db._ast2sql(sql_ast) cursor = self.db._exec_sql(sql) if __name__ == "__main__": unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_sqlite_str_functions.py0000666000000000000000000000346600000000000021547 0ustar0000000000000000# coding: utf-8 from __future__ import absolute_import, print_function, division from binascii import unhexlify import unittest from pony.orm.core import * from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) age = Optional(int) image = Optional(buffer) db.generate_mapping(create_tables=True) with db_session: p1 = Person(name='John', age=20, image=unhexlify('abcdef')) p2 = Person(name=u'Рван') # u'\u0418\u0432\u0430\u043d' class TestUnicode(unittest.TestCase): @db_session def test1(self): names = select(p.name for p in Person).order_by(lambda: p.id)[:] self.assertEqual(names, ['John', u'Рван']) @db_session def test2(self): names = select(p.name.upper() for p in Person).order_by(lambda: p.id)[:] self.assertEqual(names, ['JOHN', u'РР’РђРќ']) # u'\u0418\u0412\u0410\u041d' @db_session def test3(self): names = select(p.name.lower() for p in Person).order_by(lambda: p.id)[:] self.assertEqual(names, ['john', u'иван']) # u'\u0438\u0432\u0430\u043d' @db_session def test4(self): ages = db.select('select py_upper(age) from person') self.assertEqual(ages, ['20', None]) @db_session def test5(self): ages = db.select('select py_lower(age) from person') self.assertEqual(ages, ['20', None]) @db_session def test6(self): ages = db.select('select py_upper(image) from person') self.assertEqual(ages, [u'ABCDEF', None]) @db_session def test7(self): ages = db.select('select py_lower(image) from person') self.assertEqual(ages, [u'abcdef', None]) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/tests/test_time_parsing.py0000666000000000000000000001430300000000000017737 0ustar0000000000000000from __future__ import absolute_import, print_function, division import unittest from datetime import datetime, date, time from pony.orm.tests.testutils import raises_exception from pony.converting import str2time class TestTimeParsing(unittest.TestCase): def test_time_1(self): self.assertEqual(str2time('1:2'), time(1, 2)) self.assertEqual(str2time('01:02'), time(1, 2)) self.assertEqual(str2time('1:2:3'), time(1, 2, 3)) self.assertEqual(str2time('01:02:03'), time(1, 2, 3)) self.assertEqual(str2time('1:2:3.4'), time(1, 2, 3, 400000)) self.assertEqual(str2time('01:02:03.4'), time(1, 2, 3, 400000)) @raises_exception(ValueError, 'Unrecognized time format') def test_time_2(self): str2time('1:') @raises_exception(ValueError, 'Unrecognized time format') def test_time_3(self): str2time('1: 2') @raises_exception(ValueError, 'Unrecognized time format') def test_time_4(self): str2time('1:2:') @raises_exception(ValueError, 'Unrecognized time format') def test_time_5(self): str2time('1:2:3:') @raises_exception(ValueError, 'Unrecognized time format') def test_time_6(self): str2time('1:2:3.1234567') def test_time_7(self): self.assertEqual(str2time('1:33 am'), time(1, 33)) self.assertEqual(str2time('2:33 am'), time(2, 33)) self.assertEqual(str2time('11:33 am'), time(11, 33)) self.assertEqual(str2time('12:33 am'), time(0, 33)) def test_time_8(self): self.assertEqual(str2time('1:33 pm'), time(13, 33)) self.assertEqual(str2time('2:33 pm'), time(14, 33)) self.assertEqual(str2time('11:33 pm'), time(23, 33)) self.assertEqual(str2time('12:33 pm'), time(12, 33)) def test_time_9(self): self.assertEqual(str2time('1:33am'), time(1, 33)) self.assertEqual(str2time('1:33 am'), time(1, 33)) self.assertEqual(str2time('1:33 AM'), time(1, 33)) self.assertEqual(str2time('1:33 a.m'), time(1, 33)) self.assertEqual(str2time('1:33 A.M'), time(1, 33)) self.assertEqual(str2time('1:33 a.m.'), time(1, 33)) self.assertEqual(str2time('1:33 A.M.'), time(1, 33)) def test_time_10(self): self.assertEqual(str2time('1:33pm'), time(13, 33)) self.assertEqual(str2time('1:33 pm'), time(13, 33)) self.assertEqual(str2time('1:33 PM'), time(13, 33)) self.assertEqual(str2time('1:33 p.m'), time(13, 33)) self.assertEqual(str2time('1:33 P.M'), time(13, 33)) self.assertEqual(str2time('1:33 p.m.'), time(13, 33)) self.assertEqual(str2time('1:33 P.M.'), time(13, 33)) def test_time_11(self): self.assertEqual(str2time('12:34:56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12.34.56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12 34 56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h34m56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h 34m 56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12 h 34 m 56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h 34m 56.789s'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12 h 34 m 56.789 s'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h 34min 56.789'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h 34min 56.789sec'), time(12, 34, 56, 789000)) self.assertEqual(str2time('12h 34 min 56.789 sec'), time(12, 34, 56, 789000)) def test_time_12(self): self.assertEqual(str2time('12:34:56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12.34.56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12 34 56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h34m56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h 34m 56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12 h 34 m 56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h 34m 56.789s a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12 h 34 m 56.789 s a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h 34min 56.789 a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h 34min 56.789sec a.m.'), time(0, 34, 56, 789000)) self.assertEqual(str2time('12h 34 min 56.789 sec a.m.'), time(0, 34, 56, 789000)) def test_time_13(self): self.assertEqual(str2time('12:34'), time(12, 34)) self.assertEqual(str2time('12.34'), time(12, 34)) self.assertEqual(str2time('12 34'), time(12, 34)) self.assertEqual(str2time('12h34'), time(12, 34)) self.assertEqual(str2time('12h34m'), time(12, 34)) self.assertEqual(str2time('12h 34m'), time(12, 34)) self.assertEqual(str2time('12h34min'), time(12, 34)) self.assertEqual(str2time('12h 34min'), time(12, 34)) self.assertEqual(str2time('12 h 34 m'), time(12, 34)) self.assertEqual(str2time('12 h 34 min'), time(12, 34)) self.assertEqual(str2time('12u34'), time(12, 34)) # Belgium self.assertEqual(str2time('12h'), time(12)) self.assertEqual(str2time('12u'), time(12)) def test_time_14(self): self.assertEqual(str2time('12:34 am'), time(0, 34)) self.assertEqual(str2time('12.34 am'), time(0, 34)) self.assertEqual(str2time('12 34 am'), time(0, 34)) self.assertEqual(str2time('12h34 am'), time(0, 34)) self.assertEqual(str2time('12h34m am'), time(0, 34)) self.assertEqual(str2time('12h 34m am'), time(0, 34)) self.assertEqual(str2time('12h34min am'), time(0, 34)) self.assertEqual(str2time('12h 34min am'), time(0, 34)) self.assertEqual(str2time('12 h 34 m am'), time(0, 34)) self.assertEqual(str2time('12 h 34 min am'), time(0, 34)) self.assertEqual(str2time('12u34 am'), time(0, 34)) self.assertEqual(str2time('12h am'), time(0)) self.assertEqual(str2time('12u am'), time(0)) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_to_dict.py0000666000000000000000000001511500000000000016705 0ustar0000000000000000import unittest from decimal import Decimal from datetime import date from pony.orm import * from pony.orm.serialization import to_dict from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) gpa = Optional(Decimal, 3, 1) dob = Optional(date) group = Optional('Group') courses = Set('Course') biography = Optional(LongUnicode) class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) class Course(db.Entity): name = Required(unicode, unique=True) students = Set(Student) db.generate_mapping(create_tables=True) with db_session: g1 = Group(number=1) g2 = Group(number=2) c1 = Course(name='Math') c2 = Course(name='Physics') c3 = Course(name='Computer Science') Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) Student(id=4, name='S4') class TestObjectToDict(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): s1 = Student[1] d = s1.to_dict() self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1)) def test2(self): s1 = Student[1] d = s1.to_dict(related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=Group[1])) def test3(self): s1 = Student[1] d = s1.to_dict(with_collections=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, courses=[1, 2])) def test4(self): s1 = Student[1] d = s1.to_dict(with_collections=True, related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=Group[1], courses=[Course[1], Course[2]])) def test5(self): s1 = Student[1] d = s1.to_dict(with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, biography='some text')) def test6(self): s1 = Student[1] d = s1.to_dict(only=['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test7(self): s1 = Student[1] d = s1.to_dict(['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test8(self): s1 = Student[1] d = s1.to_dict(only='id, name, group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test9(self): s1 = Student[1] d = s1.to_dict(only='id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test10(self): s1 = Student[1] d = s1.to_dict('id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test11(self): s1 = Student[1] d = s1.to_dict('id name x group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test12(self): s1 = Student[1] d = s1.to_dict('id name group', related_objects=True) self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) def test13(self): s1 = Student[1] d = s1.to_dict(exclude=['dob', 'gpa', 'scholarship']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test14(self): s1 = Student[1] d = s1.to_dict(exclude='dob, gpa, scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test15(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test16(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa x scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test17(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship', related_objects=True) self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) def test18(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1, biography='some text')) def test19(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship biography', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test20(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1, courses=[1, 2])) def test21(self): s1 = Student[1] d = s1.to_dict(exclude='dob gpa scholarship courses', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test22(self): s1 = Student[1] d = s1.to_dict(only='id name group', exclude='dob group') self.assertEqual(d, dict(id=1, name='S1')) def test23(self): s1 = Student[1] d = s1.to_dict(only='id name group', exclude='dob group', with_collections=True, with_lazy=True) self.assertEqual(d, dict(id=1, name='S1')) def test24(self): c = Course(name='New Course') d = c.to_dict() # should do flush and get c.id from the database self.assertEqual(d, dict(id=4, name='New Course')) class TestSerializationToDict(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() def tearDown(self): rollback() db_session.__exit__() def test1(self): s4 = Student[4] self.assertEqual(s4.group, None) d = to_dict(s4) self.assertEqual(d, dict(Student={ 4 : dict(id=4, name='S4', group=None, dob=None, gpa=None, scholarship=None, courses=[]) })) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/tests/test_tracked_value.py0000666000000000000000000000347400000000000020076 0ustar0000000000000000import unittest from pony.orm.ormtypes import TrackedList, TrackedDict, TrackedValue class Object(object): def __init__(self): self.on_attr_changed = None def _attr_changed_(self, attr): if self.on_attr_changed is not None: self.on_attr_changed(attr) class Attr(object): pass class TestTrackedValue(unittest.TestCase): def test_make(self): obj = Object() attr = Attr() value = {'items': ['one', 'two', 'three']} tracked_value = TrackedValue.make(obj, attr, value) self.assertEqual(type(tracked_value), TrackedDict) self.assertEqual(type(tracked_value['items']), TrackedList) def test_dict_setitem(self): obj = Object() attr = Attr() value = {'items': ['one', 'two', 'three']} tracked_value = TrackedValue.make(obj, attr, value) log = [] obj.on_attr_changed = lambda x: log.append(x) tracked_value['items'] = [1, 2, 3] self.assertEqual(log, [attr]) def test_list_append(self): obj = Object() attr = Attr() value = {'items': ['one', 'two', 'three']} tracked_value = TrackedValue.make(obj, attr, value) log = [] obj.on_attr_changed = lambda x: log.append(x) tracked_value['items'].append('four') self.assertEqual(log, [attr]) def test_list_setslice(self): obj = Object() attr = Attr() value = {'items': ['one', 'two', 'three']} tracked_value = TrackedValue.make(obj, attr, value) log = [] obj.on_attr_changed = lambda x: log.append(x) tracked_value['items'][1:2] = ['a', 'b', 'c'] self.assertEqual(log, [attr]) self.assertEqual(tracked_value['items'], ['one', 'a', 'b', 'c', 'three']) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1537636029.0 pony-0.7.11/pony/orm/tests/test_transaction_lock.py0000666000000000000000000000226500000000000020617 0ustar0000000000000000 import unittest from pony.orm import * db = Database() class TestPost(db.Entity): category = Optional('TestCategory') name = Optional(str, default='Noname') class TestCategory(db.Entity): posts = Set(TestPost) db.bind('sqlite', ':memory:') db.generate_mapping(create_tables=True) with db_session: post = TestPost() class TransactionLockTestCase(unittest.TestCase): __call__ = db_session(unittest.TestCase.__call__) def tearDown(self): rollback() def test_create(self): p = TestPost() p.flush() cache = db._get_cache() self.assertEqual(cache.immediate, True) self.assertEqual(cache.in_transaction, True) def test_update(self): p = TestPost[post.id] p.name = 'Trash' p.flush() cache = db._get_cache() self.assertEqual(cache.immediate, True) self.assertEqual(cache.in_transaction, True) def test_delete(self): p = TestPost[post.id] p.delete() flush() cache = db._get_cache() self.assertEqual(cache.immediate, True) self.assertEqual(cache.in_transaction, True) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862664.0 pony-0.7.11/pony/orm/tests/test_validate.py0000666000000000000000000000451200000000000017050 0ustar0000000000000000import unittest, warnings from pony.orm import * from pony.orm import core from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(str) tel = Optional(str) db.generate_mapping(check_tables=False) with db_session: db.execute(""" create table Person( id int primary key, name text, tel text ) """) class TestValidate(unittest.TestCase): @db_session def setUp(self): db.execute('delete from Person') registry = getattr(core, '__warningregistry__', {}) for key in list(registry): if type(key) is not tuple: continue text, category, lineno = key if category is DatabaseContainsIncorrectEmptyValue: del registry[key] @db_session def test_1a(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) db.insert('Person', id=1, name='', tel='111') p = Person.get(id=1) self.assertEqual(p.name, '') @raises_exception(DatabaseContainsIncorrectEmptyValue, 'Database contains empty string for required attribute Person.name') @db_session def test_1b(self): with warnings.catch_warnings(): warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) db.insert('Person', id=1, name='', tel='111') p = Person.get(id=1) @db_session def test_2a(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) db.insert('Person', id=1, name=None, tel='111') p = Person.get(id=1) self.assertEqual(p.name, None) @raises_exception(DatabaseContainsIncorrectEmptyValue, 'Database contains NULL for required attribute Person.name') @db_session def test_2b(self): with warnings.catch_warnings(): warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) db.insert('Person', id=1, name=None, tel='111') p = Person.get(id=1) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/orm/tests/test_volatile.py0000666000000000000000000000246100000000000017077 0ustar0000000000000000import sys, unittest from pony.orm import * from pony.orm.tests.testutils import * class TestVolatile(unittest.TestCase): def setUp(self): db = self.db = Database('sqlite', ':memory:') class Item(self.db.Entity): name = Required(str) index = Required(int, volatile=True) db.generate_mapping(create_tables=True) with db_session: Item(name='A', index=1) Item(name='B', index=2) Item(name='C', index=3) @db_session def test_1(self): db = self.db Item = db.Item db.execute('update "Item" set "index" = "index" + 1') items = Item.select(lambda item: item.index > 0).order_by(Item.id)[:] a, b, c = items self.assertEqual(a.index, 2) self.assertEqual(b.index, 3) self.assertEqual(c.index, 4) c.index = 1 items = Item.select()[:] # force re-read from the database self.assertEqual(c.index, 1) self.assertEqual(a.index, 2) self.assertEqual(b.index, 3) @db_session def test_2(self): Item = self.db.Item item = Item[1] item.name = 'X' item.flush() self.assertEqual(item.index, 1) if __name__ == '__main__': unittest.main() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/orm/tests/testutils.py0000666000000000000000000001217100000000000016260 0ustar0000000000000000from __future__ import absolute_import, print_function, division from pony.py23compat import basestring import re from contextlib import contextmanager from pony.orm.core import Database from pony.utils import import_module def test_exception_msg(test_case, exc_msg, test_msg=None): if test_msg is None: return error_template = "incorrect exception message. expected '%s', got '%s'" error_msg = error_template % (test_msg, exc_msg) assert test_msg not in ('...', '....', '.....', '......') if '...' not in test_msg: test_case.assertEqual(test_msg, exc_msg, error_msg) else: pattern = ''.join( '[%s]' % char for char in test_msg.replace('\\', '\\\\') .replace('[', '\\[') ).replace('[.][.][.]', '.*') regex = re.compile(pattern) if not regex.match(exc_msg): test_case.fail(error_template % (test_msg, exc_msg)) def raises_exception(exc_class, test_msg=None): def decorator(func): def wrapper(test_case, *args, **kwargs): try: func(test_case, *args, **kwargs) test_case.fail("Expected exception %s wasn't raised" % exc_class.__name__) except exc_class as e: if not e.args: test_case.assertEqual(test_msg, None) else: test_exception_msg(test_case, str(e), test_msg) wrapper.__name__ = func.__name__ return wrapper return decorator @contextmanager def raises_if(test_case, cond, exc_class, test_msg=None): try: yield except exc_class as e: test_case.assertTrue(cond) test_exception_msg(test_case, str(e), test_msg) else: test_case.assertFalse(cond, "Expected exception %s wasn't raised" % exc_class.__name__) def flatten(x): result = [] for el in x: if hasattr(el, "__iter__") and not isinstance(el, basestring): result.extend(flatten(el)) else: result.append(el) return result class TestConnection(object): def __init__(con, database): con.database = database if database and database.provider_name == 'postgres': con.autocommit = True def commit(con): pass def rollback(con): pass def cursor(con): return test_cursor class TestCursor(object): def __init__(cursor): cursor.description = [] cursor.rowcount = 0 def execute(cursor, sql, args=None): pass def fetchone(cursor): return None def fetchmany(cursor, size): return [] def fetchall(cursor): return [] test_cursor = TestCursor() class TestPool(object): def __init__(pool, database): pool.database = database def connect(pool): return TestConnection(pool.database), True def release(pool, con): pass def drop(pool, con): pass def disconnect(pool): pass class TestDatabase(Database): real_provider_name = None raw_server_version = None sql = None def bind(self, provider, *args, **kwargs): provider_name = provider assert isinstance(provider_name, basestring) if self.real_provider_name is not None: provider_name = self.real_provider_name self.provider_name = provider_name provider_module = import_module('pony.orm.dbproviders.' + provider_name) provider_cls = provider_module.provider_cls raw_server_version = self.raw_server_version if raw_server_version is None: if provider_name == 'sqlite': raw_server_version = '3.7.17' elif provider_name in ('postgres', 'pygresql'): raw_server_version = '9.2' elif provider_name == 'oracle': raw_server_version = '11.2.0.2.0' elif provider_name == 'mysql': raw_server_version = '5.6.11' else: assert False, provider_name # pragma: no cover t = [ int(component) for component in raw_server_version.split('.') ] if len(t) == 2: t.append(0) server_version = tuple(t) if provider_name in ('postgres', 'pygresql'): server_version = int('%d%02d%02d' % server_version) class TestProvider(provider_cls): json1_available = False # for SQLite def inspect_connection(provider, connection): pass TestProvider.server_version = server_version kwargs['pony_check_connection'] = False kwargs['pony_pool_mockup'] = TestPool(self) Database.bind(self, TestProvider, *args, **kwargs) def _execute(database, sql, globals, locals, frame_depth): assert False # pragma: no cover def _exec_sql(database, sql, arguments=None, returning_id=False): assert type(arguments) is not list and not returning_id database.sql = sql database.arguments = arguments return test_cursor def generate_mapping(database, filename=None, check_tables=True, create_tables=False): return Database.generate_mapping(database, filename, create_tables=False) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571663610.0 pony-0.7.11/pony/orm/tmp_cmp.py0000666000000000000000000000057500000000000014522 0ustar0000000000000000def tuple_cmp(a, b, op): res = [] for i in range(min(len(a), len(b))): cur = '' tmp = [] for j in range(i): tmp.append('%r == %r' % (a[j], b[j])) if tmp: cur += ' and '.join(tmp) cur += ' and ' cur += '%r %s %r' % (a[i], op, b[i]) res.append(cur) return ' OR\n'.join(res)././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1533647435.0 pony-0.7.11/pony/orm/zzzzz.py0000666000000000000000000000167600000000000014307 0ustar0000000000000000import types class OnConnectDecorator(object): def __init__(self, database, provider): self.provider = provider def __call__(self, func=None, provider=None): if isinstance(func, types.FunctionType): self.database._on_connect_funcs.append(func, provider or self.provider) if not provider and func is basestring: provider = func if not isinstance(provider, basestring): throw(TypeError) return OnConnectDecorator(self.database, provider) class Database(object): def on_connect(database): return OnConnectDecorator(database, None) @db.on_connect def f1(conection): pass def f1(connection): pass f1 = db.on_connect(f1) @db.on_connect(provider='sqlite') def f1(conection): pass def f1(connection): pass real_decorator = db.on_connect(provider='sqlite') f1 = real_decorator(f1) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862778.0 pony-0.7.11/pony/py23compat.py0000666000000000000000000000321300000000000014257 0ustar0000000000000000import sys, platform PY2 = sys.version_info[0] == 2 PYPY = platform.python_implementation() == 'PyPy' PYPY2 = PYPY and PY2 PY37 = sys.version_info[:2] >= (3, 7) if PY2: from future_builtins import zip as izip, map as imap import __builtin__ as builtins import cPickle as pickle from cStringIO import StringIO xrange = xrange basestring = basestring unicode = unicode buffer = buffer int_types = (int, long) cmp = cmp def iteritems(dict): return dict.iteritems() def itervalues(dict): return dict.itervalues() def items_list(dict): return dict.items() def values_list(dict): return dict.values() else: import builtins, pickle from io import StringIO izip, imap, xrange = zip, map, range basestring = str unicode = str buffer = bytes int_types = (int,) def cmp(a, b): return (a > b) - (a < b) def iteritems(dict): return iter(dict.items()) def itervalues(dict): return iter(dict.values()) def items_list(dict): return list(dict.items()) def values_list(dict): return list(dict.values()) # Armin's recipe from http://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/ def with_metaclass(meta, *bases): class metaclass(meta): __call__ = type.__call__ __init__ = type.__init__ def __new__(cls, name, this_bases, d): if this_bases is None: return type.__new__(cls, name, (), d) return meta(name, bases, d) return metaclass('temporary_class', None, {}) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.9778328 pony-0.7.11/pony/thirdparty/0000777000000000000000000000000000000000000014077 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1544349699.0 pony-0.7.11/pony/thirdparty/__init__.py0000666000000000000000000000000000000000000016176 0ustar0000000000000000././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864711.0621655 pony-0.7.11/pony/thirdparty/compiler/0000777000000000000000000000000000000000000015711 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/__init__.py0000666000000000000000000000156000000000000020024 0ustar0000000000000000"""Package for parsing and compiling Python source code There are several functions defined at the top level that are imported from modules contained in the package. parse(buf, mode="exec") -> AST Converts a string containing Python source code to an abstract syntax tree (AST). The AST is defined in compiler.ast. parseFile(path) -> AST The same as parse(open(path)) walk(ast, visitor, verbose=None) Does a pre-order walk over the ast using the visitor instance. See compiler.visitor for details. compile(source, filename, mode, flags=None, dont_inherit=None) Returns a code object. A replacement for the builtin compile() function. compileFile(filename) Generates a .pyc file by compiling filename. """ from .transformer import parse, parseFile from .visitor import walk from .pycodegen import compile, compileFile ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571707432.0 pony-0.7.11/pony/thirdparty/compiler/ast.py0000666000000000000000000011613600000000000017062 0ustar0000000000000000"""Python abstract syntax node definitions This file is automatically generated by Tools/compiler/astgen.py """ from __future__ import absolute_import from pony.py23compat import items_list from .consts import CO_VARARGS, CO_VARKEYWORDS def flatten(seq): l = [] for elt in seq: t = type(elt) if t is tuple or t is list: for elt2 in flatten(elt): l.append(elt2) else: l.append(elt) return l def flatten_nodes(seq): return [n for n in flatten(seq) if isinstance(n, Node)] nodes = {} class Node: """Abstract base class for ast nodes.""" def getChildren(self): pass # implemented by subclasses def __iter__(self): for n in self.getChildren(): yield n def asList(self): # for backwards compatibility return self.getChildren() def getChildNodes(self): pass # implemented by subclasses class EmptyNode(Node): pass class Expression(Node): # Expression is an artificial node class to support "eval" nodes["expression"] = "Expression" def __init__(self, node): self.node = node def getChildren(self): return self.node, def getChildNodes(self): return self.node, def __repr__(self): return "Expression(%s)" % (repr(self.node)) class Add(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Add((%s, %s))" % (repr(self.left), repr(self.right)) class And(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "And(%s)" % (repr(self.nodes),) class AssAttr(Node): def __init__(self, expr, attrname, flags, lineno=None): self.expr = expr self.attrname = attrname self.flags = flags self.lineno = lineno def getChildren(self): return self.expr, self.attrname, self.flags def getChildNodes(self): return self.expr, def __repr__(self): return "AssAttr(%s, %s, %s)" % (repr(self.expr), repr(self.attrname), repr(self.flags)) class AssList(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "AssList(%s)" % (repr(self.nodes),) class AssName(Node): def __init__(self, name, flags, lineno=None): self.name = name self.flags = flags self.lineno = lineno def getChildren(self): return self.name, self.flags def getChildNodes(self): return () def __repr__(self): return "AssName(%s, %s)" % (repr(self.name), repr(self.flags)) class AssTuple(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "AssTuple(%s)" % (repr(self.nodes),) class Assert(Node): def __init__(self, test, fail, lineno=None): self.test = test self.fail = fail self.lineno = lineno def getChildren(self): children = [] children.append(self.test) children.append(self.fail) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.test) if self.fail is not None: nodelist.append(self.fail) return tuple(nodelist) def __repr__(self): return "Assert(%s, %s)" % (repr(self.test), repr(self.fail)) class Assign(Node): def __init__(self, nodes, expr, lineno=None): self.nodes = nodes self.expr = expr self.lineno = lineno def getChildren(self): children = [] children.extend(flatten(self.nodes)) children.append(self.expr) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) nodelist.append(self.expr) return tuple(nodelist) def __repr__(self): return "Assign(%s, %s)" % (repr(self.nodes), repr(self.expr)) class AugAssign(Node): def __init__(self, node, op, expr, lineno=None): self.node = node self.op = op self.expr = expr self.lineno = lineno def getChildren(self): return self.node, self.op, self.expr def getChildNodes(self): return self.node, self.expr def __repr__(self): return "AugAssign(%s, %s, %s)" % (repr(self.node), repr(self.op), repr(self.expr)) class Backquote(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "Backquote(%s)" % (repr(self.expr),) class Bitand(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Bitand(%s)" % (repr(self.nodes),) class Bitor(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Bitor(%s)" % (repr(self.nodes),) class Bitxor(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Bitxor(%s)" % (repr(self.nodes),) class Break(Node): def __init__(self, lineno=None): self.lineno = lineno def getChildren(self): return () def getChildNodes(self): return () def __repr__(self): return "Break()" class CallFunc(Node): def __init__(self, node, args, star_args = None, dstar_args = None, lineno=None): self.node = node self.args = args self.star_args = star_args self.dstar_args = dstar_args self.lineno = lineno def getChildren(self): children = [] children.append(self.node) children.extend(flatten(self.args)) children.append(self.star_args) children.append(self.dstar_args) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.node) nodelist.extend(flatten_nodes(self.args)) if self.star_args is not None: nodelist.append(self.star_args) if self.dstar_args is not None: nodelist.append(self.dstar_args) return tuple(nodelist) def __repr__(self): return "CallFunc(%s, %s, %s, %s)" % (repr(self.node), repr(self.args), repr(self.star_args), repr(self.dstar_args)) class Class(Node): def __init__(self, name, bases, doc, code, decorators = None, lineno=None): self.name = name self.bases = bases self.doc = doc self.code = code self.decorators = decorators self.lineno = lineno def getChildren(self): children = [] children.append(self.name) children.extend(flatten(self.bases)) children.append(self.doc) children.append(self.code) children.append(self.decorators) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.bases)) nodelist.append(self.code) if self.decorators is not None: nodelist.append(self.decorators) return tuple(nodelist) def __repr__(self): return "Class(%s, %s, %s, %s, %s)" % (repr(self.name), repr(self.bases), repr(self.doc), repr(self.code), repr(self.decorators)) class Compare(Node): def __init__(self, expr, ops, lineno=None): self.expr = expr self.ops = ops self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.extend(flatten(self.ops)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) nodelist.extend(flatten_nodes(self.ops)) return tuple(nodelist) def __repr__(self): return "Compare(%s, %s)" % (repr(self.expr), repr(self.ops)) class Const(Node): def __init__(self, value, lineno=None): self.value = value self.lineno = lineno def getChildren(self): return self.value, def getChildNodes(self): return () def __repr__(self): return "Const(%s)" % (repr(self.value),) class Continue(Node): def __init__(self, lineno=None): self.lineno = lineno def getChildren(self): return () def getChildNodes(self): return () def __repr__(self): return "Continue()" class Decorators(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Decorators(%s)" % (repr(self.nodes),) class Dict(Node): def __init__(self, items, lineno=None): self.items = items self.lineno = lineno def getChildren(self): return tuple(flatten(self.items)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.items)) return tuple(nodelist) def __repr__(self): return "Dict(%s)" % (repr(self.items),) class Discard(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "Discard(%s)" % (repr(self.expr),) class Div(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Div((%s, %s))" % (repr(self.left), repr(self.right)) class Ellipsis(Node): def __init__(self, lineno=None): self.lineno = lineno def getChildren(self): return () def getChildNodes(self): return () def __repr__(self): return "Ellipsis()" class Exec(Node): def __init__(self, expr, locals, globals, lineno=None): self.expr = expr self.locals = locals self.globals = globals self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.append(self.locals) children.append(self.globals) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) if self.locals is not None: nodelist.append(self.locals) if self.globals is not None: nodelist.append(self.globals) return tuple(nodelist) def __repr__(self): return "Exec(%s, %s, %s)" % (repr(self.expr), repr(self.locals), repr(self.globals)) class FloorDiv(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "FloorDiv((%s, %s))" % (repr(self.left), repr(self.right)) class For(Node): def __init__(self, assign, list, body, else_, lineno=None): self.assign = assign self.list = list self.body = body self.else_ = else_ self.lineno = lineno def getChildren(self): children = [] children.append(self.assign) children.append(self.list) children.append(self.body) children.append(self.else_) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.assign) nodelist.append(self.list) nodelist.append(self.body) if self.else_ is not None: nodelist.append(self.else_) return tuple(nodelist) def __repr__(self): return "For(%s, %s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.body), repr(self.else_)) class FormattedValue(Node): def __init__(self, value, fmt_spec): self.value = value self.fmt_spec = fmt_spec def getChildren(self): return self.value, self.fmt_spec def getChildNodes(self): return self.value, self.fmt_spec def __repr__(self): return "FormattedValue(%s, %s)" % (self.value, self.fmt_spec) class From(Node): def __init__(self, modname, names, level, lineno=None): self.modname = modname self.names = names self.level = level self.lineno = lineno def getChildren(self): return self.modname, self.names, self.level def getChildNodes(self): return () def __repr__(self): return "From(%s, %s, %s)" % (repr(self.modname), repr(self.names), repr(self.level)) class Function(Node): def __init__(self, decorators, name, argnames, defaults, flags, doc, code, lineno=None): self.decorators = decorators self.name = name self.argnames = argnames self.defaults = defaults self.flags = flags self.doc = doc self.code = code self.lineno = lineno self.varargs = self.kwargs = None if flags & CO_VARARGS: self.varargs = 1 if flags & CO_VARKEYWORDS: self.kwargs = 1 def getChildren(self): children = [] children.append(self.decorators) children.append(self.name) children.append(self.argnames) children.extend(flatten(self.defaults)) children.append(self.flags) children.append(self.doc) children.append(self.code) return tuple(children) def getChildNodes(self): nodelist = [] if self.decorators is not None: nodelist.append(self.decorators) nodelist.extend(flatten_nodes(self.defaults)) nodelist.append(self.code) return tuple(nodelist) def __repr__(self): return "Function(%s, %s, %s, %s, %s, %s, %s)" % (repr(self.decorators), repr(self.name), repr(self.argnames), repr(self.defaults), repr(self.flags), repr(self.doc), repr(self.code)) class GenExpr(Node): def __init__(self, code, lineno=None): self.code = code self.lineno = lineno self.argnames = ['.0'] self.varargs = self.kwargs = None def getChildren(self): return self.code, def getChildNodes(self): return self.code, def __repr__(self): return "GenExpr(%s)" % (repr(self.code),) class GenExprFor(Node): def __init__(self, assign, iter, ifs, lineno=None): self.assign = assign self.iter = iter self.ifs = ifs self.lineno = lineno self.is_outmost = False def getChildren(self): children = [] children.append(self.assign) children.append(self.iter) children.extend(flatten(self.ifs)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.assign) nodelist.append(self.iter) nodelist.extend(flatten_nodes(self.ifs)) return tuple(nodelist) def __repr__(self): return "GenExprFor(%s, %s, %s)" % (repr(self.assign), repr(self.iter), repr(self.ifs)) class GenExprIf(Node): def __init__(self, test, lineno=None): self.test = test self.lineno = lineno def getChildren(self): return self.test, def getChildNodes(self): return self.test, def __repr__(self): return "GenExprIf(%s)" % (repr(self.test),) class GenExprInner(Node): def __init__(self, expr, quals, lineno=None): self.expr = expr self.quals = quals self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.extend(flatten(self.quals)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) nodelist.extend(flatten_nodes(self.quals)) return tuple(nodelist) def __repr__(self): return "GenExprInner(%s, %s)" % (repr(self.expr), repr(self.quals)) class Getattr(Node): def __init__(self, expr, attrname, lineno=None): self.expr = expr self.attrname = attrname self.lineno = lineno def getChildren(self): return self.expr, self.attrname def getChildNodes(self): return self.expr, def __repr__(self): return "Getattr(%s, %s)" % (repr(self.expr), repr(self.attrname)) class Global(Node): def __init__(self, names, lineno=None): self.names = names self.lineno = lineno def getChildren(self): return self.names, def getChildNodes(self): return () def __repr__(self): return "Global(%s)" % (repr(self.names),) class If(Node): def __init__(self, tests, else_, lineno=None): self.tests = tests self.else_ = else_ self.lineno = lineno def getChildren(self): children = [] children.extend(flatten(self.tests)) children.append(self.else_) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.tests)) if self.else_ is not None: nodelist.append(self.else_) return tuple(nodelist) def __repr__(self): return "If(%s, %s)" % (repr(self.tests), repr(self.else_)) class IfExp(Node): def __init__(self, test, then, else_, lineno=None): self.test = test self.then = then self.else_ = else_ self.lineno = lineno def getChildren(self): return self.test, self.then, self.else_ def getChildNodes(self): return self.test, self.then, self.else_ def __repr__(self): return "IfExp(%s, %s, %s)" % (repr(self.test), repr(self.then), repr(self.else_)) class Import(Node): def __init__(self, names, lineno=None): self.names = names self.lineno = lineno def getChildren(self): return self.names, def getChildNodes(self): return () def __repr__(self): return "Import(%s)" % (repr(self.names),) class Invert(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "Invert(%s)" % (repr(self.expr),) class Keyword(Node): def __init__(self, name, expr, lineno=None): self.name = name self.expr = expr self.lineno = lineno def getChildren(self): return self.name, self.expr def getChildNodes(self): return self.expr, def __repr__(self): return "Keyword(%s, %s)" % (repr(self.name), repr(self.expr)) class Lambda(Node): def __init__(self, argnames, defaults, flags, code, lineno=None): self.argnames = argnames self.defaults = defaults self.flags = flags self.code = code self.lineno = lineno self.varargs = self.kwargs = None if flags & CO_VARARGS: self.varargs = 1 if flags & CO_VARKEYWORDS: self.kwargs = 1 def getChildren(self): children = [] children.append(self.argnames) children.extend(flatten(self.defaults)) children.append(self.flags) children.append(self.code) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.defaults)) nodelist.append(self.code) return tuple(nodelist) def __repr__(self): return "Lambda(%s, %s, %s, %s)" % (repr(self.argnames), repr(self.defaults), repr(self.flags), repr(self.code)) class LeftShift(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "LeftShift((%s, %s))" % (repr(self.left), repr(self.right)) class List(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "List(%s)" % (repr(self.nodes),) class ListComp(Node): def __init__(self, expr, quals, lineno=None): self.expr = expr self.quals = quals self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.extend(flatten(self.quals)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) nodelist.extend(flatten_nodes(self.quals)) return tuple(nodelist) def __repr__(self): return "ListComp(%s, %s)" % (repr(self.expr), repr(self.quals)) class ListCompFor(Node): def __init__(self, assign, list, ifs, lineno=None): self.assign = assign self.list = list self.ifs = ifs self.lineno = lineno def getChildren(self): children = [] children.append(self.assign) children.append(self.list) children.extend(flatten(self.ifs)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.assign) nodelist.append(self.list) nodelist.extend(flatten_nodes(self.ifs)) return tuple(nodelist) def __repr__(self): return "ListCompFor(%s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.ifs)) class ListCompIf(Node): def __init__(self, test, lineno=None): self.test = test self.lineno = lineno def getChildren(self): return self.test, def getChildNodes(self): return self.test, def __repr__(self): return "ListCompIf(%s)" % (repr(self.test),) class SetComp(Node): def __init__(self, expr, quals, lineno=None): self.expr = expr self.quals = quals self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.extend(flatten(self.quals)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) nodelist.extend(flatten_nodes(self.quals)) return tuple(nodelist) def __repr__(self): return "SetComp(%s, %s)" % (repr(self.expr), repr(self.quals)) class DictComp(Node): def __init__(self, key, value, quals, lineno=None): self.key = key self.value = value self.quals = quals self.lineno = lineno def getChildren(self): children = [] children.append(self.key) children.append(self.value) children.extend(flatten(self.quals)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.key) nodelist.append(self.value) nodelist.extend(flatten_nodes(self.quals)) return tuple(nodelist) def __repr__(self): return "DictComp(%s, %s, %s)" % (repr(self.key), repr(self.value), repr(self.quals)) class Mod(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Mod((%s, %s))" % (repr(self.left), repr(self.right)) class Module(Node): def __init__(self, doc, node, lineno=None): self.doc = doc self.node = node self.lineno = lineno def getChildren(self): return self.doc, self.node def getChildNodes(self): return self.node, def __repr__(self): return "Module(%s, %s)" % (repr(self.doc), repr(self.node)) class Mul(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Mul((%s, %s))" % (repr(self.left), repr(self.right)) class Name(Node): def __init__(self, name, lineno=None): self.name = name self.lineno = lineno def getChildren(self): return self.name, def getChildNodes(self): return () def __repr__(self): return "Name(%s)" % (repr(self.name),) class Not(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "Not(%s)" % (repr(self.expr),) class Or(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Or(%s)" % (repr(self.nodes),) class Pass(Node): def __init__(self, lineno=None): self.lineno = lineno def getChildren(self): return () def getChildNodes(self): return () def __repr__(self): return "Pass()" class Power(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Power((%s, %s))" % (repr(self.left), repr(self.right)) class Print(Node): def __init__(self, nodes, dest, lineno=None): self.nodes = nodes self.dest = dest self.lineno = lineno def getChildren(self): children = [] children.extend(flatten(self.nodes)) children.append(self.dest) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) if self.dest is not None: nodelist.append(self.dest) return tuple(nodelist) def __repr__(self): return "Print(%s, %s)" % (repr(self.nodes), repr(self.dest)) class Printnl(Node): def __init__(self, nodes, dest, lineno=None): self.nodes = nodes self.dest = dest self.lineno = lineno def getChildren(self): children = [] children.extend(flatten(self.nodes)) children.append(self.dest) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) if self.dest is not None: nodelist.append(self.dest) return tuple(nodelist) def __repr__(self): return "Printnl(%s, %s)" % (repr(self.nodes), repr(self.dest)) class Raise(Node): def __init__(self, expr1, expr2, expr3, lineno=None): self.expr1 = expr1 self.expr2 = expr2 self.expr3 = expr3 self.lineno = lineno def getChildren(self): children = [] children.append(self.expr1) children.append(self.expr2) children.append(self.expr3) return tuple(children) def getChildNodes(self): nodelist = [] if self.expr1 is not None: nodelist.append(self.expr1) if self.expr2 is not None: nodelist.append(self.expr2) if self.expr3 is not None: nodelist.append(self.expr3) return tuple(nodelist) def __repr__(self): return "Raise(%s, %s, %s)" % (repr(self.expr1), repr(self.expr2), repr(self.expr3)) class Return(Node): def __init__(self, value, lineno=None): self.value = value self.lineno = lineno def getChildren(self): return self.value, def getChildNodes(self): return self.value, def __repr__(self): return "Return(%s)" % (repr(self.value),) class RightShift(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "RightShift((%s, %s))" % (repr(self.left), repr(self.right)) class Set(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Set(%s)" % (repr(self.nodes),) class Slice(Node): def __init__(self, expr, flags, lower, upper, lineno=None): self.expr = expr self.flags = flags self.lower = lower self.upper = upper self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.append(self.flags) children.append(self.lower) children.append(self.upper) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) if self.lower is not None: nodelist.append(self.lower) if self.upper is not None: nodelist.append(self.upper) return tuple(nodelist) def __repr__(self): return "Slice(%s, %s, %s, %s)" % (repr(self.expr), repr(self.flags), repr(self.lower), repr(self.upper)) class Sliceobj(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Sliceobj(%s)" % (repr(self.nodes),) class Stmt(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Stmt(%s)" % (repr(self.nodes),) class Str(Node): def __init__(self, value, flags): self.value = value self.flags = flags def getChildren(self): return self.value, self.flags def getChildNodes(self): return self.value, def __repr__(self): return "Str(%s, %d)" % (self.value, self.flags) class JoinedStr(Node): def __init__(self, values): self.values = values def getChildren(self): return self.values def getChildNodes(self): return self.values def __repr__(self): return "JoinedStr(%s)" % (', '.join(repr(value) for value in self.values)) class Sub(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] self.right = leftright[1] self.lineno = lineno def getChildren(self): return self.left, self.right def getChildNodes(self): return self.left, self.right def __repr__(self): return "Sub((%s, %s))" % (repr(self.left), repr(self.right)) class Subscript(Node): def __init__(self, expr, flags, subs, lineno=None): self.expr = expr self.flags = flags self.subs = subs self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.append(self.flags) children.extend(flatten(self.subs)) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) nodelist.extend(flatten_nodes(self.subs)) return tuple(nodelist) def __repr__(self): return "Subscript(%s, %s, %s)" % (repr(self.expr), repr(self.flags), repr(self.subs)) class TryExcept(Node): def __init__(self, body, handlers, else_, lineno=None): self.body = body self.handlers = handlers self.else_ = else_ self.lineno = lineno def getChildren(self): children = [] children.append(self.body) children.extend(flatten(self.handlers)) children.append(self.else_) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.body) nodelist.extend(flatten_nodes(self.handlers)) if self.else_ is not None: nodelist.append(self.else_) return tuple(nodelist) def __repr__(self): return "TryExcept(%s, %s, %s)" % (repr(self.body), repr(self.handlers), repr(self.else_)) class TryFinally(Node): def __init__(self, body, final, lineno=None): self.body = body self.final = final self.lineno = lineno def getChildren(self): return self.body, self.final def getChildNodes(self): return self.body, self.final def __repr__(self): return "TryFinally(%s, %s)" % (repr(self.body), repr(self.final)) class Tuple(Node): def __init__(self, nodes, lineno=None): self.nodes = nodes self.lineno = lineno def getChildren(self): return tuple(flatten(self.nodes)) def getChildNodes(self): nodelist = [] nodelist.extend(flatten_nodes(self.nodes)) return tuple(nodelist) def __repr__(self): return "Tuple(%s)" % (repr(self.nodes),) class UnaryAdd(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "UnaryAdd(%s)" % (repr(self.expr),) class UnarySub(Node): def __init__(self, expr, lineno=None): self.expr = expr self.lineno = lineno def getChildren(self): return self.expr, def getChildNodes(self): return self.expr, def __repr__(self): return "UnarySub(%s)" % (repr(self.expr),) class While(Node): def __init__(self, test, body, else_, lineno=None): self.test = test self.body = body self.else_ = else_ self.lineno = lineno def getChildren(self): children = [] children.append(self.test) children.append(self.body) children.append(self.else_) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.test) nodelist.append(self.body) if self.else_ is not None: nodelist.append(self.else_) return tuple(nodelist) def __repr__(self): return "While(%s, %s, %s)" % (repr(self.test), repr(self.body), repr(self.else_)) class With(Node): def __init__(self, expr, vars, body, lineno=None): self.expr = expr self.vars = vars self.body = body self.lineno = lineno def getChildren(self): children = [] children.append(self.expr) children.append(self.vars) children.append(self.body) return tuple(children) def getChildNodes(self): nodelist = [] nodelist.append(self.expr) if self.vars is not None: nodelist.append(self.vars) nodelist.append(self.body) return tuple(nodelist) def __repr__(self): return "With(%s, %s, %s)" % (repr(self.expr), repr(self.vars), repr(self.body)) class Yield(Node): def __init__(self, value, lineno=None): self.value = value self.lineno = lineno def getChildren(self): return self.value, def getChildNodes(self): return self.value, def __repr__(self): return "Yield(%s)" % (repr(self.value),) for name, obj in items_list(globals()): if isinstance(obj, type) and issubclass(obj, Node): nodes[name.lower()] = obj ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/consts.py0000666000000000000000000000075300000000000017601 0ustar0000000000000000# operation flags OP_ASSIGN = 'OP_ASSIGN' OP_DELETE = 'OP_DELETE' OP_APPLY = 'OP_APPLY' SC_LOCAL = 1 SC_GLOBAL_IMPLICIT = 2 SC_GLOBAL_EXPLICIT = 3 SC_FREE = 4 SC_CELL = 5 SC_UNKNOWN = 6 CO_OPTIMIZED = 0x0001 CO_NEWLOCALS = 0x0002 CO_VARARGS = 0x0004 CO_VARKEYWORDS = 0x0008 CO_NESTED = 0x0010 CO_GENERATOR = 0x0020 CO_GENERATOR_ALLOWED = 0 CO_FUTURE_DIVISION = 0x2000 CO_FUTURE_ABSIMPORT = 0x4000 CO_FUTURE_WITH_STATEMENT = 0x8000 CO_FUTURE_PRINT_FUNCTION = 0x10000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/future.py0000666000000000000000000000367000000000000017603 0ustar0000000000000000"""Parser for future statements """ from __future__ import print_function from . import ast, walk def is_future(stmt): """Return true if statement is a well-formed future statement""" if not isinstance(stmt, ast.From): return 0 if stmt.modname == "__future__": return 1 else: return 0 class FutureParser: features = ("nested_scopes", "generators", "division", "absolute_import", "with_statement", "print_function", "unicode_literals") def __init__(self): self.found = {} # set def visitModule(self, node): stmt = node.node for s in stmt.nodes: if not self.check_stmt(s): break def check_stmt(self, stmt): if is_future(stmt): for name, asname in stmt.names: if name in self.features: self.found[name] = 1 else: raise SyntaxError("future feature %s is not defined" % name) stmt.valid_future = 1 return 1 return 0 def get_features(self): """Return list of features enabled by future statements""" return self.found.keys() class BadFutureParser: """Check for invalid future statements""" def visitFrom(self, node): if hasattr(node, 'valid_future'): return if node.modname != "__future__": return raise SyntaxError("invalid future statement " + repr(node)) def find_futures(node): p1 = FutureParser() p2 = BadFutureParser() walk(node, p1) walk(node, p2) return p1.get_features() if __name__ == "__main__": import sys from compiler import parseFile, walk for file in sys.argv[1:]: print(file) tree = parseFile(file) v = FutureParser() walk(tree, v) print(v.found) print() ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/misc.py0000666000000000000000000000351300000000000017220 0ustar0000000000000000 def flatten(tup): elts = [] for elt in tup: if isinstance(elt, tuple): elts = elts + flatten(elt) else: elts.append(elt) return elts class Set: def __init__(self): self.elts = {} def __len__(self): return len(self.elts) def __contains__(self, elt): return elt in self.elts def add(self, elt): self.elts[elt] = elt def elements(self): return self.elts.keys() def has_elt(self, elt): return elt in self.elts def remove(self, elt): del self.elts[elt] def copy(self): c = Set() c.elts.update(self.elts) return c class Stack: def __init__(self): self.stack = [] self.pop = self.stack.pop def __len__(self): return len(self.stack) def push(self, elt): self.stack.append(elt) def top(self): return self.stack[-1] def __getitem__(self, index): # needed by visitContinue() return self.stack[index] MANGLE_LEN = 256 # magic constant from compile.c def mangle(name, klass): if not name.startswith('__'): return name if len(name) + 2 >= MANGLE_LEN: return name if name.endswith('__'): return name try: i = 0 while klass[i] == '_': i = i + 1 except IndexError: return name klass = klass[i:] tlen = len(klass) + len(name) if tlen > MANGLE_LEN: klass = klass[:MANGLE_LEN-tlen] return "_%s%s" % (klass, name) def set_filename(filename, tree): """Set the filename attribute to filename on every node in tree""" worklist = [tree] while worklist: node = worklist.pop(0) node.filename = filename worklist.extend(node.getChildNodes()) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/pyassem.py0000666000000000000000000006102200000000000017745 0ustar0000000000000000"""A flow graph representation for Python bytecode""" from __future__ import absolute_import, print_function from pony.py23compat import imap, items_list import dis import types import sys from . import misc from .consts import CO_OPTIMIZED, CO_NEWLOCALS, CO_VARARGS, CO_VARKEYWORDS class FlowGraph: def __init__(self): self.current = self.entry = Block() self.exit = Block("exit") self.blocks = misc.Set() self.blocks.add(self.entry) self.blocks.add(self.exit) def startBlock(self, block): if self._debug: if self.current: print("end", repr(self.current)) print(" next", self.current.next) print(" prev", self.current.prev) print(" ", self.current.get_children()) print(repr(block)) self.current = block def nextBlock(self, block=None): # XXX think we need to specify when there is implicit transfer # from one block to the next. might be better to represent this # with explicit JUMP_ABSOLUTE instructions that are optimized # out when they are unnecessary. # # I think this strategy works: each block has a child # designated as "next" which is returned as the last of the # children. because the nodes in a graph are emitted in # reverse post order, the "next" block will always be emitted # immediately after its parent. # Worry: maintaining this invariant could be tricky if block is None: block = self.newBlock() # Note: If the current block ends with an unconditional control # transfer, then it is techically incorrect to add an implicit # transfer to the block graph. Doing so results in code generation # for unreachable blocks. That doesn't appear to be very common # with Python code and since the built-in compiler doesn't optimize # it out we don't either. self.current.addNext(block) self.startBlock(block) def newBlock(self): b = Block() self.blocks.add(b) return b def startExitBlock(self): self.startBlock(self.exit) _debug = 0 def _enable_debug(self): self._debug = 1 def _disable_debug(self): self._debug = 0 def emit(self, *inst): if self._debug: print("\t", inst) if len(inst) == 2 and isinstance(inst[1], Block): self.current.addOutEdge(inst[1]) self.current.emit(inst) def getBlocksInOrder(self): """Return the blocks in reverse postorder i.e. each node appears before all of its successors """ order = order_blocks(self.entry, self.exit) return order def getBlocks(self): return self.blocks.elements() def getRoot(self): """Return nodes appropriate for use with dominator""" return self.entry def getContainedGraphs(self): l = [] for b in self.getBlocks(): l.extend(b.getContainedGraphs()) return l def order_blocks(start_block, exit_block): """Order blocks so that they are emitted in the right order""" # Rules: # - when a block has a next block, the next block must be emitted just after # - when a block has followers (relative jumps), it must be emitted before # them # - all reachable blocks must be emitted order = [] # Find all the blocks to be emitted. remaining = set() todo = [start_block] while todo: b = todo.pop() if b in remaining: continue remaining.add(b) for c in b.get_children(): if c not in remaining: todo.append(c) # A block is dominated by another block if that block must be emitted # before it. dominators = {} for b in remaining: if __debug__ and b.next: assert b is b.next[0].prev[0], (b, b.next) # Make sure every block appears in dominators, even if no # other block must precede it. dominators.setdefault(b, set()) # preceding blocks dominate following blocks for c in b.get_followers(): while 1: dominators.setdefault(c, set()).add(b) # Any block that has a next pointer leading to c is also # dominated because the whole chain will be emitted at once. # Walk backwards and add them all. if c.prev and c.prev[0] is not b: c = c.prev[0] else: break def find_next(): # Find a block that can be emitted next. for b in remaining: for c in dominators[b]: if c in remaining: break # can't emit yet, dominated by a remaining block else: return b assert 0, 'circular dependency, cannot find next block' b = start_block while 1: order.append(b) remaining.discard(b) if b.next: b = b.next[0] continue elif b is not exit_block and not b.has_unconditional_transfer(): order.append(exit_block) if not remaining: break b = find_next() return order class Block: _count = 0 def __init__(self, label=''): self.insts = [] self.outEdges = set() self.label = label self.bid = Block._count self.next = [] self.prev = [] Block._count = Block._count + 1 def __repr__(self): if self.label: return "" % (self.label, self.bid) else: return "" % (self.bid) def __str__(self): insts = imap(str, self.insts) return "" % (self.label, self.bid, '\n'.join(insts)) def emit(self, inst): op = inst[0] self.insts.append(inst) def getInstructions(self): return self.insts def addOutEdge(self, block): self.outEdges.add(block) def addNext(self, block): self.next.append(block) assert len(self.next) == 1, list(imap(str, self.next)) block.prev.append(self) assert len(block.prev) == 1, list(imap(str, block.prev)) _uncond_transfer = ('RETURN_VALUE', 'RAISE_VARARGS', 'JUMP_ABSOLUTE', 'JUMP_FORWARD', 'CONTINUE_LOOP', ) def has_unconditional_transfer(self): """Returns True if there is an unconditional transfer to an other block at the end of this block. This means there is no risk for the bytecode executer to go past this block's bytecode.""" try: op, arg = self.insts[-1] except (IndexError, ValueError): return return op in self._uncond_transfer def get_children(self): return list(self.outEdges) + self.next def get_followers(self): """Get the whole list of followers, including the next block.""" followers = set(self.next) # Blocks that must be emitted *after* this one, because of # bytecode offsets (e.g. relative jumps) pointing to them. for inst in self.insts: if inst[0] in PyFlowGraph.hasjrel: followers.add(inst[1]) return followers def getContainedGraphs(self): """Return all graphs contained within this block. For example, a MAKE_FUNCTION block will contain a reference to the graph for the function body. """ contained = [] for inst in self.insts: if len(inst) == 1: continue op = inst[1] if hasattr(op, 'graph'): contained.append(op.graph) return contained # flags for code objects # the FlowGraph is transformed in place; it exists in one of these states RAW = "RAW" FLAT = "FLAT" CONV = "CONV" DONE = "DONE" class PyFlowGraph(FlowGraph): super_init = FlowGraph.__init__ def __init__(self, name, filename, args=(), optimized=0, klass=None): self.super_init() self.name = name self.filename = filename self.docstring = None self.args = args # XXX self.argcount = getArgCount(args) self.klass = klass if optimized: self.flags = CO_OPTIMIZED | CO_NEWLOCALS else: self.flags = 0 self.consts = [] self.names = [] # Free variables found by the symbol table scan, including # variables used only in nested scopes, are included here. self.freevars = [] self.cellvars = [] # The closure list is used to track the order of cell # variables and free variables in the resulting code object. # The offsets used by LOAD_CLOSURE/LOAD_DEREF refer to both # kinds of variables. self.closure = [] self.varnames = list(args) or [] for i in range(len(self.varnames)): var = self.varnames[i] if isinstance(var, TupleArg): self.varnames[i] = var.getName() self.stage = RAW def setDocstring(self, doc): self.docstring = doc def setFlag(self, flag): self.flags = self.flags | flag if flag == CO_VARARGS: self.argcount = self.argcount - 1 def checkFlag(self, flag): if self.flags & flag: return 1 def setFreeVars(self, names): self.freevars = list(names) def setCellVars(self, names): self.cellvars = names def getCode(self): """Get a Python code object""" assert self.stage == RAW self.computeStackDepth() self.flattenGraph() assert self.stage == FLAT self.convertArgs() assert self.stage == CONV self.makeByteCode() assert self.stage == DONE return self.newCodeObject() def dump(self, io=None): if io: save = sys.stdout sys.stdout = io pc = 0 for t in self.insts: opname = t[0] if opname == "SET_LINENO": print() if len(t) == 1: print("\t", "%3d" % pc, opname) pc = pc + 1 else: print("\t", "%3d" % pc, opname, t[1]) pc = pc + 3 if io: sys.stdout = save def computeStackDepth(self): """Compute the max stack depth. Approach is to compute the stack effect of each basic block. Then find the path through the code with the largest total effect. """ depth = {} exit = None for b in self.getBlocks(): depth[b] = findDepth(b.getInstructions()) seen = {} def max_depth(b, d): if b in seen: return d seen[b] = 1 d = d + depth[b] children = b.get_children() if children: return max([max_depth(c, d) for c in children]) else: if not b.label == "exit": return max_depth(self.exit, d) else: return d self.stacksize = max_depth(self.entry, 0) def flattenGraph(self): """Arrange the blocks in order and resolve jumps""" assert self.stage == RAW self.insts = insts = [] pc = 0 begin = {} end = {} for b in self.getBlocksInOrder(): begin[b] = pc for inst in b.getInstructions(): insts.append(inst) if len(inst) == 1: pc = pc + 1 elif inst[0] != "SET_LINENO": # arg takes 2 bytes pc = pc + 3 end[b] = pc pc = 0 for i in range(len(insts)): inst = insts[i] if len(inst) == 1: pc = pc + 1 elif inst[0] != "SET_LINENO": pc = pc + 3 opname = inst[0] if opname in self.hasjrel: oparg = inst[1] offset = begin[oparg] - pc insts[i] = opname, offset elif opname in self.hasjabs: insts[i] = opname, begin[inst[1]] self.stage = FLAT hasjrel = set() for i in dis.hasjrel: hasjrel.add(dis.opname[i]) hasjabs = set() for i in dis.hasjabs: hasjabs.add(dis.opname[i]) def convertArgs(self): """Convert arguments from symbolic to concrete form""" assert self.stage == FLAT self.consts.insert(0, self.docstring) self.sort_cellvars() for i in range(len(self.insts)): t = self.insts[i] if len(t) == 2: opname, oparg = t conv = self._converters.get(opname, None) if conv: self.insts[i] = opname, conv(self, oparg) self.stage = CONV def sort_cellvars(self): """Sort cellvars in the order of varnames and prune from freevars. """ cells = {} for name in self.cellvars: cells[name] = 1 self.cellvars = [name for name in self.varnames if name in cells] for name in self.cellvars: del cells[name] self.cellvars = self.cellvars + cells.keys() self.closure = self.cellvars + self.freevars def _lookupName(self, name, list): """Return index of name in list, appending if necessary This routine uses a list instead of a dictionary, because a dictionary can't store two different keys if the keys have the same value but different types, e.g. 2 and 2L. The compiler must treat these two separately, so it does an explicit type comparison before comparing the values. """ t = type(name) for i in range(len(list)): if t == type(list[i]) and list[i] == name: return i end = len(list) list.append(name) return end _converters = {} def _convert_LOAD_CONST(self, arg): if hasattr(arg, 'getCode'): arg = arg.getCode() return self._lookupName(arg, self.consts) def _convert_LOAD_FAST(self, arg): self._lookupName(arg, self.names) return self._lookupName(arg, self.varnames) _convert_STORE_FAST = _convert_LOAD_FAST _convert_DELETE_FAST = _convert_LOAD_FAST def _convert_LOAD_NAME(self, arg): if self.klass is None: self._lookupName(arg, self.varnames) return self._lookupName(arg, self.names) def _convert_NAME(self, arg): if self.klass is None: self._lookupName(arg, self.varnames) return self._lookupName(arg, self.names) _convert_STORE_NAME = _convert_NAME _convert_DELETE_NAME = _convert_NAME _convert_IMPORT_NAME = _convert_NAME _convert_IMPORT_FROM = _convert_NAME _convert_STORE_ATTR = _convert_NAME _convert_LOAD_ATTR = _convert_NAME _convert_DELETE_ATTR = _convert_NAME _convert_LOAD_GLOBAL = _convert_NAME _convert_STORE_GLOBAL = _convert_NAME _convert_DELETE_GLOBAL = _convert_NAME def _convert_DEREF(self, arg): self._lookupName(arg, self.names) self._lookupName(arg, self.varnames) return self._lookupName(arg, self.closure) _convert_LOAD_DEREF = _convert_DEREF _convert_STORE_DEREF = _convert_DEREF def _convert_LOAD_CLOSURE(self, arg): self._lookupName(arg, self.varnames) return self._lookupName(arg, self.closure) _cmp = list(dis.cmp_op) def _convert_COMPARE_OP(self, arg): return self._cmp.index(arg) # similarly for other opcodes... for name, obj in items_list(locals()): if name[:9] == "_convert_": opname = name[9:] _converters[opname] = obj del name, obj, opname def makeByteCode(self): assert self.stage == CONV self.lnotab = lnotab = LineAddrTable() for t in self.insts: opname = t[0] if len(t) == 1: lnotab.addCode(self.opnum[opname]) else: oparg = t[1] if opname == "SET_LINENO": lnotab.nextLine(oparg) continue hi, lo = twobyte(oparg) try: lnotab.addCode(self.opnum[opname], lo, hi) except ValueError: print(opname, oparg) print(self.opnum[opname], lo, hi) raise self.stage = DONE opnum = {} for num in range(len(dis.opname)): opnum[dis.opname[num]] = num del num def newCodeObject(self): assert self.stage == DONE if (self.flags & CO_NEWLOCALS) == 0: nlocals = 0 else: nlocals = len(self.varnames) argcount = self.argcount if self.flags & CO_VARKEYWORDS: argcount = argcount - 1 return types.CodeType(argcount, nlocals, self.stacksize, self.flags, self.lnotab.getCode(), self.getConsts(), tuple(self.names), tuple(self.varnames), self.filename, self.name, self.lnotab.firstline, self.lnotab.getTable(), tuple(self.freevars), tuple(self.cellvars)) def getConsts(self): """Return a tuple for the const slot of the code object Must convert references to code (MAKE_FUNCTION) to code objects recursively. """ l = [] for elt in self.consts: if isinstance(elt, PyFlowGraph): elt = elt.getCode() l.append(elt) return tuple(l) def isJump(opname): if opname[:4] == 'JUMP': return 1 class TupleArg: """Helper for marking func defs with nested tuples in arglist""" def __init__(self, count, names): self.count = count self.names = names def __repr__(self): return "TupleArg(%s, %s)" % (self.count, self.names) def getName(self): return ".%d" % self.count def getArgCount(args): argcount = len(args) if args: for arg in args: if isinstance(arg, TupleArg): numNames = len(misc.flatten(arg.names)) argcount = argcount - numNames return argcount def twobyte(val): """Convert an int argument into high and low bytes""" assert isinstance(val, int) return divmod(val, 256) class LineAddrTable: """lnotab This class builds the lnotab, which is documented in compile.c. Here's a brief recap: For each SET_LINENO instruction after the first one, two bytes are added to lnotab. (In some cases, multiple two-byte entries are added.) The first byte is the distance in bytes between the instruction for the last SET_LINENO and the current SET_LINENO. The second byte is offset in line numbers. If either offset is greater than 255, multiple two-byte entries are added -- see compile.c for the delicate details. """ def __init__(self): self.code = [] self.codeOffset = 0 self.firstline = 0 self.lastline = 0 self.lastoff = 0 self.lnotab = [] def addCode(self, *args): for arg in args: self.code.append(chr(arg)) self.codeOffset = self.codeOffset + len(args) def nextLine(self, lineno): if self.firstline == 0: self.firstline = lineno self.lastline = lineno else: # compute deltas addr = self.codeOffset - self.lastoff line = lineno - self.lastline # Python assumes that lineno always increases with # increasing bytecode address (lnotab is unsigned char). # Depending on when SET_LINENO instructions are emitted # this is not always true. Consider the code: # a = (1, # b) # In the bytecode stream, the assignment to "a" occurs # after the loading of "b". This works with the C Python # compiler because it only generates a SET_LINENO instruction # for the assignment. if line >= 0: push = self.lnotab.append while addr > 255: push(255); push(0) addr -= 255 while line > 255: push(addr); push(255) line -= 255 addr = 0 if addr > 0 or line > 0: push(addr); push(line) self.lastline = lineno self.lastoff = self.codeOffset def getCode(self): return ''.join(self.code) def getTable(self): return ''.join(imap(chr, self.lnotab)) class StackDepthTracker: # XXX 1. need to keep track of stack depth on jumps # XXX 2. at least partly as a result, this code is broken def findDepth(self, insts, debug=0): depth = 0 maxDepth = 0 for i in insts: opname = i[0] if debug: print(i, end=' ') delta = self.effect.get(opname, None) if delta is not None: depth = depth + delta else: # now check patterns for pat, pat_delta in self.patterns: if opname[:len(pat)] == pat: delta = pat_delta depth = depth + delta break # if we still haven't found a match if delta is None: meth = getattr(self, opname, None) if meth is not None: depth = depth + meth(i[1]) if depth > maxDepth: maxDepth = depth if debug: print(depth, maxDepth) return maxDepth effect = { 'POP_TOP': -1, 'DUP_TOP': 1, 'LIST_APPEND': -1, 'SET_ADD': -1, 'MAP_ADD': -2, 'SLICE+1': -1, 'SLICE+2': -1, 'SLICE+3': -2, 'STORE_SLICE+0': -1, 'STORE_SLICE+1': -2, 'STORE_SLICE+2': -2, 'STORE_SLICE+3': -3, 'DELETE_SLICE+0': -1, 'DELETE_SLICE+1': -2, 'DELETE_SLICE+2': -2, 'DELETE_SLICE+3': -3, 'STORE_SUBSCR': -3, 'DELETE_SUBSCR': -2, # PRINT_EXPR? 'PRINT_ITEM': -1, 'RETURN_VALUE': -1, 'YIELD_VALUE': -1, 'EXEC_STMT': -3, 'BUILD_CLASS': -2, 'STORE_NAME': -1, 'STORE_ATTR': -2, 'DELETE_ATTR': -1, 'STORE_GLOBAL': -1, 'BUILD_MAP': 1, 'COMPARE_OP': -1, 'STORE_FAST': -1, 'IMPORT_STAR': -1, 'IMPORT_NAME': -1, 'IMPORT_FROM': 1, 'LOAD_ATTR': 0, # unlike other loads # close enough... 'SETUP_EXCEPT': 3, 'SETUP_FINALLY': 3, 'FOR_ITER': 1, 'WITH_CLEANUP': -1, } # use pattern match patterns = [ ('BINARY_', -1), ('LOAD_', 1), ] def UNPACK_SEQUENCE(self, count): return count-1 def BUILD_TUPLE(self, count): return -count+1 def BUILD_LIST(self, count): return -count+1 def BUILD_SET(self, count): return -count+1 def CALL_FUNCTION(self, argc): hi, lo = divmod(argc, 256) return -(lo + hi * 2) def CALL_FUNCTION_VAR(self, argc): return self.CALL_FUNCTION(argc)-1 def CALL_FUNCTION_KW(self, argc): return self.CALL_FUNCTION(argc)-1 def CALL_FUNCTION_VAR_KW(self, argc): return self.CALL_FUNCTION(argc)-2 def MAKE_FUNCTION(self, argc): return -argc def MAKE_CLOSURE(self, argc): # XXX need to account for free variables too! return -argc def BUILD_SLICE(self, argc): if argc == 2: return -1 elif argc == 3: return -2 def DUP_TOPX(self, argc): return argc findDepth = StackDepthTracker().findDepth ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862779.0 pony-0.7.11/pony/thirdparty/compiler/pycodegen.py0000666000000000000000000014046500000000000020252 0ustar0000000000000000from __future__ import absolute_import, print_function from pony.py23compat import izip import os import marshal import struct import sys from . import ast, parse, walk, syntax from . import pyassem, misc, future, symbols from .consts import SC_LOCAL, SC_GLOBAL_IMPLICIT, SC_GLOBAL_EXPLICIT, \ SC_FREE, SC_CELL from .consts import (CO_VARARGS, CO_VARKEYWORDS, CO_NEWLOCALS, CO_NESTED, CO_GENERATOR, CO_FUTURE_DIVISION, CO_FUTURE_ABSIMPORT, CO_FUTURE_WITH_STATEMENT, CO_FUTURE_PRINT_FUNCTION) from .pyassem import TupleArg import pony.utils # XXX The version-specific code can go, since this code only works with 2.x. # Do we have Python 1.x or Python 2.x? try: VERSION = sys.version_info[0] except AttributeError: VERSION = 1 callfunc_opcode_info = { # (Have *args, Have **args) : opcode (0,0) : "CALL_FUNCTION", (1,0) : "CALL_FUNCTION_VAR", (0,1) : "CALL_FUNCTION_KW", (1,1) : "CALL_FUNCTION_VAR_KW", } LOOP = 1 EXCEPT = 2 TRY_FINALLY = 3 END_FINALLY = 4 def compileFile(filename, display=0): f = open(filename, 'U') buf = f.read() f.close() mod = Module(buf, filename) try: mod.compile(display) except SyntaxError: raise else: f = open(filename + "c", "wb") mod.dump(f) f.close() def compile(source, filename, mode, flags=None, dont_inherit=None): """Replacement for builtin compile() function""" if flags is not None or dont_inherit is not None: raise RuntimeError("not implemented yet") if mode == "single": gen = Interactive(source, filename) elif mode == "exec": gen = Module(source, filename) elif mode == "eval": gen = Expression(source, filename) else: raise ValueError("compile() 3rd arg must be 'exec' or 'eval' or 'single'") gen.compile() return gen.code class AbstractCompileMode: mode = None # defined by subclass def __init__(self, source, filename): self.source = source self.filename = filename self.code = None def _get_tree(self): tree = parse(self.source, self.mode) misc.set_filename(self.filename, tree) syntax.check(tree) return tree def compile(self): pass # implemented by subclass def getCode(self): return self.code class Expression(AbstractCompileMode): mode = "eval" def compile(self): tree = self._get_tree() gen = ExpressionCodeGenerator(tree) self.code = gen.getCode() class Interactive(AbstractCompileMode): mode = "single" def compile(self): tree = self._get_tree() gen = InteractiveCodeGenerator(tree) self.code = gen.getCode() class Module(AbstractCompileMode): mode = "exec" def compile(self, display=0): tree = self._get_tree() gen = ModuleCodeGenerator(tree) if display: import pprint print(pprint.pprint(tree)) self.code = gen.getCode() def dump(self, f): f.write(self.getPycHeader()) marshal.dump(self.code, f) if VERSION < 3: import imp MAGIC = imp.get_magic() else: import importlib.util MAGIC = importlib.util.MAGIC_NUMBER def getPycHeader(self): # compile.c uses marshal to write a long directly, with # calling the interface that would also generate a 1-byte code # to indicate the type of the value. simplest way to get the # same effect is to call marshal and then skip the code. mtime = os.path.getmtime(self.filename) mtime = struct.pack(' 0: top = top - 1 kind, loop_block = self.setups[top] if kind == LOOP: break if kind != LOOP: raise SyntaxError("'continue' outside loop (%s, %d)" % (node.filename, node.lineno)) self.emit('CONTINUE_LOOP', loop_block) self.nextBlock() elif kind == END_FINALLY: msg = "'continue' not allowed inside 'finally' clause (%s, %d)" raise SyntaxError(msg % (node.filename, node.lineno)) def visitTest(self, node, jump): end = self.newBlock() for child in node.nodes[:-1]: self.visit(child) self.emit(jump, end) self.nextBlock() self.visit(node.nodes[-1]) self.nextBlock(end) def visitAnd(self, node): self.visitTest(node, 'JUMP_IF_FALSE_OR_POP') def visitOr(self, node): self.visitTest(node, 'JUMP_IF_TRUE_OR_POP') def visitIfExp(self, node): endblock = self.newBlock() elseblock = self.newBlock() self.visit(node.test) self.emit('POP_JUMP_IF_FALSE', elseblock) self.visit(node.then) self.emit('JUMP_FORWARD', endblock) self.nextBlock(elseblock) self.visit(node.else_) self.nextBlock(endblock) def visitCompare(self, node): self.visit(node.expr) cleanup = self.newBlock() for op, code in node.ops[:-1]: self.visit(code) self.emit('DUP_TOP') self.emit('ROT_THREE') self.emit('COMPARE_OP', op) self.emit('JUMP_IF_FALSE_OR_POP', cleanup) self.nextBlock() # now do the last comparison if node.ops: op, code = node.ops[-1] self.visit(code) self.emit('COMPARE_OP', op) if len(node.ops) > 1: end = self.newBlock() self.emit('JUMP_FORWARD', end) self.startBlock(cleanup) self.emit('ROT_TWO') self.emit('POP_TOP') self.nextBlock(end) # list comprehensions def visitListComp(self, node): self.set_lineno(node) # setup list self.emit('BUILD_LIST', 0) stack = [] for i, for_ in izip(range(len(node.quals)), node.quals): start, anchor = self.visit(for_) cont = None for if_ in for_.ifs: if cont is None: cont = self.newBlock() self.visit(if_, cont) stack.insert(0, (start, cont, anchor)) self.visit(node.expr) self.emit('LIST_APPEND', len(node.quals) + 1) for start, cont, anchor in stack: if cont: self.nextBlock(cont) self.emit('JUMP_ABSOLUTE', start) self.startBlock(anchor) def visitSetComp(self, node): self.set_lineno(node) # setup list self.emit('BUILD_SET', 0) stack = [] for i, for_ in izip(range(len(node.quals)), node.quals): start, anchor = self.visit(for_) cont = None for if_ in for_.ifs: if cont is None: cont = self.newBlock() self.visit(if_, cont) stack.insert(0, (start, cont, anchor)) self.visit(node.expr) self.emit('SET_ADD', len(node.quals) + 1) for start, cont, anchor in stack: if cont: self.nextBlock(cont) self.emit('JUMP_ABSOLUTE', start) self.startBlock(anchor) def visitDictComp(self, node): self.set_lineno(node) # setup list self.emit('BUILD_MAP', 0) stack = [] for i, for_ in izip(range(len(node.quals)), node.quals): start, anchor = self.visit(for_) cont = None for if_ in for_.ifs: if cont is None: cont = self.newBlock() self.visit(if_, cont) stack.insert(0, (start, cont, anchor)) self.visit(node.value) self.visit(node.key) self.emit('MAP_ADD', len(node.quals) + 1) for start, cont, anchor in stack: if cont: self.nextBlock(cont) self.emit('JUMP_ABSOLUTE', start) self.startBlock(anchor) def visitListCompFor(self, node): start = self.newBlock() anchor = self.newBlock() self.visit(node.list) self.emit('GET_ITER') self.nextBlock(start) self.set_lineno(node, force=True) self.emit('FOR_ITER', anchor) self.nextBlock() self.visit(node.assign) return start, anchor def visitListCompIf(self, node, branch): self.set_lineno(node, force=True) self.visit(node.test) self.emit('POP_JUMP_IF_FALSE', branch) self.newBlock() def _makeClosure(self, gen, args): frees = gen.scope.get_free_vars() if frees: for name in frees: self.emit('LOAD_CLOSURE', name) self.emit('BUILD_TUPLE', len(frees)) self.emit('LOAD_CONST', gen) self.emit('MAKE_CLOSURE', args) else: self.emit('LOAD_CONST', gen) self.emit('MAKE_FUNCTION', args) def visitGenExpr(self, node): gen = GenExprCodeGenerator(node, self.scopes, self.class_name, self.get_module()) walk(node.code, gen) gen.finish() self.set_lineno(node) self._makeClosure(gen, 0) # precomputation of outmost iterable self.visit(node.code.quals[0].iter) self.emit('GET_ITER') self.emit('CALL_FUNCTION', 1) def visitGenExprInner(self, node): self.set_lineno(node) # setup list stack = [] for i, for_ in izip(range(len(node.quals)), node.quals): start, anchor, end = self.visit(for_) cont = None for if_ in for_.ifs: if cont is None: cont = self.newBlock() self.visit(if_, cont) stack.insert(0, (start, cont, anchor, end)) self.visit(node.expr) self.emit('YIELD_VALUE') self.emit('POP_TOP') for start, cont, anchor, end in stack: if cont: self.nextBlock(cont) self.emit('JUMP_ABSOLUTE', start) self.startBlock(anchor) self.emit('POP_BLOCK') self.setups.pop() self.nextBlock(end) self.emit('LOAD_CONST', None) def visitGenExprFor(self, node): start = self.newBlock() anchor = self.newBlock() end = self.newBlock() self.setups.push((LOOP, start)) self.emit('SETUP_LOOP', end) if node.is_outmost: self.loadName('.0') else: self.visit(node.iter) self.emit('GET_ITER') self.nextBlock(start) self.set_lineno(node, force=True) self.emit('FOR_ITER', anchor) self.nextBlock() self.visit(node.assign) return start, anchor, end def visitGenExprIf(self, node, branch): self.set_lineno(node, force=True) self.visit(node.test) self.emit('POP_JUMP_IF_FALSE', branch) self.newBlock() # exception related def visitAssert(self, node): # XXX would be interesting to implement this via a # transformation of the AST before this stage if __debug__: end = self.newBlock() self.set_lineno(node) # XXX AssertionError appears to be special case -- it is always # loaded as a global even if there is a local name. I guess this # is a sort of renaming op. self.nextBlock() self.visit(node.test) self.emit('POP_JUMP_IF_TRUE', end) self.nextBlock() self.emit('LOAD_GLOBAL', 'AssertionError') if node.fail: self.visit(node.fail) self.emit('RAISE_VARARGS', 2) else: self.emit('RAISE_VARARGS', 1) self.nextBlock(end) def visitRaise(self, node): self.set_lineno(node) n = 0 if node.expr1: self.visit(node.expr1) n = n + 1 if node.expr2: self.visit(node.expr2) n = n + 1 if node.expr3: self.visit(node.expr3) n = n + 1 self.emit('RAISE_VARARGS', n) def visitTryExcept(self, node): body = self.newBlock() handlers = self.newBlock() end = self.newBlock() if node.else_: lElse = self.newBlock() else: lElse = end self.set_lineno(node) self.emit('SETUP_EXCEPT', handlers) self.nextBlock(body) self.setups.push((EXCEPT, body)) self.visit(node.body) self.emit('POP_BLOCK') self.setups.pop() self.emit('JUMP_FORWARD', lElse) self.startBlock(handlers) last = len(node.handlers) - 1 for i in range(len(node.handlers)): expr, target, body = node.handlers[i] self.set_lineno(expr) if expr: self.emit('DUP_TOP') self.visit(expr) self.emit('COMPARE_OP', 'exception match') next = self.newBlock() self.emit('POP_JUMP_IF_FALSE', next) self.nextBlock() self.emit('POP_TOP') if target: self.visit(target) else: self.emit('POP_TOP') self.emit('POP_TOP') self.visit(body) self.emit('JUMP_FORWARD', end) if expr: self.nextBlock(next) else: self.nextBlock() self.emit('END_FINALLY') if node.else_: self.nextBlock(lElse) self.visit(node.else_) self.nextBlock(end) def visitTryFinally(self, node): body = self.newBlock() final = self.newBlock() self.set_lineno(node) self.emit('SETUP_FINALLY', final) self.nextBlock(body) self.setups.push((TRY_FINALLY, body)) self.visit(node.body) self.emit('POP_BLOCK') self.setups.pop() self.emit('LOAD_CONST', None) self.nextBlock(final) self.setups.push((END_FINALLY, final)) self.visit(node.final) self.emit('END_FINALLY') self.setups.pop() __with_count = 0 def visitWith(self, node): body = self.newBlock() final = self.newBlock() self.__with_count += 1 valuevar = "_[%d]" % self.__with_count self.set_lineno(node) self.visit(node.expr) self.emit('DUP_TOP') self.emit('LOAD_ATTR', '__exit__') self.emit('ROT_TWO') self.emit('LOAD_ATTR', '__enter__') self.emit('CALL_FUNCTION', 0) if node.vars is None: self.emit('POP_TOP') else: self._implicitNameOp('STORE', valuevar) self.emit('SETUP_FINALLY', final) self.nextBlock(body) self.setups.push((TRY_FINALLY, body)) if node.vars is not None: self._implicitNameOp('LOAD', valuevar) self._implicitNameOp('DELETE', valuevar) self.visit(node.vars) self.visit(node.body) self.emit('POP_BLOCK') self.setups.pop() self.emit('LOAD_CONST', None) self.nextBlock(final) self.setups.push((END_FINALLY, final)) self.emit('WITH_CLEANUP') self.emit('END_FINALLY') self.setups.pop() self.__with_count -= 1 # misc def visitDiscard(self, node): self.set_lineno(node) self.visit(node.expr) self.emit('POP_TOP') def visitConst(self, node): self.emit('LOAD_CONST', node.value) def visitKeyword(self, node): self.emit('LOAD_CONST', node.name) self.visit(node.expr) def visitGlobal(self, node): # no code to generate pass def visitName(self, node): self.set_lineno(node) self.loadName(node.name) def visitPass(self, node): self.set_lineno(node) def visitImport(self, node): self.set_lineno(node) level = 0 if self.graph.checkFlag(CO_FUTURE_ABSIMPORT) else -1 for name, alias in node.names: if VERSION > 1: self.emit('LOAD_CONST', level) self.emit('LOAD_CONST', None) self.emit('IMPORT_NAME', name) mod = name.split(".")[0] if alias: self._resolveDots(name) self.storeName(alias) else: self.storeName(mod) def visitFrom(self, node): self.set_lineno(node) level = node.level if level == 0 and not self.graph.checkFlag(CO_FUTURE_ABSIMPORT): level = -1 fromlist = tuple(name for (name, alias) in node.names) if VERSION > 1: self.emit('LOAD_CONST', level) self.emit('LOAD_CONST', fromlist) self.emit('IMPORT_NAME', node.modname) for name, alias in node.names: if VERSION > 1: if name == '*': self.namespace = 0 self.emit('IMPORT_STAR') # There can only be one name w/ from ... import * assert len(node.names) == 1 return else: self.emit('IMPORT_FROM', name) self._resolveDots(name) self.storeName(alias or name) else: self.emit('IMPORT_FROM', name) self.emit('POP_TOP') def _resolveDots(self, name): elts = name.split(".") if len(elts) == 1: return for elt in elts[1:]: self.emit('LOAD_ATTR', elt) def visitGetattr(self, node): self.visit(node.expr) self.emit('LOAD_ATTR', self.mangle(node.attrname)) # next five implement assignments def visitAssign(self, node): self.set_lineno(node) self.visit(node.expr) dups = len(node.nodes) - 1 for i in range(len(node.nodes)): elt = node.nodes[i] if i < dups: self.emit('DUP_TOP') if isinstance(elt, ast.Node): self.visit(elt) def visitAssName(self, node): if node.flags == 'OP_ASSIGN': self.storeName(node.name) elif node.flags == 'OP_DELETE': self.set_lineno(node) self.delName(node.name) else: print("oops", node.flags) def visitAssAttr(self, node): self.visit(node.expr) if node.flags == 'OP_ASSIGN': self.emit('STORE_ATTR', self.mangle(node.attrname)) elif node.flags == 'OP_DELETE': self.emit('DELETE_ATTR', self.mangle(node.attrname)) else: print("warning: unexpected flags:", node.flags) print(node) def _visitAssSequence(self, node, op='UNPACK_SEQUENCE'): if findOp(node) != 'OP_DELETE': self.emit(op, len(node.nodes)) for child in node.nodes: self.visit(child) if VERSION > 1: visitAssTuple = _visitAssSequence visitAssList = _visitAssSequence else: def visitAssTuple(self, node): self._visitAssSequence(node, 'UNPACK_TUPLE') def visitAssList(self, node): self._visitAssSequence(node, 'UNPACK_LIST') # augmented assignment def visitAugAssign(self, node): self.set_lineno(node) aug_node = wrap_aug(node.node) self.visit(aug_node, "load") self.visit(node.expr) self.emit(self._augmented_opcode[node.op]) self.visit(aug_node, "store") _augmented_opcode = { '+=' : 'INPLACE_ADD', '-=' : 'INPLACE_SUBTRACT', '*=' : 'INPLACE_MULTIPLY', '/=' : 'INPLACE_DIVIDE', '//=': 'INPLACE_FLOOR_DIVIDE', '%=' : 'INPLACE_MODULO', '**=': 'INPLACE_POWER', '>>=': 'INPLACE_RSHIFT', '<<=': 'INPLACE_LSHIFT', '&=' : 'INPLACE_AND', '^=' : 'INPLACE_XOR', '|=' : 'INPLACE_OR', } def visitAugName(self, node, mode): if mode == "load": self.loadName(node.name) elif mode == "store": self.storeName(node.name) def visitAugGetattr(self, node, mode): if mode == "load": self.visit(node.expr) self.emit('DUP_TOP') self.emit('LOAD_ATTR', self.mangle(node.attrname)) elif mode == "store": self.emit('ROT_TWO') self.emit('STORE_ATTR', self.mangle(node.attrname)) def visitAugSlice(self, node, mode): if mode == "load": self.visitSlice(node, 1) elif mode == "store": slice = 0 if node.lower: slice = slice | 1 if node.upper: slice = slice | 2 if slice == 0: self.emit('ROT_TWO') elif slice == 3: self.emit('ROT_FOUR') else: self.emit('ROT_THREE') self.emit('STORE_SLICE+%d' % slice) def visitAugSubscript(self, node, mode): if mode == "load": self.visitSubscript(node, 1) elif mode == "store": self.emit('ROT_THREE') self.emit('STORE_SUBSCR') def visitExec(self, node): self.visit(node.expr) if node.locals is None: self.emit('LOAD_CONST', None) else: self.visit(node.locals) if node.globals is None: self.emit('DUP_TOP') else: self.visit(node.globals) self.emit('EXEC_STMT') def visitCallFunc(self, node): pos = 0 kw = 0 self.set_lineno(node) self.visit(node.node) for arg in node.args: self.visit(arg) if isinstance(arg, ast.Keyword): kw = kw + 1 else: pos = pos + 1 if node.star_args is not None: self.visit(node.star_args) if node.dstar_args is not None: self.visit(node.dstar_args) have_star = node.star_args is not None have_dstar = node.dstar_args is not None opcode = callfunc_opcode_info[have_star, have_dstar] self.emit(opcode, kw << 8 | pos) def visitPrint(self, node, newline=0): self.set_lineno(node) if node.dest: self.visit(node.dest) for child in node.nodes: if node.dest: self.emit('DUP_TOP') self.visit(child) if node.dest: self.emit('ROT_TWO') self.emit('PRINT_ITEM_TO') else: self.emit('PRINT_ITEM') if node.dest and not newline: self.emit('POP_TOP') def visitPrintnl(self, node): self.visitPrint(node, newline=1) if node.dest: self.emit('PRINT_NEWLINE_TO') else: self.emit('PRINT_NEWLINE') def visitReturn(self, node): self.set_lineno(node) self.visit(node.value) self.emit('RETURN_VALUE') def visitYield(self, node): self.set_lineno(node) self.visit(node.value) self.emit('YIELD_VALUE') # slice and subscript stuff def visitSlice(self, node, aug_flag=None): # aug_flag is used by visitAugSlice self.visit(node.expr) slice = 0 if node.lower: self.visit(node.lower) slice = slice | 1 if node.upper: self.visit(node.upper) slice = slice | 2 if aug_flag: if slice == 0: self.emit('DUP_TOP') elif slice == 3: self.emit('DUP_TOPX', 3) else: self.emit('DUP_TOPX', 2) if node.flags == 'OP_APPLY': self.emit('SLICE+%d' % slice) elif node.flags == 'OP_ASSIGN': self.emit('STORE_SLICE+%d' % slice) elif node.flags == 'OP_DELETE': self.emit('DELETE_SLICE+%d' % slice) else: print("weird slice", node.flags) pony.utils.reraise(*sys.exc_info()) def visitSubscript(self, node, aug_flag=None): self.visit(node.expr) for sub in node.subs: self.visit(sub) if len(node.subs) > 1: self.emit('BUILD_TUPLE', len(node.subs)) if aug_flag: self.emit('DUP_TOPX', 2) if node.flags == 'OP_APPLY': self.emit('BINARY_SUBSCR') elif node.flags == 'OP_ASSIGN': self.emit('STORE_SUBSCR') elif node.flags == 'OP_DELETE': self.emit('DELETE_SUBSCR') # binary ops def binaryOp(self, node, op): self.visit(node.left) self.visit(node.right) self.emit(op) def visitAdd(self, node): return self.binaryOp(node, 'BINARY_ADD') def visitSub(self, node): return self.binaryOp(node, 'BINARY_SUBTRACT') def visitMul(self, node): return self.binaryOp(node, 'BINARY_MULTIPLY') def visitDiv(self, node): return self.binaryOp(node, self._div_op) def visitFloorDiv(self, node): return self.binaryOp(node, 'BINARY_FLOOR_DIVIDE') def visitMod(self, node): return self.binaryOp(node, 'BINARY_MODULO') def visitPower(self, node): return self.binaryOp(node, 'BINARY_POWER') def visitLeftShift(self, node): return self.binaryOp(node, 'BINARY_LSHIFT') def visitRightShift(self, node): return self.binaryOp(node, 'BINARY_RSHIFT') # unary ops def unaryOp(self, node, op): self.visit(node.expr) self.emit(op) def visitInvert(self, node): return self.unaryOp(node, 'UNARY_INVERT') def visitUnarySub(self, node): return self.unaryOp(node, 'UNARY_NEGATIVE') def visitUnaryAdd(self, node): return self.unaryOp(node, 'UNARY_POSITIVE') def visitUnaryInvert(self, node): return self.unaryOp(node, 'UNARY_INVERT') def visitNot(self, node): return self.unaryOp(node, 'UNARY_NOT') def visitBackquote(self, node): return self.unaryOp(node, 'UNARY_CONVERT') # bit ops def bitOp(self, nodes, op): self.visit(nodes[0]) for node in nodes[1:]: self.visit(node) self.emit(op) def visitBitand(self, node): return self.bitOp(node.nodes, 'BINARY_AND') def visitBitor(self, node): return self.bitOp(node.nodes, 'BINARY_OR') def visitBitxor(self, node): return self.bitOp(node.nodes, 'BINARY_XOR') # object constructors def visitEllipsis(self, node): self.emit('LOAD_CONST', Ellipsis) def visitTuple(self, node): self.set_lineno(node) for elt in node.nodes: self.visit(elt) self.emit('BUILD_TUPLE', len(node.nodes)) def visitList(self, node): self.set_lineno(node) for elt in node.nodes: self.visit(elt) self.emit('BUILD_LIST', len(node.nodes)) def visitSet(self, node): self.set_lineno(node) for elt in node.nodes: self.visit(elt) self.emit('BUILD_SET', len(node.nodes)) def visitSliceobj(self, node): for child in node.nodes: self.visit(child) self.emit('BUILD_SLICE', len(node.nodes)) def visitDict(self, node): self.set_lineno(node) self.emit('BUILD_MAP', 0) for k, v in node.items: self.emit('DUP_TOP') self.visit(k) self.visit(v) self.emit('ROT_THREE') self.emit('STORE_SUBSCR') class NestedScopeMixin: """Defines initClass() for nested scoping (Python 2.2-compatible)""" def initClass(self): self.__class__.NameFinder = LocalNameFinder self.__class__.FunctionGen = FunctionCodeGenerator self.__class__.ClassGen = ClassCodeGenerator class ModuleCodeGenerator(NestedScopeMixin, CodeGenerator): __super_init = CodeGenerator.__init__ scopes = None def __init__(self, tree): self.graph = pyassem.PyFlowGraph("", tree.filename) self.futures = future.find_futures(tree) self.__super_init() walk(tree, self) def get_module(self): return self class ExpressionCodeGenerator(NestedScopeMixin, CodeGenerator): __super_init = CodeGenerator.__init__ scopes = None futures = () def __init__(self, tree): self.graph = pyassem.PyFlowGraph("", tree.filename) self.__super_init() walk(tree, self) def get_module(self): return self class InteractiveCodeGenerator(NestedScopeMixin, CodeGenerator): __super_init = CodeGenerator.__init__ scopes = None futures = () def __init__(self, tree): self.graph = pyassem.PyFlowGraph("", tree.filename) self.__super_init() self.set_lineno(tree) walk(tree, self) self.emit('RETURN_VALUE') def get_module(self): return self def visitDiscard(self, node): # XXX Discard means it's an expression. Perhaps this is a bad # name. self.visit(node.expr) self.emit('PRINT_EXPR') class AbstractFunctionCode: optimized = 1 lambdaCount = 0 def __init__(self, func, scopes, isLambda, class_name, mod): self.class_name = class_name self.module = mod if isLambda: klass = FunctionCodeGenerator name = "" % klass.lambdaCount klass.lambdaCount = klass.lambdaCount + 1 else: name = func.name args, hasTupleArg = generateArgList(func.argnames) self.graph = pyassem.PyFlowGraph(name, func.filename, args, optimized=1) self.isLambda = isLambda self.super_init() if not isLambda and func.doc: self.setDocstring(func.doc) lnf = walk(func.code, self.NameFinder(args), verbose=0) self.locals.push(lnf.getLocals()) if func.varargs: self.graph.setFlag(CO_VARARGS) if func.kwargs: self.graph.setFlag(CO_VARKEYWORDS) self.set_lineno(func) if hasTupleArg: self.generateArgUnpack(func.argnames) def get_module(self): return self.module def finish(self): self.graph.startExitBlock() if not self.isLambda: self.emit('LOAD_CONST', None) self.emit('RETURN_VALUE') def generateArgUnpack(self, args): for i in range(len(args)): arg = args[i] if isinstance(arg, tuple): self.emit('LOAD_FAST', '.%d' % (i * 2)) self.unpackSequence(arg) def unpackSequence(self, tup): if VERSION > 1: self.emit('UNPACK_SEQUENCE', len(tup)) else: self.emit('UNPACK_TUPLE', len(tup)) for elt in tup: if isinstance(elt, tuple): self.unpackSequence(elt) else: self._nameOp('STORE', elt) unpackTuple = unpackSequence class FunctionCodeGenerator(NestedScopeMixin, AbstractFunctionCode, CodeGenerator): super_init = CodeGenerator.__init__ # call be other init scopes = None __super_init = AbstractFunctionCode.__init__ def __init__(self, func, scopes, isLambda, class_name, mod): self.scopes = scopes self.scope = scopes[func] self.__super_init(func, scopes, isLambda, class_name, mod) self.graph.setFreeVars(self.scope.get_free_vars()) self.graph.setCellVars(self.scope.get_cell_vars()) if self.scope.generator is not None: self.graph.setFlag(CO_GENERATOR) class GenExprCodeGenerator(NestedScopeMixin, AbstractFunctionCode, CodeGenerator): super_init = CodeGenerator.__init__ # call be other init scopes = None __super_init = AbstractFunctionCode.__init__ def __init__(self, gexp, scopes, class_name, mod): self.scopes = scopes self.scope = scopes[gexp] self.__super_init(gexp, scopes, 1, class_name, mod) self.graph.setFreeVars(self.scope.get_free_vars()) self.graph.setCellVars(self.scope.get_cell_vars()) self.graph.setFlag(CO_GENERATOR) class AbstractClassCode: def __init__(self, klass, scopes, module): self.class_name = klass.name self.module = module self.graph = pyassem.PyFlowGraph(klass.name, klass.filename, optimized=0, klass=1) self.super_init() lnf = walk(klass.code, self.NameFinder(), verbose=0) self.locals.push(lnf.getLocals()) self.graph.setFlag(CO_NEWLOCALS) if klass.doc: self.setDocstring(klass.doc) def get_module(self): return self.module def finish(self): self.graph.startExitBlock() self.emit('LOAD_LOCALS') self.emit('RETURN_VALUE') class ClassCodeGenerator(NestedScopeMixin, AbstractClassCode, CodeGenerator): super_init = CodeGenerator.__init__ scopes = None __super_init = AbstractClassCode.__init__ def __init__(self, klass, scopes, module): self.scopes = scopes self.scope = scopes[klass] self.__super_init(klass, scopes, module) self.graph.setFreeVars(self.scope.get_free_vars()) self.graph.setCellVars(self.scope.get_cell_vars()) self.set_lineno(klass) self.emit("LOAD_GLOBAL", "__name__") self.storeName("__module__") if klass.doc: self.emit("LOAD_CONST", klass.doc) self.storeName('__doc__') def generateArgList(arglist): """Generate an arg list marking TupleArgs""" args = [] extra = [] count = 0 for i in range(len(arglist)): elt = arglist[i] if isinstance(elt, str): args.append(elt) elif isinstance(elt, tuple): args.append(TupleArg(i * 2, elt)) extra.extend(misc.flatten(elt)) count = count + 1 else: raise ValueError("unexpect argument type: %r" % type(elt)) return args + extra, count def findOp(node): """Find the op (DELETE, LOAD, STORE) in an AssTuple tree""" v = OpFinder() walk(node, v, verbose=0) return v.op class OpFinder: def __init__(self): self.op = None def visitAssName(self, node): if self.op is None: self.op = node.flags elif self.op != node.flags: raise ValueError("mixed ops in stmt") visitAssAttr = visitAssName visitSubscript = visitAssName class Delegator: """Base class to support delegation for augmented assignment nodes To generator code for augmented assignments, we use the following wrapper classes. In visitAugAssign, the left-hand expression node is visited twice. The first time the visit uses the normal method for that node . The second time the visit uses a different method that generates the appropriate code to perform the assignment. These delegator classes wrap the original AST nodes in order to support the variant visit methods. """ def __init__(self, obj): self.obj = obj def __getattr__(self, attr): return getattr(self.obj, attr) class AugGetattr(Delegator): pass class AugName(Delegator): pass class AugSlice(Delegator): pass class AugSubscript(Delegator): pass wrapper = { ast.Getattr: AugGetattr, ast.Name: AugName, ast.Slice: AugSlice, ast.Subscript: AugSubscript, } def wrap_aug(node): return wrapper[node.__class__](node) if __name__ == "__main__": for file in sys.argv[1:]: compileFile(file) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/symbols.py0000666000000000000000000003513500000000000017762 0ustar0000000000000000"""Module symbol-table generator""" from __future__ import print_function import sys, types from . import ast from .consts import SC_LOCAL, SC_GLOBAL_IMPLICIT, SC_GLOBAL_EXPLICIT, \ SC_FREE, SC_CELL, SC_UNKNOWN from .misc import mangle MANGLE_LEN = 256 class Scope: # XXX how much information do I need about each name? def __init__(self, name, module, klass=None): self.name = name self.module = module self.defs = {} self.uses = {} self.globals = {} self.params = {} self.frees = {} self.cells = {} self.children = [] # nested is true if the class could contain free variables, # i.e. if it is nested within another function. self.nested = None self.generator = None self.klass = None if klass is not None: for i in range(len(klass)): if klass[i] != '_': self.klass = klass[i:] break def __repr__(self): return "<%s: %s>" % (self.__class__.__name__, self.name) def mangle(self, name): if self.klass is None: return name return mangle(name, self.klass) def add_def(self, name): self.defs[self.mangle(name)] = 1 def add_use(self, name): self.uses[self.mangle(name)] = 1 def add_global(self, name): name = self.mangle(name) if name in self.uses or name in self.defs: pass # XXX warn about global following def/use if name in self.params: raise SyntaxError("%s in %s is global and parameter" % (name, self.name)) self.globals[name] = 1 self.module.add_def(name) def add_param(self, name): name = self.mangle(name) self.defs[name] = 1 self.params[name] = 1 def get_names(self): d = {} d.update(self.defs) d.update(self.uses) d.update(self.globals) return d.keys() def add_child(self, child): self.children.append(child) def get_children(self): return self.children def DEBUG(self): print(self.name, self.nested and "nested" or "", file=sys.stderr) print("\tglobals: ", self.globals, file=sys.stderr) print("\tcells: ", self.cells, file=sys.stderr) print("\tdefs: ", self.defs, file=sys.stderr) print("\tuses: ", self.uses, file=sys.stderr) print("\tfrees:", self.frees, file=sys.stderr) def check_name(self, name): """Return scope of name. The scope of a name could be LOCAL, GLOBAL, FREE, or CELL. """ if name in self.globals: return SC_GLOBAL_EXPLICIT if name in self.cells: return SC_CELL if name in self.defs: return SC_LOCAL if self.nested and (name in self.frees or name in self.uses): return SC_FREE if self.nested: return SC_UNKNOWN else: return SC_GLOBAL_IMPLICIT def get_free_vars(self): if not self.nested: return () free = {} free.update(self.frees) for name in self.uses.keys(): if name not in self.defs and name not in self.globals: free[name] = 1 return free.keys() def handle_children(self): for child in self.children: frees = child.get_free_vars() globals = self.add_frees(frees) for name in globals: child.force_global(name) def force_global(self, name): """Force name to be global in scope. Some child of the current node had a free reference to name. When the child was processed, it was labelled a free variable. Now that all its enclosing scope have been processed, the name is known to be a global or builtin. So walk back down the child chain and set the name to be global rather than free. Be careful to stop if a child does not think the name is free. """ self.globals[name] = 1 if name in self.frees: del self.frees[name] for child in self.children: if child.check_name(name) == SC_FREE: child.force_global(name) def add_frees(self, names): """Process list of free vars from nested scope. Returns a list of names that are either 1) declared global in the parent or 2) undefined in a top-level parent. In either case, the nested scope should treat them as globals. """ child_globals = [] for name in names: sc = self.check_name(name) if self.nested: if sc == SC_UNKNOWN or sc == SC_FREE \ or isinstance(self, ClassScope): self.frees[name] = 1 elif sc == SC_GLOBAL_IMPLICIT: child_globals.append(name) elif isinstance(self, FunctionScope) and sc == SC_LOCAL: self.cells[name] = 1 elif sc != SC_CELL: child_globals.append(name) else: if sc == SC_LOCAL: self.cells[name] = 1 elif sc != SC_CELL: child_globals.append(name) return child_globals def get_cell_vars(self): return self.cells.keys() class ModuleScope(Scope): __super_init = Scope.__init__ def __init__(self): self.__super_init("global", self) class FunctionScope(Scope): pass class GenExprScope(Scope): __super_init = Scope.__init__ __counter = 1 def __init__(self, module, klass=None): i = self.__counter self.__counter += 1 self.__super_init("generator expression<%d>"%i, module, klass) self.add_param('.0') def get_names(self): keys = Scope.get_names(self) return keys class LambdaScope(FunctionScope): __super_init = Scope.__init__ __counter = 1 def __init__(self, module, klass=None): i = self.__counter self.__counter += 1 self.__super_init("lambda.%d" % i, module, klass) class ClassScope(Scope): __super_init = Scope.__init__ def __init__(self, name, module): self.__super_init(name, module, name) class SymbolVisitor: def __init__(self): self.scopes = {} self.klass = None # node that define new scopes def visitModule(self, node): scope = self.module = self.scopes[node] = ModuleScope() self.visit(node.node, scope) visitExpression = visitModule def visitFunction(self, node, parent): if node.decorators: self.visit(node.decorators, parent) parent.add_def(node.name) for n in node.defaults: self.visit(n, parent) scope = FunctionScope(node.name, self.module, self.klass) if parent.nested or isinstance(parent, FunctionScope): scope.nested = 1 self.scopes[node] = scope self._do_args(scope, node.argnames) self.visit(node.code, scope) self.handle_free_vars(scope, parent) def visitGenExpr(self, node, parent): scope = GenExprScope(self.module, self.klass); if parent.nested or isinstance(parent, FunctionScope) \ or isinstance(parent, GenExprScope): scope.nested = 1 self.scopes[node] = scope self.visit(node.code, scope) self.handle_free_vars(scope, parent) def visitGenExprInner(self, node, scope): for genfor in node.quals: self.visit(genfor, scope) self.visit(node.expr, scope) def visitGenExprFor(self, node, scope): self.visit(node.assign, scope, 1) self.visit(node.iter, scope) for if_ in node.ifs: self.visit(if_, scope) def visitGenExprIf(self, node, scope): self.visit(node.test, scope) def visitLambda(self, node, parent, assign=0): # Lambda is an expression, so it could appear in an expression # context where assign is passed. The transformer should catch # any code that has a lambda on the left-hand side. assert not assign for n in node.defaults: self.visit(n, parent) scope = LambdaScope(self.module, self.klass) if parent.nested or isinstance(parent, FunctionScope): scope.nested = 1 self.scopes[node] = scope self._do_args(scope, node.argnames) self.visit(node.code, scope) self.handle_free_vars(scope, parent) def _do_args(self, scope, args): for name in args: if type(name) == types.TupleType: self._do_args(scope, name) else: scope.add_param(name) def handle_free_vars(self, scope, parent): parent.add_child(scope) scope.handle_children() def visitClass(self, node, parent): parent.add_def(node.name) for n in node.bases: self.visit(n, parent) scope = ClassScope(node.name, self.module) if parent.nested or isinstance(parent, FunctionScope): scope.nested = 1 if node.doc is not None: scope.add_def('__doc__') scope.add_def('__module__') self.scopes[node] = scope prev = self.klass self.klass = node.name self.visit(node.code, scope) self.klass = prev self.handle_free_vars(scope, parent) # name can be a def or a use # XXX a few calls and nodes expect a third "assign" arg that is # true if the name is being used as an assignment. only # expressions contained within statements may have the assign arg. def visitName(self, node, scope, assign=0): if assign: scope.add_def(node.name) else: scope.add_use(node.name) # operations that bind new names def visitFor(self, node, scope): self.visit(node.assign, scope, 1) self.visit(node.list, scope) self.visit(node.body, scope) if node.else_: self.visit(node.else_, scope) def visitFrom(self, node, scope): for name, asname in node.names: if name == "*": continue scope.add_def(asname or name) def visitImport(self, node, scope): for name, asname in node.names: i = name.find(".") if i > -1: name = name[:i] scope.add_def(asname or name) def visitGlobal(self, node, scope): for name in node.names: scope.add_global(name) def visitAssign(self, node, scope): """Propagate assignment flag down to child nodes. The Assign node doesn't itself contains the variables being assigned to. Instead, the children in node.nodes are visited with the assign flag set to true. When the names occur in those nodes, they are marked as defs. Some names that occur in an assignment target are not bound by the assignment, e.g. a name occurring inside a slice. The visitor handles these nodes specially; they do not propagate the assign flag to their children. """ for n in node.nodes: self.visit(n, scope, 1) self.visit(node.expr, scope) def visitAssName(self, node, scope, assign=1): scope.add_def(node.name) def visitAssAttr(self, node, scope, assign=0): self.visit(node.expr, scope, 0) def visitSubscript(self, node, scope, assign=0): self.visit(node.expr, scope, 0) for n in node.subs: self.visit(n, scope, 0) def visitSlice(self, node, scope, assign=0): self.visit(node.expr, scope, 0) if node.lower: self.visit(node.lower, scope, 0) if node.upper: self.visit(node.upper, scope, 0) def visitAugAssign(self, node, scope): # If the LHS is a name, then this counts as assignment. # Otherwise, it's just use. self.visit(node.node, scope) if isinstance(node.node, ast.Name): self.visit(node.node, scope, 1) # XXX worry about this self.visit(node.expr, scope) # prune if statements if tests are false _const_types = bytes, str, int, float def visitIf(self, node, scope): for test, body in node.tests: if isinstance(test, ast.Const): if type(test.value) in self._const_types: if not test.value: continue self.visit(test, scope) self.visit(body, scope) if node.else_: self.visit(node.else_, scope) # a yield statement signals a generator def visitYield(self, node, scope): scope.generator = 1 self.visit(node.value, scope) def list_eq(l1, l2): return sorted(l1) == sorted(l2) if __name__ == "__main__": import sys from compiler import parseFile, walk import symtable def get_names(syms): return [s for s in [s.get_name() for s in syms.get_symbols()] if not (s.startswith('_[') or s.startswith('.'))] for file in sys.argv[1:]: print(file) f = open(file) buf = f.read() f.close() syms = symtable.symtable(buf, file, "exec") mod_names = get_names(syms) tree = parseFile(file) s = SymbolVisitor() walk(tree, s) # compare module-level symbols names2 = s.scopes[tree].get_names() if not list_eq(mod_names, names2): print() print("oops", file) print(sorted(mod_names)) print(sorted(names2)) sys.exit(-1) d = {} d.update(s.scopes) del d[tree] scopes = d.values() del d for s in syms.get_symbols(): if s.is_namespace(): l = [sc for sc in scopes if sc.name == s.get_name()] if len(l) > 1: print("skipping", s.get_name()) else: if not list_eq(get_names(s.get_namespace()), l[0].get_names()): print(s.get_name()) print(sorted(get_names(s.get_namespace()))) print(sorted(l[0].get_names())) sys.exit(-1) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/syntax.py0000666000000000000000000000271400000000000017615 0ustar0000000000000000"""Check for errs in the AST. The Python parser does not catch all syntax errors. Others, like assignments with invalid targets, are caught in the code generation phase. The compiler package catches some errors in the transformer module. But it seems clearer to write checkers that use the AST to detect errors. """ from . import ast, walk def check(tree, multi=None): v = SyntaxErrorChecker(multi) walk(tree, v) return v.errors class SyntaxErrorChecker: """A visitor to find syntax errors in the AST.""" def __init__(self, multi=None): """Create new visitor object. If optional argument multi is not None, then print messages for each error rather than raising a SyntaxError for the first. """ self.multi = multi self.errors = 0 def error(self, node, msg): self.errors = self.errors + 1 if self.multi is not None: print("%s:%s: %s" % (node.filename, node.lineno, msg)) else: raise SyntaxError("%s (%s:%s)" % (msg, node.filename, node.lineno)) def visitAssign(self, node): # the transformer module handles many of these pass ## for target in node.nodes: ## if isinstance(target, ast.AssList): ## if target.lineno is None: ## target.lineno = node.lineno ## self.error(target, "can't assign to list comprehension") ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862779.0 pony-0.7.11/pony/thirdparty/compiler/transformer.py0000666000000000000000000016040200000000000020630 0ustar0000000000000000"""Parse tree transformation module. Transforms Python source code into an abstract syntax tree (AST) defined in the ast module. The simplest ways to invoke this module are via parse and parseFile. parse(buf) -> AST parseFile(path) -> AST """ # Original version written by Greg Stein (gstein@lyra.org) # and Bill Tutt (rassilon@lima.mudlib.org) # February 1997. # # Modifications and improvements for Python 2.0 by Jeremy Hylton and # Mark Hammond # # Some fixes to try to have correct line number on almost all nodes # (except Module, Discard and Stmt) added by Sylvain Thenault # # Portions of this file are: # Copyright (C) 1997-1998 Greg Stein. All Rights Reserved. # # This module is provided under a BSD-ish license. See # http://www.opensource.org/licenses/bsd-license.html # and replace OWNER, ORGANIZATION, and YEAR as appropriate. from __future__ import absolute_import, print_function from pony.py23compat import PY2, unicode from .ast import * import parser import symbol import sys import token # Python 2.6 compatibility fix if not hasattr(symbol, 'testlist_comp'): symbol.testlist_comp = symbol.testlist_gexp if not hasattr(symbol, 'comp_iter'): symbol.comp_iter = symbol.gen_iter if not hasattr(symbol, 'comp_for'): symbol.comp_for = symbol.gen_for if not hasattr(symbol, 'comp_if'): symbol.comp_if = symbol.gen_if atom_expr = getattr(symbol, 'atom_expr', None) namedexpr_test = getattr(symbol, 'namedexpr_test', None) class WalkerError(Exception): pass from .consts import CO_VARARGS, CO_VARKEYWORDS from .consts import OP_ASSIGN, OP_DELETE, OP_APPLY def parseFile(path): f = open(path, "U") # XXX The parser API tolerates files without a trailing newline, # but not strings without a trailing newline. Always add an extra # newline to the file contents, since we're going through the string # version of the API. src = f.read() + "\n" f.close() return parse(src) def parse(buf, mode="exec"): if mode == "exec" or mode == "single": return Transformer().parsesuite(buf) elif mode == "eval": return Transformer().parseexpr(buf) else: raise ValueError("compile() arg 3 must be 'exec' or 'eval' or 'single'") def asList(nodes): l = [] for item in nodes: if hasattr(item, "asList"): l.append(item.asList()) else: if type(item) is type( (None, None) ): l.append(tuple(asList(item))) elif type(item) is type( [] ): l.append(asList(item)) else: l.append(item) return l def extractLineNo(ast): if not isinstance(ast[1], tuple): # get a terminal node return ast[2] for child in ast[1:]: if isinstance(child, tuple): lineno = extractLineNo(child) if lineno is not None: return lineno def Node(*args): kind = args[0] if kind in nodes: try: return nodes[kind](*args[1:]) except TypeError: print(nodes[kind], len(args), args) raise else: raise WalkerError("Can't find appropriate Node type: %s" % str(args)) #return apply(ast.Node, args) class Transformer: """Utility object for transforming Python parse trees. Exposes the following methods: tree = transform(ast_tree) tree = parsesuite(text) tree = parseexpr(text) tree = parsefile(fileob | filename) """ def atom_expr(self, nodelist): atom_nodelist = nodelist[0] assert atom_nodelist[0] == symbol.atom, atom_nodelist[0] node = self.atom(atom_nodelist[1:]) for i in range(1, len(nodelist)): elt = nodelist[i] node = self.com_apply_trailer(node, elt) return node def namedexpr_test(self, nodelist): return self.test(nodelist[0][1:]) def __init__(self): self._dispatch = {} for value, name in symbol.sym_name.items(): if hasattr(self, name): self._dispatch[value] = getattr(self, name) self._dispatch[token.NEWLINE] = self.com_NEWLINE self._atom_dispatch = {token.LPAR: self.atom_lpar, token.LSQB: self.atom_lsqb, token.LBRACE: self.atom_lbrace, token.NUMBER: self.atom_number, token.STRING: self.atom_string, token.NAME: self.atom_name, } if PY2: self._atom_dispatch.update({ token.BACKQUOTE: self.atom_backquote, }) if not PY2: self._atom_dispatch.update({ token.ELLIPSIS: self.atom_ellipsis }) self.encoding = None def print_tree(self, tree, indent=''): for item in tree: if isinstance(item, tuple): self.print_tree(item, indent+' ') else: print(indent, symbol.sym_name.get(item, item)) def transform(self, tree): """Transform an AST into a modified parse tree.""" if not (isinstance(tree, tuple) or isinstance(tree, list)): tree = parser.st2tuple(tree, line_info=1) # self.print_tree(tree) return self.compile_node(tree) def parsesuite(self, text): """Return a modified parse tree for the given suite text.""" return self.transform(parser.suite(text)) def parseexpr(self, text): """Return a modified parse tree for the given expression text.""" return self.transform(parser.expr(text)) def parsefile(self, file): """Return a modified parse tree for the contents of the given file.""" if type(file) == type(''): file = open(file) return self.parsesuite(file.read()) # -------------------------------------------------------------- # # PRIVATE METHODS # def compile_node(self, node): ### emit a line-number node? n = node[0] if n == symbol.encoding_decl: self.encoding = node[2] node = node[1] n = node[0] if n == symbol.single_input: return self.single_input(node[1:]) if n == symbol.file_input: return self.file_input(node[1:]) if n == symbol.eval_input: return self.eval_input(node[1:]) if n == symbol.lambdef: return self.lambdef(node[1:]) if n == symbol.funcdef: return self.funcdef(node[1:]) if n == symbol.classdef: return self.classdef(node[1:]) raise WalkerError('unexpected node type: %r' % n) def single_input(self, node): ### do we want to do anything about being "interactive" ? # NEWLINE | simple_stmt | compound_stmt NEWLINE n = node[0][0] if n != token.NEWLINE: return self.com_stmt(node[0]) return Pass() def file_input(self, nodelist): doc = self.get_docstring(nodelist, symbol.file_input) if doc is not None: i = 1 else: i = 0 stmts = [] for node in nodelist[i:]: if node[0] != token.ENDMARKER and node[0] != token.NEWLINE: self.com_append_stmt(stmts, node) return Module(doc, Stmt(stmts)) def eval_input(self, nodelist): # from the built-in function input() ### is this sufficient? return Expression(self.com_node(nodelist[0])) def decorator_name(self, nodelist): listlen = len(nodelist) assert listlen >= 1 and listlen % 2 == 1 item = self.atom_name(nodelist) i = 1 while i < listlen: assert nodelist[i][0] == token.DOT assert nodelist[i + 1][0] == token.NAME item = Getattr(item, nodelist[i + 1][1]) i += 2 return item def decorator(self, nodelist): # '@' dotted_name [ '(' [arglist] ')' ] assert len(nodelist) in (3, 5, 6) assert nodelist[0][0] == token.AT assert nodelist[-1][0] == token.NEWLINE assert nodelist[1][0] == symbol.dotted_name funcname = self.decorator_name(nodelist[1][1:]) if len(nodelist) > 3: assert nodelist[2][0] == token.LPAR expr = self.com_call_function(funcname, nodelist[3]) else: expr = funcname return expr def decorators(self, nodelist): # decorators: decorator ([NEWLINE] decorator)* NEWLINE items = [] for dec_nodelist in nodelist: assert dec_nodelist[0] == symbol.decorator items.append(self.decorator(dec_nodelist[1:])) return Decorators(items) def decorated(self, nodelist): assert nodelist[0][0] == symbol.decorators if nodelist[1][0] == symbol.funcdef: n = [nodelist[0]] + list(nodelist[1][1:]) return self.funcdef(n) elif nodelist[1][0] == symbol.classdef: decorators = self.decorators(nodelist[0][1:]) cls = self.classdef(nodelist[1][1:]) cls.decorators = decorators return cls raise WalkerError() def funcdef(self, nodelist): # -6 -5 -4 -3 -2 -1 # funcdef: [decorators] 'def' NAME parameters ':' suite # parameters: '(' [varargslist] ')' if len(nodelist) == 6: assert nodelist[0][0] == symbol.decorators decorators = self.decorators(nodelist[0][1:]) else: assert len(nodelist) == 5 decorators = None lineno = nodelist[-4][2] name = nodelist[-4][1] args = nodelist[-3][2] if args[0] == symbol.varargslist: names, defaults, flags = self.com_arglist(args[1:]) else: names = defaults = () flags = 0 doc = self.get_docstring(nodelist[-1]) # code for function code = self.com_node(nodelist[-1]) if doc is not None: assert isinstance(code, Stmt) assert isinstance(code.nodes[0], Discard) del code.nodes[0] return Function(decorators, name, names, defaults, flags, doc, code, lineno=lineno) def lambdef(self, nodelist): # lambdef: 'lambda' [varargslist] ':' test if nodelist[2][0] == symbol.varargslist: names, defaults, flags = self.com_arglist(nodelist[2][1:]) else: names = defaults = () flags = 0 # code for lambda code = self.com_node(nodelist[-1]) return Lambda(names, defaults, flags, code, lineno=nodelist[1][2]) old_lambdef = lambdef def classdef(self, nodelist): # classdef: 'class' NAME ['(' [testlist] ')'] ':' suite name = nodelist[1][1] doc = self.get_docstring(nodelist[-1]) if nodelist[2][0] == token.COLON: bases = [] elif nodelist[3][0] == token.RPAR: bases = [] else: bases = self.com_bases(nodelist[3]) # code for class code = self.com_node(nodelist[-1]) if doc is not None: assert isinstance(code, Stmt) assert isinstance(code.nodes[0], Discard) del code.nodes[0] return Class(name, bases, doc, code, lineno=nodelist[1][2]) def stmt(self, nodelist): return self.com_stmt(nodelist[0]) small_stmt = stmt flow_stmt = stmt compound_stmt = stmt def simple_stmt(self, nodelist): # small_stmt (';' small_stmt)* [';'] NEWLINE stmts = [] for i in range(0, len(nodelist), 2): self.com_append_stmt(stmts, nodelist[i]) return Stmt(stmts) def parameters(self, nodelist): raise WalkerError() def varargslist(self, nodelist): raise WalkerError() def fpdef(self, nodelist): raise WalkerError() def fplist(self, nodelist): raise WalkerError() def dotted_name(self, nodelist): raise WalkerError() def comp_op(self, nodelist): raise WalkerError() def trailer(self, nodelist): raise WalkerError() def sliceop(self, nodelist): raise WalkerError() def argument(self, nodelist): raise WalkerError() # -------------------------------------------------------------- # # STATEMENT NODES (invoked by com_node()) # def expr_stmt(self, nodelist): # augassign testlist | testlist ('=' testlist)* en = nodelist[-1] exprNode = self.lookup_node(en)(en[1:]) if len(nodelist) == 1: return Discard(exprNode, lineno=exprNode.lineno) if nodelist[1][0] == token.EQUAL: nodesl = [] for i in range(0, len(nodelist) - 2, 2): nodesl.append(self.com_assign(nodelist[i], OP_ASSIGN)) return Assign(nodesl, exprNode, lineno=nodelist[1][2]) else: lval = self.com_augassign(nodelist[0]) op = self.com_augassign_op(nodelist[1]) return AugAssign(lval, op[1], exprNode, lineno=op[2]) raise WalkerError("can't get here") def print_stmt(self, nodelist): # print ([ test (',' test)* [','] ] | '>>' test [ (',' test)+ [','] ]) items = [] if len(nodelist) == 1: start = 1 dest = None elif nodelist[1][0] == token.RIGHTSHIFT: assert len(nodelist) == 3 \ or nodelist[3][0] == token.COMMA dest = self.com_node(nodelist[2]) start = 4 else: dest = None start = 1 for i in range(start, len(nodelist), 2): items.append(self.com_node(nodelist[i])) if nodelist[-1][0] == token.COMMA: return Print(items, dest, lineno=nodelist[0][2]) return Printnl(items, dest, lineno=nodelist[0][2]) def del_stmt(self, nodelist): return self.com_assign(nodelist[1], OP_DELETE) def pass_stmt(self, nodelist): return Pass(lineno=nodelist[0][2]) def break_stmt(self, nodelist): return Break(lineno=nodelist[0][2]) def continue_stmt(self, nodelist): return Continue(lineno=nodelist[0][2]) def return_stmt(self, nodelist): # return: [testlist] if len(nodelist) < 2: return Return(Const(None), lineno=nodelist[0][2]) return Return(self.com_node(nodelist[1]), lineno=nodelist[0][2]) def yield_stmt(self, nodelist): expr = self.com_node(nodelist[0]) return Discard(expr, lineno=expr.lineno) def yield_expr(self, nodelist): if len(nodelist) > 1: value = self.com_node(nodelist[1]) else: value = Const(None) return Yield(value, lineno=nodelist[0][2]) def raise_stmt(self, nodelist): # raise: [test [',' test [',' test]]] if len(nodelist) > 5: expr3 = self.com_node(nodelist[5]) else: expr3 = None if len(nodelist) > 3: expr2 = self.com_node(nodelist[3]) else: expr2 = None if len(nodelist) > 1: expr1 = self.com_node(nodelist[1]) else: expr1 = None return Raise(expr1, expr2, expr3, lineno=nodelist[0][2]) def import_stmt(self, nodelist): # import_stmt: import_name | import_from assert len(nodelist) == 1 return self.com_node(nodelist[0]) def import_name(self, nodelist): # import_name: 'import' dotted_as_names return Import(self.com_dotted_as_names(nodelist[1]), lineno=nodelist[0][2]) def import_from(self, nodelist): # import_from: 'from' ('.'* dotted_name | '.') 'import' ('*' | # '(' import_as_names ')' | import_as_names) assert nodelist[0][1] == 'from' idx = 1 while nodelist[idx][1] == '.': idx += 1 level = idx - 1 if nodelist[idx][0] == symbol.dotted_name: fromname = self.com_dotted_name(nodelist[idx]) idx += 1 else: fromname = "" assert nodelist[idx][1] == 'import' if nodelist[idx + 1][0] == token.STAR: return From(fromname, [('*', None)], level, lineno=nodelist[0][2]) else: node = nodelist[idx + 1 + (nodelist[idx + 1][0] == token.LPAR)] return From(fromname, self.com_import_as_names(node), level, lineno=nodelist[0][2]) def global_stmt(self, nodelist): # global: NAME (',' NAME)* names = [] for i in range(1, len(nodelist), 2): names.append(nodelist[i][1]) return Global(names, lineno=nodelist[0][2]) def exec_stmt(self, nodelist): # exec_stmt: 'exec' expr ['in' expr [',' expr]] expr1 = self.com_node(nodelist[1]) if len(nodelist) >= 4: expr2 = self.com_node(nodelist[3]) if len(nodelist) >= 6: expr3 = self.com_node(nodelist[5]) else: expr3 = None else: expr2 = expr3 = None return Exec(expr1, expr2, expr3, lineno=nodelist[0][2]) def assert_stmt(self, nodelist): # 'assert': test, [',' test] expr1 = self.com_node(nodelist[1]) if (len(nodelist) == 4): expr2 = self.com_node(nodelist[3]) else: expr2 = None return Assert(expr1, expr2, lineno=nodelist[0][2]) def if_stmt(self, nodelist): # if: test ':' suite ('elif' test ':' suite)* ['else' ':' suite] tests = [] for i in range(0, len(nodelist) - 3, 4): testNode = self.com_node(nodelist[i + 1]) suiteNode = self.com_node(nodelist[i + 3]) tests.append((testNode, suiteNode)) if len(nodelist) % 4 == 3: elseNode = self.com_node(nodelist[-1]) ## elseNode.lineno = nodelist[-1][1][2] else: elseNode = None return If(tests, elseNode, lineno=nodelist[0][2]) def while_stmt(self, nodelist): # 'while' test ':' suite ['else' ':' suite] testNode = self.com_node(nodelist[1]) bodyNode = self.com_node(nodelist[3]) if len(nodelist) > 4: elseNode = self.com_node(nodelist[6]) else: elseNode = None return While(testNode, bodyNode, elseNode, lineno=nodelist[0][2]) def for_stmt(self, nodelist): # 'for' exprlist 'in' exprlist ':' suite ['else' ':' suite] assignNode = self.com_assign(nodelist[1], OP_ASSIGN) listNode = self.com_node(nodelist[3]) bodyNode = self.com_node(nodelist[5]) if len(nodelist) > 8: elseNode = self.com_node(nodelist[8]) else: elseNode = None return For(assignNode, listNode, bodyNode, elseNode, lineno=nodelist[0][2]) def try_stmt(self, nodelist): return self.com_try_except_finally(nodelist) def with_stmt(self, nodelist): return self.com_with(nodelist) def with_var(self, nodelist): return self.com_with_var(nodelist) def suite(self, nodelist): # simple_stmt | NEWLINE INDENT NEWLINE* (stmt NEWLINE*)+ DEDENT if len(nodelist) == 1: return self.com_stmt(nodelist[0]) stmts = [] for node in nodelist: if node[0] == symbol.stmt: self.com_append_stmt(stmts, node) return Stmt(stmts) # -------------------------------------------------------------- # # EXPRESSION NODES (invoked by com_node()) # def testlist(self, nodelist): # testlist: expr (',' expr)* [','] # testlist_safe: test [(',' test)+ [',']] # exprlist: expr (',' expr)* [','] return self.com_binary(Tuple, nodelist) def testlist_star_expr(self, nodelist): return self.com_binary(Tuple, nodelist) def star_expr(self, *args): raise NotImplementedError testlist_safe = testlist # XXX testlist1 = testlist exprlist = testlist def testlist_comp(self, nodelist): # test ( comp_for | (',' test)* [','] ) PY38 = sys.version_info >= (3, 8) code = nodelist[0][0] if code not in (symbol.test, namedexpr_test): assert False, symbol.sym_name.get(code, code) if len(nodelist) == 2 and nodelist[1][0] == symbol.comp_for: test = self.com_node(nodelist[0]) return self.com_generator_expression(test, nodelist[1]) return self.testlist(nodelist) testlist_gexp = testlist_comp # Python 2.6 compatibility fix def test(self, nodelist): # or_test ['if' or_test 'else' test] | lambdef if len(nodelist) == 1 and nodelist[0][0] == symbol.lambdef: return self.lambdef(nodelist[0]) then = self.com_node(nodelist[0]) if len(nodelist) > 1: assert len(nodelist) == 5 assert nodelist[1][1] == 'if' assert nodelist[3][1] == 'else' test = self.com_node(nodelist[2]) else_ = self.com_node(nodelist[4]) return IfExp(test, then, else_, lineno=nodelist[1][2]) return then def or_test(self, nodelist): # and_test ('or' and_test)* | lambdef if len(nodelist) == 1 and nodelist[0][0] == symbol.lambdef: return self.lambdef(nodelist[0]) return self.com_binary(Or, nodelist) old_test = or_test test_nocond = old_test def and_test(self, nodelist): # not_test ('and' not_test)* return self.com_binary(And, nodelist) def not_test(self, nodelist): # 'not' not_test | comparison result = self.com_node(nodelist[-1]) if len(nodelist) == 2: return Not(result, lineno=nodelist[0][2]) return result def comparison(self, nodelist): # comparison: expr (comp_op expr)* node = self.com_node(nodelist[0]) if len(nodelist) == 1: return node results = [] for i in range(2, len(nodelist), 2): nl = nodelist[i-1] # comp_op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '==' # | 'in' | 'not' 'in' | 'is' | 'is' 'not' n = nl[1] if n[0] == token.NAME: type = n[1] if len(nl) == 3: if type == 'not': type = 'not in' else: type = 'is not' else: type = _cmp_types[n[0]] lineno = nl[1][2] results.append((type, self.com_node(nodelist[i]))) # we need a special "compare" node so that we can distinguish # 3 < x < 5 from (3 < x) < 5 # the two have very different semantics and results (note that the # latter form is always true) return Compare(node, results, lineno=lineno) def expr(self, nodelist): # xor_expr ('|' xor_expr)* return self.com_binary(Bitor, nodelist) def xor_expr(self, nodelist): # xor_expr ('^' xor_expr)* return self.com_binary(Bitxor, nodelist) def and_expr(self, nodelist): # xor_expr ('&' xor_expr)* return self.com_binary(Bitand, nodelist) def shift_expr(self, nodelist): # shift_expr ('<<'|'>>' shift_expr)* node = self.com_node(nodelist[0]) for i in range(2, len(nodelist), 2): right = self.com_node(nodelist[i]) if nodelist[i-1][0] == token.LEFTSHIFT: node = LeftShift([node, right], lineno=nodelist[1][2]) elif nodelist[i-1][0] == token.RIGHTSHIFT: node = RightShift([node, right], lineno=nodelist[1][2]) else: raise ValueError("unexpected token: %s" % nodelist[i-1][0]) return node def arith_expr(self, nodelist): node = self.com_node(nodelist[0]) for i in range(2, len(nodelist), 2): right = self.com_node(nodelist[i]) if nodelist[i-1][0] == token.PLUS: node = Add([node, right], lineno=nodelist[1][2]) elif nodelist[i-1][0] == token.MINUS: node = Sub([node, right], lineno=nodelist[1][2]) else: raise ValueError("unexpected token: %s" % nodelist[i-1][0]) return node def term(self, nodelist): node = self.com_node(nodelist[0]) for i in range(2, len(nodelist), 2): right = self.com_node(nodelist[i]) t = nodelist[i-1][0] if t == token.STAR: node = Mul([node, right]) elif t == token.SLASH: node = Div([node, right]) elif t == token.PERCENT: node = Mod([node, right]) elif t == token.DOUBLESLASH: node = FloorDiv([node, right]) else: raise ValueError("unexpected token: %s" % t) node.lineno = nodelist[1][2] return node def factor(self, nodelist): elt = nodelist[0] t = elt[0] node = self.lookup_node(nodelist[-1])(nodelist[-1][1:]) # need to handle (unary op)constant here... if t == token.PLUS: return UnaryAdd(node, lineno=elt[2]) elif t == token.MINUS: return UnarySub(node, lineno=elt[2]) elif t == token.TILDE: node = Invert(node, lineno=elt[2]) return node def power(self, nodelist): # power: atom trailer* ('**' factor)* node = self.com_node(nodelist[0]) for i in range(1, len(nodelist)): elt = nodelist[i] if elt[0] == token.DOUBLESTAR: return Power([node, self.com_node(nodelist[i+1])], lineno=elt[2]) node = self.com_apply_trailer(node, elt) return node def atom(self, nodelist): return self._atom_dispatch[nodelist[0][0]](nodelist) def atom_lpar(self, nodelist): if nodelist[1][0] == token.RPAR: return Tuple((), lineno=nodelist[0][2]) return self.com_node(nodelist[1]) def atom_lsqb(self, nodelist): if nodelist[1][0] == token.RSQB: return List((), lineno=nodelist[0][2]) return self.com_list_constructor(nodelist[1]) def atom_lbrace(self, nodelist): if nodelist[1][0] == token.RBRACE: return Dict((), lineno=nodelist[0][2]) return self.com_dictorsetmaker(nodelist[1]) def atom_backquote(self, nodelist): return Backquote(self.com_node(nodelist[1])) def atom_ellipsis(self, nodelist): return Ellipsis() def atom_number(self, nodelist): ### need to verify this matches compile.c k = eval(nodelist[0][1]) return Const(k, lineno=nodelist[0][2]) def decode_literal(self, lit): if self.encoding: # this is particularly fragile & a bit of a # hack... changes in compile.c:parsestr and # tokenizer.c must be reflected here. if self.encoding not in ['utf-8', 'iso-8859-1']: lit = unicode(lit, 'utf-8').encode(self.encoding) return eval("# coding: %s\n%s" % (self.encoding, lit)) else: return eval(lit) def atom_string(self, nodelist): k = '' for node in nodelist: k += self.decode_literal(node[1]) return Const(k, lineno=nodelist[0][2]) def atom_name(self, nodelist): return Name(nodelist[0][1], lineno=nodelist[0][2]) # -------------------------------------------------------------- # # INTERNAL PARSING UTILITIES # # The use of com_node() introduces a lot of extra stack frames, # enough to cause a stack overflow compiling test.test_parser with # the standard interpreter recursionlimit. The com_node() is a # convenience function that hides the dispatch details, but comes # at a very high cost. It is more efficient to dispatch directly # in the callers. In these cases, use lookup_node() and call the # dispatched node directly. def lookup_node(self, node): return self._dispatch[node[0]] def com_node(self, node): # Note: compile.c has handling in com_node for del_stmt, pass_stmt, # break_stmt, stmt, small_stmt, flow_stmt, simple_stmt, # and compound_stmt. # We'll just dispatch them. return self._dispatch[node[0]](node[1:]) def com_NEWLINE(self, *args): # A ';' at the end of a line can make a NEWLINE token appear # here, Render it harmless. (genc discards ('discard', # ('const', xxxx)) Nodes) return Discard(Const(None)) def com_arglist(self, nodelist): # varargslist: # (fpdef ['=' test] ',')* ('*' NAME [',' '**' NAME] | '**' NAME) # | fpdef ['=' test] (',' fpdef ['=' test])* [','] # fpdef: NAME | '(' fplist ')' # fplist: fpdef (',' fpdef)* [','] names = [] defaults = [] flags = 0 i = 0 while i < len(nodelist): node = nodelist[i] if node[0] == token.STAR or node[0] == token.DOUBLESTAR: if node[0] == token.STAR: node = nodelist[i+1] if node[0] == token.NAME: names.append(node[1]) flags = flags | CO_VARARGS i = i + 3 if i < len(nodelist): # should be DOUBLESTAR t = nodelist[i][0] if t == token.DOUBLESTAR: node = nodelist[i+1] else: raise ValueError("unexpected token: %s" % t) names.append(node[1]) flags = flags | CO_VARKEYWORDS break # fpdef: NAME | '(' fplist ')' names.append(self.com_fpdef(node)) i = i + 1 if i < len(nodelist) and nodelist[i][0] == token.EQUAL: defaults.append(self.com_node(nodelist[i + 1])) i = i + 2 elif len(defaults): # we have already seen an argument with default, but here # came one without raise SyntaxError("non-default argument follows default argument") # skip the comma i = i + 1 return names, defaults, flags def com_fpdef(self, node): # fpdef: NAME | '(' fplist ')' if node[1][0] == token.LPAR: return self.com_fplist(node[2]) return node[1][1] def com_fplist(self, node): # fplist: fpdef (',' fpdef)* [','] if len(node) == 2: return self.com_fpdef(node[1]) list = [] for i in range(1, len(node), 2): list.append(self.com_fpdef(node[i])) return tuple(list) def com_dotted_name(self, node): # String together the dotted names and return the string name = "" for n in node: if type(n) == type(()) and n[0] == 1: name = name + n[1] + '.' return name[:-1] def com_dotted_as_name(self, node): assert node[0] == symbol.dotted_as_name node = node[1:] dot = self.com_dotted_name(node[0][1:]) if len(node) == 1: return dot, None assert node[1][1] == 'as' assert node[2][0] == token.NAME return dot, node[2][1] def com_dotted_as_names(self, node): assert node[0] == symbol.dotted_as_names node = node[1:] names = [self.com_dotted_as_name(node[0])] for i in range(2, len(node), 2): names.append(self.com_dotted_as_name(node[i])) return names def com_import_as_name(self, node): assert node[0] == symbol.import_as_name node = node[1:] assert node[0][0] == token.NAME if len(node) == 1: return node[0][1], None assert node[1][1] == 'as', node assert node[2][0] == token.NAME return node[0][1], node[2][1] def com_import_as_names(self, node): assert node[0] == symbol.import_as_names node = node[1:] names = [self.com_import_as_name(node[0])] for i in range(2, len(node), 2): names.append(self.com_import_as_name(node[i])) return names def com_bases(self, node): bases = [] for i in range(1, len(node), 2): bases.append(self.com_node(node[i])) return bases def com_try_except_finally(self, nodelist): # ('try' ':' suite # ((except_clause ':' suite)+ ['else' ':' suite] ['finally' ':' suite] # | 'finally' ':' suite)) if nodelist[3][0] == token.NAME: # first clause is a finally clause: only try-finally return TryFinally(self.com_node(nodelist[2]), self.com_node(nodelist[5]), lineno=nodelist[0][2]) #tryexcept: [TryNode, [except_clauses], elseNode)] clauses = [] elseNode = None finallyNode = None for i in range(3, len(nodelist), 3): node = nodelist[i] if node[0] == symbol.except_clause: # except_clause: 'except' [expr [(',' | 'as') expr]] */ if len(node) > 2: expr1 = self.com_node(node[2]) if len(node) > 4: expr2 = self.com_assign(node[4], OP_ASSIGN) else: expr2 = None else: expr1 = expr2 = None clauses.append((expr1, expr2, self.com_node(nodelist[i+2]))) if node[0] == token.NAME: if node[1] == 'else': elseNode = self.com_node(nodelist[i+2]) elif node[1] == 'finally': finallyNode = self.com_node(nodelist[i+2]) try_except = TryExcept(self.com_node(nodelist[2]), clauses, elseNode, lineno=nodelist[0][2]) if finallyNode: return TryFinally(try_except, finallyNode, lineno=nodelist[0][2]) else: return try_except def com_with(self, nodelist): # with_stmt: 'with' with_item (',' with_item)* ':' suite body = self.com_node(nodelist[-1]) for i in range(len(nodelist) - 3, 0, -2): ret = self.com_with_item(nodelist[i], body, nodelist[0][2]) if i == 1: return ret body = ret def com_with_item(self, nodelist, body, lineno): # with_item: test ['as' expr] if len(nodelist) == 4: var = self.com_assign(nodelist[3], OP_ASSIGN) else: var = None expr = self.com_node(nodelist[1]) return With(expr, var, body, lineno=lineno) def com_augassign_op(self, node): assert node[0] == symbol.augassign return node[1] def com_augassign(self, node): """Return node suitable for lvalue of augmented assignment Names, slices, and attributes are the only allowable nodes. """ l = self.com_node(node) if l.__class__ in (Name, Slice, Subscript, Getattr): return l raise SyntaxError("can't assign to %s" % l.__class__.__name__) def com_assign(self, node, assigning): # return a node suitable for use as an "lvalue" # loop to avoid trivial recursion while 1: t = node[0] if t in (symbol.exprlist, symbol.testlist, symbol.testlist_comp) \ or PY2 and t == symbol.testlist_safe: if len(node) > 2: return self.com_assign_tuple(node, assigning) node = node[1] elif t in _assign_types: if len(node) > 2: raise SyntaxError("can't assign to operator") node = node[1] elif t == symbol.power: if node[1][0] not in (symbol.atom, atom_expr): raise SyntaxError("can't assign to operator") if len(node) > 2: primary = self.com_node(node[1]) for i in range(2, len(node)-1): ch = node[i] if ch[0] == token.DOUBLESTAR: raise SyntaxError("can't assign to operator") primary = self.com_apply_trailer(primary, ch) return self.com_assign_trailer(primary, node[-1], assigning) node = node[1] elif t == atom_expr: if node[1][0] != symbol.atom: raise SyntaxError("can't assign to operator") if len(node) > 2: primary = self.com_node(node[1]) for i in range(2, len(node)-1): ch = node[i] primary = self.com_apply_trailer(primary, ch) return self.com_assign_trailer(primary, node[-1], assigning) node = node[1] elif t == symbol.atom: t = node[1][0] if t == token.LPAR: node = node[2] if node[0] == token.RPAR: raise SyntaxError("can't assign to ()") elif t == token.LSQB: node = node[2] if node[0] == token.RSQB: raise SyntaxError("can't assign to []") return self.com_assign_list(node, assigning) elif t == token.NAME: return self.com_assign_name(node[1], assigning) else: raise SyntaxError("can't assign to literal") else: raise SyntaxError("bad assignment (%s)" % t) def com_assign_tuple(self, node, assigning): assigns = [] for i in range(1, len(node), 2): assigns.append(self.com_assign(node[i], assigning)) return AssTuple(assigns, lineno=extractLineNo(node)) def com_assign_list(self, node, assigning): assigns = [] for i in range(1, len(node), 2): if i + 1 < len(node): if node[i + 1][0] == symbol.list_for: raise SyntaxError("can't assign to list comprehension") assert node[i + 1][0] == token.COMMA, node[i + 1] assigns.append(self.com_assign(node[i], assigning)) return AssList(assigns, lineno=extractLineNo(node)) def com_assign_name(self, node, assigning): return AssName(node[1], assigning, lineno=node[2]) def com_assign_trailer(self, primary, node, assigning): t = node[1][0] if t == token.DOT: return self.com_assign_attr(primary, node[2], assigning) if t == token.LSQB: return self.com_subscriptlist(primary, node[2], assigning) if t == token.LPAR: raise SyntaxError("can't assign to function call") raise SyntaxError("unknown trailer type: %s" % t) def com_assign_attr(self, primary, node, assigning): return AssAttr(primary, node[1], assigning, lineno=node[-1]) def com_binary(self, constructor, nodelist): "Compile 'NODE (OP NODE)*' into (type, [ node1, ..., nodeN ])." l = len(nodelist) if l == 1: n = nodelist[0] return self.lookup_node(n)(n[1:]) items = [] for i in range(0, l, 2): n = nodelist[i] items.append(self.lookup_node(n)(n[1:])) return constructor(items, lineno=extractLineNo(nodelist)) def com_stmt(self, node): result = self.lookup_node(node)(node[1:]) assert result is not None if isinstance(result, Stmt): return result return Stmt([result]) def com_append_stmt(self, stmts, node): result = self.lookup_node(node)(node[1:]) assert result is not None if isinstance(result, Stmt): stmts.extend(result.nodes) else: stmts.append(result) def com_list_constructor(self, nodelist): # listmaker: test ( list_for | (',' test)* [','] ) values = [] for i in range(1, len(nodelist)): if PY2 and nodelist[i][0] == symbol.list_for: assert len(nodelist[i:]) == 1 return self.com_list_comprehension(values[0], nodelist[i]) elif nodelist[i][0] == token.COMMA: continue values.append(self.com_node(nodelist[i])) return List(values, lineno=values[0].lineno) def com_list_comprehension(self, expr, node): return self.com_comprehension(expr, None, node, 'list') def com_comprehension(self, expr1, expr2, node, type): # list_iter: list_for | list_if # list_for: 'for' exprlist 'in' testlist [list_iter] # list_if: 'if' test [list_iter] # XXX should raise SyntaxError for assignment # XXX(avassalotti) Set and dict comprehensions should have generator # semantics. In other words, they shouldn't leak # variables outside of the comprehension's scope. lineno = node[1][2] fors = [] while node: t = node[1][1] if t == 'for': assignNode = self.com_assign(node[2], OP_ASSIGN) compNode = self.com_node(node[4]) newfor = ListCompFor(assignNode, compNode, []) newfor.lineno = node[1][2] fors.append(newfor) if len(node) == 5: node = None elif type == 'list': node = self.com_list_iter(node[5]) else: node = self.com_comp_iter(node[5]) elif t == 'if': test = self.com_node(node[2]) newif = ListCompIf(test, lineno=node[1][2]) newfor.ifs.append(newif) if len(node) == 3: node = None elif type == 'list': node = self.com_list_iter(node[3]) else: node = self.com_comp_iter(node[3]) else: raise SyntaxError("unexpected comprehension element: %s %d" % (node, lineno)) if type == 'list': return ListComp(expr1, fors, lineno=lineno) elif type == 'set': return SetComp(expr1, fors, lineno=lineno) elif type == 'dict': return DictComp(expr1, expr2, fors, lineno=lineno) else: raise ValueError("unexpected comprehension type: " + repr(type)) def com_list_iter(self, node): assert node[0] == symbol.list_iter return node[1] def com_comp_iter(self, node): assert node[0] == symbol.comp_iter return node[1] def com_generator_expression(self, expr, node): # comp_iter: comp_for | comp_if # comp_for: 'for' exprlist 'in' test [comp_iter] # comp_if: 'if' test [comp_iter] PY37 = sys.version_info >= (3, 7) fors = [] while node: if PY37 and node[0] == symbol.comp_for: node = node[1] assert node[0] == symbol.sync_comp_for lineno = node[1][2] assert lineno is None or isinstance(lineno, int) t = node[1][1] if t == 'for': assignNode = self.com_assign(node[2], OP_ASSIGN) genNode = self.com_node(node[4]) newfor = GenExprFor(assignNode, genNode, [], lineno=node[1][2]) fors.append(newfor) if (len(node)) == 5: node = None else: node = self.com_comp_iter(node[5]) elif t == 'if': test = self.com_node(node[2]) newif = GenExprIf(test, lineno=node[1][2]) newfor.ifs.append(newif) if len(node) == 3: node = None else: node = self.com_comp_iter(node[3]) else: raise SyntaxError("unexpected generator expression element: %s %d" % (node, lineno)) fors[0].is_outmost = True return GenExpr(GenExprInner(expr, fors), lineno=expr.lineno) def com_dictorsetmaker(self, nodelist): # dictorsetmaker: ( (test ':' test (comp_for | (',' test ':' test)* [','])) | # (test (comp_for | (',' test)* [','])) ) assert nodelist[0] == symbol.dictorsetmaker nodelist = nodelist[1:] if len(nodelist) == 1 or nodelist[1][0] == token.COMMA: # set literal items = [] for i in range(0, len(nodelist), 2): items.append(self.com_node(nodelist[i])) return Set(items, lineno=items[0].lineno) elif nodelist[1][0] == symbol.comp_for: # set comprehension expr = self.com_node(nodelist[0]) return self.com_comprehension(expr, None, nodelist[1], 'set') elif len(nodelist) > 3 and nodelist[3][0] == symbol.comp_for: # dict comprehension assert nodelist[1][0] == token.COLON key = self.com_node(nodelist[0]) value = self.com_node(nodelist[2]) return self.com_comprehension(key, value, nodelist[3], 'dict') else: # dict literal items = [] for i in range(0, len(nodelist), 4): items.append((self.com_node(nodelist[i]), self.com_node(nodelist[i+2]))) return Dict(items, lineno=items[0][0].lineno) def com_apply_trailer(self, primaryNode, nodelist): t = nodelist[1][0] if t == token.LPAR: return self.com_call_function(primaryNode, nodelist[2]) if t == token.DOT: return self.com_select_member(primaryNode, nodelist[2]) if t == token.LSQB: return self.com_subscriptlist(primaryNode, nodelist[2], OP_APPLY) raise SyntaxError('unknown node type: %s' % t) def com_select_member(self, primaryNode, nodelist): if nodelist[0] != token.NAME: raise SyntaxError("member must be a name") return Getattr(primaryNode, nodelist[1], lineno=nodelist[2]) def com_call_function(self, primaryNode, nodelist): if nodelist[0] == token.RPAR: return CallFunc(primaryNode, [], lineno=extractLineNo(nodelist)) args = [] kw = 0 star_node = dstar_node = None len_nodelist = len(nodelist) i = 1 while i < len_nodelist: node = nodelist[i] if node[0]==token.STAR: if star_node is not None: raise SyntaxError('already have the varargs indentifier') star_node = self.com_node(nodelist[i+1]) i = i + 3 continue elif node[0]==token.DOUBLESTAR: if dstar_node is not None: raise SyntaxError('already have the kwargs indentifier') dstar_node = self.com_node(nodelist[i+1]) i = i + 3 continue # positional or named parameters kw, result = self.com_argument(node, kw, star_node) if len_nodelist != 2 and isinstance(result, GenExpr) \ and len(node) == 3 and node[2][0] == symbol.comp_for: # allow f(x for x in y), but reject f(x for x in y, 1) # should use f((x for x in y), 1) instead of f(x for x in y, 1) raise SyntaxError('generator expression needs parenthesis') args.append(result) i = i + 2 return CallFunc(primaryNode, args, star_node, dstar_node, lineno=extractLineNo(nodelist)) def com_argument(self, nodelist, kw, star_node): if len(nodelist) == 3 and nodelist[2][0] == symbol.comp_for: test = self.com_node(nodelist[1]) return 0, self.com_generator_expression(test, nodelist[2]) if len(nodelist) == 2: if kw: raise SyntaxError("non-keyword arg after keyword arg") if star_node: raise SyntaxError("only named arguments may follow *expression") return 0, self.com_node(nodelist[1]) assert len(nodelist) > 3, [kw, star_node, nodelist] result = self.com_node(nodelist[3]) n = nodelist[1] while len(n) == 2 and n[0] != token.NAME: n = n[1] if n[0] != token.NAME: raise SyntaxError("keyword can't be an expression (%s)" % n[0]) node = Keyword(n[1], result, lineno=n[2]) return 1, node def com_subscriptlist(self, primary, nodelist, assigning): # slicing: simple_slicing | extended_slicing # simple_slicing: primary "[" short_slice "]" # extended_slicing: primary "[" slice_list "]" # slice_list: slice_item ("," slice_item)* [","] # backwards compat slice for '[i:j]' if len(nodelist) == 2: sub = nodelist[1] if (sub[1][0] == token.COLON or \ (len(sub) > 2 and sub[2][0] == token.COLON)) and \ sub[-1][0] != symbol.sliceop: return self.com_slice(primary, sub, assigning) subscripts = [] for i in range(1, len(nodelist), 2): subscripts.append(self.com_subscript(nodelist[i])) return Subscript(primary, assigning, subscripts, lineno=extractLineNo(nodelist)) def com_subscript(self, node): # slice_item: expression | proper_slice | ellipsis ch = node[1] t = ch[0] if t == token.DOT and node[2][0] == token.DOT: return Ellipsis() if t == token.COLON or len(node) > 2: return self.com_sliceobj(node) return self.com_node(ch) def com_sliceobj(self, node): # proper_slice: short_slice | long_slice # short_slice: [lower_bound] ":" [upper_bound] # long_slice: short_slice ":" [stride] # lower_bound: expression # upper_bound: expression # stride: expression # # Note: a stride may be further slicing... items = [] if node[1][0] == token.COLON: items.append(Const(None)) i = 2 else: items.append(self.com_node(node[1])) # i == 2 is a COLON i = 3 if i < len(node) and node[i][0] == symbol.test: items.append(self.com_node(node[i])) i = i + 1 else: items.append(Const(None)) # a short_slice has been built. look for long_slice now by looking # for strides... for j in range(i, len(node)): ch = node[j] if len(ch) == 2: items.append(Const(None)) else: items.append(self.com_node(ch[2])) return Sliceobj(items, lineno=extractLineNo(node)) def com_slice(self, primary, node, assigning): # short_slice: [lower_bound] ":" [upper_bound] lower = upper = None if len(node) == 3: if node[1][0] == token.COLON: upper = self.com_node(node[2]) else: lower = self.com_node(node[1]) elif len(node) == 4: lower = self.com_node(node[1]) upper = self.com_node(node[3]) return Slice(primary, assigning, lower, upper, lineno=extractLineNo(node)) def get_docstring(self, node, n=None): if n is None: n = node[0] node = node[1:] if n == symbol.suite: if len(node) == 1: return self.get_docstring(node[0]) for sub in node: if sub[0] == symbol.stmt: return self.get_docstring(sub) return None if n == symbol.file_input: for sub in node: if sub[0] == symbol.stmt: return self.get_docstring(sub) return None if n == symbol.atom: if node[0][0] == token.STRING: s = '' for t in node: s = s + eval(t[1]) return s return None if n == symbol.stmt or n == symbol.simple_stmt \ or n == symbol.small_stmt: return self.get_docstring(node[0]) if n in _doc_nodes and len(node) == 1: return self.get_docstring(node[0]) return None _doc_nodes = [ symbol.expr_stmt, symbol.testlist, symbol.test, symbol.or_test, symbol.and_test, symbol.not_test, symbol.comparison, symbol.expr, symbol.xor_expr, symbol.and_expr, symbol.shift_expr, symbol.arith_expr, symbol.term, symbol.factor, symbol.power, ] if hasattr(symbol, 'testlist_safe'): _doc_nodes.append(symbol.testlist_safe) # comp_op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '==' # | 'in' | 'not' 'in' | 'is' | 'is' 'not' _cmp_types = { token.LESS : '<', token.GREATER : '>', token.EQEQUAL : '==', token.EQUAL : '==', token.LESSEQUAL : '<=', token.GREATEREQUAL : '>=', token.NOTEQUAL : '!=', } _legal_node_types = [ symbol.funcdef, symbol.classdef, symbol.stmt, symbol.small_stmt, symbol.flow_stmt, symbol.simple_stmt, symbol.compound_stmt, symbol.expr_stmt, symbol.del_stmt, symbol.pass_stmt, symbol.break_stmt, symbol.continue_stmt, symbol.return_stmt, symbol.raise_stmt, symbol.import_stmt, symbol.global_stmt, symbol.assert_stmt, symbol.if_stmt, symbol.while_stmt, symbol.for_stmt, symbol.try_stmt, symbol.with_stmt, symbol.suite, symbol.testlist, symbol.test, symbol.and_test, symbol.not_test, symbol.comparison, symbol.exprlist, symbol.expr, symbol.xor_expr, symbol.and_expr, symbol.shift_expr, symbol.arith_expr, symbol.term, symbol.factor, symbol.power, symbol.atom, ] if hasattr(symbol, 'yield_stmt'): _legal_node_types.append(symbol.yield_stmt) if hasattr(symbol, 'yield_expr'): _legal_node_types.append(symbol.yield_expr) if hasattr(symbol, 'print_stmt'): _legal_node_types.append(symbol.print_stmt) if hasattr(symbol, 'exec_stmt'): _legal_node_types.append(symbol.print_stmt) if hasattr(symbol, 'testlist_safe'): _legal_node_types.append(symbol.testlist_safe) _assign_types = [ symbol.test, symbol.or_test, symbol.and_test, symbol.not_test, symbol.comparison, symbol.expr, symbol.xor_expr, symbol.and_expr, symbol.shift_expr, symbol.arith_expr, symbol.term, symbol.factor, ] _names = {} for k, v in symbol.sym_name.items(): _names[k] = v for k, v in token.tok_name.items(): _names[k] = v def debug_tree(tree): l = [] for elt in tree: if isinstance(elt, int): l.append(_names.get(elt, elt)) elif isinstance(elt, str): l.append(elt) else: l.append(debug_tree(elt)) return l ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/compiler/visitor.py0000666000000000000000000000772700000000000017777 0ustar0000000000000000from __future__ import print_function from . import ast # XXX should probably rename ASTVisitor to ASTWalker # XXX can it be made even more generic? class ASTVisitor: """Performs a depth-first walk of the AST The ASTVisitor will walk the AST, performing either a preorder or postorder traversal depending on which method is called. methods: preorder(tree, visitor) postorder(tree, visitor) tree: an instance of ast.Node visitor: an instance with visitXXX methods The ASTVisitor is responsible for walking over the tree in the correct order. For each node, it checks the visitor argument for a method named 'visitNodeType' where NodeType is the name of the node's class, e.g. Class. If the method exists, it is called with the node as its sole argument. The visitor method for a particular node type can control how child nodes are visited during a preorder walk. (It can't control the order during a postorder walk, because it is called _after_ the walk has occurred.) The ASTVisitor modifies the visitor argument by adding a visit method to the visitor; this method can be used to visit a child node of arbitrary type. """ VERBOSE = 0 def __init__(self): self.node = None self._cache = {} def default(self, node, *args): for child in node.getChildNodes(): self.dispatch(child, *args) def dispatch(self, node, *args): self.node = node klass = node.__class__ meth = self._cache.get(klass, None) if meth is None: className = klass.__name__ meth = getattr(self.visitor, 'visit' + className, self.default) self._cache[klass] = meth ## if self.VERBOSE > 0: ## className = klass.__name__ ## if self.VERBOSE == 1: ## if meth == 0: ## print("dispatch", className) ## else: ## print("dispatch", className, (meth and meth.__name__ or '')) return meth(node, *args) def preorder(self, tree, visitor, *args): """Do preorder walk of tree using visitor""" self.visitor = visitor visitor.visit = self.dispatch self.dispatch(tree, *args) # XXX *args make sense? class ExampleASTVisitor(ASTVisitor): """Prints examples of the nodes that aren't visited This visitor-driver is only useful for development, when it's helpful to develop a visitor incrementally, and get feedback on what you still have to do. """ examples = {} def dispatch(self, node, *args): self.node = node meth = self._cache.get(node.__class__, None) className = node.__class__.__name__ if meth is None: meth = getattr(self.visitor, 'visit' + className, 0) self._cache[node.__class__] = meth if self.VERBOSE > 1: print("dispatch", className, (meth and meth.__name__ or '')) if meth: meth(node, *args) elif self.VERBOSE > 0: klass = node.__class__ if klass not in self.examples: self.examples[klass] = klass print() print(self.visitor) print(klass) for attr in dir(node): if attr[0] != '_': print("\t", "%-12.12s" % attr, getattr(node, attr)) print() return self.default(node, *args) # XXX this is an API change _walker = ASTVisitor def walk(tree, visitor, walker=None, verbose=None): if walker is None: walker = _walker() if verbose is not None: walker.VERBOSE = verbose walker.preorder(tree, visitor) return walker.visitor def dumpNode(node): print(node.__class__) for attr in dir(node): if attr[0] != '_': print("\t", "%-10.10s" % attr, getattr(node, attr)) ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/thirdparty/decorator.py0000666000000000000000000002640600000000000016443 0ustar0000000000000000########################## LICENCE ############################### # Copyright (c) 2005-2012, Michele Simionato # All rights reserved. # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # Redistributions in bytecode form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in # the documentation and/or other materials provided with the # distribution. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS # OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR # TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE # USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH # DAMAGE. """ Decorator module, see http://pypi.python.org/pypi/decorator for the documentation. """ from __future__ import absolute_import, print_function from pony.py23compat import PY2 __version__ = '3.4.0' __all__ = ["decorator", "FunctionMaker", "contextmanager"] import sys, re, inspect if sys.version >= '3': from inspect import getfullargspec def get_init(cls): return cls.__init__ else: class getfullargspec(object): "A quick and dirty replacement for getfullargspec for Python 2.X" def __init__(self, f): self.args, self.varargs, self.varkw, self.defaults = \ inspect.getargspec(f) self.kwonlyargs = [] self.kwonlydefaults = None def __iter__(self): yield self.args yield self.varargs yield self.varkw yield self.defaults def get_init(cls): return cls.__init__.im_func DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') # basic functionality class FunctionMaker(object): """ An object with the ability to create functions with a given signature. It has attributes name, doc, module, signature, defaults, dict and methods update and make. """ def __init__(self, func=None, name=None, signature=None, defaults=None, doc=None, module=None, funcdict=None): self.shortsignature = signature if func: # func can be a class or a callable, but not an instance method self.name = func.__name__ if self.name == '': # small hack for lambda functions self.name = '_lambda_' self.doc = func.__doc__ self.module = func.__module__ if inspect.isfunction(func): argspec = getfullargspec(func) self.annotations = getattr(func, '__annotations__', {}) for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', 'kwonlydefaults'): setattr(self, a, getattr(argspec, a)) for i, arg in enumerate(self.args): setattr(self, 'arg%d' % i, arg) if sys.version < '3': # easy way self.shortsignature = self.signature = \ inspect.formatargspec( formatvalue=lambda val: "", *argspec)[1:-1] else: # Python 3 way allargs = list(self.args) allshortargs = list(self.args) if self.varargs: allargs.append('*' + self.varargs) allshortargs.append('*' + self.varargs) elif self.kwonlyargs: allargs.append('*') # single star syntax for a in self.kwonlyargs: allargs.append('%s=None' % a) allshortargs.append('%s=%s' % (a, a)) if self.varkw: allargs.append('**' + self.varkw) allshortargs.append('**' + self.varkw) self.signature = ', '.join(allargs) self.shortsignature = ', '.join(allshortargs) self.dict = func.__dict__.copy() # func=None happens when decorating a caller if name: self.name = name if signature is not None: self.signature = signature if defaults: self.defaults = defaults if doc: self.doc = doc if module: self.module = module if funcdict: self.dict = funcdict # check existence required attributes assert hasattr(self, 'name') if not hasattr(self, 'signature'): raise TypeError('You are decorating a non function: %s' % func) def update(self, func, **kw): "Update the signature of func with the data in self" func.__name__ = self.name func.__doc__ = getattr(self, 'doc', None) func.__dict__ = getattr(self, 'dict', {}) if PY2: func.func_defaults = getattr(self, 'defaults', ()) else: func.__defaults__ = getattr(self, 'defaults', ()) func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None) func.__annotations__ = getattr(self, 'annotations', None) callermodule = sys._getframe(3).f_globals.get('__name__', '?') func.__module__ = getattr(self, 'module', callermodule) func.__dict__.update(kw) def make(self, src_templ, evaldict=None, addsource=False, **attrs): "Make a new function from a given template and update the signature" src = src_templ % vars(self) # expand name and signature evaldict = evaldict or {} mo = DEF.match(src) if mo is None: raise SyntaxError('not a valid function template\n%s' % src) name = mo.group(1) # extract the function name names = set([name] + [arg.strip(' *') for arg in self.shortsignature.split(',')]) for n in names: if n in ('_func_', '_call_'): raise NameError('%s is overridden in\n%s' % (n, src)) if not src.endswith('\n'): # add a newline just for safety src += '\n' # this is needed in old versions of Python try: # print(src) if PY2: code = compile(src, '' % self.name, 'single') exec('exec code in evaldict') else: code = compile(src, '', 'single') exec(code, evaldict) except: print('Error in generated code:', file=sys.stderr) print(src, file=sys.stderr) raise func = evaldict[name] if addsource: attrs['__source__'] = src self.update(func, **attrs) return func @classmethod def create(cls, obj, body, evaldict, defaults=None, doc=None, module=None, addsource=True, **attrs): """ Create a function from the strings name, signature and body. evaldict is the evaluation dictionary. If addsource is true an attribute __source__ is added to the result. The attributes attrs are added, if any. """ if isinstance(obj, str): # "name(signature)" name, rest = obj.strip().split('(', 1) signature = rest[:-1] #strip a right parens func = None else: # a function name = None signature = None func = obj self = cls(func, name, signature, defaults, doc, module) ibody = '\n'.join(' ' + line for line in body.splitlines()) return self.make('def %(name)s(%(signature)s):\n' + ibody, evaldict, addsource, **attrs) def decorator(caller, func=None): """ decorator(caller) converts a caller function into a decorator; decorator(caller, func) decorates a function using a caller. """ if func is not None: # returns a decorated function if PY2: evaldict = func.func_globals.copy() else: evaldict = func.__globals__.copy() evaldict['_call_'] = caller evaldict['_func_'] = func return FunctionMaker.create( func, "return _call_(_func_, %(shortsignature)s)", evaldict, undecorated=func, __wrapped__=func) else: # returns a decorator if inspect.isclass(caller): name = caller.__name__.lower() callerfunc = get_init(caller) doc = 'decorator(%s) converts functions/generators into ' \ 'factories of %s objects' % (caller.__name__, caller.__name__) fun = getfullargspec(callerfunc).args[1] # second arg elif inspect.isfunction(caller): name = '_lambda_' if caller.__name__ == '' \ else caller.__name__ callerfunc = caller doc = caller.__doc__ fun = getfullargspec(callerfunc).args[0] # first arg else: # assume caller is an object with a __call__ method name = caller.__class__.__name__.lower() if PY2: callerfunc = caller.__call__.im_func else: callerfunc = caller.__call__.__func__ doc = caller.__call__.__doc__ fun = getfullargspec(callerfunc).args[1] # second arg if PY2: evaldict = callerfunc.func_globals.copy() else: evaldict = callerfunc.__globals__.copy() evaldict['_call_'] = caller evaldict['decorator'] = decorator return FunctionMaker.create( '%s(%s)' % (name, fun), 'return decorator(_call_, %s)' % fun, evaldict, undecorated=caller, __wrapped__=caller, doc=doc, module=caller.__module__) ######################### contextmanager ######################## def __call__(self, func): 'Context manager decorator' return FunctionMaker.create( func, "with _self_: return _func_(%(shortsignature)s)", dict(_self_=self, _func_=func), __wrapped__=func) try: # Python >= 3.2 from contextlib import _GeneratorContextManager ContextManager = type( 'ContextManager', (_GeneratorContextManager,), dict(__call__=__call__)) except ImportError: # Python >= 2.5 from contextlib import GeneratorContextManager def __init__(self, f, *a, **k): return GeneratorContextManager.__init__(self, f(*a, **k)) ContextManager = type( 'ContextManager', (GeneratorContextManager,), dict(__call__=__call__, __init__=__init__)) contextmanager = decorator(ContextManager) ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864711.0806293 pony-0.7.11/pony/utils/0000777000000000000000000000000000000000000013045 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/utils/__init__.py0000666000000000000000000000006300000000000015155 0ustar0000000000000000 from .utils import * from .properties import *././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1495318476.0 pony-0.7.11/pony/utils/properties.py0000666000000000000000000000203100000000000015607 0ustar0000000000000000 class cached_property(object): """ A property that is only computed once per instance and then replaces itself with an ordinary attribute. Deleting the attribute resets the property. Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 """ # noqa def __init__(self, func): self.__doc__ = getattr(func, '__doc__') self.func = func def __get__(self, obj, cls): if obj is None: return self value = obj.__dict__[self.func.__name__] = self.func(obj) return value class class_property(object): """ Read-only class property """ def __init__(self, func): self.func = func def __get__(self, instance, cls): return self.func(cls) class class_cached_property(object): def __init__(self, func): self.func = func def __get__(self, obj, cls): value = self.func(cls) setattr(cls, self.func.__name__, value) return value././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864283.0 pony-0.7.11/pony/utils/utils.py0000666000000000000000000003644300000000000014571 0ustar0000000000000000from __future__ import absolute_import, print_function from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems import io, re, os.path, sys, inspect, types, warnings from datetime import datetime from itertools import count as _count from inspect import isfunction from time import strptime from collections import defaultdict from functools import update_wrapper, wraps from xml.etree import cElementTree from copy import deepcopy import pony from pony import options from pony.thirdparty.compiler import ast from pony.thirdparty.decorator import decorator as _decorator if pony.MODE.startswith('GAE-'): localbase = object else: from threading import local as localbase class PonyDeprecationWarning(DeprecationWarning): pass def deprecated(stacklevel, message): warnings.warn(message, PonyDeprecationWarning, stacklevel) warnings.simplefilter('once', PonyDeprecationWarning) def _improved_decorator(caller, func): if isfunction(func): return _decorator(caller, func) def pony_wrapper(*args, **kwargs): return caller(func, *args, **kwargs) return pony_wrapper def decorator(caller, func=None): if func is not None: return _improved_decorator(caller, func) def new_decorator(func): return _improved_decorator(caller, func) if isfunction(caller): update_wrapper(new_decorator, caller) return new_decorator def decorator_with_params(dec): def parameterized_decorator(*args, **kwargs): if len(args) == 1 and isfunction(args[0]) and not kwargs: return decorator(dec(), args[0]) return decorator(dec(*args, **kwargs)) return parameterized_decorator @decorator def cut_traceback(func, *args, **kwargs): if not options.CUT_TRACEBACK: return func(*args, **kwargs) try: return func(*args, **kwargs) except AssertionError: raise except Exception: exc_type, exc, tb = sys.exc_info() full_tb = tb last_pony_tb = None try: while tb.tb_next: module_name = tb.tb_frame.f_globals['__name__'] if module_name == 'pony' or (module_name is not None # may be None during import and module_name.startswith('pony.')): last_pony_tb = tb tb = tb.tb_next if last_pony_tb is None: raise module_name = tb.tb_frame.f_globals.get('__name__') or '' if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': reraise(exc_type, exc, last_pony_tb) reraise(exc_type, exc, full_tb) finally: del exc, full_tb, tb, last_pony_tb cut_traceback_depth = 2 if pony.MODE != 'INTERACTIVE': cut_traceback_depth = 0 def cut_traceback(func): return func if PY2: exec('''def reraise(exc_type, exc, tb): try: raise exc_type, exc, tb finally: del tb''') else: def reraise(exc_type, exc, tb): try: raise exc.with_traceback(tb) finally: del exc, tb def throw(exc_type, *args, **kwargs): if isinstance(exc_type, Exception): assert not args and not kwargs exc = exc_type else: exc = exc_type(*args, **kwargs) exc.__cause__ = None try: if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): raise exc else: raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback finally: del exc def truncate_repr(s, max_len=100): s = repr(s) return s if len(s) <= max_len else s[:max_len-3] + '...' codeobjects = {} def get_codeobject_id(codeobject): codeobject_id = id(codeobject) if codeobject_id not in codeobjects: codeobjects[codeobject_id] = codeobject return codeobject_id lambda_args_cache = {} def get_lambda_args(func): if type(func) is types.FunctionType: codeobject = func.func_code if PY2 else func.__code__ cache_key = get_codeobject_id(codeobject) elif isinstance(func, ast.Lambda): cache_key = func else: assert False # pragma: no cover names = lambda_args_cache.get(cache_key) if names is not None: return names if type(func) is types.FunctionType: if hasattr(inspect, 'signature'): names, argsname, kwname, defaults = [], None, None, [] for p in inspect.signature(func).parameters.values(): if p.default is not p.empty: defaults.append(p.default) if p.kind == p.POSITIONAL_OR_KEYWORD: names.append(p.name) elif p.kind == p.VAR_POSITIONAL: argsname = p.name elif p.kind == p.VAR_KEYWORD: kwname = p.name elif p.kind == p.POSITIONAL_ONLY: throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name) elif p.kind == p.KEYWORD_ONLY: throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name) else: assert False else: names, argsname, kwname, defaults = inspect.getargspec(func) elif isinstance(func, ast.Lambda): names = func.argnames if func.kwargs: names, kwname = names[:-1], names[-1] else: kwname = None if func.varargs: names, argsname = names[:-1], names[-1] else: argsname = None defaults = func.defaults else: assert False # pragma: no cover if argsname: throw(TypeError, '*%s is not supported' % argsname) if kwname: throw(TypeError, '**%s is not supported' % kwname) if defaults: throw(TypeError, 'Defaults are not supported') lambda_args_cache[cache_key] = names return names def error_method(*args, **kwargs): raise TypeError() _ident_re = re.compile(r'^[A-Za-z_]\w*\Z') # is_ident = ident_re.match def is_ident(string): 'is_ident(string) -> bool' return bool(_ident_re.match(string)) _name_parts_re = re.compile(r''' [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM | [A-Z][a-z]* # Capitalized or single capital | [a-z]+ # all-lowercase | [0-9]+ # numbers | _+ # underscores ''', re.VERBOSE) def split_name(name): "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']" if not _ident_re.match(name): raise ValueError('Name is not correct Python identifier') list = _name_parts_re.findall(name) if not (list[0].strip('_') and list[-1].strip('_')): raise ValueError('Name must not starting or ending with underscores') return [ s for s in list if s.strip('_') ] def uppercase_name(name): "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'" return '_'.join(s.upper() for s in split_name(name)) def lowercase_name(name): "uppercase_name('Some_FUNNYName') -> 'some_funny_name'" return '_'.join(s.lower() for s in split_name(name)) def camelcase_name(name): "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'" return ''.join(s.capitalize() for s in split_name(name)) def mixedcase_name(name): "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'" list = split_name(name) return list[0].lower() + ''.join(s.capitalize() for s in list[1:]) def import_module(name): "import_module('a.b.c') -> " mod = sys.modules.get(name) if mod is not None: return mod mod = __import__(name) components = name.split('.') for comp in components[1:]: mod = getattr(mod, comp) return mod if sys.platform == 'win32': _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]') else: _absolute_re = re.compile(r'^/') def is_absolute_path(filename): return bool(_absolute_re.match(filename)) def absolutize_path(filename, frame_depth): if is_absolute_path(filename): return filename code_filename = sys._getframe(frame_depth+1).f_code.co_filename if not is_absolute_path(code_filename): if code_filename.startswith('<') and code_filename.endswith('>'): if pony.MODE == 'INTERACTIVE': raise ValueError( 'When in interactive mode, please provide absolute file path. Got: %r' % filename) raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename) code_path = os.path.dirname(code_filename) return os.path.join(code_path, filename) def current_timestamp(): return datetime2timestamp(datetime.now()) def datetime2timestamp(d): result = d.isoformat(' ') if len(result) == 19: return result + '.000000' return result def timestamp2datetime(t): time_tuple = strptime(t[:19], '%Y-%m-%d %H:%M:%S') microseconds = int((t[20:26] + '000000')[:6]) return datetime(*(time_tuple[:6] + (microseconds,))) expr1_re = re.compile(r''' ([A-Za-z_]\w*) # identifier (group 1) | ([(]) # open parenthesis (group 2) ''', re.VERBOSE) expr2_re = re.compile(r''' \s*(?: (;) # semicolon (group 1) | (\.\s*[A-Za-z_]\w*) # dot + identifier (group 2) | ([([]) # open parenthesis or braces (group 3) ) ''', re.VERBOSE) expr3_re = re.compile(r""" [()[\]] # parenthesis or braces (group 1) | '''(?:[^\\]|\\.)*?''' # '''triple-quoted string''' | \"""(?:[^\\]|\\.)*?\""" # \"""triple-quoted string\""" | '(?:[^'\\]|\\.)*?' # 'string' | "(?:[^"\\]|\\.)*?" # "string" """, re.VERBOSE) def parse_expr(s, pos=0): z = 0 match = expr1_re.match(s, pos) if match is None: raise ValueError() start = pos i = match.lastindex if i == 1: pos = match.end() # identifier elif i == 2: z = 2 # "(" else: assert False # pragma: no cover while True: match = expr2_re.match(s, pos) if match is None: return s[start:pos], z==1 pos = match.end() i = match.lastindex if i == 1: return s[start:pos], False # ";" - explicit end of expression elif i == 2: z = 2 # .identifier elif i == 3: # "(" or "[" pos = match.end() counter = 1 open = match.group(i) if open == '(': close = ')' elif open == '[': close = ']'; z = 2 else: assert False # pragma: no cover while True: match = expr3_re.search(s, pos) if match is None: raise ValueError() pos = match.end() x = match.group() if x == open: counter += 1 elif x == close: counter -= 1 if not counter: z += 1; break else: assert False # pragma: no cover def tostring(x): if isinstance(x, basestring): return x if hasattr(x, '__unicode__'): try: return unicode(x) except: pass if hasattr(x, 'makeelement'): return cElementTree.tostring(x) try: return str(x) except: pass try: return repr(x) except: pass if type(x) == types.InstanceType: return '<%s instance at 0x%X>' % (x.__class__.__name__) return '<%s object at 0x%X>' % (x.__class__.__name__) def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): "Can join mix of unicode and byte strings in different encodings" strings = list(strings) try: return sep.join(strings) except UnicodeDecodeError: pass for i, s in enumerate(strings): if isinstance(s, str): strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') result = sep.join(strings) if dest_encoding is None: return result return result.encode(dest_encoding, 'replace') def count(*args, **kwargs): if kwargs: return _count(*args, **kwargs) if len(args) != 1: return _count(*args) arg = args[0] if hasattr(arg, 'count'): return arg.count() try: it = iter(arg) except TypeError: return _count(arg) return len(set(it)) def avg(iter): count = 0 sum = 0.0 for elem in iter: if elem is None: continue sum += elem count += 1 if not count: return None return sum / count def group_concat(items, sep=','): if items is None: return None return str(sep).join(str(item) for item in items) def coalesce(*args): for arg in args: if arg is not None: return arg return None def distinct(iter): d = defaultdict(int) for item in iter: d[item] = d[item] + 1 return d def concat(*args): return ''.join(tostring(arg) for arg in args) def between(x, a, b): return a <= x <= b def is_utf8(encoding): return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') def _persistent_id(obj): if obj is Ellipsis: return "Ellipsis" def _persistent_load(persid): if persid == "Ellipsis": return Ellipsis raise pickle.UnpicklingError("unsupported persistent object") def pickle_ast(val): pickled = io.BytesIO() pickler = pickle.Pickler(pickled) pickler.persistent_id = _persistent_id pickler.dump(val) return pickled def unpickle_ast(pickled): pickled.seek(0) unpickler = pickle.Unpickler(pickled) unpickler.persistent_load = _persistent_load return unpickler.load() def copy_ast(tree): return unpickle_ast(pickle_ast(tree)) def _hashable_wrap(func): @wraps(func, assigned=('__name__', '__doc__')) def new_func(self, *args, **kwargs): if getattr(self, '_hash', None) is not None: assert False, 'Cannot mutate HashableDict instance after the hash value is calculated' return func(self, *args, **kwargs) return new_func class HashableDict(dict): def __hash__(self): result = getattr(self, '_hash', None) if result is None: result = 0 for key, value in self.items(): result ^= hash(key) result ^= hash(value) self._hash = result return result def __deepcopy__(self, memo): if getattr(self, '_hash', None) is not None: return self return HashableDict({deepcopy(key, memo): deepcopy(value, memo) for key, value in iteritems(self)}) __setitem__ = _hashable_wrap(dict.__setitem__) __delitem__ = _hashable_wrap(dict.__delitem__) clear = _hashable_wrap(dict.clear) pop = _hashable_wrap(dict.pop) popitem = _hashable_wrap(dict.popitem) setdefault = _hashable_wrap(dict.setdefault) update = _hashable_wrap(dict.update) def deref_proxy(value): t = type(value) if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: # Flask local proxy value = value._get_current_object() elif t.__name__ == 'EntityProxy': # Pony proxy value = value._get_object() return value def deduplicate(value, deduplication_cache): t = type(value) try: return deduplication_cache.setdefault(t, t).setdefault(value, value) except: return value ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864710.2173421 pony-0.7.11/pony.egg-info/0000777000000000000000000000000000000000000013377 5ustar0000000000000000././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864709.0 pony-0.7.11/pony.egg-info/PKG-INFO0000666000000000000000000000617700000000000014507 0ustar0000000000000000Metadata-Version: 1.1 Name: pony Version: 0.7.11 Summary: Pony Object-Relational Mapper Home-page: https://ponyorm.com Author: Alexander Kozlovsky, Alexey Malashkevich Author-email: team@ponyorm.com License: Apache License Version 2.0 Download-URL: http://pypi.python.org/pypi/pony/ Description: About ========= Pony ORM is easy to use and powerful object-relational mapper for Python. Using Pony, developers can create and maintain database-oriented software applications faster and with less effort. One of the most interesting features of Pony is its ability to write queries to the database using generator expressions. Pony then analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. Following is an example of a query in Pony:: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Such approach simplify the code and allows a programmer to concentrate on the business logic of the application. Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. The package `pony.orm.examples `_ contains several examples. Installation ================= :: pip install pony Entity-Relationship Diagram Editor ============================================= `Pony online ER Diagram Editor `_ is a great tool for prototyping. You can draw your ER diagram online, generate Pony entity declarations or SQL script for creating database schema based on the diagram and start working with the database in seconds. Pony ORM Links: ================= - Main site: https://ponyorm.com - Documentation: https://docs.ponyorm.com - GitHub: https://github.com/ponyorm/pony - Mailing list: http://ponyorm-list.ponyorm.com - ER Diagram Editor: https://editor.ponyorm.com - Blog: https://blog.ponyorm.com Platform: UNKNOWN Classifier: Development Status :: 4 - Beta Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: OS Independent Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Software Development :: Libraries Classifier: Topic :: Database ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864710.0 pony-0.7.11/pony.egg-info/SOURCES.txt0000666000000000000000000001164600000000000015273 0ustar0000000000000000LICENSE MANIFEST.in README.md setup.py pony/__init__.py pony/converting.py pony/options.py pony/py23compat.py pony.egg-info/PKG-INFO pony.egg-info/SOURCES.txt pony.egg-info/dependency_links.txt pony.egg-info/top_level.txt pony/flask/__init__.py pony/flask/example/__init__.py pony/flask/example/__main__.py pony/flask/example/app.py pony/flask/example/config.py pony/flask/example/models.py pony/flask/example/views.py pony/flask/example/templates/index.html pony/flask/example/templates/login.html pony/flask/example/templates/reg.html pony/orm/__init__.py pony/orm/asttranslation.py pony/orm/core.py pony/orm/dbapiprovider.py pony/orm/dbschema.py pony/orm/decompiling.py pony/orm/ormtypes.py pony/orm/serialization.py pony/orm/sqlbuilding.py pony/orm/sqlsymbols.py pony/orm/sqltranslation.py pony/orm/tmp_cmp.py pony/orm/zzzzz.py pony/orm/dbproviders/__init__.py pony/orm/dbproviders/mysql.py pony/orm/dbproviders/oracle.py pony/orm/dbproviders/postgres.py pony/orm/dbproviders/sqlite.py pony/orm/examples/__init__.py pony/orm/examples/bottle_example.py pony/orm/examples/compositekeys.py pony/orm/examples/demo.py pony/orm/examples/estore.py pony/orm/examples/inheritance1.py pony/orm/examples/test_numbers.py pony/orm/examples/tmp.py pony/orm/examples/tmp2.py pony/orm/examples/tmp3.py pony/orm/examples/university1.py pony/orm/examples/university2.py pony/orm/integration/__init__.py pony/orm/integration/bottle_plugin.py pony/orm/tests/__init__.py pony/orm/tests/fixtures.py pony/orm/tests/model1.py pony/orm/tests/py36_test_f_strings.py pony/orm/tests/queries.txt pony/orm/tests/sql_tests.py pony/orm/tests/test_array.py pony/orm/tests/test_attribute_options.py pony/orm/tests/test_autostrip.py pony/orm/tests/test_buffer.py pony/orm/tests/test_bug_170.py pony/orm/tests/test_bug_182.py pony/orm/tests/test_bug_331.py pony/orm/tests/test_bug_386.py pony/orm/tests/test_cascade.py pony/orm/tests/test_cascade_delete.py pony/orm/tests/test_collections.py pony/orm/tests/test_core_find_in_cache.py pony/orm/tests/test_core_multiset.py pony/orm/tests/test_crud.py pony/orm/tests/test_crud_raw_sql.py pony/orm/tests/test_datetime.py pony/orm/tests/test_db_session.py pony/orm/tests/test_declarative_attr_set_monad.py pony/orm/tests/test_declarative_exceptions.py pony/orm/tests/test_declarative_func_monad.py pony/orm/tests/test_declarative_join_optimization.py pony/orm/tests/test_declarative_object_flat_monad.py pony/orm/tests/test_declarative_orderby_limit.py pony/orm/tests/test_declarative_query_set_monad.py pony/orm/tests/test_declarative_sqltranslator.py pony/orm/tests/test_declarative_sqltranslator2.py pony/orm/tests/test_declarative_strings.py pony/orm/tests/test_decompiler.py pony/orm/tests/test_deduplication.py pony/orm/tests/test_diagram.py pony/orm/tests/test_diagram_attribute.py pony/orm/tests/test_diagram_keys.py pony/orm/tests/test_distinct.py pony/orm/tests/test_entity_init.py pony/orm/tests/test_entity_instances.py pony/orm/tests/test_entity_proxy.py pony/orm/tests/test_exists.py pony/orm/tests/test_f_strings.py pony/orm/tests/test_filter.py pony/orm/tests/test_flush.py pony/orm/tests/test_frames.py pony/orm/tests/test_generator_db_session.py pony/orm/tests/test_get_pk.py pony/orm/tests/test_getattr.py pony/orm/tests/test_hooks.py pony/orm/tests/test_hybrid_methods_and_properties.py pony/orm/tests/test_indexes.py pony/orm/tests/test_inheritance.py pony/orm/tests/test_inner_join_syntax.py pony/orm/tests/test_isinstance.py pony/orm/tests/test_json.py pony/orm/tests/test_lazy.py pony/orm/tests/test_mapping.py pony/orm/tests/test_objects_to_save_cleanup.py pony/orm/tests/test_prefetching.py pony/orm/tests/test_query.py pony/orm/tests/test_random.py pony/orm/tests/test_raw_sql.py pony/orm/tests/test_relations_m2m.py pony/orm/tests/test_relations_one2many.py pony/orm/tests/test_relations_one2one1.py pony/orm/tests/test_relations_one2one2.py pony/orm/tests/test_relations_one2one3.py pony/orm/tests/test_relations_one2one4.py pony/orm/tests/test_relations_symmetric_m2m.py pony/orm/tests/test_relations_symmetric_one2one.py pony/orm/tests/test_select_from_select_queries.py pony/orm/tests/test_show.py pony/orm/tests/test_sqlbuilding_formatstyles.py pony/orm/tests/test_sqlbuilding_sqlast.py pony/orm/tests/test_sqlite_str_functions.py pony/orm/tests/test_time_parsing.py pony/orm/tests/test_to_dict.py pony/orm/tests/test_tracked_value.py pony/orm/tests/test_transaction_lock.py pony/orm/tests/test_validate.py pony/orm/tests/test_volatile.py pony/orm/tests/testutils.py pony/thirdparty/__init__.py pony/thirdparty/decorator.py pony/thirdparty/compiler/__init__.py pony/thirdparty/compiler/ast.py pony/thirdparty/compiler/consts.py pony/thirdparty/compiler/future.py pony/thirdparty/compiler/misc.py pony/thirdparty/compiler/pyassem.py pony/thirdparty/compiler/pycodegen.py pony/thirdparty/compiler/symbols.py pony/thirdparty/compiler/syntax.py pony/thirdparty/compiler/transformer.py pony/thirdparty/compiler/visitor.py pony/utils/__init__.py pony/utils/properties.py pony/utils/utils.py././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864709.0 pony-0.7.11/pony.egg-info/dependency_links.txt0000666000000000000000000000000100000000000017445 0ustar0000000000000000 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571864709.0 pony-0.7.11/pony.egg-info/top_level.txt0000666000000000000000000000000500000000000016124 0ustar0000000000000000pony ././@PaxHeader0000000000000000000000000000003400000000000011452 xustar000000000000000028 mtime=1571864711.0886142 pony-0.7.11/setup.cfg0000666000000000000000000000005200000000000012536 0ustar0000000000000000[egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000011453 xustar000000000000000022 mtime=1571862779.0 pony-0.7.11/setup.py0000666000000000000000000001022400000000000012431 0ustar0000000000000000from __future__ import print_function from setuptools import setup import sys import unittest def test_suite(): test_loader = unittest.TestLoader() test_suite = test_loader.discover('pony.orm.tests', pattern='test_*.py') return test_suite name = "pony" version = __import__('pony').__version__ description = "Pony Object-Relational Mapper" long_description = """ About ========= Pony ORM is easy to use and powerful object-relational mapper for Python. Using Pony, developers can create and maintain database-oriented software applications faster and with less effort. One of the most interesting features of Pony is its ability to write queries to the database using generator expressions. Pony then analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. Following is an example of a query in Pony:: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Such approach simplify the code and allows a programmer to concentrate on the business logic of the application. Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. The package `pony.orm.examples `_ contains several examples. Installation ================= :: pip install pony Entity-Relationship Diagram Editor ============================================= `Pony online ER Diagram Editor `_ is a great tool for prototyping. You can draw your ER diagram online, generate Pony entity declarations or SQL script for creating database schema based on the diagram and start working with the database in seconds. Pony ORM Links: ================= - Main site: https://ponyorm.com - Documentation: https://docs.ponyorm.com - GitHub: https://github.com/ponyorm/pony - Mailing list: http://ponyorm-list.ponyorm.com - ER Diagram Editor: https://editor.ponyorm.com - Blog: https://blog.ponyorm.com """ classifiers = [ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries', 'Topic :: Database' ] author = "Alexander Kozlovsky, Alexey Malashkevich" author_email = "team@ponyorm.com" url = "https://ponyorm.com" licence = "Apache License Version 2.0" packages = [ "pony", "pony.flask", "pony.flask.example", "pony.orm", "pony.orm.dbproviders", "pony.orm.examples", "pony.orm.integration", "pony.orm.tests", "pony.thirdparty", "pony.thirdparty.compiler", "pony.utils" ] package_data = { 'pony.flask.example': ['templates/*.html'], 'pony.orm.tests': ['queries.txt'] } download_url = "http://pypi.python.org/pypi/pony/" if __name__ == "__main__": pv = sys.version_info[:2] if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8)): s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.8." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) setup( name=name, version=version, description=description, long_description=long_description, classifiers=classifiers, author=author, author_email=author_email, url=url, license=licence, packages=packages, package_data=package_data, download_url=download_url, test_suite='setup.test_suite' )