././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4291248 storm-1.0/0000755000175000017500000000000014645532536013035 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/LICENSE0000644000175000017500000006363711752263216014051 0ustar00cjwatsoncjwatson GNU LESSER GENERAL PUBLIC LICENSE Version 2.1, February 1999 Copyright (C) 1991, 1999 Free Software Foundation, Inc. 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. [This is the first released version of the Lesser GPL. It also counts as the successor of the GNU Library Public License, version 2, hence the version number 2.1.] Preamble The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public Licenses are intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This license, the Lesser General Public License, applies to some specially designated software packages--typically libraries--of the Free Software Foundation and other authors who decide to use it. You can use it too, but we suggest you first think carefully about whether this license or the ordinary General Public License is the better strategy to use in any particular case, based on the explanations below. When we speak of free software, we are referring to freedom of use, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish); that you receive source code or can get it if you want it; that you can change the software and use pieces of it in new free programs; and that you are informed that you can do these things. To protect your rights, we need to make restrictions that forbid distributors to deny you these rights or to ask you to surrender these rights. These restrictions translate to certain responsibilities for you if you distribute copies of the library or if you modify it. For example, if you distribute copies of the library, whether gratis or for a fee, you must give the recipients all the rights that we gave you. You must make sure that they, too, receive or can get the source code. If you link other code with the library, you must provide complete object files to the recipients, so that they can relink them with the library after making changes to the library and recompiling it. And you must show them these terms so they know their rights. We protect your rights with a two-step method: (1) we copyright the library, and (2) we offer you this license, which gives you legal permission to copy, distribute and/or modify the library. To protect each distributor, we want to make it very clear that there is no warranty for the free library. Also, if the library is modified by someone else and passed on, the recipients should know that what they have is not the original version, so that the original author's reputation will not be affected by problems that might be introduced by others. Finally, software patents pose a constant threat to the existence of any free program. We wish to make sure that a company cannot effectively restrict the users of a free program by obtaining a restrictive license from a patent holder. Therefore, we insist that any patent license obtained for a version of the library must be consistent with the full freedom of use specified in this license. Most GNU software, including some libraries, is covered by the ordinary GNU General Public License. This license, the GNU Lesser General Public License, applies to certain designated libraries, and is quite different from the ordinary General Public License. We use this license for certain libraries in order to permit linking those libraries into non-free programs. When a program is linked with a library, whether statically or using a shared library, the combination of the two is legally speaking a combined work, a derivative of the original library. The ordinary General Public License therefore permits such linking only if the entire combination fits its criteria of freedom. The Lesser General Public License permits more lax criteria for linking other code with the library. We call this license the "Lesser" General Public License because it does Less to protect the user's freedom than the ordinary General Public License. It also provides other free software developers Less of an advantage over competing non-free programs. These disadvantages are the reason we use the ordinary General Public License for many libraries. However, the Lesser license provides advantages in certain special circumstances. For example, on rare occasions, there may be a special need to encourage the widest possible use of a certain library, so that it becomes a de-facto standard. To achieve this, non-free programs must be allowed to use the library. A more frequent case is that a free library does the same job as widely used non-free libraries. In this case, there is little to gain by limiting the free library to free software only, so we use the Lesser General Public License. In other cases, permission to use a particular library in non-free programs enables a greater number of people to use a large body of free software. For example, permission to use the GNU C Library in non-free programs enables many more people to use the whole GNU operating system, as well as its variant, the GNU/Linux operating system. Although the Lesser General Public License is Less protective of the users' freedom, it does ensure that the user of a program that is linked with the Library has the freedom and the wherewithal to run that program using a modified version of the Library. The precise terms and conditions for copying, distribution and modification follow. Pay close attention to the difference between a "work based on the library" and a "work that uses the library". The former contains code derived from the library, whereas the latter must be combined with the library in order to run. GNU LESSER GENERAL PUBLIC LICENSE TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 0. This License Agreement applies to any software library or other program which contains a notice placed by the copyright holder or other authorized party saying it may be distributed under the terms of this Lesser General Public License (also called "this License"). Each licensee is addressed as "you". A "library" means a collection of software functions and/or data prepared so as to be conveniently linked with application programs (which use some of those functions and data) to form executables. The "Library", below, refers to any such software library or work which has been distributed under these terms. A "work based on the Library" means either the Library or any derivative work under copyright law: that is to say, a work containing the Library or a portion of it, either verbatim or with modifications and/or translated straightforwardly into another language. (Hereinafter, translation is included without limitation in the term "modification".) "Source code" for a work means the preferred form of the work for making modifications to it. For a library, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the library. Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running a program using the Library is not restricted, and output from such a program is covered only if its contents constitute a work based on the Library (independent of the use of the Library in a tool for writing it). Whether that is true depends on what the Library does and what the program that uses the Library does. 1. You may copy and distribute verbatim copies of the Library's complete source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and distribute a copy of this License along with the Library. You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. 2. You may modify your copy or copies of the Library or any portion of it, thus forming a work based on the Library, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: a) The modified work must itself be a software library. b) You must cause the files modified to carry prominent notices stating that you changed the files and the date of any change. c) You must cause the whole of the work to be licensed at no charge to all third parties under the terms of this License. d) If a facility in the modified Library refers to a function or a table of data to be supplied by an application program that uses the facility, other than as an argument passed when the facility is invoked, then you must make a good faith effort to ensure that, in the event an application does not supply such function or table, the facility still operates, and performs whatever part of its purpose remains meaningful. (For example, a function in a library to compute square roots has a purpose that is entirely well-defined independent of the application. Therefore, Subsection 2d requires that any application-supplied function or table used by this function must be optional: if the application does not supply it, the square root function must still compute square roots.) These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Library, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Library, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Library. In addition, mere aggregation of another work not based on the Library with the Library (or with a work based on the Library) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. 3. You may opt to apply the terms of the ordinary GNU General Public License instead of this License to a given copy of the Library. To do this, you must alter all the notices that refer to this License, so that they refer to the ordinary GNU General Public License, version 2, instead of to this License. (If a newer version than version 2 of the ordinary GNU General Public License has appeared, then you can specify that version instead if you wish.) Do not make any other change in these notices. Once this change is made in a given copy, it is irreversible for that copy, so the ordinary GNU General Public License applies to all subsequent copies and derivative works made from that copy. This option is useful when you wish to copy part of the code of the Library into a program that is not a library. 4. You may copy and distribute the Library (or a portion or derivative of it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange. If distribution of object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place satisfies the requirement to distribute the source code, even though third parties are not compelled to copy the source along with the object code. 5. A program that contains no derivative of any portion of the Library, but is designed to work with the Library by being compiled or linked with it, is called a "work that uses the Library". Such a work, in isolation, is not a derivative work of the Library, and therefore falls outside the scope of this License. However, linking a "work that uses the Library" with the Library creates an executable that is a derivative of the Library (because it contains portions of the Library), rather than a "work that uses the library". The executable is therefore covered by this License. Section 6 states terms for distribution of such executables. When a "work that uses the Library" uses material from a header file that is part of the Library, the object code for the work may be a derivative work of the Library even though the source code is not. Whether this is true is especially significant if the work can be linked without the Library, or if the work is itself a library. The threshold for this to be true is not precisely defined by law. If such an object file uses only numerical parameters, data structure layouts and accessors, and small macros and small inline functions (ten lines or less in length), then the use of the object file is unrestricted, regardless of whether it is legally a derivative work. (Executables containing this object code plus portions of the Library will still fall under Section 6.) Otherwise, if the work is a derivative of the Library, you may distribute the object code for the work under the terms of Section 6. Any executables containing that work also fall under Section 6, whether or not they are linked directly with the Library itself. 6. As an exception to the Sections above, you may also combine or link a "work that uses the Library" with the Library to produce a work containing portions of the Library, and distribute that work under terms of your choice, provided that the terms permit modification of the work for the customer's own use and reverse engineering for debugging such modifications. You must give prominent notice with each copy of the work that the Library is used in it and that the Library and its use are covered by this License. You must supply a copy of this License. If the work during execution displays copyright notices, you must include the copyright notice for the Library among them, as well as a reference directing the user to the copy of this License. Also, you must do one of these things: a) Accompany the work with the complete corresponding machine-readable source code for the Library including whatever changes were used in the work (which must be distributed under Sections 1 and 2 above); and, if the work is an executable linked with the Library, with the complete machine-readable "work that uses the Library", as object code and/or source code, so that the user can modify the Library and then relink to produce a modified executable containing the modified Library. (It is understood that the user who changes the contents of definitions files in the Library will not necessarily be able to recompile the application to use the modified definitions.) b) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (1) uses at run time a copy of the library already present on the user's computer system, rather than copying library functions into the executable, and (2) will operate properly with a modified version of the library, if the user installs one, as long as the modified version is interface-compatible with the version that the work was made with. c) Accompany the work with a written offer, valid for at least three years, to give the same user the materials specified in Subsection 6a, above, for a charge no more than the cost of performing this distribution. d) If distribution of the work is made by offering access to copy from a designated place, offer equivalent access to copy the above specified materials from the same place. e) Verify that the user has already received a copy of these materials or that you have already sent this user a copy. For an executable, the required form of the "work that uses the Library" must include any data and utility programs needed for reproducing the executable from it. However, as a special exception, the materials to be distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. It may happen that this requirement contradicts the license restrictions of other proprietary libraries that do not normally accompany the operating system. Such a contradiction means you cannot use both them and the Library together in an executable that you distribute. 7. You may place library facilities that are a work based on the Library side-by-side in a single library together with other library facilities not covered by this License, and distribute such a combined library, provided that the separate distribution of the work based on the Library and of the other library facilities is otherwise permitted, and provided that you do these two things: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities. This must be distributed under the terms of the Sections above. b) Give prominent notice with the combined library of the fact that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 8. You may not copy, modify, sublicense, link with, or distribute the Library except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense, link with, or distribute the Library is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. 9. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Library or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Library (or any work based on the Library), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Library or works based on it. 10. Each time you redistribute the Library (or any work based on the Library), the recipient automatically receives a license from the original licensor to copy, distribute, link with or modify the Library subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties with this License. 11. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Library at all. For example, if a patent license would not permit royalty-free redistribution of the Library by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Library. If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply, and the section as a whole is intended to apply in other circumstances. It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. 12. If the distribution and/or use of the Library is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Library under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. 13. The Free Software Foundation may publish revised and/or new versions of the Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Library does not specify a license version number, you may choose any version ever published by the Free Software Foundation. 14. If you wish to incorporate parts of the Library into other free programs whose distribution conditions are incompatible with these, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. NO WARRANTY 15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. END OF TERMS AND CONDITIONS How to Apply These Terms to Your New Libraries If you develop a new library, and you want it to be of the greatest possible use to the public, we recommend making it free software that everyone can redistribute and change. You can do so by permitting redistribution under these terms (or, alternatively, under the terms of the ordinary General Public License). To apply these terms, attach the following notices to the library. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. Copyright (C) This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA Also add information on how to contact you by electronic and paper mail. You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the library, if necessary. Here is a sample; alter the names: Yoyodyne, Inc., hereby disclaims all copyright interest in the library `Frob' (a library for tweaking knobs) written by James Random Hacker. , 1 April 1990 Ty Coon, President of Vice That's all there is to it! ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1618825751.0 storm-1.0/MANIFEST.in0000644000175000017500000000031614037251027014557 0ustar00cjwatsoncjwatsonrecursive-include storm *.py *.c *.zcml *.rst include MANIFEST.in LICENSE README TODO NEWS Makefile dev/test setup.cfg tox.ini include storm/docs/Makefile prune storm/docs/_build prune db prune db-mysql ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1559316409.0 storm-1.0/Makefile0000644000175000017500000000134413474243671014475 0ustar00cjwatsoncjwatsonPYTHON ?= python PYDOCTOR ?= pydoctor TEST_COMMAND = $(PYTHON) setup.py test all: build build: $(PYTHON) setup.py build_ext -i develop: $(TEST_COMMAND) --quiet --dry-run check: tox check-with-trial: STORM_TEST_RUNNER=trial tox doc: $(PYDOCTOR) --make-html --html-output apidoc --add-package storm release: $(PYTHON) setup.py sdist clean: rm -rf build rm -rf build-stamp rm -rf dist rm -rf storm.egg-info rm -rf debian/files rm -rf debian/python-storm rm -rf debian/python-storm.* rm -rf .tox rm -rf *.egg rm -rf _trial_temp find . -name "*.so" -type f -exec rm -f {} \; find . -name "*.pyc" -type f -exec rm -f {} \; find . -name "*~" -type f -exec rm -f {} \; .PHONY: all build check clean develop doc release ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152559.0 storm-1.0/NEWS0000644000175000017500000007113614645532057013542 0ustar00cjwatsoncjwatson1.0 (2024-07-16) ================ Improvements ------------ - Fix several syntax warnings from recent Python versions. - Support Python 3.12. API changes ----------- - Remove support for Python < 3.6 (including Python 2). - Remove the storm.compat module. - Deprecate storm.sqlobject.AutoUnicodeVariable and storm.sqlobject.AutoUnicode. 0.26 (2023-07-04) ================= Improvements ------------ - Clarify exception when creating a property with both allow_none=False and default=None. - Add storm.expr.Is and storm.expr.IsNot operators. - Support Python 3.10 and 3.11. Bug fixes --------- - Avoid traceback reference cycles when wrapping exceptions. - Fix test compatibility with MySQL >= 8.0.24. 0.25 (2021-04-19) ================= Improvements ------------ - Add optional case_sensitive argument to Comparable.startswith, Comparable.endswith, and Comparable.contains_string. This is only supported in the PostgreSQL backend. - Restore MySQL support. - Support Python 3.9. 0.24 (2020-06-12) ================= Improvements ------------ - Add Sphinx documentation. - Convert CaptureTracer to the improved API in fixtures >= 1.3.0. - Add block_access to storm.store.__all__. Bug fixes --------- - Fix list() on security-proxied ResultSets on Python 3. - Fix slicing of security-proxied bound ReferenceSets on Python 2. 0.23 (2020-03-18) ================= Improvements ------------ - Add whitespace around "<<", ">>", "+", "-", "*", "/", and "%" operators. This helps to avoid confusion in tracers, since they run before parameter substitution: for example, sqlparse parses "expr+%s" as "expr", "+%", "s" rather than "expr", "+", "%s". - Implement __getitem__ on bound ReferenceSets. - The storm.twisted.testing.FakeTransactor test helper now uses a fake transaction by default, so tests won't perform real commits as it's not generally needed (bug #1009983). - A new block_access context manager blocks database access for one or more stores in the managed scope (bug #617182). - Implement is_empty on bound ReferenceSets. Bug fixes --------- - Stop using deprecated assertEquals/assertNotEquals unittest methods. - Return self from EmptyResultSet.config() to match ResultSet (bug #905529). API changes ----------- - Rename RawStr and RawStrVariable to Bytes and BytesVariable respectively, since that matches Python 3's terminology. RawStr and RawStrVariable still exist as deprecated aliases. 0.22 (2019-11-21) ================= Improvements ------------ - Use the postgresfixture package to set up a temporary cluster for PostgreSQL tests, simplifying developer setup. Bug fixes --------- - Fix incorrect caching of wrapped DB-API exceptions (bug 1845702). - Support Python 3.6, 3.7, and 3.8. - Fix incorrect expected type for the 'join' parameter to the C version of Compile.__call__ on Python 3. 0.21 (2019-09-20) ================= Improvements ------------ - A new storm.schema.sharding.Sharding class has been added to the schema package, which can apply patches "in parallel" against a set of stores. See the module docstring for more information. - Added tox testing support. - Re-raise DB-API exceptions wrapped in exception types that inherit from both StormError and the appropriate DB-API exception type, rather than injecting virtual base classes. This preserves existing exception handling in applications while also being a viable approach in Python 3. - Port to Python 3 (bug #1530734, based partly on contributions from Thiago Bellini). Existing Python 2 users should be unaffected. For people porting to Python 3, note the following API changes relative to Python 2: - On Python 3, raw=True and token=True in storm.expr.Compile.__call__ only treat str specially, not bytes and str, because ultimately the compiler is assembling a text string to send to the database. - On Python 3, storm.tracer.BaseStatementTracer.connection_raw_execute renders text parameters using ascii() rather than by encoding to bytes and then calling repr(). While this does result in slightly different output from Python 2, it's normally more useful since the encoding is in terms of Unicode codepoints rather than UTF-8. - storm.sqlobject.AutoUnicodeVariable (and hence StringCol) explicitly documents that it only accepts text on Python 3, since native strings are already Unicode there so there's much less need for the porting affordance. Bug fixes --------- - Fixed SQLObject tests to work with SQLObject 2.x by using Unicode strings for LIKE operations. - Cope with ThreadTransactionManager changes in transaction 2.4.0. - Move tests to a new storm.tests package to avoid package manager conflicts (bug #1199578). API changes ----------- - Removed Django support, storm.django is no more. - Removed MySQL support. - Removed support for Python < 2.6. 0.20 (2013-06-28) ================= Improvements ------------ - A new CaptureTracer has been added to storm.testing, for capturing all SQL statements executed by Storm. It can be used like this: with CaptureTracer() as tracer: # Run queries pass print(tracer.queries) # Print all queries run in the context manager block You will need the python-fixtures package in order to use this feature. - Setuptools is now required to run setup.py. This makes it much easier to install the majority of the dependencies required to run the test suite in its entirety. - Disconnection errors arising from PostgreSQL are now more reliably detected, especially with regard to recent libpq updates in Ubuntu. There are also disconnection tests that simulate sudden termination of pgbouncer . - Insert expressions now support multi-row and subquery INSERT statements. - Support in the postgres backend to use the RETURNING extension for UPDATE statement, optionally specifying the columns to return. (PostgreSQL >= 8.2 only) - Add a new Distinct expression for pre-pending the 'DISTINCT' token to SQL expressions that support it (like columns). - Switch to REPEATABLE READ as isolation level for Postgres. If you use psycopg2 < 2.4, it doesn't change anything. For psycopg2 2.4 and newer, it will keep the same behavior on Postgres < 9 as it's equivalent to SERIALIZABLE, but it will be less strict on Postgres >= 9. - Add support for two-phase commits, using the DB API version 2.0, which is only supported by PostgreSQL. - ZStormResourceManager now has a schema_stamp_dir optional instance attribute that, if set, will be used to to save timestamps of the schema's patch packages, so schema upgrades will be performed only when needed. Bug fixes --------- - When a SQLObjectResultSet object was sliced with slice.start and slice.end both zero (sqlobject[0:0]), the full, unsliced resultset was returned (bug #872086). - Fixes some test failures around Django disconnections from PostgreSQL stores. - Some of the proxy-less disconnection tests were causing segfaults, so they're now run in a subprocess via subunit's IsolatedTestCase, when it's available. - Fix a few non-uses of super() in TestHelper. - Abort the transaction and reset ZStorm at the end of tests/zope/README.txt. - Fix the ./test script to not import tests until after local eggs have been added to sys.path. This ensures that all possible features are available to the tests. - If transaction.commit() fails, call transaction.abort() to rollback everything and leave the transaction in a clean state when using Django's ZopeTransactionMiddleware and DatabaseWrapper. - It's now again possible to use the Desc() expression when passing an order_by parameter to a ReferenceSet (bug #620369). 0.19 (2011-10-03) ================= Improvements ------------ - A new Cast expressions compiles an input and the type to cast it to into a call the CAST function (bug #681121). - The storm.zope.testing.ZStormResourceManager now supports applying database schemas using a custom URI, typically for connecting to the database using a different user with greater privileges than the user running the tests. Note that the format of the 'databases' parameter passed to the constructor of the ZStormResourceManager class has changed. So now you have to create your resource manager roughly like this:: databases = [{"name": "test", "uri": "postgres://user@host/db", "schema": Schema(...), "schema-uri: "postgres://schema_user@host/db"}] manager = ZStormResourceManager(databases) where the "schema-uri" key is optional and would default to "uri" if not given. The old format of the 'databases' parameter is still supported but deprecated. (bug #772258) - A new storm.twisted.transact module has been added with facilities to integrate Storm with twisted, by running transactions in a separate thread from the main one in order to not block the reactor. (bug #777047) - ResultSet.config's "distinct" argument now also accepts a tuple of columns, which will be turned into a DISTINCT ON clause. - Provide wrapped cursor objects in the Django integration layer. Catch some disconnection errors that might otherwise be missed, leading to broken connections. (bug #816049) - A new JSON property is available. It's similar to the existing Pickle property, except that it serializes data as JSON, and must back onto a text column rather than a byte column. (bug #726799, #846867) - Cache the compilation of columns and tables (bug #826170, #848925). - Add two new tracers extracted from the Launchpad codebase. BaseStatementTracer provides statements with parameters substituted to its subclasses. TimelineTracer records queries in a timeline (useful for OOPS reports). - New ROW constructor (bug #698344). - Add support for Postgres DISTINCT ON queries. (bug #374777) Bug fixes --------- - When retrieving and using an object with Store.get(), Storm will no longer issue two queries when there is a live invalidated object (bug #697513). - When a datetime object is returned by the database driver, DateVariable failed to detect and convert it to a date() object (bug #391601). - The ISQLObjectResultSet declares an is_empty method, which matches the existing implementation. This makes it possible to call the method in security proxied environments (bug #759384). - The UUIDVariable correctly converts inputs to unicode before sending them to the database. This makes the UUID property usable (bug #691752). - Move the firing of the register-transaction event in Connection.execute before the connection checking, to make sure that the store gets registered properly for future rollbacks (bug #819282). - Skip tests/zope/README.txt when zope.security is not found. (bug #848848) - Fix the handling of disconnection errors in the storm.django bridge. (bug #854787) 0.18 (2010-10-25) ================= Improvements ------------ - Include code to manage and migrate database schemas. See the storm.schema sub-package (bug #250412). - Added a storm.zope.testing.ZStormResourceManager class to manage a set of stores registered with ZStorm (bug #618704). It can be used roughly like this:: from testresources import ResourcedTestCase from storm.zope.testing import ZStormResourceManager from storm.schema import Schema name = "test" uri = "sqlite:" schema = Schema(...) manager = ZStormResourceManager({name: (uri, schema)}) class MyTest(ResourcedTestCase): resources = [("zstorm", manager)] def test_stuff(self): store = self.zstorm.get("test") store.execute(...) - When a TimeoutError is raised it includes a description about why the exception was raised, to help make it easier to reason about timeout-related issues (bug #617973). - Improved the IResultSet interface to document the rationale of why some attributes are not included (bug #659883). Bug fixes --------- - Make storm compatible with psycopg2 2.2 (bug #585704). - Fix bug #620615, which caused lazy expressions to cause subsequent loading of objects to explode if unflushed. - Fix bug #620508, which caused slicing a ResultSet to break subsequent count() calls. - Fix bug #659708, correcting the behavior of the sqlobject is_empty and __nonzero__ methods. - Fix bug #619017, which caused __storm_loaded__ to be called without its object's variables defined if the object were in the alive cache but disappeared. 0.17 (2010-08-05) ================= Improvements ------------ - The order_by parameter defined on ReferenceSet can now be specified as a string, to workaround circular dependency issues. The order by property column will be resolved after import time (bug #580037). - The Store and Connection classes have block_access() and unblock_access() methods that can be used to block access to the database connection. This can be used to ensure that an application doesn't access the database at unexpected times. - When using stores managed by ZStorm, a ZStormError will be raised on attempts to use a per-thread store from the wrong thread (bug #348815). - ResultSet.is_empty strips the ORDER BY clause, when present, to provide a performance boost for queries that would match a large number of rows (bug #246200). - A new ResultSet.get_select_expr method returns a Select expression built for a specified set of columns, based on the settings of the result set (bug #337494). - ResultSet.any and ReferenceSet.any strips the ORDER BY clause, when present, to provide a performance boost for queries that would match a large number of rows (bug #608825). - SQLObjectResultSet has a new is_empty method which should be used in preference to __nonzero__, because it is compatible with ResultSet. Bug fixes --------- - SQLite reserved words are handled properly (bug #593633). - A bug in change checkpointing logic has been fixed to detect changes in mutable objects correctly and to prevent useless (or potentially harmful!) columns to be specified in updates (bug #553334). 0.16 (2009-11-28) ================= Improvements ------------ - The set expression constructor will now flatten its first argument if it is of the same type. The resulting expression tree uses less stack when compiling so reduces the chance of hitting Python's recursion limit (bug #242813). - Add startswith(), endswith() and contains_string() methods to Comparable. These methods perform prefix, suffix and substring checks respectively using the LIKE operator, taking care of escaping for you (bug #387840). - C extensions are enabled by default. Define the STORM_CEXTENSIONS=0 environment variable to disable them (bug #410592). - The README file contains information about Storm's license and detailed instructions on setting up a development environment suitable for running the entire test suite. - 'make doc' uses Pydoctor to generate API documentation. - Integration tests for Django now work with Django 1.1. Bug fixes --------- - Remove a leak when mutable variables (ListVariable or PickleVariable instances) are collected before store.flush, leaving hooks behind them. - The ResultSet min, max and sum methods now work correctly when the result set is empty and the column has allow_none=False set. Previously this resulted in a NoneError (bug #457801). - MySQL reserved words are handled properly (bug #433833). - Test loading code has been simplified. Support for py.test has been removed in this process, as it was not functioning correctly and didn't fit into the PyUnit framework (bug #331905). - Remote diverged and remote deleted references now use a weak (Python) reference to the local object. This prevents a leak when the remote object stays in memory (bug #475148). - Check for invalidated state when returning the remote object of a relation: it fixes a bug if the local key of the Reference is the primary key (bug #435962). - The default Cache instance created for a Store honours Cache's default size. Store's docstring has been updated to reflect this (bug #374180). 0.15 (2009-08-07) ================= Improvements ------------ - Add support for latest version on Django by not checking arguments passed to _cursor. - Added support for using Expressions with ResultSet.set(). - The default cache size was changed from 100 to 1000 objects. - Added the new GenerationalCache with a faster implementation of the caching API (by Jeroen Vermeulen). This will likely become the default implementation in the future, so please test it if possible. - A new UUID property type has been added. It depends on the uuid module introduced in Python 2.5 to represent values. - The StoreDataManager now gets passed a transaction manager from the ZStorm utility. This will make it easier to support non-default transaction managers in the future. - An adapter is now available for converting ISQLObjectResultSet objects to IResultSet. This is intended to help in gradual porting SQLObject applications to Storm's native API (bug #338184, by Gavin Panella). - If a disconnection occurs outside of Storm's control, a DisconnectionError will still be raised next time the connection is used through Storm. This is useful if the connection is being shared with another framework like Django (bug #374909). - The PostgreSQL backend now requires psycopg2 >= 2.0.7. The work around for broken quoting behaviour in older psycopg2 versions has been removed. (bug #322206). - A new Neg expression is available. It provides unary minus by prepending a minus sign to whatever expression is passed to it (bug #397654 by Michael Hudson). - A new Coalese expressions is available. - ResultSets now have a find(). It acts similar to Store.find(), but without the first argument (it uses the same classes as the original result), and only returns results found in the original result set. (bug #338255). - Result.rowcount exposes the number of rows affected by the query, when known. - ResultSet.remove returns the number of rows deleted from the database. (bug #180122). Bug fixes --------- - The fix for ResultSet.count() on distinct results from last release has been improved, and the fix has been extended to the other aggregates. This may change the result of some count(), min(), max(), sum() and avg() calls for results using distinct, limit or offset. - The test suite now passes when run with Python 2.6. - ListVariable now converts its elements to database representation correctly (bug #136806 reported by Marc Tardif). - compile_python now works for values that don't produce valid Python expressions with repr(). - The C extension should now build under Windows. 0.14 (2009-01-09) ================= Improvements ------------ - A new doctest describing the 'Infoheritance' pattern is integrated into the test suite. - A new storm.django package has been added to allow use of Storm within Django applications. - The way Storm interacts with the Zope transaction manager has changed. Rather than using a synchronizer to join each new transaction, it now delays joining the transaction until the store is actually used. - The Store constructor takes an optional cache keyword argument. - ResultSets now offer an is_empty() method. Bug fixes --------- - Manage better row full of NULL in case of LEFT JOIN, without validating against object constraints. - The Reference class now has an __ne__() method, so inequality checks against a reference now work in find expressions (bug #244768 reported by Stuart Bishop). - Make ResultSet.count() handles the distinct, limit and offset flags, to really reflect the length of the current ResultSet. - The store doesn't iterate on all the alive objects anymore, using instead events that objects can subscribe to. This improves performance drastically when lots of objects are in the cache. 0.13 (2008-08-28) ================= Improvements ------------ - Add group_by/having methods on ResultSet objects, to allow access to the "GROUP BY" and "HAVING" statements. - Change tests/store to keep the connection during the tests to make it faster. - Implemented support for plugging generic "tracers". Statement debugging is now implemented using a tracer, and easily enabled with storm.tracer.debug(True) (storm.database.DEBUG = True is gone). - All properties now accept a "validator" parameter. When used, a function like validate(object, attribute_name, value) should be given, and it may validate or modify the value before it's set in the property. The value assigned to the property is the result of the validator, so the original value should be returned if changing it isn't intended. - Expressions can be passed to Store.find() as well as classes. This makes it possible to request individual columns from a table, computed expressions or aggregates. - Objects will be flushed in the order they become dirty by default. This means that generally the order in which Python operations are performed will be used to define the order in which flushes are done, which is generally the most expected. - The cextensions module was fixed, optimized, and improved. It's now built by default, but to actually enable it the environment variable STORM_CEXTENSIONS=1 must be defined at runtime. The module will likely be enabled by default on a future release. - ClassAlias will now cache all explicitly named aliases, to prevent the cost of rebuilding it. - Result sets and reference sets now have a __contains__() method. While code like "item in set" was previously possible, it involved iterating over the result set, which is expensive for large databases. - The storm.zope.zstorm code can now be used with only the zope.interface and transaction packages installed. This makes it easier to reuse the per-thread store management and global transaction handling from other web frameworks. Bug fixes --------- - Make is_in returns False instead of NULL on an empty result set. - ZStorm now keeps strong references to named stores. Previously it only kept weak ones, so stores were recreated more often than necessary. - References now won't flush the store or query the database when the foreign key is None. - When a reference is set to an object that wasn't yet inserted in the database, the foreign key is immediately unset instead of keeping the old value up to the flushing. - Setting a reference to None works even if the previously referenced object isn't in memory. - When setting a reference, flush ordering is only enforced if the key is dirty. This allows a number of changes that would previously raise OrderLoopError. - If the remote object in a back reference is removed, the reference will now be broken. - Fixed a race condition when two threads try to initialize the ClassInfo for a given class at the same time. - Improve handling of AUTO INCREMENT columns in the MySQL backend to remove an unnecessary query when adding objects to a store. 0.12 (2008-01-28) ================= Improvements ------------ - The Connection will reconnect automatically when connection drops are detected and a rollback is performed. As a result, the Store should handle reconnections in a seamless way in most circumstances (#94986, by James Henstridge). This is supported in the MySQL and PostgreSQL backends. - Store.flush() will not load values inserted in the database. Instead, undefined variables are set to AutoReload, and resolved once first accessed. This won't be noticeable in normal usage, but will boost the performance of inserts. - Support in the postgres backend to use the RETURNING extension of the INSERT statement to retrieve the primary key on inserts for object identity mapping (PostgreSQL >= 8.2 only) - Introduced a cache mechanism that keeps the N last retrieved objects in memory to optimize cases where the same object is retrieved often while no strong references are kept elsewhere. Implemented by Bernd Dorn, from Lovely Systems. - Improved support for TimeDelta properties on all backends. Many more formats are accepted now, and some issues were fixed. Bug fixes --------- - TimeDelta was added to storm.locals. - Fixed TimeDelta support in SQLite, MySQL, and PostgreSQL, and enabled base test for all backends to ensure that it continues to work. - Schema names are accepted in __storm_table__ when using PostgreSQL as the database (e.g. "schema.table"). (#146580, reported by James Mayfield) - Test runner handles path correctly on Windows, and SQLite tests won't break (patch by Ali Afshar). - In the SQLite backend, ensure that we're able to recommit a transaction after "database is locked" errors. Also make sure that when this happens the timeout is actually the expected one (patch by Richard Boulton) - TransactionFailedError is now imported from the public place: ZODB.POSException (#129715, by James Henstridge). - TimeDelta was added to storm.locals. - Tables named with reserved keywords are properly escaped. - Reserved keywords on column names are properly escaped when part of an insert or update statement (#184849, by Thomas Herve). - Prevent cached objects from issuing selects to retrieve their data when find()s were previously made and already brough their data (#174388, reported and fixed by Bernd Dorn). - Fixed bug which caused an object to be readded to the store when a reference of an object that had already been removed was looked up. - Prevent pathological case which happens when a statement like "SELECT ... WHERE table.id = currval(...)" is executed in PostgreSQL. The change is only meaningful on PostgreSQL < 8.2, as newer versions will use the RETURNING extension instead. - Specify both of the joining tables explicitly when compiling Proxy, so that it doesn't break due to incorrect references in the ON clause when multiple tables are used (reported in #162528 by S3ym0ur and Hamilton Tran) - MySQL client charset now defaults to UTF-8 (reported by Brad Crittenden). 0.11 (2007-10-08) ================= Improvements ------------ - Added case-insensitive support to the Like expression and Class.attr.like() constructions. - ZStorm.get_name() for obtaining the the name of the given store. Bug fixes --------- - storm.zope wasn't included on the tarball due to an error in setup.py. - Binary strings are now properly quoted with the E'' formatting if needed, in the postgres backend. 0.10 (2007-08-08) ================= Improvements ------------ - Improvements were made to the tutorial. - There is now a setup.py script for installing Storm. - Count and ClassAlias is now available through the storm.locals module. - A new hook, __storm_pre_flush__, can be implemented on objects in a Store. It is called before an object is flushed to the database. - Storm can now use the built-in sqlite support in Python 2.5 and above. - There is now a storm.properties.Decimal, which allows you to store Decimal (as opposed to binary) floating point values. - storm.zope was added, which offers a simple integration mechanism with the Zope transaction machinery. - Complex expressions other than simple Columns can now be passed to the aggregation methods of ResultSet (avg,max,min,sum). - Backend implementors can now override preset_primary_key on their Database object to come up with primary key values before an Insert. - A large amount of API documentation was added. Bug fixes --------- - SQL reserved words are now properly escaped in SQL statements. - GROUP BY and ORDER BY statements are now ordered correctly. - Running the tests with trial now works. - All backends are now initialized such that their transactions are truly SERIALIZABLE. Psycopg2 and Pysqlite2 both did not previously have serializable transactions by default, but this has been fixed. - A bug in ResultSet.cached which could occasionally cause inconsistencies in ResultSet.set was fixed. API Changes ----------- Most changes are backwards compatible. There were some incompatible changes which may affect alternative database backends. - Chars was renamed to RawStr. Chars still exists, but is deprecated. All raw 8-bit data in your database should be represented with RawStr. - compiler handlers have had their arguments reordered. - The Compile.__call__ method now returns only the Statement. - Compile.fork was renamed to Compile.create_child. - Many methods which previously had underscores were renamed to get rid of the underscores to reflect their status as things which can be safely touched in subclasses. Documentation was added clarifying their intended use. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4291248 storm-1.0/PKG-INFO0000644000175000017500000001727414645532536014145 0ustar00cjwatsoncjwatsonMetadata-Version: 2.1 Name: storm Version: 1.0 Summary: Storm is an object-relational mapper (ORM) for Python developed at Canonical. Home-page: https://storm.canonical.com Download-URL: https://launchpad.net/storm/+download Author: Gustavo Niemeyer Author-email: gustavo@niemeyer.net Maintainer: Storm Developers Maintainer-email: storm@lists.canonical.com License: LGPL Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL) Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 Classifier: Topic :: Database Classifier: Topic :: Database :: Front-Ends Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=3.6 Description-Content-Type: text/x-rst License-File: LICENSE Requires-Dist: packaging>=14.1 Provides-Extra: doc Requires-Dist: fixtures; extra == "doc" Requires-Dist: sphinx; extra == "doc" Requires-Dist: sphinx-epytext; extra == "doc" Provides-Extra: test Requires-Dist: fixtures>=1.3.0; extra == "test" Requires-Dist: mysqlclient; extra == "test" Requires-Dist: pgbouncer>=0.0.7; extra == "test" Requires-Dist: postgresfixture; extra == "test" Requires-Dist: psycopg2>=2.3.0; extra == "test" Requires-Dist: testresources>=0.2.4; extra == "test" Requires-Dist: testtools>=0.9.8; extra == "test" Requires-Dist: timeline>=0.0.2; extra == "test" Requires-Dist: transaction>=1.0.0; extra == "test" Requires-Dist: Twisted>=10.0.0; extra == "test" Requires-Dist: zope.component>=3.8.0; extra == "test" Requires-Dist: zope.configuration; extra == "test" Requires-Dist: zope.interface>=4.0.0; extra == "test" Requires-Dist: zope.security>=3.7.2; extra == "test" Storm is an Object Relational Mapper for Python developed at Canonical. API docs, a manual, and a tutorial are available from: https://storm.canonical.com/ Introduction ============ The project was in development for more than a year for use in Canonical projects such as Launchpad and Landscape before being released as free software on July 9th, 2007. Design: * Clean and lightweight API offers a short learning curve and long-term maintainability. * Storm is developed in a test-driven manner. An untested line of code is considered a bug. * Storm needs no special class constructors, nor imperative base classes. * Storm is well designed (different classes have very clear boundaries, with small and clean public APIs). * Designed from day one to work both with thin relational databases, such as SQLite, and big iron systems like PostgreSQL and MySQL. * Storm is easy to debug, since its code is written with a KISS principle, and thus is easy to understand. * Designed from day one to work both at the low end, with trivial small databases, and the high end, with applications accessing billion row tables and committing to multiple database backends. * It's very easy to write and support backends for Storm (current backends have around 100 lines of code). Features: * Storm is fast. * Storm lets you efficiently access and update large datasets by allowing you to formulate complex queries spanning multiple tables using Python. * Storm allows you to fallback to SQL if needed (or if you just prefer), allowing you to mix "old school" code and ORM code * Storm handles composed primary keys with ease (no need for surrogate keys). * Storm doesn't do schema management, and as a result you're free to manage the schema as wanted, and creating classes that work with Storm is clean and simple. * Storm works very well connecting to several databases and using the same Python types (or different ones) with all of them. * Storm can handle obj.attr = assignments, when that's really needed (the expression is executed at INSERT/UPDATE time). * Storm handles relationships between objects even before they were added to a database. * Storm works well with existing database schemas. * Storm will flush changes to the database automatically when needed, so that queries made affect recently modified objects. License ======= Copyright (C) 2006-2020 Canonical, Ltd. All contributions must have copyright assigned to Canonical. This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA On Ubuntu systems, the complete text of the GNU Lesser General Public Version 2.1 License is in /usr/share/common-licenses/LGPL-2.1 Developing Storm ================ SHORT VERSION: If you are running ubuntu, or probably debian, the following should work. If not, and for reference, the long version is below. $ dev/ubuntu-deps $ echo "$PWD/** rwk," | sudo tee /etc/apparmor.d/local/usr.sbin.mysqld >/dev/null $ sudo aa-enforce /usr/sbin/mysqld $ make develop $ make check LONG VERSION: The following instructions describe the procedure for setting up a development environment and running the test suite. Installing dependencies ----------------------- The following instructions assume that you're using Ubuntu. The same procedure will probably work without changes on a Debian system and with minimal changes on a non-Debian-based linux distribution. In order to run the test suite, and exercise all supported backends, you will need to install MySQL and PostgreSQL, along with the related Python database drivers: $ sudo apt-get install \ mysql-server \ postgresql pgbouncer \ build-essential These will take a few minutes to download. The Python dependencies for running tests can be installed with apt-get: $ apt-get install \ python3-fixtures \ python3-pgbouncer \ python3-psycopg2 \ python3-testresources \ python3-timeline \ python3-transaction \ python3-twisted \ python3-zope.component \ python3-zope.security Alternatively, dependencies can be downloaded as eggs into the current directory with: $ make develop This ensures that all dependencies are available, downloading from PyPI as appropriate. Database setup -------------- Most database setup is done automatically by the test suite. However, Ubuntu's default MySQL packaging ships an AppArmor profile that prevents it from writing to a local data directory. To allow the test suite to do this, you will need to grant it access, which is most easily done by adding a line such as this to /etc/apparmor.d/local/usr.sbin.mysqld: /path/to/storm/** rwk, Then reload the profile: $ sudo aa-enforce /usr/sbin/mysqld Running the tests ----------------- Finally, its time to run the tests! Go into the base directory of the storm branch you want to test, and run: $ make check They'll take a while to run. All tests should pass: failures mean there's a problem with your environment or a bug in Storm. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/README0000644000175000017500000001315114571373456013720 0ustar00cjwatsoncjwatsonStorm is an Object Relational Mapper for Python developed at Canonical. API docs, a manual, and a tutorial are available from: https://storm.canonical.com/ Introduction ============ The project was in development for more than a year for use in Canonical projects such as Launchpad and Landscape before being released as free software on July 9th, 2007. Design: * Clean and lightweight API offers a short learning curve and long-term maintainability. * Storm is developed in a test-driven manner. An untested line of code is considered a bug. * Storm needs no special class constructors, nor imperative base classes. * Storm is well designed (different classes have very clear boundaries, with small and clean public APIs). * Designed from day one to work both with thin relational databases, such as SQLite, and big iron systems like PostgreSQL and MySQL. * Storm is easy to debug, since its code is written with a KISS principle, and thus is easy to understand. * Designed from day one to work both at the low end, with trivial small databases, and the high end, with applications accessing billion row tables and committing to multiple database backends. * It's very easy to write and support backends for Storm (current backends have around 100 lines of code). Features: * Storm is fast. * Storm lets you efficiently access and update large datasets by allowing you to formulate complex queries spanning multiple tables using Python. * Storm allows you to fallback to SQL if needed (or if you just prefer), allowing you to mix "old school" code and ORM code * Storm handles composed primary keys with ease (no need for surrogate keys). * Storm doesn't do schema management, and as a result you're free to manage the schema as wanted, and creating classes that work with Storm is clean and simple. * Storm works very well connecting to several databases and using the same Python types (or different ones) with all of them. * Storm can handle obj.attr = assignments, when that's really needed (the expression is executed at INSERT/UPDATE time). * Storm handles relationships between objects even before they were added to a database. * Storm works well with existing database schemas. * Storm will flush changes to the database automatically when needed, so that queries made affect recently modified objects. License ======= Copyright (C) 2006-2020 Canonical, Ltd. All contributions must have copyright assigned to Canonical. This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA On Ubuntu systems, the complete text of the GNU Lesser General Public Version 2.1 License is in /usr/share/common-licenses/LGPL-2.1 Developing Storm ================ SHORT VERSION: If you are running ubuntu, or probably debian, the following should work. If not, and for reference, the long version is below. $ dev/ubuntu-deps $ echo "$PWD/** rwk," | sudo tee /etc/apparmor.d/local/usr.sbin.mysqld >/dev/null $ sudo aa-enforce /usr/sbin/mysqld $ make develop $ make check LONG VERSION: The following instructions describe the procedure for setting up a development environment and running the test suite. Installing dependencies ----------------------- The following instructions assume that you're using Ubuntu. The same procedure will probably work without changes on a Debian system and with minimal changes on a non-Debian-based linux distribution. In order to run the test suite, and exercise all supported backends, you will need to install MySQL and PostgreSQL, along with the related Python database drivers: $ sudo apt-get install \ mysql-server \ postgresql pgbouncer \ build-essential These will take a few minutes to download. The Python dependencies for running tests can be installed with apt-get: $ apt-get install \ python3-fixtures \ python3-pgbouncer \ python3-psycopg2 \ python3-testresources \ python3-timeline \ python3-transaction \ python3-twisted \ python3-zope.component \ python3-zope.security Alternatively, dependencies can be downloaded as eggs into the current directory with: $ make develop This ensures that all dependencies are available, downloading from PyPI as appropriate. Database setup -------------- Most database setup is done automatically by the test suite. However, Ubuntu's default MySQL packaging ships an AppArmor profile that prevents it from writing to a local data directory. To allow the test suite to do this, you will need to grant it access, which is most easily done by adding a line such as this to /etc/apparmor.d/local/usr.sbin.mysqld: /path/to/storm/** rwk, Then reload the profile: $ sudo aa-enforce /usr/sbin/mysqld Running the tests ----------------- Finally, its time to run the tests! Go into the base directory of the storm branch you want to test, and run: $ make check They'll take a while to run. All tests should pass: failures mean there's a problem with your environment or a bug in Storm. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/TODO0000644000175000017500000000465111752263216013523 0ustar00cjwatsoncjwatson- The on_remote flag of references should be infered when the local property is a primary key (or part of it?). - Allow something like Int(unique=True), so that it may be used for cached gets (perhaps with get(Class.the_key, value), or get(Class, value, attribute=Class.the_key). - Unicode(autoreload=True) will mark the field as autoreload by default. - Lazy-by-default attributes: class C(object): ... attr = Unicode(lazy=True) This would make attr be loaded only if touched. Or maybe lazy groups: class C(object): ... lazy_group = LazyGroup() attr = Unicode(lazy=True, lazy_group=lazy_group) Once one of the group attributes are accessed all of them are retrieved at the same time. Lazy groups may be integers as well: class C(object): ... attr = Unicode(lazy_group=1) lazy_group=None means not lazy. - Implement ResultSet.reverse[d]() to invert order_by()? - Add support to cyclic references when all of elements of the cycle are flushed at the same time. - Implement support for negative caches to tell when an object isn't available. - Implement support for complex removes and updates with Exists(). - Log SQL statements and Store actions. - Support for quoted strings. - Option to keep object in cache until explicitly removed? - Implement store.copy() - Implement must_define in properties. - Implement slicing ([:]) in BoundReferenceSet - Handle $foo$bar$foo$ literals - Could Reference(Set)s include a "where" clause? Readonly perhaps? - Make the primary key for a class be optional. If it's not provided the object isn't cached and updates aren't tracked. - Between() - Automatic class generation, perhaps based on Django's inspectdb: http://www.djangoproject.com/documentation/legacy_databases/ http://www.djangoproject.com/documentation/django_admin/ - Support allow_microseconds=False in DateTime properties/variables. - Support allow_self in Reference and ReferenceSet, and default to false. - Set operations in ReferenceSets (suggested by Stephan Diehl): accessGroups = set([grp1, grp2, grp3]) if usr.groups & accessGroups: doSomething - Think about something like a store.cache(...) method with the same signature of store.find(...) which stores objects in the cache so that they don't get deallocated during the current transaction. The Cache class interface would have to be expanded to handle these cases in a special way. ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4091246 storm-1.0/dev/0000755000175000017500000000000014645532536013613 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/dev/test0000755000175000017500000001776214571373456014537 0ustar00cjwatsoncjwatson#!/usr/bin/env python3 # # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import glob import optparse import os import re import shutil import socket import subprocess import sys import time import unittest from pkg_resources import parse_version def add_eggs_to_path(): here = os.path.dirname(__file__) egg_paths = glob.glob(os.path.join(here, "*.egg")) sys.path[:0] = map(os.path.abspath, egg_paths) # python setup.py test [--dry-run] puts $package.egg directories in the # top directory, so we add them to sys.path here for convenience. add_eggs_to_path() def test_with_runner(runner): usage = "test.py [options] [, ...]" parser = optparse.OptionParser(usage=usage) parser.add_option('--verbose', action='store_true') opts, args = parser.parse_args() if opts.verbose: runner.verbosity = 2 # Import late, after any and all sys.path jiggery pokery. from storm.tests import find_tests suite = find_tests(args) result = runner.run(suite) return not result.wasSuccessful() def test_with_trial(): from twisted.trial import unittest as trial_unittest from twisted.trial.reporter import TreeReporter from twisted.trial.runner import TrialRunner # This is only imported to work around # https://twistedmatrix.com/trac/ticket/8267. trial_unittest runner = TrialRunner(reporterFactory=TreeReporter, realTimeErrors=True) return test_with_runner(runner) def test_with_unittest(): runner = unittest.TextTestRunner() return test_with_runner(runner) def with_postgresfixture(runner_func): """If possible, wrap a test runner with code to set up PostgreSQL.""" try: from postgresfixture import ClusterFixture except ImportError: return runner_func from urllib.parse import urlunsplit from storm.uri import escape def wrapper(): cluster = ClusterFixture("db") cluster.create() # Allow use of prepared transactions, which some tests need. pg_conf_path = os.path.join(cluster.datadir, "postgresql.conf") with open(pg_conf_path) as pg_conf_old: with open(pg_conf_path + ".new", "w") as pg_conf_new: for line in pg_conf_old: pg_conf_new.write(re.sub( r"^#(max_prepared_transactions.*)= 0", r"\1= 200", line)) os.fchmod( pg_conf_new.fileno(), os.fstat(pg_conf_old.fileno()).st_mode) os.rename(pg_conf_path + ".new", pg_conf_path) with cluster: cluster.createdb("storm_test") uri = urlunsplit( ("postgres", escape(cluster.datadir), "/storm_test", "", "")) os.environ["STORM_POSTGRES_URI"] = uri os.environ["STORM_POSTGRES_HOST_URI"] = uri return runner_func() return wrapper def with_mysql(runner_func): """If possible, wrap a test runner with code to set up MySQL. Loosely based on the approach taken by pytest-mysql, although implemented separately. """ try: import MySQLdb except ImportError: return runner_func from urllib.parse import ( urlencode, urlunsplit, ) def wrapper(): basedir = os.path.abspath("db-mysql") datadir = os.path.join(basedir, "data") unix_socket = os.path.join(basedir, "mysql.sock") logfile = os.path.join(basedir, "mysql.log") if os.path.exists(basedir): shutil.rmtree(basedir) os.makedirs(basedir) mysqld_version_output = subprocess.check_output( ["mysqld", "--version"], universal_newlines=True).rstrip("\n") version = re.search( r"Ver ([\d.]+)", mysqld_version_output, flags=re.I).group(1) if ("MariaDB" not in mysqld_version_output and parse_version(version) >= parse_version("5.7.6")): subprocess.check_call([ "mysqld", "--initialize-insecure", "--datadir=%s" % datadir, "--tmpdir=%s" % basedir, "--log-error=%s" % logfile, ]) else: subprocess.check_call([ "mysql_install_db", "--datadir=%s" % datadir, "--tmpdir=%s" % basedir, ]) with open("/dev/null", "w") as devnull: server_proc = subprocess.Popen([ "mysqld_safe", "--datadir=%s" % datadir, "--pid-file=%s" % os.path.join(basedir, "mysql.pid"), "--socket=%s" % unix_socket, "--skip-networking", "--log-error=%s" % logfile, "--tmpdir=%s" % basedir, "--skip-syslog", # We don't care about durability of test data. Try to # persuade MySQL to agree. "--innodb-doublewrite=0", "--innodb-flush-log-at-trx-commit=0", "--innodb-flush-method=O_DIRECT_NO_FSYNC", "--skip-innodb-file-per-table", "--sync-binlog=0", ], stdout=devnull) try: start_time = time.time() while time.time() < start_time + 60: code = server_proc.poll() if code is not None and code != 0: raise Exception("mysqld_base exited %d" % code) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: sock.connect(unix_socket) break except socket.error: pass # try again finally: sock.close() time.sleep(0.1) try: connection = MySQLdb.connect( user="root", unix_socket=unix_socket) try: cursor = connection.cursor() try: cursor.execute( "CREATE DATABASE storm_test " "CHARACTER SET utf8mb4;") finally: cursor.close() finally: connection.close() uri = urlunsplit( ("mysql", "root@localhost", "/storm_test", urlencode({"unix_socket": unix_socket}), "")) os.environ["STORM_MYSQL_URI"] = uri os.environ["STORM_MYSQL_HOST_URI"] = uri return runner_func() finally: subprocess.check_call([ "mysqladmin", "--socket=%s" % unix_socket, "--user=root", "shutdown", ]) finally: if server_proc.poll() is None: server_proc.kill() server_proc.wait() return wrapper if __name__ == "__main__": runner = os.environ.get("STORM_TEST_RUNNER") if not runner: runner = "unittest" runner_func = globals().get("test_with_%s" % runner.replace(".", "_")) if not runner_func: sys.exit("Test runner not found: %s" % runner) runner_func = with_postgresfixture(runner_func) runner_func = with_mysql(runner_func) sys.exit(runner_func()) # vim:ts=4:sw=4:et ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4291248 storm-1.0/setup.cfg0000644000175000017500000000010614645532536014653 0ustar00cjwatsoncjwatson[sdist] formats = bztar, gztar [egg_info] tag_build = tag_date = 0 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039171.0 storm-1.0/setup.py0000755000175000017500000000536414645174503014556 0ustar00cjwatsoncjwatson#!/usr/bin/env python3 import os import re from setuptools import setup, Extension, find_packages if os.path.isfile("MANIFEST"): os.unlink("MANIFEST") BUILD_CEXTENSIONS = True with open("storm/__init__.py") as init_py: VERSION = re.search('version = "([^"]+)"', init_py.read()).group(1) with open("README") as readme: long_description = readme.read() tests_require = [ "fixtures >= 1.3.0", "mysqlclient", "pgbouncer >= 0.0.7", "postgresfixture", "psycopg2 >= 2.3.0", "testresources >= 0.2.4", "testtools >= 0.9.8", "timeline >= 0.0.2", "transaction >= 1.0.0", "Twisted >= 10.0.0", "zope.component >= 3.8.0", "zope.configuration", "zope.interface >= 4.0.0", "zope.security >= 3.7.2", ] setup( name="storm", version=VERSION, description=( "Storm is an object-relational mapper (ORM) for Python " "developed at Canonical."), long_description=long_description, long_description_content_type="text/x-rst", author="Gustavo Niemeyer", author_email="gustavo@niemeyer.net", maintainer="Storm Developers", maintainer_email="storm@lists.canonical.com", license="LGPL", url="https://storm.canonical.com", download_url="https://launchpad.net/storm/+download", packages=find_packages(), package_data={"": ["*.zcml"]}, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", ("License :: OSI Approved :: GNU Library or " "Lesser General Public License (LGPL)"), "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Topic :: Database", "Topic :: Database :: Front-Ends", "Topic :: Software Development :: Libraries :: Python Modules", ], ext_modules=(BUILD_CEXTENSIONS and [Extension("storm.cextensions", ["storm/cextensions.c"])]), # The following options are specific to setuptools but ignored (with a # warning) by distutils. include_package_data=True, zip_safe=False, python_requires=">=3.6", install_requires=["packaging >= 14.1"], test_suite="storm.tests.find_tests", tests_require=tests_require, extras_require={ "doc": [ "fixtures", # so that storm.testing can be imported "sphinx", "sphinx-epytext", ], "test": tests_require, }, ) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4131246 storm-1.0/storm/0000755000175000017500000000000014645532536014201 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039247.0 storm-1.0/storm/__init__.py0000644000175000017500000000363014645174617016316 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os # Use a 4-digit version for development versions with 99 as the final version. # For example, if 0.15 is the currently released version of Storm, the # development version should be version 0.15.0.99. This will make it obvious # that this isn't the 0.15 release version while also allowing us to release # an 0.15.1 if need be. Release versions should use 2-digit version numbers, # with 0.16 being the next release version in this example. version = "1.0" version_info = tuple([int(x) for x in version.split(".")]) class UndefType: def __repr__(self): return "Undef" def __reduce__(self): return "Undef" Undef = UndefType() # C extensions are enabled by default. They are not used if the # STORM_CEXTENSIONS environment variable is set to '0'. If they can't be # imported Storm will automatically use Python versions of the optimized code # in the C extension. has_cextensions = False if os.environ.get("STORM_CEXTENSIONS") != "0": try: from storm import cextensions has_cextensions = True except ImportError as e: if "cextensions" not in str(e): raise ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/base.py0000644000175000017500000000220614571373456015467 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.properties import PropertyPublisherMeta __all__ = ["Storm"] class Storm(metaclass=PropertyPublisherMeta): """An optional base class for objects stored in a Storm Store. It causes your subclasses to be associated with a Storm PropertyRegistry. It is necessary to use this if you want to specify References with strings. """ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/cache.py0000644000175000017500000001326514645174376015631 0ustar00cjwatsoncjwatsonimport itertools class Cache: """Prevents recently used objects from being deallocated. This prevents recently used objects from being deallocated by Python even if the user isn't holding any strong references to it. It does that by holding strong references to the objects referenced by the last C{N} C{obj_info}\\ s added to it (where C{N} is the cache size). """ def __init__(self, size=1000): self._size = size self._cache = {} # {obj_info: obj, ...} self._order = [] # [obj_info, ...] def clear(self): """Clear the entire cache at once.""" self._cache.clear() del self._order[:] def add(self, obj_info): """Add C{obj_info} as the most recent entry in the cache. If the C{obj_info} is already in the cache, it remains in the cache and has its order changed to become the most recent entry (IOW, will be the last to leave). """ if self._size != 0: if obj_info in self._cache: self._order.remove(obj_info) else: self._cache[obj_info] = obj_info.get_obj() self._order.insert(0, obj_info) if len(self._cache) > self._size: del self._cache[self._order.pop()] def remove(self, obj_info): """Remove C{obj_info} from the cache, if present. @return: True if C{obj_info} was cached, False otherwise. """ if obj_info in self._cache: self._order.remove(obj_info) del self._cache[obj_info] return True return False def set_size(self, size): """Set the maximum number of objects that may be held in this cache. If the size is reduced, older C{obj_info}\\ s may be dropped from the cache to respect the new size. """ if size == 0: self.clear() else: # Remove all entries above the new size. while len(self._cache) > size: del self._cache[self._order.pop()] self._size = size def get_cached(self): """Return an ordered list of the currently cached C{obj_info}\\ s. The most recently added objects come first in the list. """ return list(self._order) class GenerationalCache: """Generational replacement for Storm's LRU cache. This cache approximates LRU without keeping exact track. Instead, it keeps a primary dict of "recently used" objects plus a similar, secondary dict of objects used in a previous timeframe. When the "most recently used" dict reaches its size limit, it is demoted to secondary dict and a fresh primary dict is set up. The previous secondary dict is evicted in its entirety. Use this to replace the LRU cache for sizes where LRU tracking overhead becomes too large (e.g. 100,000 objects) or the `StupidCache` when it eats up too much memory. """ def __init__(self, size=1000): """Create a generational cache with the given size limit. The size limit applies not to the overall cache, but to the primary one only. When this reaches the size limit, the real number of cached objects will be somewhere between size and 2*size depending on how much overlap there is between the primary and secondary caches. """ self._size = size self._new_cache = {} self._old_cache = {} def clear(self): """See L{Cache.clear}. Clears both the primary and the secondary caches. """ self._new_cache.clear() self._old_cache.clear() def _bump_generation(self): """Start a new generation of the cache. The primary generation becomes the secondary one, and the old secondary generation is evicted. Kids at home: do not try this for yourself. We are trained professionals working with specially-bred generations. This would not be an appropriate way of treating older generations of actual people. """ self._old_cache, self._new_cache = self._new_cache, self._old_cache self._new_cache.clear() def add(self, obj_info): """See L{Cache.add}.""" if self._size != 0 and obj_info not in self._new_cache: if len(self._new_cache) >= self._size: self._bump_generation() self._new_cache[obj_info] = obj_info.get_obj() def remove(self, obj_info): """See L{Cache.remove}.""" in_new_cache = self._new_cache.pop(obj_info, None) is not None in_old_cache = self._old_cache.pop(obj_info, None) is not None return in_new_cache or in_old_cache def set_size(self, size): """See L{Cache.set_size}. After calling this, the cache may still contain more than `size` objects, but no more than twice that number. """ self._size = size cache = itertools.islice( itertools.chain(self._new_cache.items(), self._old_cache.items()), 0, size) self._new_cache = dict(cache) self._old_cache.clear() def get_cached(self): """See L{Cache.get_cached}. The result is a loosely-ordered list. Any object in the primary generation comes before any object that is only in the secondary generation, but objects within a generation are not ordered and there is no indication of the boundary between the two. Objects that are in both the primary and the secondary generation are listed only as part of the primary generation. """ cached = self._new_cache.copy() cached.update(self._old_cache) return list(cached) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/cextensions.c0000644000175000017500000020562714571373456016725 0ustar00cjwatsoncjwatson/* # # Copyright (c) 2006-2008 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # */ #include #include #define CATCH(error_value, expression) \ do { \ if ((expression) == error_value) {\ /*printf("GOT AN ERROR AT LINE %d!\n", __LINE__);*/ \ goto error; \ } \ } while (0) #define REPLACE(variable, new_value) \ do { \ PyObject *tmp = variable; \ variable = new_value; \ Py_DECREF(tmp); \ } while(0) static PyObject *Undef = NULL; static PyObject *LazyValue = NULL; static PyObject *raise_none_error = NULL; static PyObject *get_cls_info = NULL; static PyObject *EventSystem = NULL; static PyObject *SQLRaw = NULL; static PyObject *SQLToken = NULL; static PyObject *State = NULL; static PyObject *CompileError = NULL; static PyObject *parenthesis_format = NULL; static PyObject *default_compile_join = NULL; typedef struct { PyObject_HEAD PyObject *_weakreflist; PyObject *_owner_ref; PyObject *_hooks; } EventSystemObject; typedef struct { PyObject_HEAD PyObject *_value; PyObject *_lazy_value; PyObject *_checkpoint_state; PyObject *_allow_none; PyObject *_validator; PyObject *_validator_object_factory; PyObject *_validator_attribute; PyObject *column; PyObject *event; } VariableObject; typedef struct { PyObject_HEAD PyObject *_weakreflist; PyObject *_local_dispatch_table; PyObject *_local_precedence; PyObject *_local_reserved_words; PyObject *_dispatch_table; PyObject *_precedence; PyObject *_reserved_words; PyObject *_children; PyObject *_parents; } CompileObject; typedef struct { PyDictObject super; PyObject *_weakreflist; PyObject *_obj_ref; PyObject *_obj_ref_callback; PyObject *cls_info; PyObject *event; PyObject *variables; PyObject *primary_vars; } ObjectInfoObject; static int initialize_globals(void) { static int initialized = -1; PyObject *module; if (initialized >= 0) { /* This function should never fail under normal circumstances, * but if it does, raise an error on subsequent calls instead * of segfaulting. This should make it easier to track down * the cause of such errors. * * https://bugs.launchpad.net/storm/+bug/1006284 */ if (!initialized) PyErr_SetString(PyExc_RuntimeError, "initialize_globals() failed the first time " "it was run"); return initialized; } initialized = 0; /* Import objects from storm module */ module = PyImport_ImportModule("storm"); if (!module) return 0; Undef = PyObject_GetAttrString(module, "Undef"); if (!Undef) return 0; Py_DECREF(module); /* Import objects from storm.variables module */ module = PyImport_ImportModule("storm.variables"); if (!module) return 0; raise_none_error = PyObject_GetAttrString(module, "raise_none_error"); if (!raise_none_error) return 0; LazyValue = PyObject_GetAttrString(module, "LazyValue"); if (!LazyValue) return 0; Py_DECREF(module); /* Import objects from storm.info module */ module = PyImport_ImportModule("storm.info"); if (!module) return 0; get_cls_info = PyObject_GetAttrString(module, "get_cls_info"); if (!get_cls_info) return 0; Py_DECREF(module); /* Import objects from storm.event module */ module = PyImport_ImportModule("storm.event"); if (!module) return 0; EventSystem = PyObject_GetAttrString(module, "EventSystem"); if (!EventSystem) return 0; Py_DECREF(module); /* Import objects from storm.expr module */ module = PyImport_ImportModule("storm.expr"); if (!module) return 0; SQLRaw = PyObject_GetAttrString(module, "SQLRaw"); if (!SQLRaw) return 0; SQLToken = PyObject_GetAttrString(module, "SQLToken"); if (!SQLToken) return 0; State = PyObject_GetAttrString(module, "State"); if (!State) return 0; CompileError = PyObject_GetAttrString(module, "CompileError"); if (!CompileError) return 0; Py_DECREF(module); /* A few frequently used objects which are part of the fast path. */ parenthesis_format = PyUnicode_DecodeASCII("(%s)", 4, "strict"); default_compile_join = PyUnicode_DecodeASCII(", ", 2, "strict"); initialized = 1; return initialized; } static int EventSystem_init(EventSystemObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"owner", NULL}; PyObject *owner; int result = -1; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O", kwlist, &owner)) return -1; self->_weakreflist = NULL; /* self._owner_ref = weakref.ref(owner) */ self->_owner_ref = PyWeakref_NewRef(owner, NULL); if (self->_owner_ref) { /* self._hooks = {} */ self->_hooks = PyDict_New(); if (self->_hooks) { result = 0; } } return result; } static int EventSystem_traverse(EventSystemObject *self, visitproc visit, void *arg) { Py_VISIT(self->_owner_ref); Py_VISIT(self->_hooks); return 0; } static int EventSystem_clear(EventSystemObject *self) { if (self->_weakreflist) PyObject_ClearWeakRefs((PyObject *)self); Py_CLEAR(self->_owner_ref); Py_CLEAR(self->_hooks); return 0; } static void EventSystem_dealloc(EventSystemObject *self) { EventSystem_clear(self); Py_TYPE(self)->tp_free((PyObject *)self); } static PyObject * EventSystem_hook(EventSystemObject *self, PyObject *args) { PyObject *result = NULL; PyObject *name, *callback, *data; if (PyTuple_GET_SIZE(args) < 2) { PyErr_SetString(PyExc_TypeError, "Invalid number of arguments"); return NULL; } name = PyTuple_GET_ITEM(args, 0); callback = PyTuple_GET_ITEM(args, 1); data = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args)); if (data) { /* callbacks = self._hooks.get(name) if callbacks is None: self._hooks.setdefault(name, set()).add((callback, data)) else: callbacks.add((callback, data)) */ PyObject *callbacks = PyDict_GetItem(self->_hooks, name); if (!PyErr_Occurred()) { if (callbacks == NULL) { callbacks = PySet_New(NULL); if (callbacks && PyDict_SetItem(self->_hooks, name, callbacks) == -1) { Py_DECREF(callbacks); callbacks = NULL; } } else { Py_INCREF(callbacks); } if (callbacks) { PyObject *tuple = PyTuple_New(2); if (tuple) { Py_INCREF(callback); PyTuple_SET_ITEM(tuple, 0, callback); Py_INCREF(data); PyTuple_SET_ITEM(tuple, 1, data); if (PySet_Add(callbacks, tuple) != -1) { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(tuple); } Py_DECREF(callbacks); } } Py_DECREF(data); } return result; } static PyObject * EventSystem_unhook(EventSystemObject *self, PyObject *args) { PyObject *result = NULL; PyObject *name, *callback, *data; if (PyTuple_GET_SIZE(args) < 2) { PyErr_SetString(PyExc_TypeError, "Invalid number of arguments"); return NULL; } name = PyTuple_GET_ITEM(args, 0); callback = PyTuple_GET_ITEM(args, 1); data = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args)); if (data) { /* callbacks = self._hooks.get(name) if callbacks is not None: callbacks.discard((callback, data)) */ PyObject *callbacks = PyDict_GetItem(self->_hooks, name); if (callbacks) { PyObject *tuple = PyTuple_New(2); if (tuple) { Py_INCREF(callback); PyTuple_SET_ITEM(tuple, 0, callback); Py_INCREF(data); PyTuple_SET_ITEM(tuple, 1, data); if (PySet_Discard(callbacks, tuple) != -1) { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(tuple); } } else if (!PyErr_Occurred()) { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(data); } return result; } static PyObject * EventSystem__do_emit_call(PyObject *callback, PyObject *owner, PyObject *args, PyObject *data) { /* return callback(owner, *(args+data)) */ PyObject *result = NULL; PyObject *tuple = PyTuple_New(PyTuple_GET_SIZE(args) + PyTuple_GET_SIZE(data) + 1); if (tuple) { Py_ssize_t i, tuple_i; Py_INCREF(owner); PyTuple_SET_ITEM(tuple, 0, owner); tuple_i = 1; for (i = 0; i != PyTuple_GET_SIZE(args); i++) { PyObject *item = PyTuple_GET_ITEM(args, i); Py_INCREF(item); PyTuple_SET_ITEM(tuple, tuple_i++, item); } for (i = 0; i != PyTuple_GET_SIZE(data); i++) { PyObject *item = PyTuple_GET_ITEM(data, i); Py_INCREF(item); PyTuple_SET_ITEM(tuple, tuple_i++, item); } result = PyObject_Call(callback, tuple, NULL); Py_DECREF(tuple); } return result; } static PyObject * EventSystem_emit(EventSystemObject *self, PyObject *all_args) { PyObject *result = NULL; PyObject *name, *args; if (PyTuple_GET_SIZE(all_args) == 0) { PyErr_SetString(PyExc_TypeError, "Invalid number of arguments"); return NULL; } /* XXX In the following code we trust on the format inserted by * the hook() method. If it's hacked somehow, it may blow up. */ name = PyTuple_GET_ITEM(all_args, 0); args = PyTuple_GetSlice(all_args, 1, PyTuple_GET_SIZE(all_args)); if (args) { /* owner = self._owner_ref() */ PyObject *owner = PyWeakref_GET_OBJECT(self->_owner_ref); /* if owner is not None: */ if (owner != Py_None) { /* callbacks = self._hooks.get(name) */ PyObject *callbacks = PyDict_GetItem(self->_hooks, name); Py_INCREF(owner); /* if callbacks: */ if (callbacks && PySet_GET_SIZE(callbacks) != 0) { /* for callback, data in tuple(callbacks): */ PyObject *sequence = \ PySequence_Fast(callbacks, "callbacks object isn't a set"); if (sequence) { int failed = 0; Py_ssize_t i; for (i = 0; i != PySequence_Fast_GET_SIZE(sequence); i++) { PyObject *item = PySequence_Fast_GET_ITEM(sequence, i); PyObject *callback = PyTuple_GET_ITEM(item, 0); PyObject *data = PyTuple_GET_ITEM(item, 1); PyObject *res; /* if callback(owner, *(args+data)) is False: callbacks.discard((callback, data)) */ res = EventSystem__do_emit_call(callback, owner, args, data); Py_XDECREF(res); if (res == NULL || (res == Py_False && PySet_Discard(callbacks, item) == -1)) { failed = 1; break; } } if (!failed) { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(sequence); } } else if (!PyErr_Occurred()) { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(owner); } else { Py_INCREF(Py_None); result = Py_None; } Py_DECREF(args); } return result; } static PyMethodDef EventSystem_methods[] = { {"hook", (PyCFunction)EventSystem_hook, METH_VARARGS, NULL}, {"unhook", (PyCFunction)EventSystem_unhook, METH_VARARGS, NULL}, {"emit", (PyCFunction)EventSystem_emit, METH_VARARGS, NULL}, {NULL, NULL} }; #define OFFSETOF(x) offsetof(EventSystemObject, x) static PyMemberDef EventSystem_members[] = { {"_object_ref", T_OBJECT, OFFSETOF(_owner_ref), READONLY, 0}, {"_hooks", T_OBJECT, OFFSETOF(_hooks), READONLY, 0}, {NULL} }; #undef OFFSETOF static PyTypeObject EventSystem_Type = { PyVarObject_HEAD_INIT(NULL, 0) "storm.variables.EventSystem", /*tp_name*/ sizeof(EventSystemObject), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)EventSystem_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash*/ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ (traverseproc)EventSystem_traverse, /*tp_traverse*/ (inquiry)EventSystem_clear, /*tp_clear*/ 0, /*tp_richcompare*/ offsetof(EventSystemObject, _weakreflist), /*tp_weaklistoffset*/ 0, /*tp_iter*/ 0, /*tp_iternext*/ EventSystem_methods, /*tp_methods*/ EventSystem_members, /*tp_members*/ 0, /*tp_getset*/ 0, /*tp_base*/ 0, /*tp_dict*/ 0, /*tp_descr_get*/ 0, /*tp_descr_set*/ 0, /*tp_dictoffset*/ (initproc)EventSystem_init, /*tp_init*/ 0, /*tp_alloc*/ 0, /*tp_new*/ 0, /*tp_free*/ 0, /*tp_is_gc*/ }; static PyObject * Variable_new(PyTypeObject *type, PyObject *args, PyObject *kwargs) { VariableObject *self = (VariableObject *)type->tp_alloc(type, 0); if (!initialize_globals()) return NULL; /* The following are defined as class properties, so we must initialize them here for methods to work with the same logic. */ Py_INCREF(Undef); self->_value = Undef; Py_INCREF(Undef); self->_lazy_value = Undef; Py_INCREF(Undef); self->_checkpoint_state = Undef; Py_INCREF(Py_True); self->_allow_none = Py_True; Py_INCREF(Py_None); self->event = Py_None; Py_INCREF(Py_None); self->column = Py_None; return (PyObject *)self; } static int Variable_init(VariableObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"value", "value_factory", "from_db", "allow_none", "column", "event", "validator", "validator_object_factory", "validator_attribute", NULL}; PyObject *value = Undef; PyObject *value_factory = Undef; PyObject *from_db = Py_False; PyObject *allow_none = Py_True; PyObject *column = Py_None; PyObject *event = Py_None; PyObject *validator = Py_None; PyObject *validator_object_factory = Py_None; PyObject *validator_attribute = Py_None; PyObject *tmp; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OOOOOOOOO", kwlist, &value, &value_factory, &from_db, &allow_none, &column, &event, &validator, &validator_object_factory, &validator_attribute)) return -1; /* if not allow_none: */ if (allow_none != Py_True && (allow_none == Py_False || !PyObject_IsTrue(allow_none))) { /* self._allow_none = False */ Py_INCREF(Py_False); REPLACE(self->_allow_none, Py_False); /* if value is None: */ if (value == Py_None) { /* raise_none_error(column, default=True) */ tmp = PyObject_CallFunctionObjArgs(raise_none_error, column, Py_True, NULL); /* tmp should always be NULL here. */ Py_XDECREF(tmp); goto error; } } /* if value is not Undef: */ if (value != Undef) { /* self.set(value, from_db) */ CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)self, "set", "OO", value, from_db)); Py_DECREF(tmp); } /* elif value_factory is not Undef: */ else if (value_factory != Undef) { /* self.set(value_factory(), from_db) */ CATCH(NULL, value = PyObject_CallFunctionObjArgs(value_factory, NULL)); tmp = PyObject_CallMethod((PyObject *)self, "set", "OO", value, from_db); Py_DECREF(value); CATCH(NULL, tmp); Py_DECREF(tmp); } /* if validator is not None: */ if (validator != Py_None) { /* self._validator = validator */ Py_INCREF(validator); self->_validator = validator; /* self._validator_object_factory = validator_object_factory */ Py_INCREF(validator_object_factory); self->_validator_object_factory = validator_object_factory; /* self._validator_attribute = validator_attribute */ Py_INCREF(validator_attribute); self->_validator_attribute = validator_attribute; } /* self.column = column */ Py_DECREF(self->column); Py_INCREF(column); self->column = column; /* self.event = weakref.proxy(event) if event is not None else None */ Py_DECREF(self->event); if (event != Py_None) { PyObject *event_proxy = PyWeakref_NewProxy(event, NULL); if (event_proxy) self->event = event_proxy; else goto error; } else { Py_INCREF(Py_None); self->event = Py_None; } return 0; error: return -1; } static int Variable_traverse(VariableObject *self, visitproc visit, void *arg) { Py_VISIT(self->_value); Py_VISIT(self->_lazy_value); Py_VISIT(self->_checkpoint_state); /* Py_VISIT(self->_allow_none); */ Py_VISIT(self->_validator); Py_VISIT(self->_validator_object_factory); Py_VISIT(self->_validator_attribute); Py_VISIT(self->column); Py_VISIT(self->event); return 0; } static int Variable_clear(VariableObject *self) { Py_CLEAR(self->_value); Py_CLEAR(self->_lazy_value); Py_CLEAR(self->_checkpoint_state); Py_CLEAR(self->_allow_none); Py_CLEAR(self->_validator); Py_CLEAR(self->_validator_object_factory); Py_CLEAR(self->_validator_attribute); Py_CLEAR(self->column); Py_CLEAR(self->event); return 0; } static void Variable_dealloc(VariableObject *self) { Variable_clear(self); Py_TYPE(self)->tp_free((PyObject *)self); } static PyObject * Variable_parse_get(VariableObject *self, PyObject *args) { /* return value */ PyObject *value, *to_db; if (!PyArg_ParseTuple(args, "OO:parse_get", &value, &to_db)) return NULL; Py_INCREF(value); return value; } static PyObject * Variable_parse_set(VariableObject *self, PyObject *args) { /* return value */ PyObject *value, *from_db; if (!PyArg_ParseTuple(args, "OO:parse_set", &value, &from_db)) return NULL; Py_INCREF(value); return value; } static PyObject * Variable_get_lazy(VariableObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"default", NULL}; PyObject *default_ = Py_None; PyObject *result; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O:get_lazy", kwlist, &default_)) return NULL; /* if self._lazy_value is Undef: return default return self._lazy_value */ if (self->_lazy_value == Undef) { result = default_; } else { result = self->_lazy_value; } Py_INCREF(result); return result; } static PyObject * Variable_get(VariableObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"default", "to_db", NULL}; PyObject *default_ = Py_None; PyObject *to_db = Py_False; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO:get", kwlist, &default_, &to_db)) return NULL; /* if self._lazy_value is not Undef and self.event is not None: */ if (self->_lazy_value != Undef && self->event != Py_None) { PyObject *tmp; /* self.event.emit("resolve-lazy-value", self, self._lazy_value) */ CATCH(NULL, tmp = PyObject_CallMethod(self->event, "emit", "sOO", "resolve-lazy-value", self, self->_lazy_value)); Py_DECREF(tmp); } /* value = self->_value */ /* if value is Undef: */ if (self->_value == Undef) { /* return default */ Py_INCREF(default_); return default_; } /* if value is None: */ if (self->_value == Py_None) { /* return None */ Py_RETURN_NONE; } /* return self.parse_get(value, to_db) */ return PyObject_CallMethod((PyObject *)self, "parse_get", "OO", self->_value, to_db); error: return NULL; } static PyObject * Variable_set(VariableObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"value", "from_db", NULL}; PyObject *value = Py_None; PyObject *from_db = Py_False; PyObject *old_value = NULL; PyObject *new_value = NULL; PyObject *tmp; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|OO:set", kwlist, &value, &from_db)) return NULL; Py_INCREF(value); /* if isinstance(value, LazyValue): */ if (PyObject_IsInstance(value, LazyValue)) { /* self._lazy_value = value */ Py_INCREF(value); REPLACE(self->_lazy_value, value); /* self._checkpoint_state = new_value = Undef */ Py_INCREF(Undef); Py_INCREF(Undef); new_value = Undef; Py_DECREF(self->_checkpoint_state); self->_checkpoint_state = Undef; } /* else: */ else { /* if not from_db and self._validator is not None: */ if (self->_validator && !PyObject_IsTrue(from_db)) { /* value = self._validator(self._validator_object_factory and self._validator_object_factory(), self._validator_attribute, value) */ PyObject *validator_object, *tmp; if (self->_validator_object_factory == Py_None) { Py_INCREF(Py_None); validator_object = Py_None; } else { CATCH(NULL, validator_object = PyObject_CallFunctionObjArgs( self->_validator_object_factory, NULL)); } tmp = PyObject_CallFunctionObjArgs(self->_validator, validator_object, self->_validator_attribute, value, NULL); Py_DECREF(validator_object); CATCH(NULL, tmp); Py_DECREF(value); value = tmp; } /* self._lazy_value = Undef */ Py_INCREF(Undef); Py_DECREF(self->_lazy_value); self->_lazy_value = Undef; /* if value is None: */ if (value == Py_None) { /* if self._allow_none is False: */ if (self->_allow_none == Py_False) { /* raise_none_error(self.column) */ tmp = PyObject_CallFunctionObjArgs(raise_none_error, self->column, NULL); /* tmp should always be NULL here. */ Py_XDECREF(tmp); goto error; } /* new_value = None */ Py_INCREF(Py_None); new_value = Py_None; } /* else: */ else { /* new_value = self.parse_set(value, from_db) */ CATCH(NULL, new_value = PyObject_CallMethod((PyObject *)self, "parse_set", "OO", value, from_db)); /* if from_db: */ if (PyObject_IsTrue(from_db)) { /* value = self.parse_get(new_value, False) */ Py_DECREF(value); CATCH(NULL, value = PyObject_CallMethod((PyObject *)self, "parse_get", "OO", new_value, Py_False)); } } } /* old_value = self._value */ old_value = self->_value; /* Keep the reference with old_value. */ /* self._value = new_value */ Py_INCREF(new_value); self->_value = new_value; /* if (self.event is not None and (self._lazy_value is not Undef or new_value != old_value)): */ if (self->event != Py_None && (self->_lazy_value != Undef || PyObject_RichCompareBool(new_value, old_value, Py_NE))) { /* if old_value is not None and old_value is not Undef: */ if (old_value != Py_None && old_value != Undef) { /* old_value = self.parse_get(old_value, False) */ CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)self, "parse_get", "OO", old_value, Py_False)); Py_DECREF(old_value); old_value = tmp; } /* self.event.emit("changed", self, old_value, value, from_db) */ CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)self->event, "emit", "sOOOO", "changed", self, old_value, value, from_db)); Py_DECREF(tmp); } Py_DECREF(value); Py_DECREF(old_value); Py_DECREF(new_value); Py_RETURN_NONE; error: Py_XDECREF(value); Py_XDECREF(old_value); Py_XDECREF(new_value); return NULL; } static PyObject * Variable_delete(VariableObject *self, PyObject *args) { PyObject *old_value; PyObject *tmp; /* old_value = self._value */ old_value = self->_value; Py_INCREF(old_value); /* if old_value is not Undef: */ if (old_value != Undef) { /* self._value = Undef */ Py_DECREF(self->_value); Py_INCREF(Undef); self->_value = Undef; /* if self.event is not None: */ if (self->event != Py_None) { /* if old_value is not None and old_value is not Undef: */ if (old_value != Py_None && old_value != Undef) { /* old_value = self.parse_get(old_value, False) */ CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)self, "parse_get", "OO", old_value, Py_False)); Py_DECREF(old_value); old_value = tmp; } /* self.event.emit("changed", self, old_value, Undef, False) */ CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)self->event, "emit", "sOOOO", "changed", self, old_value, Undef, Py_False)); Py_DECREF(tmp); } } Py_DECREF(old_value); Py_RETURN_NONE; error: Py_XDECREF(old_value); return NULL; } static PyObject * Variable_is_defined(VariableObject *self, PyObject *args) { /* return self._value is not Undef */ return PyBool_FromLong(self->_value != Undef); } static PyObject * Variable_has_changed(VariableObject *self, PyObject *args) { /* return (self._lazy_value is not Undef or self.get_state() != self._checkpoint_state) */ PyObject *result = Py_True; if (self->_lazy_value == Undef) { PyObject *state; int res; CATCH(NULL, state = PyObject_CallMethod((PyObject *)self, "get_state", NULL)); res = PyObject_RichCompareBool(state, self->_checkpoint_state, Py_EQ); Py_DECREF(state); CATCH(-1, res); if (res) result = Py_False; } Py_INCREF(result); return result; error: return NULL; } static PyObject * Variable_get_state(VariableObject *self, PyObject *args) { /* return (self._lazy_value, self._value) */ PyObject *result; CATCH(NULL, result = PyTuple_New(2)); Py_INCREF(self->_lazy_value); Py_INCREF(self->_value); PyTuple_SET_ITEM(result, 0, self->_lazy_value); PyTuple_SET_ITEM(result, 1, self->_value); return result; error: return NULL; } static PyObject * Variable_set_state(VariableObject *self, PyObject *args) { /* self._lazy_value, self._value = state */ PyObject *lazy_value, *value; if (!PyArg_ParseTuple(args, "(OO):set_state", &lazy_value, &value)) return NULL; Py_INCREF(lazy_value); REPLACE(self->_lazy_value, lazy_value); Py_INCREF(value); REPLACE(self->_value, value); Py_RETURN_NONE; } static PyObject * Variable_checkpoint(VariableObject *self, PyObject *args) { /* self._checkpoint_state = self.get_state() */ PyObject *state = PyObject_CallMethod((PyObject *)self, "get_state", NULL); if (!state) return NULL; Py_DECREF(self->_checkpoint_state); self->_checkpoint_state = state; Py_RETURN_NONE; } static PyObject * Variable_copy(VariableObject *self, PyObject *args) { PyObject *noargs = NULL; PyObject *variable = NULL; PyObject *state = NULL; PyObject *tmp; /* variable = self.__class__.__new__(self.__class__) */ noargs = PyTuple_New(0); CATCH(NULL, variable = Py_TYPE(self)->tp_new(Py_TYPE(self), noargs, NULL)); /* variable.set_state(self.get_state()) */ CATCH(NULL, state = PyObject_CallMethod((PyObject *)self, "get_state", NULL)); CATCH(NULL, tmp = PyObject_CallMethod((PyObject *)variable, "set_state", "(O)", state)); Py_DECREF(tmp); Py_DECREF(noargs); Py_DECREF(state); return variable; error: Py_XDECREF(noargs); Py_XDECREF(state); Py_XDECREF(variable); return NULL; } static PyMethodDef Variable_methods[] = { {"parse_get", (PyCFunction)Variable_parse_get, METH_VARARGS, NULL}, {"parse_set", (PyCFunction)Variable_parse_set, METH_VARARGS, NULL}, {"get_lazy", (PyCFunction)Variable_get_lazy, METH_VARARGS | METH_KEYWORDS, NULL}, {"get", (PyCFunction)Variable_get, METH_VARARGS | METH_KEYWORDS, NULL}, {"set", (PyCFunction)Variable_set, METH_VARARGS | METH_KEYWORDS, NULL}, {"delete", (PyCFunction)Variable_delete, METH_VARARGS | METH_KEYWORDS, NULL}, {"is_defined", (PyCFunction)Variable_is_defined, METH_NOARGS, NULL}, {"has_changed", (PyCFunction)Variable_has_changed, METH_NOARGS, NULL}, {"get_state", (PyCFunction)Variable_get_state, METH_NOARGS, NULL}, {"set_state", (PyCFunction)Variable_set_state, METH_VARARGS, NULL}, {"checkpoint", (PyCFunction)Variable_checkpoint, METH_NOARGS, NULL}, {"copy", (PyCFunction)Variable_copy, METH_NOARGS, NULL}, {NULL, NULL} }; #define OFFSETOF(x) offsetof(VariableObject, x) static PyMemberDef Variable_members[] = { {"_value", T_OBJECT, OFFSETOF(_value), 0, 0}, {"_lazy_value", T_OBJECT, OFFSETOF(_lazy_value), 0, 0}, {"_checkpoint_state", T_OBJECT, OFFSETOF(_checkpoint_state), 0, 0}, {"_allow_none", T_OBJECT, OFFSETOF(_allow_none), 0, 0}, {"column", T_OBJECT, OFFSETOF(column), 0, 0}, {"event", T_OBJECT, OFFSETOF(event), 0, 0}, {NULL} }; #undef OFFSETOF static PyTypeObject Variable_Type = { PyVarObject_HEAD_INIT(NULL, 0) "storm.variables.Variable", /*tp_name*/ sizeof(VariableObject), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Variable_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash*/ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ (traverseproc)Variable_traverse, /*tp_traverse*/ (inquiry)Variable_clear, /*tp_clear*/ 0, /*tp_richcompare*/ 0, /*tp_weaklistoffset*/ 0, /*tp_iter*/ 0, /*tp_iternext*/ Variable_methods, /*tp_methods*/ Variable_members, /*tp_members*/ 0, /*tp_getset*/ 0, /*tp_base*/ 0, /*tp_dict*/ 0, /*tp_descr_get*/ 0, /*tp_descr_set*/ 0, /*tp_dictoffset*/ (initproc)Variable_init, /*tp_init*/ 0, /*tp_alloc*/ Variable_new, /*tp_new*/ 0, /*tp_free*/ 0, /*tp_is_gc*/ }; static PyObject * Compile__update_cache(CompileObject *self, PyObject *args); static int Compile_init(CompileObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"parent", NULL}; PyObject *parent = Py_None; PyObject *module = NULL; PyObject *WeakKeyDictionary = NULL; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &parent)) return -1; /* self._local_dispatch_table = {} self._local_precedence = {} self._local_reserved_words = {} self._dispatch_table = {} self._precedence = {} self._reserved_words = {} */ CATCH(NULL, self->_local_dispatch_table = PyDict_New()); CATCH(NULL, self->_local_precedence = PyDict_New()); CATCH(NULL, self->_local_reserved_words = PyDict_New()); CATCH(NULL, self->_dispatch_table = PyDict_New()); CATCH(NULL, self->_precedence = PyDict_New()); CATCH(NULL, self->_reserved_words = PyDict_New()); /* self._children = WeakKeyDictionary() */ CATCH(NULL, module = PyImport_ImportModule("weakref")); CATCH(NULL, WeakKeyDictionary = \ PyObject_GetAttrString(module, "WeakKeyDictionary")); Py_CLEAR(module); CATCH(NULL, self->_children = \ PyObject_CallFunctionObjArgs(WeakKeyDictionary, NULL)); Py_CLEAR(WeakKeyDictionary); /* self._parents = [] */ CATCH(NULL, self->_parents = PyList_New(0)); /* if parent: */ if (parent != Py_None) { PyObject *tmp; /* self._parents.extend(parent._parents) */ CompileObject *parent_object = (CompileObject *)parent; CATCH(-1, PyList_SetSlice(self->_parents, 0, 0, parent_object->_parents)); /* self._parents.append(parent) */ CATCH(-1, PyList_Append(self->_parents, parent)); /* parent._children[self] = True */ CATCH(-1, PyObject_SetItem(parent_object->_children, (PyObject *)self, Py_True)); /* self._update_cache() */ CATCH(NULL, tmp = Compile__update_cache(self, NULL)); Py_DECREF(tmp); } return 0; error: Py_XDECREF(module); Py_XDECREF(WeakKeyDictionary); return -1; } static int Compile_traverse(CompileObject *self, visitproc visit, void *arg) { Py_VISIT(self->_local_dispatch_table); Py_VISIT(self->_local_precedence); Py_VISIT(self->_local_reserved_words); Py_VISIT(self->_dispatch_table); Py_VISIT(self->_precedence); Py_VISIT(self->_reserved_words); Py_VISIT(self->_children); Py_VISIT(self->_parents); return 0; } static int Compile_clear(CompileObject *self) { if (self->_weakreflist) PyObject_ClearWeakRefs((PyObject *)self); Py_CLEAR(self->_local_dispatch_table); Py_CLEAR(self->_local_precedence); Py_CLEAR(self->_local_reserved_words); Py_CLEAR(self->_dispatch_table); Py_CLEAR(self->_precedence); Py_CLEAR(self->_reserved_words); Py_CLEAR(self->_children); Py_CLEAR(self->_parents); return 0; } static void Compile_dealloc(CompileObject *self) { Compile_clear(self); Py_TYPE(self)->tp_free((PyObject *)self); } static PyObject * Compile__update_cache(CompileObject *self, PyObject *args) { PyObject *iter = NULL; PyObject *child = NULL; Py_ssize_t size; int i; /* for parent in self._parents: */ size = PyList_GET_SIZE(self->_parents); for (i = 0; i != size; i++) { CompileObject *parent = \ (CompileObject *)PyList_GET_ITEM(self->_parents, i); /* self._dispatch_table.update(parent._local_dispatch_table) */ CATCH(-1, PyDict_Update(self->_dispatch_table, parent->_local_dispatch_table)); /* self._precedence.update(parent._local_precedence) */ CATCH(-1, PyDict_Update(self->_precedence, parent->_local_precedence)); /* self._reserved_words.update(parent._local_reserved_words) */ CATCH(-1, PyDict_Update(self->_reserved_words, parent->_local_reserved_words)); } /* self._dispatch_table.update(self._local_dispatch_table) */ CATCH(-1, PyDict_Update(self->_dispatch_table, self->_local_dispatch_table)); /* self._precedence.update(self._local_precedence) */ CATCH(-1, PyDict_Update(self->_precedence, self->_local_precedence)); /* self._reserved_words.update(self._local_reserved_words) */ CATCH(-1, PyDict_Update(self->_reserved_words, self->_local_reserved_words)); /* for child in self._children: */ CATCH(NULL, iter = PyObject_GetIter(self->_children)); while((child = PyIter_Next(iter))) { PyObject *tmp; /* child._update_cache() */ CATCH(NULL, tmp = Compile__update_cache((CompileObject *)child, NULL)); Py_DECREF(tmp); Py_DECREF(child); } if (PyErr_Occurred()) goto error; Py_CLEAR(iter); Py_RETURN_NONE; error: Py_XDECREF(child); Py_XDECREF(iter); return NULL; } static PyObject * Compile_when(CompileObject *self, PyObject *types) { PyObject *result = NULL; PyObject *module = PyImport_ImportModule("storm.expr"); if (module) { PyObject *_when = PyObject_GetAttrString(module, "_when"); if (_when) { result = PyObject_CallFunctionObjArgs(_when, self, types, NULL); Py_DECREF(_when); } Py_DECREF(module); } return result; } static PyObject * Compile_add_reserved_words(CompileObject *self, PyObject *words) { PyObject *lower_word = NULL; PyObject *iter = NULL; PyObject *word = NULL; PyObject *tmp; /* self._local_reserved_words.update((word.lower(), True) for word in words) */ CATCH(NULL, iter = PyObject_GetIter(words)); while ((word = PyIter_Next(iter))) { CATCH(NULL, lower_word = PyObject_CallMethod(word, "lower", NULL)); CATCH(-1, PyDict_SetItem(self->_local_reserved_words, lower_word, Py_True)); Py_CLEAR(lower_word); Py_DECREF(word); } if (PyErr_Occurred()) goto error; Py_CLEAR(iter); /* self._update_cache() */ CATCH(NULL, tmp = Compile__update_cache(self, NULL)); Py_DECREF(tmp); Py_RETURN_NONE; error: Py_XDECREF(lower_word); Py_XDECREF(word); Py_XDECREF(iter); return NULL; } static PyObject * Compile_remove_reserved_words(CompileObject *self, PyObject *words) { PyObject *lower_word = NULL; PyObject *word = NULL; PyObject *iter = NULL; PyObject *tmp; /* self._local_reserved_words.update((word.lower(), None) for word in words) */ CATCH(NULL, iter = PyObject_GetIter(words)); while ((word = PyIter_Next(iter))) { CATCH(NULL, lower_word = PyObject_CallMethod(word, "lower", NULL)); CATCH(-1, PyDict_SetItem(self->_local_reserved_words, lower_word, Py_None)); Py_CLEAR(lower_word); Py_DECREF(word); } if (PyErr_Occurred()) goto error; Py_CLEAR(iter); /* self._update_cache() */ CATCH(NULL, tmp = Compile__update_cache(self, NULL)); Py_DECREF(tmp); Py_RETURN_NONE; error: Py_XDECREF(lower_word); Py_XDECREF(word); Py_XDECREF(iter); return NULL; } static PyObject * Compile_is_reserved_word(CompileObject *self, PyObject *word) { PyObject *lower_word = NULL; PyObject *result = Py_False; PyObject *item; /* return self._reserved_words.get(word.lower()) is not None */ CATCH(NULL, lower_word = PyObject_CallMethod(word, "lower", NULL)); item = PyDict_GetItem(self->_reserved_words, lower_word); if (item == NULL && PyErr_Occurred()) { goto error; } else if (item != NULL && item != Py_None) { result = Py_True; } Py_DECREF(lower_word); Py_INCREF(result); return result; error: Py_XDECREF(lower_word); return NULL; } static PyTypeObject Compile_Type; static PyObject * Compile_create_child(CompileObject *self, PyObject *args) { /* return self.__class__(self) */ return PyObject_CallFunctionObjArgs((PyObject *)Py_TYPE(self), self, NULL); } static PyObject * Compile_get_precedence(CompileObject *self, PyObject *type) { /* return self._precedence.get(type, MAX_PRECEDENCE) */ PyObject *result = PyDict_GetItem(self->_precedence, type); if (result == NULL && !PyErr_Occurred()) { /* That should be MAX_PRECEDENCE, defined in expr.py */ return PyLong_FromLong(1000); } Py_INCREF(result); return result; } static PyObject * Compile_set_precedence(CompileObject *self, PyObject *args) { Py_ssize_t size = PyTuple_GET_SIZE(args); PyObject *precedence = NULL; PyObject *tmp; int i; if (size < 2) { PyErr_SetString(PyExc_TypeError, "set_precedence() takes at least 2 arguments."); return NULL; } /* for type in types: */ precedence = PyTuple_GET_ITEM(args, 0); for (i = 1; i != size; i++) { PyObject *type = PyTuple_GET_ITEM(args, i); /* self._local_precedence[type] = precedence */ CATCH(-1, PyDict_SetItem(self->_local_precedence, type, precedence)); } /* self._update_cache() */ CATCH(NULL, tmp = Compile__update_cache(self, NULL)); Py_DECREF(tmp); Py_RETURN_NONE; error: return NULL; } PyObject * Compile_single(CompileObject *self, PyObject *expr, PyObject *state, PyObject *outer_precedence) { PyObject *inner_precedence = NULL; PyObject *statement = NULL; /* cls = expr.__class__ */ PyObject *cls = (PyObject *)expr->ob_type; /* dispatch_table = self._dispatch_table if cls in dispatch_table: handler = dispatch_table[cls] else: */ PyObject *handler = PyDict_GetItem(self->_dispatch_table, cls); if (!handler) { PyObject *mro; Py_ssize_t size, i; if (PyErr_Occurred()) goto error; /* for mro_cls in cls.__mro__: */ mro = expr->ob_type->tp_mro; size = PyTuple_GET_SIZE(mro); for (i = 0; i != size; i++) { PyObject *mro_cls = PyTuple_GET_ITEM(mro, i); /* if mro_cls in dispatch_table: handler = dispatch_table[mro_cls] break */ handler = PyDict_GetItem(self->_dispatch_table, mro_cls); if (handler) break; if (PyErr_Occurred()) goto error; } /* else: */ if (i == size) { /* raise CompileError("Don't know how to compile type %r of %r" % (expr.__class__, expr)) */ PyObject *repr = PyObject_Repr(expr); if (repr) { PyErr_Format(CompileError, "Don't know how to compile type %s of %s", expr->ob_type->tp_name, _PyUnicode_AsString(repr)); Py_DECREF(repr); } goto error; } } /* inner_precedence = state.precedence = \ self._precedence.get(cls, MAX_PRECEDENCE) */ CATCH(NULL, inner_precedence = Compile_get_precedence(self, cls)); CATCH(-1, PyObject_SetAttrString(state, "precedence", inner_precedence)); /* statement = handler(self, expr, state) */ CATCH(NULL, statement = PyObject_CallFunctionObjArgs(handler, self, expr, state, NULL)); /* if inner_precedence < outer_precedence: */ if (PyObject_RichCompareBool(inner_precedence, outer_precedence, Py_LT)) { PyObject *args, *tmp; if (PyErr_Occurred()) goto error; /* return "(%s)" % statement */ CATCH(NULL, args = PyTuple_Pack(1, statement)); tmp = PyUnicode_Format(parenthesis_format, args); Py_DECREF(args); CATCH(NULL, tmp); Py_DECREF(statement); statement = tmp; } Py_DECREF(inner_precedence); return statement; error: Py_XDECREF(inner_precedence); Py_XDECREF(statement); return NULL; } PyObject * Compile_one_or_many(CompileObject *self, PyObject *expr, PyObject *state, PyObject *join, int raw, int token) { PyObject *outer_precedence = NULL; PyObject *compiled = NULL; PyObject *sequence = NULL; PyObject *statement = NULL; Py_ssize_t size, i; Py_INCREF(expr); /* expr_type = type(expr) if expr_type is SQLRaw or (raw and expr_type is str): return expr */ if ((PyObject *)expr->ob_type == SQLRaw || (raw && PyUnicode_CheckExact(expr))) { /* Pass our reference on. */ return expr; } /* if token and expr_type is str: expr = SQLToken(expr) */ if (token && PyUnicode_CheckExact(expr)) { PyObject *tmp; CATCH(NULL, tmp = PyObject_CallFunctionObjArgs(SQLToken, expr, NULL)); Py_DECREF(expr); expr = tmp; } /* if state is None: state = State() */ /* That's done in Compile__call__ just once. */ /* outer_precedence = state.precedence */ CATCH(NULL, outer_precedence = PyObject_GetAttrString(state, "precedence")); /* if expr_type is tuple or expr_type is list: */ if (PyTuple_CheckExact(expr) || PyList_CheckExact(expr)) { /* compiled = [] */ CATCH(NULL, compiled = PyList_New(0)); /* for subexpr in expr: */ sequence = PySequence_Fast(expr, "This can't actually fail! ;-)"); size = PySequence_Fast_GET_SIZE(sequence); for (i = 0; i != size; i++) { PyObject *subexpr = PySequence_Fast_GET_ITEM(sequence, i); /* subexpr_type = type(subexpr) if subexpr_type is SQLRaw or (raw and subexpr_type is str): */ if ((PyObject *)subexpr->ob_type == (PyObject *)SQLRaw || (raw && PyUnicode_CheckExact(subexpr))) { /* statement = subexpr */ Py_INCREF(subexpr); statement = subexpr; /* elif subexpr_type is tuple or subexpr_type is list: */ } else if (PyTuple_CheckExact(subexpr) || PyList_CheckExact(subexpr)) { /* state.precedence = outer_precedence */ CATCH(-1, PyObject_SetAttrString(state, "precedence", outer_precedence)); /* statement = self(subexpr, state, join, raw, token) */ CATCH(NULL, statement = Compile_one_or_many(self, subexpr, state, join, raw, token)); /* else: */ } else { /* if token and subexpr_type is str: */ if (token && PyUnicode_CheckExact(subexpr)) { /* subexpr = SQLToken(subexpr) */ CATCH(NULL, subexpr = PyObject_CallFunctionObjArgs(SQLToken, subexpr, NULL)); } else { Py_INCREF(subexpr); } /* statement = self._compile_single(subexpr, state, outer_precedence) */ statement = Compile_single(self, subexpr, state, outer_precedence); Py_DECREF(subexpr); CATCH(NULL, statement); } /* compiled.append(statement) */ CATCH(-1, PyList_Append(compiled, statement)); Py_CLEAR(statement); } Py_CLEAR(sequence); /* statement = join.join(compiled) */ CATCH(NULL, statement = PyUnicode_Join(join, compiled)); Py_CLEAR(compiled); } else { /* statement = self._compile_single(expr, state, outer_precedence) */ CATCH(NULL, statement = Compile_single(self, expr, state, outer_precedence)); } /* state.precedence = outer_precedence */ CATCH(-1, PyObject_SetAttrString(state, "precedence", outer_precedence)); Py_CLEAR(outer_precedence); Py_DECREF(expr); return statement; error: Py_XDECREF(expr); Py_XDECREF(outer_precedence); Py_XDECREF(compiled); Py_XDECREF(sequence); Py_XDECREF(statement); return NULL; } static PyObject * Compile__call__(CompileObject *self, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"expr", "state", "join", "raw", "token", NULL}; PyObject *expr = NULL; PyObject *state = Py_None; PyObject *join; char raw = 0; char token = 0; PyObject *result = NULL; if (!initialize_globals()) return NULL; join = default_compile_join; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OObb", kwlist, &expr, &state, &join, &raw, &token)) { return NULL; } if (join && !PyUnicode_Check(join)) { PyErr_Format(PyExc_TypeError, "'join' argument must be a string, not %.80s", Py_TYPE(join)->tp_name); return NULL; } if (state == Py_None) { state = PyObject_CallFunctionObjArgs(State, NULL); } else { Py_INCREF(state); } if (state) { result = Compile_one_or_many(self, expr, state, join, raw, token); Py_DECREF(state); } return result; } static PyMethodDef Compile_methods[] = { {"_update_cache", (PyCFunction)Compile__update_cache, METH_NOARGS, NULL}, {"when", (PyCFunction)Compile_when, METH_VARARGS, NULL}, {"add_reserved_words", (PyCFunction)Compile_add_reserved_words, METH_O, NULL}, {"remove_reserved_words", (PyCFunction)Compile_remove_reserved_words, METH_O, NULL}, {"is_reserved_word", (PyCFunction)Compile_is_reserved_word, METH_O, NULL}, {"create_child", (PyCFunction)Compile_create_child, METH_NOARGS, NULL}, {"get_precedence", (PyCFunction)Compile_get_precedence, METH_O, NULL}, {"set_precedence", (PyCFunction)Compile_set_precedence, METH_VARARGS, NULL}, {NULL, NULL} }; #define OFFSETOF(x) offsetof(CompileObject, x) static PyMemberDef Compile_members[] = { {"_local_dispatch_table", T_OBJECT, OFFSETOF(_local_dispatch_table), 0, 0}, {"_local_precedence", T_OBJECT, OFFSETOF(_local_precedence), 0, 0}, {"_local_reserved_words", T_OBJECT, OFFSETOF(_local_reserved_words), 0, 0}, {"_dispatch_table", T_OBJECT, OFFSETOF(_dispatch_table), 0, 0}, {"_precedence", T_OBJECT, OFFSETOF(_precedence), 0, 0}, {"_reserved_words", T_OBJECT, OFFSETOF(_reserved_words), 0, 0}, {"_children", T_OBJECT, OFFSETOF(_children), 0, 0}, {"_parents", T_OBJECT, OFFSETOF(_parents), 0, 0}, {NULL} }; #undef OFFSETOF static PyTypeObject Compile_Type = { PyVarObject_HEAD_INIT(NULL, 0) "storm.variables.Compile", /*tp_name*/ sizeof(CompileObject), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)Compile_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash*/ (ternaryfunc)Compile__call__, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ (traverseproc)Compile_traverse, /*tp_traverse*/ (inquiry)Compile_clear, /*tp_clear*/ 0, /*tp_richcompare*/ offsetof(CompileObject, _weakreflist), /*tp_weaklistoffset*/ 0, /*tp_iter*/ 0, /*tp_iternext*/ Compile_methods, /*tp_methods*/ Compile_members, /*tp_members*/ 0, /*tp_getset*/ 0, /*tp_base*/ 0, /*tp_dict*/ 0, /*tp_descr_get*/ 0, /*tp_descr_set*/ 0, /*tp_dictoffset*/ (initproc)Compile_init, /*tp_init*/ 0, /*tp_alloc*/ 0, /*tp_new*/ 0, /*tp_free*/ 0, /*tp_is_gc*/ }; static PyObject * ObjectInfo__emit_object_deleted(ObjectInfoObject *self, PyObject *args) { /* self.event.emit("object-deleted") */ return PyObject_CallMethod(self->event, "emit", "s", "object-deleted"); } static PyMethodDef ObjectInfo_deleted_callback = {"_emit_object_deleted", (PyCFunction)ObjectInfo__emit_object_deleted, METH_O, NULL}; static int ObjectInfo_init(ObjectInfoObject *self, PyObject *args) { PyObject *self_get_obj = NULL; PyObject *empty_args = NULL; PyObject *factory_kwargs = NULL; PyObject *columns = NULL; PyObject *primary_key = NULL; PyObject *obj; Py_ssize_t i; empty_args = PyTuple_New(0); CATCH(-1, PyDict_Type.tp_init((PyObject *)self, empty_args, NULL)); CATCH(0, initialize_globals()); if (!PyArg_ParseTuple(args, "O", &obj)) goto error; /* self.cls_info = get_cls_info(type(obj)) */ CATCH(NULL, self->cls_info = PyObject_CallFunctionObjArgs(get_cls_info, obj->ob_type, NULL)); /* self.set_obj(obj) */ CATCH(NULL, self->_obj_ref_callback = PyCFunction_NewEx(&ObjectInfo_deleted_callback, (PyObject *)self, NULL)); CATCH(NULL, self->_obj_ref = PyWeakref_NewRef(obj, self->_obj_ref_callback)); /* self.event = event = EventSystem(self) */ CATCH(NULL, self->event = PyObject_CallFunctionObjArgs(EventSystem, self, NULL)); /* self->variables = variables = {} */ CATCH(NULL, self->variables = PyDict_New()); CATCH(NULL, self_get_obj = PyObject_GetAttrString((PyObject *)self, "get_obj")); CATCH(NULL, factory_kwargs = PyDict_New()); CATCH(-1, PyDict_SetItemString(factory_kwargs, "event", self->event)); CATCH(-1, PyDict_SetItemString(factory_kwargs, "validator_object_factory", self_get_obj)); /* for column in self.cls_info.columns: */ CATCH(NULL, columns = PyObject_GetAttrString(self->cls_info, "columns")); for (i = 0; i != PyTuple_GET_SIZE(columns); i++) { /* variables[column] = \ column.variable_factory(column=column, event=event, validator_object_factory=self.get_obj) */ PyObject *column = PyTuple_GET_ITEM(columns, i); PyObject *variable, *factory; CATCH(-1, PyDict_SetItemString(factory_kwargs, "column", column)); CATCH(NULL, factory = PyObject_GetAttrString(column, "variable_factory")); variable = PyObject_Call(factory, empty_args, factory_kwargs); Py_DECREF(factory); CATCH(NULL, variable); if (PyDict_SetItem(self->variables, column, variable) == -1) { Py_DECREF(variable); goto error; } Py_DECREF(variable); } /* self.primary_vars = tuple(variables[column] for column in self.cls_info.primary_key) */ CATCH(NULL, primary_key = PyObject_GetAttrString((PyObject *)self->cls_info, "primary_key")); /* XXX Check primary_key type here. */ CATCH(NULL, self->primary_vars = PyTuple_New(PyTuple_GET_SIZE(primary_key))); for (i = 0; i != PyTuple_GET_SIZE(primary_key); i++) { PyObject *column = PyTuple_GET_ITEM(primary_key, i); PyObject *variable = PyDict_GetItem(self->variables, column); Py_INCREF(variable); PyTuple_SET_ITEM(self->primary_vars, i, variable); } Py_DECREF(self_get_obj); Py_DECREF(empty_args); Py_DECREF(factory_kwargs); Py_DECREF(columns); Py_DECREF(primary_key); return 0; error: Py_XDECREF(self_get_obj); Py_XDECREF(empty_args); Py_XDECREF(factory_kwargs); Py_XDECREF(columns); Py_XDECREF(primary_key); return -1; } static PyObject * ObjectInfo_get_obj(ObjectInfoObject *self, PyObject *args) { PyObject *obj = PyWeakref_GET_OBJECT(self->_obj_ref); Py_INCREF(obj); return obj; } static PyObject * ObjectInfo_set_obj(ObjectInfoObject *self, PyObject *args) { PyObject *obj; /* self._ref = ref(obj, self._emit_object_deleted) */ if (!PyArg_ParseTuple(args, "O", &obj)) return NULL; Py_DECREF(self->_obj_ref); self->_obj_ref = PyWeakref_NewRef(obj, self->_obj_ref_callback); if (!self->_obj_ref) return NULL; Py_RETURN_NONE; } static PyObject * ObjectInfo_checkpoint(ObjectInfoObject *self, PyObject *args) { PyObject *column, *variable, *tmp; Py_ssize_t i = 0; /* for variable in self.variables.values(): */ while (PyDict_Next(self->variables, &i, &column, &variable)) { /* variable.checkpoint() */ CATCH(NULL, tmp = PyObject_CallMethod(variable, "checkpoint", NULL)); Py_DECREF(tmp); } Py_RETURN_NONE; error: return NULL; } static PyObject * ObjectInfo__storm_object_info__(PyObject *self, void *closure) { /* __storm_object_info__ = property(lambda self:self) */ Py_INCREF(self); return self; } static int ObjectInfo_traverse(ObjectInfoObject *self, visitproc visit, void *arg) { Py_VISIT(self->_obj_ref); Py_VISIT(self->_obj_ref_callback); Py_VISIT(self->cls_info); Py_VISIT(self->event); Py_VISIT(self->variables); Py_VISIT(self->primary_vars); return PyDict_Type.tp_traverse((PyObject *)self, visit, arg); } static int ObjectInfo_clear(ObjectInfoObject *self) { Py_CLEAR(self->_obj_ref); Py_CLEAR(self->_obj_ref_callback); Py_CLEAR(self->cls_info); Py_CLEAR(self->event); Py_CLEAR(self->variables); Py_CLEAR(self->primary_vars); return PyDict_Type.tp_clear((PyObject *)self); } static PyObject * ObjectInfo_richcompare(PyObject *self, PyObject *other, int op) { PyObject *res; /* Implement equality via object identity. */ switch (op) { case Py_EQ: res = (self == other) ? Py_True : Py_False; break; case Py_NE: res = (self != other) ? Py_True : Py_False; break; default: res = Py_NotImplemented; } Py_INCREF(res); return res; } static void ObjectInfo_dealloc(ObjectInfoObject *self) { if (self->_weakreflist) PyObject_ClearWeakRefs((PyObject *)self); Py_CLEAR(self->_obj_ref); Py_CLEAR(self->_obj_ref_callback); Py_CLEAR(self->cls_info); Py_CLEAR(self->event); Py_CLEAR(self->variables); Py_CLEAR(self->primary_vars); PyDict_Type.tp_dealloc((PyObject *)self); } static PyMethodDef ObjectInfo_methods[] = { {"_emit_object_deleted", (PyCFunction)ObjectInfo__emit_object_deleted, METH_O, NULL}, {"get_obj", (PyCFunction)ObjectInfo_get_obj, METH_NOARGS, NULL}, {"set_obj", (PyCFunction)ObjectInfo_set_obj, METH_VARARGS, NULL}, {"checkpoint", (PyCFunction)ObjectInfo_checkpoint, METH_VARARGS, NULL}, {NULL, NULL} }; #define OFFSETOF(x) offsetof(ObjectInfoObject, x) static PyMemberDef ObjectInfo_members[] = { {"cls_info", T_OBJECT, OFFSETOF(cls_info), 0, 0}, {"event", T_OBJECT, OFFSETOF(event), 0, 0}, {"variables", T_OBJECT, OFFSETOF(variables), 0, 0}, {"primary_vars", T_OBJECT, OFFSETOF(primary_vars), 0, 0}, {NULL} }; #undef OFFSETOF static PyGetSetDef ObjectInfo_getset[] = { {"__storm_object_info__", (getter)ObjectInfo__storm_object_info__, NULL, NULL}, {NULL} }; static PyTypeObject ObjectInfo_Type = { PyVarObject_HEAD_INIT(NULL, 0) "storm.info.ObjectInfo", /*tp_name*/ sizeof(ObjectInfoObject), /*tp_basicsize*/ 0, /*tp_itemsize*/ (destructor)ObjectInfo_dealloc, /*tp_dealloc*/ 0, /*tp_print*/ 0, /*tp_getattr*/ 0, /*tp_setattr*/ 0, /*tp_compare*/ 0, /*tp_repr*/ 0, /*tp_as_number*/ 0, /*tp_as_sequence*/ 0, /*tp_as_mapping*/ 0, /*tp_hash*/ 0, /*tp_call*/ 0, /*tp_str*/ 0, /*tp_getattro*/ 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT|Py_TPFLAGS_BASETYPE|Py_TPFLAGS_HAVE_GC, /*tp_flags*/ 0, /*tp_doc*/ (traverseproc)ObjectInfo_traverse, /*tp_traverse*/ (inquiry)ObjectInfo_clear, /*tp_clear*/ ObjectInfo_richcompare, /*tp_richcompare*/ offsetof(ObjectInfoObject, _weakreflist), /*tp_weaklistoffset*/ 0, /*tp_iter*/ 0, /*tp_iternext*/ ObjectInfo_methods, /*tp_methods*/ ObjectInfo_members, /*tp_members*/ ObjectInfo_getset, /*tp_getset*/ 0, /*tp_base*/ 0, /*tp_dict*/ 0, /*tp_descr_get*/ 0, /*tp_descr_set*/ 0, /*tp_dictoffset*/ (initproc)ObjectInfo_init, /*tp_init*/ 0, /*tp_alloc*/ 0, /*tp_new*/ 0, /*tp_free*/ 0, /*tp_is_gc*/ }; static PyObject * get_obj_info(PyObject *self, PyObject *obj) { PyObject *obj_info; if (obj->ob_type == &ObjectInfo_Type) { /* Much better than asking the ObjectInfo to return itself. ;-) */ Py_INCREF(obj); return obj; } /* try: return obj.__storm_object_info__ */ obj_info = PyObject_GetAttrString(obj, "__storm_object_info__"); /* except AttributeError: */ if (obj_info == NULL) { PyErr_Clear(); /* obj_info = ObjectInfo(obj) */ obj_info = PyObject_CallFunctionObjArgs((PyObject *)&ObjectInfo_Type, obj, NULL); if (!obj_info) return NULL; /* return obj.__dict__.setdefault("__storm_object_info__", obj_info) */ if (PyObject_SetAttrString(obj, "__storm_object_info__", obj_info) == -1) return NULL; } return obj_info; } static PyMethodDef cextensions_methods[] = { {"get_obj_info", (PyCFunction)get_obj_info, METH_O, NULL}, {NULL, NULL} }; static int prepare_type(PyTypeObject *type) { if (!type->tp_getattro && !type->tp_getattr) type->tp_getattro = PyObject_GenericGetAttr; if (!type->tp_setattro && !type->tp_setattr) type->tp_setattro = PyObject_GenericSetAttr; if (!type->tp_alloc) type->tp_alloc = PyType_GenericAlloc; /* Don't fill in tp_new if this class has a base class */ if (!type->tp_base && !type->tp_new) type->tp_new = PyType_GenericNew; if (!type->tp_free) { assert((type->tp_flags & Py_TPFLAGS_HAVE_GC) != 0); type->tp_free = PyObject_GC_Del; } return PyType_Ready(type); } static int do_init(PyObject *module) { prepare_type(&EventSystem_Type); prepare_type(&Compile_Type); ObjectInfo_Type.tp_base = &PyDict_Type; ObjectInfo_Type.tp_hash = (hashfunc)_Py_HashPointer; prepare_type(&ObjectInfo_Type); prepare_type(&Variable_Type); Py_INCREF(&Variable_Type); #define REGISTER_TYPE(name) \ do { \ Py_INCREF(&name##_Type); \ PyModule_AddObject(module, #name, (PyObject*)&name##_Type); \ } while(0) REGISTER_TYPE(Variable); REGISTER_TYPE(ObjectInfo); REGISTER_TYPE(Compile); REGISTER_TYPE(EventSystem); return 0; } static struct PyModuleDef cextensionsmodule = { PyModuleDef_HEAD_INIT, "cextensions", NULL, -1, cextensions_methods, NULL, NULL, NULL, NULL }; PyMODINIT_FUNC PyInit_cextensions(void) { PyObject *module = PyModule_Create(&cextensionsmodule); if (module == NULL) return NULL; do_init(module); return module; } /* vim:ts=4:sw=4:et */ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/database.py0000644000175000017500000006006614645174376016333 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """Basic database interfacing mechanisms for Storm. This is the common code for database support; specific databases are supported in modules in L{storm.databases}. """ from collections.abc import Callable from functools import wraps from storm.expr import Expr, State, compile # Circular import: imported at the end of the module. # from storm.tracer import trace from storm.variables import Variable from storm.xid import Xid from storm.exceptions import ( ClosedError, ConnectionBlockedError, DatabaseError, DisconnectionError, Error, ProgrammingError, wrap_exceptions) from storm.uri import URI import storm __all__ = ["Database", "Connection", "Result", "convert_param_marks", "create_database", "register_scheme"] STATE_CONNECTED = 1 STATE_DISCONNECTED = 2 STATE_RECONNECT = 3 class Result: """A representation of the results from a single SQL statement.""" _closed = False def __init__(self, connection, raw_cursor): self._connection = connection # Ensures deallocation order. self._raw_cursor = raw_cursor if raw_cursor.arraysize == 1: # Default of 1 is silly. self._raw_cursor.arraysize = 10 def __del__(self): """Close the cursor.""" try: self.close() except: pass def close(self): """Close the underlying raw cursor, if it hasn't already been closed. """ if not self._closed: self._closed = True self._raw_cursor.close() self._raw_cursor = None def get_one(self): """Fetch one result from the cursor. The result will be converted to an appropriate format via L{from_database}. @raise DisconnectionError: Raised when the connection is lost. Reconnection happens automatically on rollback. @return: A converted row or None, if no data is left. """ row = self._connection._check_disconnect(self._raw_cursor.fetchone) if row is not None: return tuple(self.from_database(row)) return None def get_all(self): """Fetch all results from the cursor. The results will be converted to an appropriate format via L{from_database}. @raise DisconnectionError: Raised when the connection is lost. Reconnection happens automatically on rollback. """ result = self._connection._check_disconnect(self._raw_cursor.fetchall) if result: return [tuple(self.from_database(row)) for row in result] return result def __iter__(self): """Yield all results, one at a time. The results will be converted to an appropriate format via L{from_database}. @raise DisconnectionError: Raised when the connection is lost. Reconnection happens automatically on rollback. """ while True: results = self._connection._check_disconnect( self._raw_cursor.fetchmany) if not results: break for result in results: yield tuple(self.from_database(result)) @property def rowcount(self): """ See PEP 249 for further details on rowcount. @return: the number of affected rows, or None if the database backend does not provide this information. Return value is undefined if all results have not yet been retrieved. """ if self._raw_cursor.rowcount == -1: return None return self._raw_cursor.rowcount def get_insert_identity(self, primary_columns, primary_variables): """Get a query which will return the row that was just inserted. This must be overridden in database-specific subclasses. @rtype: L{storm.expr.Expr} """ raise NotImplementedError @staticmethod def set_variable(variable, value): """Set the given variable's value from the database.""" variable.set(value, from_db=True) @staticmethod def from_database(row): """Convert a row fetched from the database to an agnostic format. This method is intended to be overridden in subclasses, but not called externally. If there are any peculiarities in the datatypes returned from a database backend, this method should be overridden in the backend subclass to convert them. """ return row class CursorWrapper: """A DB-API cursor, wrapping exceptions as StormError instances.""" def __init__(self, cursor, database): super().__setattr__('_cursor', cursor) super().__setattr__('_database', database) def __getattr__(self, name): attr = getattr(self._cursor, name) if isinstance(attr, Callable): @wraps(attr) def wrapper(*args, **kwargs): with wrap_exceptions(self._database): return attr(*args, **kwargs) return wrapper else: return attr def __setattr__(self, name, value): return setattr(self._cursor, name, value) def __iter__(self): with wrap_exceptions(self._database): yield from self._cursor def __enter__(self): return self def __exit__(self, type_, value, tb): with wrap_exceptions(self._database): self.close() class ConnectionWrapper: """A DB-API connection, wrapping exceptions as StormError instances.""" def __init__(self, connection, database): self.__dict__['_connection'] = connection self.__dict__['_database'] = database def __getattr__(self, name): attr = getattr(self._connection, name) if isinstance(attr, Callable): @wraps(attr) def wrapper(*args, **kwargs): with wrap_exceptions(self._database): return attr(*args, **kwargs) return wrapper else: return attr def __setattr__(self, name, value): return setattr(self._connection, name, value) def __enter__(self): return self def __exit__(self, type_, value, tb): with wrap_exceptions(self._database): if type_ is None and value is None and tb is None: self.commit() else: self.rollback() def cursor(self): with wrap_exceptions(self._database): return CursorWrapper(self._connection.cursor(), self._database) class Connection: """A connection to a database. @cvar result_factory: A callable which takes this L{Connection} and the backend cursor and returns an instance of L{Result}. @type param_mark: C{str} @cvar param_mark: The dbapi paramstyle that the database backend expects. @type compile: L{storm.expr.Compile} @cvar compile: The compiler to use for connections of this type. """ result_factory = Result param_mark = "?" compile = compile _blocked = False _closed = False _two_phase_transaction = False # If True, a two-phase transaction has # been started with begin() _state = STATE_CONNECTED def __init__(self, database, event=None): self._database = database # Ensures deallocation order. self._event = event self._raw_connection = self._database.raw_connect() def __del__(self): """Close the connection.""" try: self.close() except: pass def block_access(self): """Block access to the connection. Attempts to execute statements or commit a transaction will result in a C{ConnectionBlockedError} exception. Rollbacks are permitted as that operation is often used in case of failures. """ self._blocked = True def unblock_access(self): """Unblock access to the connection.""" self._blocked = False def execute(self, statement, params=None, noresult=False): """Execute a statement with the given parameters. @type statement: L{Expr} or C{str} @param statement: The statement to execute. It will be compiled if necessary. @param noresult: If True, no result will be returned. @raise ConnectionBlockedError: Raised if access to the connection has been blocked with L{block_access}. @raise DisconnectionError: Raised when the connection is lost. Reconnection happens automatically on rollback. @return: The result of C{self.result_factory}, or None if C{noresult} is True. """ if self._closed: raise ClosedError("Connection is closed") if self._blocked: raise ConnectionBlockedError("Access to connection is blocked") if self._event: self._event.emit("register-transaction") self._ensure_connected() if isinstance(statement, Expr): if params is not None: raise ValueError("Can't pass parameters with expressions") state = State() statement = self.compile(statement, state) params = state.parameters statement = convert_param_marks(statement, "?", self.param_mark) raw_cursor = self.raw_execute(statement, params) if noresult: self._check_disconnect(raw_cursor.close) return None return self.result_factory(self, raw_cursor) def close(self): """Close the connection if it is not already closed.""" if not self._closed: self._closed = True if self._raw_connection is not None: self._raw_connection.close() self._raw_connection = None def begin(self, xid): """Begin a two-phase transaction.""" if self._two_phase_transaction: raise ProgrammingError("begin cannot be used inside a transaction") self._ensure_connected() raw_xid = self._raw_xid(xid) self._check_disconnect(self._raw_connection.tpc_begin, raw_xid) self._two_phase_transaction = True def prepare(self): """Run the prepare phase of a two-phase transaction.""" if not self._two_phase_transaction: raise ProgrammingError("prepare must be called inside a two-phase " "transaction") self._check_disconnect(self._raw_connection.tpc_prepare) def commit(self, xid=None): """Commit the connection. @param xid: Optionally the L{Xid} of a previously prepared transaction to commit. This form should be called outside of a transaction, and is intended for use in recovery. @raise ConnectionBlockedError: Raised if access to the connection has been blocked with L{block_access}. @raise DisconnectionError: Raised when the connection is lost. Reconnection happens automatically on rollback. """ try: self._ensure_connected() if xid: raw_xid = self._raw_xid(xid) self._check_disconnect(self._raw_connection.tpc_commit, raw_xid) elif self._two_phase_transaction: self._check_disconnect(self._raw_connection.tpc_commit) self._two_phase_transaction = False else: self._check_disconnect(self._raw_connection.commit) finally: self._check_disconnect(trace, "connection_commit", self, xid) def recover(self): """Return a list of L{Xid}\\ s representing pending transactions.""" self._ensure_connected() raw_xids = self._check_disconnect(self._raw_connection.tpc_recover) return [Xid(raw_xid[0], raw_xid[1], raw_xid[2]) for raw_xid in raw_xids] def rollback(self, xid=None): """Rollback the connection. @param xid: Optionally the L{Xid} of a previously prepared transaction to rollback. This form should be called outside of a transaction, and is intended for use in recovery. """ try: if self._state == STATE_CONNECTED: try: if xid: raw_xid = self._raw_xid(xid) self._raw_connection.tpc_rollback(raw_xid) elif self._two_phase_transaction: self._raw_connection.tpc_rollback() else: self._raw_connection.rollback() except Error as exc: if self.is_disconnection_error(exc): self._raw_connection = None self._state = STATE_RECONNECT self._two_phase_transaction = False else: raise else: self._two_phase_transaction = False else: self._two_phase_transaction = False self._state = STATE_RECONNECT finally: self._check_disconnect(trace, "connection_rollback", self, xid) @staticmethod def to_database(params): """Convert some parameters into values acceptable to a database backend. It is acceptable to override this method in subclasses, but it is not intended to be used externally. This delegates conversion to any L{Variable }\\ s in the parameter list, and passes through all other values untouched. """ for param in params: if isinstance(param, Variable): yield param.get(to_db=True) else: yield param def build_raw_cursor(self): """Get a new dbapi cursor object. It is acceptable to override this method in subclasses, but it is not intended to be called externally. """ return self._raw_connection.cursor() def raw_execute(self, statement, params=None): """Execute a raw statement with the given parameters. It's acceptable to override this method in subclasses, but it is not intended to be called externally. If the global C{DEBUG} is True, the statement will be printed to standard out. @return: The dbapi cursor object, as fetched from L{build_raw_cursor}. """ raw_cursor = self._check_disconnect(self.build_raw_cursor) self._prepare_execution(raw_cursor, params, statement) args = self._execution_args(params, statement) self._run_execution(raw_cursor, args, params, statement) return raw_cursor def _execution_args(self, params, statement): """Get the appropriate statement execution arguments.""" if params: args = (statement, tuple(self.to_database(params))) else: args = (statement,) return args def _run_execution(self, raw_cursor, args, params, statement): """Complete the statement execution, along with result reports.""" try: self._check_disconnect(raw_cursor.execute, *args) except Exception as error: self._check_disconnect( trace, "connection_raw_execute_error", self, raw_cursor, statement, params or (), error) raise else: self._check_disconnect( trace, "connection_raw_execute_success", self, raw_cursor, statement, params or ()) def _prepare_execution(self, raw_cursor, params, statement): """Prepare the statement execution to be run.""" try: self._check_disconnect( trace, "connection_raw_execute", self, raw_cursor, statement, params or ()) except Exception as error: self._check_disconnect( trace, "connection_raw_execute_error", self, raw_cursor, statement, params or (), error) raise def _ensure_connected(self): """Ensure that we are connected to the database. If the connection is marked as dead, or if we can't reconnect, then raise DisconnectionError. """ if self._blocked: raise ConnectionBlockedError("Access to connection is blocked") if self._state == STATE_CONNECTED: return elif self._state == STATE_DISCONNECTED: raise DisconnectionError("Already disconnected") elif self._state == STATE_RECONNECT: try: self._raw_connection = self._database.raw_connect() except DatabaseError as exc: self._state = STATE_DISCONNECTED self._raw_connection = None raise DisconnectionError(str(exc)) else: self._state = STATE_CONNECTED def is_disconnection_error(self, exc, extra_disconnection_errors=()): """Check whether an exception represents a database disconnection. This should be overridden by backends to detect whichever exception values are used to represent this condition. """ return False def _raw_xid(self, xid): """Return a raw xid from the given high-level L{Xid} object.""" return self._raw_connection.xid(xid.format_id, xid.global_transaction_id, xid.branch_qualifier) def _check_disconnect(self, function, *args, **kwargs): """Run the given function, checking for database disconnections.""" # Allow the caller to specify additional exception types that # should be treated as possible disconnection errors. extra_disconnection_errors = kwargs.pop( 'extra_disconnection_errors', ()) try: return function(*args, **kwargs) except Exception as exc: if self.is_disconnection_error(exc, extra_disconnection_errors): self._state = STATE_DISCONNECTED self._raw_connection = None raise DisconnectionError(str(exc)) else: raise def preset_primary_key(self, primary_columns, primary_variables): """Process primary variables before an insert happens. This method may be overwritten by backends to implement custom changes in primary variables before an insert happens. """ class Database: """A database that can be connected to. This should be subclassed for individual database backends. @cvar connection_factory: A callable which will take this database and should return an instance of L{Connection}. """ connection_factory = Connection def __init__(self, uri=None): self._uri = uri self._exception_types = {} def get_uri(self): """Return the URI object this database was created with.""" return self._uri def connect(self, event=None): """Create a connection to the database. It calls C{self.connection_factory} to allow for ease of customization. @param event: The event system to broadcast messages with. If not specified, then no events will be broadcast. @return: An instance of L{Connection}. """ return self.connection_factory(self, event) def raw_connect(self): """Create a raw database connection. This is used by L{Connection} objects to connect to the database. It should be overriden in subclasses to do any database-specific connection setup. @return: A DB-API connection object. """ raise NotImplementedError @property def _exception_module(self): """The module where appropriate DB-API exception types are defined. Subclasses should set this if they support re-raising DB-API exceptions as StormError instances. """ return None def _make_combined_exception_type(self, wrapper_type, dbapi_type): """Make a combined exception based on both DB-API and Storm. Storm historically defined its own exception types as ABCs and registered the DB-API exception types as virtual subclasses. However, this doesn't work properly in Python 3 (https://bugs.python.org/issue12029). Instead, we create and cache subclass-specific exception types that inherit from both StormError and the DB-API exception type, allowing code that catches either StormError (or subclasses) or the specific DB-API exceptions to keep working. @type wrapper_type: L{type} @param wrapper_type: The type of the wrapper exception to create; a subclass of L{StormError}. @type dbapi_type: L{type} @param dbapi_type: The type of the DB-API exception. @return: The combined exception type. """ if dbapi_type.__name__ not in self._exception_types: self._exception_types[dbapi_type.__name__] = type( dbapi_type.__name__, (dbapi_type, wrapper_type), {}) return self._exception_types[dbapi_type.__name__] def _wrap_exception(self, wrapper_type, exception): """Wrap a DB-API exception as a StormError instance. This constructs a wrapper exception with the same C{args} as the DB-API exception. Subclasses may override this to set additional attributes on the wrapper exception. @type wrapper_type: L{type} @param wrapper_type: The type of the wrapper exception to create; a subclass of L{StormError}. @type exception: L{Exception} @param exception: The DB-API exception to wrap. @return: The wrapped exception; an instance of L{StormError}. """ return self._make_combined_exception_type( wrapper_type, exception.__class__)(*exception.args) def convert_param_marks(statement, from_param_mark, to_param_mark): # TODO: Add support for $foo$bar$foo$ literals. if from_param_mark == to_param_mark or from_param_mark not in statement: return statement tokens = statement.split("'") for i in range(0, len(tokens), 2): tokens[i] = tokens[i].replace(from_param_mark, to_param_mark) return "'".join(tokens) _database_schemes = {} def register_scheme(scheme, factory): """Register a handler for a new database URI scheme. @param scheme: the database URI scheme @param factory: a function taking a URI instance and returning a database. """ _database_schemes[scheme] = factory def create_database(uri): """Create a database instance. @param uri: An URI instance, or a string describing the URI. Some examples: "sqlite:" An in memory sqlite database. "sqlite:example.db" A SQLite database called example.db "postgres:test" The database 'test' from the local postgres server. "postgres://user:password@host/test" The database test on machine host with supplied user credentials, using postgres. "anything:..." Where 'anything' has previously been registered with L{register_scheme}. """ if isinstance(uri, str): uri = URI(uri) if uri.scheme in _database_schemes: factory = _database_schemes[uri.scheme] else: module = __import__("%s.databases.%s" % (storm.__name__, uri.scheme), None, None, [""]) factory = module.create_from_uri return factory(uri) # Deal with circular import. from storm.tracer import trace ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4131246 storm-1.0/storm/databases/0000755000175000017500000000000014645532536016130 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/databases/__init__.py0000644000175000017500000000222114645174376020242 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # class Dummy: """Magic "infectious" class. This class simplifies nice errors on the creation of unsupported databases. """ def __getattr__(self, name): return self def __call__(self, *args, **kwargs): return self def __add__(self, other): return self def __bool__(self): return False dummy = Dummy() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/databases/mysql.py0000644000175000017500000002212714645174376017657 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import time, timedelta from array import array import sys from storm.databases import dummy try: import MySQLdb import MySQLdb.converters except ImportError: MySQLdb = dummy from storm.database import ( Connection, ConnectionWrapper, Database, Result, ) from storm.exceptions import ( DatabaseModuleError, OperationalError, wrap_exceptions, ) from storm.expr import ( compile, compile_select, Insert, is_safe_token, Select, SQLToken, Undef, ) from storm.variables import Variable compile = compile.create_child() @compile.when(Select) def compile_select_mysql(compile, select, state): if select.offset is not Undef and select.limit is Undef: select.limit = sys.maxsize return compile_select(compile, select, state) @compile.when(SQLToken) def compile_sql_token_mysql(compile, expr, state): """MySQL uses ` as the escape character by default.""" if is_safe_token(expr) and not compile.is_reserved_word(expr): return expr return '`%s`' % expr.replace('`', '``') class MySQLResult(Result): @staticmethod def from_database(row): """Convert MySQL-specific datatypes to "normal" Python types. If there are any C{array} instances in the row, convert them to strings. """ for value in row: if isinstance(value, array): yield value.tostring() else: yield value class MySQLConnection(Connection): result_factory = MySQLResult param_mark = "%s" compile = compile def execute(self, statement, params=None, noresult=False): if (isinstance(statement, Insert) and statement.primary_variables is not Undef): result = Connection.execute(self, statement, params) # The lastrowid value will be set if: # - the table had an AUTO INCREMENT column, and # - the column was not set during the insert or set to 0 # # If these conditions are met, then lastrowid will be the # value of the first such column set. We assume that it # is the first undefined primary key variable. if result._raw_cursor.lastrowid: for variable in statement.primary_variables: if not variable.is_defined(): variable.set(result._raw_cursor.lastrowid, from_db=True) break if noresult: result = None return result return Connection.execute(self, statement, params, noresult) def to_database(self, params): for param in params: if isinstance(param, Variable): param = param.get(to_db=True) if isinstance(param, timedelta): yield str(param) else: yield param def is_disconnection_error(self, exc, extra_disconnection_errors=()): # http://dev.mysql.com/doc/refman/5.0/en/gone-away.html return (isinstance(exc, (OperationalError, extra_disconnection_errors)) and exc.args[0] in (2006, 2013)) # (SERVER_GONE_ERROR, SERVER_LOST) class MySQL(Database): connection_factory = MySQLConnection _exception_module = MySQLdb _converters = None def __init__(self, uri): super().__init__(uri) if MySQLdb is dummy: raise DatabaseModuleError("'MySQLdb' module not found") self._connect_kwargs = {} if uri.database is not None: self._connect_kwargs["db"] = uri.database if uri.host is not None: self._connect_kwargs["host"] = uri.host if uri.port is not None: self._connect_kwargs["port"] = uri.port if uri.username is not None: self._connect_kwargs["user"] = uri.username if uri.password is not None: self._connect_kwargs["passwd"] = uri.password for option in ["unix_socket"]: if option in uri.options: self._connect_kwargs[option] = uri.options.get(option) if self._converters is None: # MySQLdb returns a timedelta by default on TIME fields. converters = MySQLdb.converters.conversions.copy() converters[MySQLdb.converters.FIELD_TYPE.TIME] = _convert_time self.__class__._converters = converters self._connect_kwargs["conv"] = self._converters self._connect_kwargs["use_unicode"] = True # utf8mb3 (a.k.a. utf8) is deprecated, but it's not clear that we # can change it without breaking applications. Default to utf8mb3 # for now. self._connect_kwargs["charset"] = uri.options.get("charset", "utf8mb3") def _raw_connect(self): raw_connection = ConnectionWrapper( MySQLdb.connect(**self._connect_kwargs), self) # Here is another sad story about bad transactional behavior. MySQL # offers a feature to automatically reconnect dropped connections. # What sounds like a dream, is actually a nightmare for anyone who # is dealing with transactions. When a reconnection happens, the # currently running transaction is transparently rolled back, and # everything that was being done is lost, without notice. Not only # that, but the connection may be put back in AUTOCOMMIT mode, even # when that's not the default MySQLdb behavior. The MySQL developers # quickly understood that this is a terrible idea, and removed the # behavior in MySQL 5.0.3. Unfortunately, Debian and Ubuntu still # have a patch for the MySQLdb module which *reenables* that # behavior by default even past version 5.0.3 of MySQL. # # Some links: # http://dev.mysql.com/doc/refman/5.0/en/auto-reconnect.html # http://dev.mysql.com/doc/refman/5.0/en/mysql-reconnect.html # http://dev.mysql.com/doc/refman/5.0/en/gone-away.html # # What we do here is to explore something that is a very weird # side-effect, discovered by reading the code. When we call the # ping() with a False argument, the automatic reconnection is # disabled in a *permanent* way for this connection. The argument # to ping() is new in 1.2.2, though. if MySQLdb.version_info >= (1, 2, 2): raw_connection.ping(False) return raw_connection def raw_connect(self): with wrap_exceptions(self): return self._raw_connect() create_from_uri = MySQL def _convert_time(time_str): h, m, s = time_str.split(":") if "." in s: f = float(s) s = int(f) return time(int(h), int(m), s, (f-s)*1000000) return time(int(h), int(m), int(s), 0) # -------------------------------------------------------------------- # Reserved words, MySQL specific # The list of reserved words here are MySQL specific. SQL92 reserved words # are registered in storm.expr, near the "Reserved words, from SQL1992" # comment. The reserved words here were taken from: # # http://dev.mysql.com/doc/refman/5.4/en/reserved-words.html compile.add_reserved_words(""" accessible analyze asensitive before bigint binary blob call change condition current_user database databases day_hour day_microsecond day_minute day_second delayed deterministic distinctrow div dual each elseif enclosed escaped exit explain float4 float8 force fulltext high_priority hour_microsecond hour_minute hour_second if ignore index infile inout int1 int2 int3 int4 int8 iterate keys kill leave limit linear lines load localtime localtimestamp lock long longblob longtext loop low_priority master_ssl_verify_server_cert mediumblob mediumint mediumtext middleint minute_microsecond minute_second mod modifies no_write_to_binlog optimize optionally out outfile purge range read_write reads regexp release rename repeat replace require return rlike schemas second_microsecond sensitive separator show spatial specific sql_big_result sql_calc_found_rows sql_small_result sqlexception sqlwarning ssl starting straight_join terminated tinyblob tinyint tinytext trigger undo unlock unsigned use utc_date utc_time utc_timestamp varbinary varcharacter while xor year_month zerofill """.split()) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/databases/postgres.py0000644000175000017500000004342114645174376020360 0ustar00cjwatsoncjwatson# # Copyright (c) 2006-2009 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta import json from packaging.version import parse as parse_version from storm.databases import dummy # PostgreSQL support in Storm requires psycopg2 2.3.0 or greater, in order # to support the two-phase commit protocol. REQUIRED_PSYCOPG2_VERSION = parse_version('2.3.0') PSYCOPG2_VERSION = None try: import psycopg2 PSYCOPG2_VERSION = parse_version(psycopg2.__version__.split(' ')[0]) if PSYCOPG2_VERSION < REQUIRED_PSYCOPG2_VERSION: psycopg2 = dummy else: import psycopg2.extensions except ImportError: psycopg2 = dummy from storm.expr import ( Undef, Expr, SetExpr, Select, Insert, Alias, And, Eq, FuncExpr, SQLRaw, Sequence, Like, SQLToken, BinaryOper, COLUMN, COLUMN_NAME, COLUMN_PREFIX, TABLE, compile, compile_select, compile_insert, compile_set_expr, compile_like, compile_sql_token) from storm.variables import ( Variable, ListVariable, JSONVariable as BaseJSONVariable) from storm.properties import SimpleProperty from storm.database import Database, Connection, ConnectionWrapper, Result from storm.exceptions import ( DatabaseError, DatabaseModuleError, InterfaceError, OperationalError, ProgrammingError, TimeoutError, Error, wrap_exceptions) from storm.tracer import TimeoutTracer compile = compile.create_child() class Returning(Expr): """Appends the "RETURNING " suffix to an INSERT or UPDATE. @param expr: an L{Insert} or L{Update} expression. @param columns: The columns to return, if C{None} then C{expr.primary_columns} will be used. This is only supported in PostgreSQL 8.2+. """ def __init__(self, expr, columns=None): self.expr = expr self.columns = columns @compile.when(Returning) def compile_returning(compile, expr, state): state.push("context", COLUMN) columns = expr.columns or expr.expr.primary_columns columns = compile(columns, state) state.pop() state.push("precedence", 0) expr = compile(expr.expr, state) state.pop() return "%s RETURNING %s" % (expr, columns) class Case(Expr): """A CASE statement. @params cases: a list of tuples of (condition, result) or (value, result), if an expression is passed too. @param expression: the expression to compare (if the simple form is used). @param default: an optional default condition if no other case matches. """ def __init__(self, cases, expression=Undef, default=Undef): self.cases = cases self.expression = expression self.default = default @compile.when(Case) def compile_case(compile, expr, state): cases = [ "WHEN %s THEN %s" % ( compile(condition, state), compile(value, state)) for condition, value in expr.cases] if expr.expression is not Undef: expression = compile(expr.expression, state) + " " else: expression = "" if expr.default is not Undef: default = " ELSE %s" % compile(expr.default, state) else: default = "" return "CASE %s%s%s END" % (expression, " ".join(cases), default) class currval(FuncExpr): name = "currval" def __init__(self, column): self.column = column @compile.when(currval) def compile_currval(compile, expr, state): """Compile a currval. This is a bit involved because we have to get escaping right. Here are a few cases to keep in mind:: currval('thetable_thecolumn_seq') currval('theschema.thetable_thecolumn_seq') currval('"the schema".thetable_thecolumn_seq') currval('theschema."the table_thecolumn_seq"') currval('theschema."thetable_the column_seq"') currval('"thetable_the column_seq"') currval('"the schema"."the table_the column_seq"') """ state.push("context", COLUMN_PREFIX) table = compile(expr.column.table, state, token=True) state.pop() column_name = compile(expr.column.name, state, token=True) if table.endswith('"'): table = table[:-1] if column_name.endswith('"'): column_name = column_name[1:-1] return "currval('%s_%s_seq\"')" % (table, column_name) elif column_name.endswith('"'): column_name = column_name[1:-1] if "." in table: schema, table = table.rsplit(".", 1) return "currval('%s.\"%s_%s_seq\"')" % (schema, table, column_name) else: return "currval('\"%s_%s_seq\"')" % (table, column_name) else: return "currval('%s_%s_seq')" % (table, column_name) @compile.when(ListVariable) def compile_list_variable(compile, list_variable, state): elements = [] variables = list_variable.get(to_db=True) if variables is None: return "NULL" if not variables: return "'{}'" for variable in variables: elements.append(compile(variable, state)) return "ARRAY[%s]" % ",".join(elements) @compile.when(SetExpr) def compile_set_expr_postgres(compile, expr, state): if expr.order_by is not Undef: # The following statement breaks in postgres: # SELECT 1 AS id UNION SELECT 1 ORDER BY id+1 # With the error: # ORDER BY on a UNION/INTERSECT/EXCEPT result must # be on one of the result columns # So we transform it into something close to: # SELECT * FROM (SELECT 1 AS id UNION SELECT 1) AS a ORDER BY id+1 # Build new set expression without arguments (order_by, etc). new_expr = expr.__class__() new_expr.exprs = expr.exprs new_expr.all = expr.all # Make sure that state.aliases isn't None, since we want them to # compile our order_by statement below. no_aliases = state.aliases is None if no_aliases: state.push("aliases", {}) # Build set expression, collecting aliases. set_stmt = SQLRaw("(%s)" % compile_set_expr(compile, new_expr, state)) # Build order_by statement, using aliases. state.push("context", COLUMN_NAME) order_by_stmt = SQLRaw(compile(expr.order_by, state)) state.pop() # Discard aliases, if they were not being collected previously. if no_aliases: state.pop() # Build wrapping select statement. select = Select(SQLRaw("*"), tables=Alias(set_stmt), limit=expr.limit, offset=expr.offset, order_by=order_by_stmt) return compile_select(compile, select, state) else: return compile_set_expr(compile, expr, state) @compile.when(Insert) def compile_insert_postgres(compile, insert, state): # PostgreSQL fails with INSERT INTO table VALUES (), so we transform # that to INSERT INTO table (id) VALUES (DEFAULT). if not insert.map and insert.primary_columns is not Undef: insert.map.update(dict.fromkeys(insert.primary_columns, SQLRaw("DEFAULT"))) return compile_insert(compile, insert, state) @compile.when(Sequence) def compile_sequence_postgres(compile, sequence, state): return "nextval('%s')" % sequence.name @compile.when(Like) def compile_like_postgres(compile, like, state): if like.case_sensitive is False: return compile_like(compile, like, state, oper=" ILIKE ") return compile_like(compile, like, state) @compile.when(SQLToken) def compile_sql_token_postgres(compile, expr, state): if "." in expr and state.context in (TABLE, COLUMN_PREFIX): return ".".join(compile_sql_token(compile, subexpr, state) for subexpr in expr.split(".")) return compile_sql_token(compile, expr, state) class PostgresResult(Result): def get_insert_identity(self, primary_key, primary_variables): equals = [] for column, variable in zip(primary_key, primary_variables): if not variable.is_defined(): # The Select here prevents PostgreSQL from going nuts and # performing a sequential scan when there *is* an index. # http://tinyurl.com/2n8mv3 variable = Select(currval(column)) equals.append(Eq(column, variable)) return And(*equals) pg_connection_failure_codes = frozenset([ '08006', # CONNECTION FAILURE '08001', # SQLCLIENT UNABLE TO ESTABLISH SQLCONNECTION '08004', # SQLSERVER REJECTED ESTABLISHMENT OF SQLCONNECTION '53300', # TOO MANY CONNECTIONS '57000', # OPERATOR INTERVENTION '57P01', # ADMIN SHUTDOWN '57P02', # CRASH SHUTDOWN '57P03', # CANNOT CONNECT NOW ]) class PostgresConnection(Connection): result_factory = PostgresResult param_mark = "%s" compile = compile def execute(self, statement, params=None, noresult=False): """Execute a statement with the given parameters. This extends the L{Connection.execute} method to add support for automatic retrieval of inserted primary keys to link in-memory objects with their specific rows. """ if (isinstance(statement, Insert) and self._database._version >= 80200 and statement.primary_variables is not Undef and statement.primary_columns is not Undef): # Here we decorate the Insert statement with a Returning # expression, so that we get back in the result the values # for the primary key just inserted. This prevents a round # trip to the database for obtaining these values. result = Connection.execute(self, Returning(statement), params) for variable, value in zip(statement.primary_variables, result.get_one()): result.set_variable(variable, value) return result return Connection.execute(self, statement, params, noresult) def to_database(self, params): """ Like L{Connection.to_database}, but this converts datetime types to strings, and bytes to L{psycopg2.Binary} instances. """ for param in params: if isinstance(param, Variable): param = param.get(to_db=True) if isinstance(param, (datetime, date, time, timedelta)): yield str(param) elif isinstance(param, bytes): yield psycopg2.Binary(param) else: yield param def is_disconnection_error(self, exc, extra_disconnection_errors=()): # Attempt to use pgcode to determine the nature of the error. This is # more reliable than string matching because it is not affected by # locale settings. Fall through if pgcode is not available. if isinstance(exc, Error): pgcode = getattr(exc, "pgcode", None) if pgcode in pg_connection_failure_codes: return True disconnection_errors = ( DatabaseError, InterfaceError, OperationalError, ProgrammingError, extra_disconnection_errors) if isinstance(exc, disconnection_errors): # When the connection is closed by a termination of pgbouncer, a # DatabaseError or subclass with no message (depending on # psycopg2 version) is raised. If the raw connection is closed # we assume it's actually a disconnection. if isinstance(exc, DatabaseError): if self._raw_connection.closed: return True msg = str(exc) return ( "SSL SYSCALL error" in msg or "EOF detected" in msg or "connection already closed" in msg or "connection not open" in msg or "could not connect to server" in msg or "could not receive data from server" in msg or "could not send data to server" in msg or "losed the connection unexpectedly" in msg or "no connection to the server" in msg or "server closed the connection unexpectedly" in msg or "terminating connection due to administrator" in msg) return False class Postgres(Database): connection_factory = PostgresConnection _exception_module = psycopg2 # An integer representing the server version. If the server does # not support the server_version_num variable, this will be set to # 0. In practice, this means the variable will be 0 or greater # than or equal to 80200. _version = None def __init__(self, uri): super().__init__(uri) if psycopg2 is dummy: raise DatabaseModuleError( "'psycopg2' >= %s not found. Found %s." % (REQUIRED_PSYCOPG2_VERSION, PSYCOPG2_VERSION)) self._dsn = make_dsn(uri) isolation = uri.options.get("isolation", "repeatable-read") isolation_mapping = { "autocommit": psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT, "serializable": psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE, "read-committed": psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED, "repeatable-read": psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ, "read-uncommitted": psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED} try: self._isolation = isolation_mapping[isolation] except KeyError: raise ValueError( "Unknown serialization level %r: expected one of " "'autocommit', 'serializable', 'read-committed'" % (isolation,)) _psycopg_error_attributes = ["pgerror", "pgcode", "cursor"] # Added in psycopg2 2.5. if hasattr(psycopg2.Error, "diag"): _psycopg_error_attributes.append("diag") _psycopg_error_attributes = tuple(_psycopg_error_attributes) def _make_combined_exception_type(self, wrapper_type, dbapi_type): combined_type = super()._make_combined_exception_type( wrapper_type, dbapi_type) for name in self._psycopg_error_attributes: setattr(combined_type, name, lambda err: getattr(err, "_" + name)) return combined_type def _wrap_exception(self, wrapper_type, exception): wrapped = super()._wrap_exception(wrapper_type, exception) for name in self._psycopg_error_attributes: setattr(wrapped, "_" + name, getattr(exception, name)) return wrapped def _raw_connect(self): raw_connection = ConnectionWrapper(psycopg2.connect(self._dsn), self) if self._version is None: cursor = raw_connection.cursor() try: cursor.execute("SHOW server_version_num") except ProgrammingError: self._version = 0 else: self._version = int(cursor.fetchone()[0]) raw_connection.rollback() raw_connection.set_client_encoding("UTF8") raw_connection.set_isolation_level(self._isolation) return raw_connection def raw_connect(self): with wrap_exceptions(self): return self._raw_connect() create_from_uri = Postgres if psycopg2 is not dummy: psycopg2.extensions.register_type(psycopg2.extensions.UNICODE) psycopg2.extensions.register_type(psycopg2._psycopg.UNICODEARRAY) def make_dsn(uri): """Convert a URI object to a PostgreSQL DSN string.""" dsn = "dbname=%s" % uri.database if uri.host is not None: dsn += " host=%s" % uri.host if uri.port is not None: dsn += " port=%d" % uri.port if uri.username is not None: dsn += " user=%s" % uri.username if uri.password is not None: dsn += " password=%s" % uri.password return dsn class PostgresTimeoutTracer(TimeoutTracer): def set_statement_timeout(self, raw_cursor, remaining_time): raw_cursor.execute("SET statement_timeout TO %d" % (remaining_time * 1000)) def connection_raw_execute_error(self, connection, raw_cursor, statement, params, error): # This should just check for # psycopg2.extensions.QueryCanceledError in the future. if (isinstance(error, DatabaseError) and "statement timeout" in str(error)): raise TimeoutError( statement, params, "SQL server cancelled statement") # Postgres-specific operators class JSONElement(BinaryOper): """Return an element of a JSON value (by index or field name).""" oper = "->" class JSONTextElement(BinaryOper): """Return an element of a JSON value (by index or field name) as text.""" oper = "->>" # Postgres-specific properties and variables class JSONVariable(BaseJSONVariable): __slots__ = () def _loads(self, value): if isinstance(value, str): # psycopg versions < 2.5 don't automatically convert JSON columns # to python objects, they return a string. # # Note that on newer versions, if the object contained is an actual # string, it's returned as unicode, so the check is still valid. return json.loads(value) return value class JSON(SimpleProperty): variable_class = JSONVariable ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/databases/sqlite.py0000644000175000017500000002271114645174376020012 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta from time import sleep, time as now import sys from storm.databases import dummy try: from pysqlite2 import dbapi2 as sqlite except ImportError: try: from sqlite3 import dbapi2 as sqlite except ImportError: sqlite = dummy from storm.variables import Variable, BytesVariable from storm.database import Database, Connection, ConnectionWrapper, Result from storm.exceptions import ( DatabaseModuleError, OperationalError, wrap_exceptions) from storm.expr import ( Insert, Select, SELECT, Undef, SQLRaw, Union, Except, Intersect, compile, compile_insert, compile_select) compile = compile.create_child() @compile.when(Select) def compile_select_sqlite(compile, select, state): if select.offset is not Undef and select.limit is Undef: if sys.maxsize > 2**32: # On 64-bit platforms sqlite doesn't like maxsize as LIMIT. See # also # https://lists.ubuntu.com/archives/storm/2013-June/001492.html select.limit = sys.maxsize - 1 else: select.limit = sys.maxsize statement = compile_select(compile, select, state) if state.context is SELECT: # SQLite breaks with (SELECT ...) UNION (SELECT ...), so we # do SELECT * FROM (SELECT ...) instead. This is important # because SELECT ... UNION SELECT ... ORDER BY binds the ORDER BY # to the UNION instead of SELECT. return "SELECT * FROM (%s)" % statement return statement # Considering the above, selects have a greater precedence. compile.set_precedence(5, Union, Except, Intersect) @compile.when(Insert) def compile_insert_sqlite(compile, insert, state): # SQLite fails with INSERT INTO table VALUES (), so we transform # that to INSERT INTO table (id) VALUES (NULL). if not insert.map and insert.primary_columns is not Undef: insert.map.update(dict.fromkeys(insert.primary_columns, None)) return compile_insert(compile, insert, state) class SQLiteResult(Result): def get_insert_identity(self, primary_key, primary_variables): return SQLRaw("(OID=%d)" % self._raw_cursor.lastrowid) @staticmethod def set_variable(variable, value): if isinstance(variable, BytesVariable) and isinstance(value, str): # pysqlite2 may return unicode. value = value.encode("UTF-8") variable.set(value, from_db=True) class SQLiteConnection(Connection): result_factory = SQLiteResult compile = compile _in_transaction = False @staticmethod def to_database(params): """ Like L{Connection.to_database}, but this also converts instances of L{datetime} types to strings. """ for param in params: if isinstance(param, Variable): param = param.get(to_db=True) if isinstance(param, (datetime, date, time, timedelta)): yield str(param) else: yield param def commit(self): self._ensure_connected() # See story at the end to understand why we do COMMIT manually. if self._in_transaction: self.raw_execute("COMMIT", _end=True) def rollback(self): # See story at the end to understand why we do ROLLBACK manually. if self._in_transaction: self.raw_execute("ROLLBACK", _end=True) def raw_execute(self, statement, params=None, _end=False): """Execute a raw statement with the given parameters. This method will automatically retry on locked database errors. This should be done by pysqlite, but it doesn't work with versions < 2.3.4, so we make sure the timeout is respected here. """ if _end: self._in_transaction = False elif not self._in_transaction: # See story at the end to understand why we do BEGIN manually. self._in_transaction = True self._raw_connection.execute("BEGIN") # Remember the time at which we started the operation. If pysqlite # handles the timeout correctly, we won't retry the operation, because # the timeout will have expired when the raw_execute() returns. started = now() while True: try: return Connection.raw_execute(self, statement, params) except OperationalError as e: if str(e) != "database is locked": raise elif now() - started < self._database._timeout: # pysqlite didn't handle the timeout correctly, # so we sleep a little and then retry. sleep(0.1) else: # The operation failed due to being unable to get a # lock on the database. In this case, we are still # in a transaction. if _end: self._in_transaction = True raise class SQLite(Database): connection_factory = SQLiteConnection _exception_module = sqlite def __init__(self, uri): super().__init__(uri) if sqlite is dummy: raise DatabaseModuleError("'pysqlite2' module not found") self._filename = uri.database or ":memory:" self._timeout = float(uri.options.get("timeout", 5)) self._synchronous = uri.options.get("synchronous") self._journal_mode = uri.options.get("journal_mode") self._foreign_keys = uri.options.get("foreign_keys") def _raw_connect(self): # See the story at the end to understand why we set isolation_level. raw_connection = ConnectionWrapper( sqlite.connect( self._filename, timeout=self._timeout, isolation_level=None), self) if self._synchronous is not None: raw_connection.execute("PRAGMA synchronous = %s" % (self._synchronous,)) if self._journal_mode is not None: raw_connection.execute("PRAGMA journal_mode = %s" % (self._journal_mode,)) if self._foreign_keys is not None: raw_connection.execute("PRAGMA foreign_keys = %s" % (self._foreign_keys,)) return raw_connection def raw_connect(self): with wrap_exceptions(self): return self._raw_connect() create_from_uri = SQLite # Here is a sad story about PySQLite2. # # PySQLite does some very dirty tricks to control the moment in # which transactions begin and end. It actually *changes* the # transactional behavior of SQLite. # # The real behavior of SQLite is that transactions are SERIALIZABLE # by default. That is, any reads are repeatable, and changes in # other threads or processes won't modify data for already started # transactions that have issued any reading or writing statements. # # PySQLite changes that in a very unpredictable way. First, it will # only actually begin a transaction if a INSERT/UPDATE/DELETE/REPLACE # operation is executed (yes, it will parse the statement). This # means that any SELECTs executed *before* one of the former mentioned # operations are seen, will be operating in READ COMMITTED mode. Then, # if after that a INSERT/UPDATE/DELETE/REPLACE is seen, the transaction # actually begins, and so it moves into SERIALIZABLE mode. # # Another pretty surprising behavior is that it will *commit* any # on-going transaction if any other statement besides # SELECT/INSERT/UPDATE/DELETE/REPLACE is seen. # # In an ORM we're really dealing with cached data, so working on top # of a system like that means that cache validity is pretty random. # # So what we do about that in this module is disabling all that hackery # by *pretending* to PySQLite that we'll work without transactions # (isolation_level=None), and then we actually take responsibility for # controlling the transaction. # # References: # http://www.sqlite.org/lockingv3.html # http://docs.python.org/lib/sqlite3-Controlling-Transactions.html # # -------------------------------------------------------------------- # Reserved words, SQLite specific # The list of reserved words here are SQLite specific. SQL92 reserved words # are registered in storm.expr, near the "Reserved words, from SQL1992" # comment. The reserved words here were taken from: # # http://www.sqlite.org/lang_keywords.html compile.add_reserved_words(""" abort after analyze attach autoincrement before conflict database detach each exclusive explain fail glob if ignore index indexed instead isnull limit notnull offset plan pragma query raise regexp reindex release rename replace row savepoint temp trigger vacuum virtual """.split()) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4171247 storm-1.0/storm/docs/0000755000175000017500000000000014645532536015131 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/Makefile0000644000175000017500000000110513663205305016554 0ustar00cjwatsoncjwatson# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/__init__.py0000644000175000017500000000000013663205305017216 0ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/api.rst0000644000175000017500000000534113663205305016425 0ustar00cjwatsoncjwatsonAPI === .. contents:: :local: Locals ------ The following names are re-exported from ``storm.locals`` for convenience: * :py:class:`storm.base.Storm` * :py:func:`storm.database.create_database` * :py:class:`storm.exceptions.StormError` * :py:class:`storm.expr.And` * :py:class:`storm.expr.Asc` * :py:class:`storm.expr.Count` * :py:class:`storm.expr.Delete` * :py:class:`storm.expr.Desc` * :py:class:`storm.expr.In` * :py:class:`storm.expr.Insert` * :py:class:`storm.expr.Join` * :py:class:`storm.expr.Like` * :py:class:`storm.expr.Max` * :py:class:`storm.expr.Min` * :py:class:`storm.expr.Not` * :py:class:`storm.expr.Or` * :py:class:`storm.expr.SQL` * :py:class:`storm.expr.Select` * :py:class:`storm.expr.Update` * :py:class:`storm.info.ClassAlias` * :py:class:`storm.properties.Bool` * :py:class:`storm.properties.Bytes` * :py:class:`storm.properties.Date` * :py:class:`storm.properties.DateTime` * :py:class:`storm.properties.Decimal` * :py:class:`storm.properties.Enum` * :py:class:`storm.properties.Float` * :py:class:`storm.properties.Int` * :py:class:`storm.properties.JSON` * :py:class:`storm.properties.List` * :py:class:`storm.properties.Pickle` * :py:class:`storm.properties.Time` * :py:class:`storm.properties.TimeDelta` * :py:class:`storm.properties.UUID` * :py:class:`storm.properties.Unicode` * :py:class:`storm.references.Proxy` * :py:class:`storm.references.Reference` * :py:class:`storm.references.ReferenceSet` * :py:data:`storm.store.AutoReload` * :py:class:`storm.store.Store` * :py:class:`storm.xid.Xid` Store ----- .. automodule:: storm.store .. autoclass:: storm.store.ResultSet .. autoclass:: storm.store.TableSet .. autodata:: storm.store.AutoReload :annotation: Defining tables and columns --------------------------- Base ~~~~ .. automodule:: storm.base Properties ~~~~~~~~~~ .. automodule:: storm.properties References ~~~~~~~~~~ .. automodule:: storm.references Variables ~~~~~~~~~ .. automodule:: storm.variables SQLObject emulation ~~~~~~~~~~~~~~~~~~~ .. automodule:: storm.sqlobject Expressions ----------- .. automodule:: storm.expr Databases --------- .. automodule:: storm.database PostgreSQL ~~~~~~~~~~ .. automodule:: storm.databases.postgres SQLite ~~~~~~ .. automodule:: storm.databases.sqlite Transaction identifiers ~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: storm.xid Hooks and events ---------------- Event ~~~~~ .. automodule:: storm.event Tracer ~~~~~~ .. automodule:: storm.tracer Miscellaneous ------------- Cache ~~~~~ .. automodule:: storm.cache Exceptions ~~~~~~~~~~ .. automodule:: storm.exceptions Info ~~~~~ .. automodule:: storm.info Testing ~~~~~~~ .. automodule:: storm.testing Timezone ~~~~~~~~ .. automodule:: storm.tz URIs ~~~~ .. automodule:: storm.uri WSGI ~~~~ .. automodule:: storm.wsgi ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/docs/conf.py0000644000175000017500000001331214645174376016434 0ustar00cjwatsoncjwatson# Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. import os import sys sys.path.insert(0, os.path.abspath('..')) # Import and document the pure-Python versions of things. os.environ['STORM_CEXTENSIONS'] = '0' # -- Project information ----------------------------------------------------- project = 'Storm' copyright = '2006-2020, Canonical Ltd.' author = 'Gustavo Niemeyer' # The short X.Y version version = '' # The full version, including alpha/beta/rc tags release = '' # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Fix missing support for @ivar and @cvar. from sphinx_epytext.process_docstring import FIELDS FIELDS.append("cvar") FIELDS.append("ivar") # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx_epytext', ] # Add any paths that contain templates here, relative to this directory. # templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'Stormdoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'Storm.tex', 'Storm Documentation', 'Gustavo Niemeyer', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'storm', 'Storm Documentation', [author], 1) ] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'Storm', 'Storm Documentation', author, 'Storm', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] intersphinx_mapping = {'https://docs.python.org/3': None} # Sphinx 1.8+ prefers this to `autodoc_default_flags`. It's documented that # either True or None mean the same thing as just setting the flag, but # only None works in 1.8 (True works in 2.0) autodoc_default_options = { 'members': None, 'show-inheritance': None, } autodoc_member_order = 'bysource' autoclass_content = 'both' ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/index.rst0000644000175000017500000000073413663205305016764 0ustar00cjwatsoncjwatson.. Storm documentation master file, created by sphinx-quickstart on Sat May 23 01:56:39 2020. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to Storm's documentation! ================================= .. toctree:: :maxdepth: 2 :caption: Contents: tutorial infoheritance zope api Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/infoheritance.rst0000644000175000017500000002024713663205305020474 0ustar00cjwatsoncjwatsonInfoheritance ============= Storm doesn't support classes that have columns in multiple tables. This makes using inheritance rather difficult. The infoheritance pattern described here provides a way to get the benefits of inheritance without running into the problems Storm has with multi-table classes. Defining a sample model ----------------------- Let's consider an inheritance hierarchy to migrate to Storm. .. code-block:: python class Person(object): def __init__(self, name): self.name = name class SecretAgent(Person): def __init__(self, name, passcode): super(SecretAgent, self).__init__(name) self.passcode = passcode class Teacher(Person): def __init__(self, name, school): super(Employee, self).__init__(name): self.school = school We want to use three tables to store data for these objects: ``person``, ``secret_agent`` and ``teacher``. We can't simply convert instance attributes to Storm properties and add ``__storm_table__`` definitions because a single object may not have columns that come from more than one table. We can't have ``Teacher`` getting its ``name`` column from the ``person`` table and its ``school`` column from the ``teacher`` table, for example. The infoheritance pattern ------------------------- The infoheritance pattern uses composition instead of inheritance to work around the multiple table limitation. A base Storm class is used to represent all objects in the hierarchy. Each instance of this base class has an info property that yields an instance of a specific info class. An info class provides the additional data and behaviour you'd normally implement in a subclass. Following is the design from above converted to use the pattern. .. doctest:: >>> from storm.locals import Storm, Store, Int, Unicode, Reference >>> person_info_types = {} >>> def register_person_info_type(info_type, info_class): ... existing_info_class = person_info_types.get(info_type) ... if existing_info_class is not None: ... raise RuntimeError("%r has the same info_type of %r" % ... (info_class, existing_info_class)) ... person_info_types[info_type] = info_class ... info_class.info_type = info_type >>> class Person(Storm): ... ... __storm_table__ = "person" ... ... id = Int(allow_none=False, primary=True) ... name = Unicode(allow_none=False) ... info_type = Int(allow_none=False) ... _info = None ... ... def __init__(self, store, name, info_class, **kwargs): ... self.name = name ... self.info_type = info_class.info_type ... store.add(self) ... self._info = info_class(self, **kwargs) ... ... @property ... def info(self): ... if self._info is not None: ... return self._info ... assert self.id is not None ... info_class = person_info_types[self.info_type] ... if not hasattr(info_class, "__storm_table__"): ... info = info_class.__new__(info_class) ... info.person = self ... else: ... info = Store.of(self).get(info_class, self.id) ... self._info = info ... return info >>> class PersonInfo(object): ... ... def __init__(self, person): ... self.person = person >>> class StoredPersonInfo(PersonInfo): ... ... person_id = Int(allow_none=False, primary=True) ... person = Reference(person_id, Person.id) >>> class SecretAgent(StoredPersonInfo): ... ... __storm_table__ = "secret_agent" ... ... passcode = Unicode(allow_none=False) ... ... def __init__(self, person, passcode=None): ... super(SecretAgent, self).__init__(person) ... self.passcode = passcode >>> class Teacher(StoredPersonInfo): ... ... __storm_table__ = "teacher" ... ... school = Unicode(allow_none=False) ... ... def __init__(self, person, school=None): ... super(Teacher, self).__init__(person) ... self.school = school The pattern works by having a base class, ``Person``, keep a reference to an info class, ``PersonInfo``. Info classes need to be registered so that ``Person`` can discover them and load them when necessary. Note that info types have the same ID as their parent object. This isn't strictly necessary, but it makes certain things easy, such as being able to look up info objects directly by ID when given a person object. ``Person`` objects are required to be in a store to ensure that an ID is available and can used by the info class. Registering info classes ------------------------ Let's register our info classes. Each class must be registered with a unique info type key. This key is stored in the database, so be sure to use a stable value. .. doctest:: >>> register_person_info_type(1, SecretAgent) >>> register_person_info_type(2, Teacher) Let's create a database to store person objects before we continue. .. doctest:: >>> from storm.locals import create_database >>> database = create_database("sqlite:") >>> store = Store(database) >>> result = store.execute(""" ... CREATE TABLE person ( ... id INTEGER PRIMARY KEY, ... info_type INTEGER NOT NULL, ... name TEXT NOT NULL) ... """) >>> result = store.execute(""" ... CREATE TABLE secret_agent ( ... person_id INTEGER PRIMARY KEY, ... passcode TEXT NOT NULL) ... """) >>> result = store.execute(""" ... CREATE TABLE teacher ( ... person_id INTEGER PRIMARY KEY, ... school TEXT NOT NULL) ... """) Creating info classes --------------------- We can easily create person objects now. .. doctest:: >>> secret_agent = Person(store, u"Dick Tracy", ... SecretAgent, passcode=u"secret!") >>> teacher = Person(store, u"Mrs. Cohen", ... Teacher, school=u"Cameron Elementary School") >>> store.commit() And we can easily find them again. .. doctest:: >>> del secret_agent >>> del teacher >>> store.rollback() >>> [type(person.info) ... for person in store.find(Person).order_by(Person.name)] [, ] Retrieving info classes ----------------------- Now that we have our basic hierarchy in place we're going to want to retrieve objects by info type. Let's implement a function to make finding ``Person``\ s easier. .. doctest:: >>> def get_persons(store, info_classes=None): ... where = [] ... if info_classes: ... info_types = [ ... info_class.info_type for info_class in info_classes] ... where = [Person.info_type.is_in(info_types)] ... result = store.find(Person, *where) ... result.order_by(Person.name) ... return result >>> secret_agent = get_persons(store, info_classes=[SecretAgent]).one() >>> print(secret_agent.name) Dick Tracy >>> print(secret_agent.info.passcode) secret! >>> teacher = get_persons(store, info_classes=[Teacher]).one() >>> print(teacher.name) Mrs. Cohen >>> print(teacher.info.school) Cameron Elementary School Great, we can easily find different kinds of ``Person``\ s. In-memory info objects ---------------------- This design also allows for in-memory info objects. Let's add one to our hierarchy. .. doctest:: >>> class Ghost(PersonInfo): ... ... friendly = True >>> register_person_info_type(3, Ghost) We create and load in-memory objects the same way we do stored ones. .. doctest:: >>> ghost = Person(store, u"Casper", Ghost) >>> store.commit() >>> del ghost >>> store.rollback() >>> ghost = get_persons(store, info_classes=[Ghost]).one() >>> print(ghost.name) Casper >>> print(ghost.info.friendly) True This pattern is very handy when using Storm with code that would naturally be implemented using inheritance. .. >>> Person._storm_property_registry.clear() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1639494180.0 storm-1.0/storm/docs/tutorial.rst0000644000175000017500000006007114156131044017514 0ustar00cjwatsoncjwatsonTutorial ======== Importing --------- Let's start by importing some names into the namespace. .. doctest:: >>> from storm.locals import * Basic definition ---------------- Now we define a type with some properties describing the information we're about to map. .. doctest:: >>> class Person(object): ... __storm_table__ = "person" ... id = Int(primary=True) ... name = Unicode() Notice that this has no Storm-defined base class or constructor. Creating a database and the store --------------------------------- We still don't have anyone to talk to, so let's define an in-memory SQLite database to play with, and a store using that database. .. doctest:: >>> database = create_database("sqlite:") >>> store = Store(database) Three databases are supported at the moment: SQLite, MySQL, and PostgreSQL. The parameter passed to :py:func:`~storm.database.create_database` is an URI, as follows: .. code-block:: python # database = create_database( # "scheme://username:password@hostname:port/database_name") The ``scheme`` may be ``sqlite``, ``mysql``, or ``postgres``. Now we have to create the table that will actually hold the data for our class. .. doctest:: >>> store.execute("CREATE TABLE person " ... "(id INTEGER PRIMARY KEY, name VARCHAR)") We got a result back, but we don't care about it for now. We could also use ``noresult=True`` to avoid the result entirely. Creating an object ------------------ Let's create an object of the defined class. .. doctest:: >>> joe = Person() >>> joe.name = u"Joe Johnes" >>> print(joe.id) None >>> print(joe.name) Joe Johnes So far this object has no connection to a database. Let's add it to the store we've created above. .. doctest:: >>> store.add(joe) <...Person object at 0x...> >>> print(joe.id) None >>> print(joe.name) Joe Johnes Notice that the object wasn't changed, even after being added to the store. That's because it wasn't flushed yet. The store of an object ---------------------- Once an object is added to a store, or retrieved from a store, its relation to that store is known. We can easily verify which store an object is bound. .. doctest:: >>> Store.of(joe) is store True >>> Store.of(Person()) is None True Finding an object ----------------- Now, what would happen if we actually asked the store to give us the person named "Joe Johnes"? .. doctest:: >>> person = store.find(Person, Person.name == u"Joe Johnes").one() >>> print(person.id) 1 >>> print(person.name) Joe Johnes The person is there! Yeah, ok, you were expecting it. :-) We can also retrieve the object using its primary key. .. doctest:: >>> print(store.get(Person, 1).name) Joe Johnes Caching behavior ---------------- One interesting thing is that this person is actually Joe, right? We've just added this object, so there's only one Joe, why would there be two different objects? There isn't. .. doctest:: >>> person is joe True What's going on behind the scenes is that each store has an object cache. When an object is linked to a store, it will be cached by the store for as long as there's a reference to the object somewhere, or while the object is dirty (has unflushed changes). Storm ensures that at least a certain number of recently used objects will stay in memory inside the transaction, so that frequently used objects are not retrieved from the database too often. Flushing -------- When we tried to find Joe in the database for the first time, we've noticed that the ``id`` property was magically assigned. This happened because the object was flushed implicitly so that the operation would affect any pending changes as well. Flushes may also happen explicitly. .. doctest:: >>> mary = Person() >>> mary.name = u"Mary Margaret" >>> store.add(mary) <...Person object at 0x...> >>> print(mary.id) None >>> print(mary.name) Mary Margaret >>> store.flush() >>> print(mary.id) 2 >>> print(mary.name) Mary Margaret Changing objects with the Store ------------------------------- Besides changing objects as usual, we can also benefit from the fact that objects are tied to a database to change them using expressions. .. doctest:: >>> store.find( ... Person, Person.name == u"Mary Margaret").set(name=u"Mary Maggie") >>> print(mary.name) Mary Maggie This operation will touch every matching object in the database, and also objects that are alive in memory. Committing ---------- Everything we've done so far is inside a transaction. At this point, we can either make these changes and any pending uncommitted changes persistent by committing them, or we can undo everything by rolling them back. We'll commit them, with something as simple as .. doctest:: >>> store.commit() That was straightforward. Everything is still the way it was, but now changes are there "for real". Rolling back ------------ Aborting changes is very straightforward as well. .. doctest:: >>> joe.name = u"Tom Thomas" Let's see if these changes are really being considered by Storm and by the database. .. doctest:: >>> person = store.find(Person, Person.name == u"Tom Thomas").one() >>> person is joe True Yes, they are. Now, for the magic step (suspense music, please). .. doctest:: >>> store.rollback() Erm.. nothing happened? Actually, something happened.. with Joe. He's back! .. doctest:: >>> print(joe.id) 1 >>> print(joe.name) Joe Johnes Constructors ------------ So, we've been working for too long with people only. Let's introduce a new kind of data in our model: companies. For the company, we'll use a constructor, just for the fun of it. It will be the simplest company class you've ever seen: .. doctest:: >>> class Company(object): ... __storm_table__ = "company" ... id = Int(primary=True) ... name = Unicode() ... ... def __init__(self, name): ... self.name = name Notice that the constructor parameter isn't optional. It could be optional, if we wanted, but our companies always have names. Let's add the table for it. .. doctest:: >>> store.execute( ... "CREATE TABLE company (id INTEGER PRIMARY KEY, name VARCHAR)", ... noresult=True) Then, create a new company. .. doctest:: >>> circus = Company(u"Circus Inc.") >>> print(circus.id) None >>> print(circus.name) Circus Inc. The ``id`` is still undefined because we haven't flushed it. In fact, we haven't even **added** the company to the store. We'll do that soon. Watch out. References and subclassing -------------------------- Now we want to assign some employees to our company. Rather than redoing the Person definition, we'll keep it as it is, since it's general, and will create a new subclass of it for employees, which include one extra field: the company id. .. doctest:: >>> class Employee(Person): ... __storm_table__ = "employee" ... company_id = Int() ... company = Reference(company_id, Company.id) ... ... def __init__(self, name): ... self.name = name Pay attention to that definition for a moment. Notice that it doesn't define what's already in person, and introduces the ``company_id``, and a ``company`` property, which is a reference to another class. It also has a constructor, but which leaves the company alone. As usual, we need a table. SQLite has no idea of what a foreign key is, so we'll not bother to define it. .. doctest:: >>> store.execute( ... "CREATE TABLE employee " ... "(id INTEGER PRIMARY KEY, name VARCHAR, company_id INTEGER)", ... noresult=True) Let's give life to Ben now. .. doctest:: >>> ben = store.add(Employee(u"Ben Bill")) >>> print(ben.id) None >>> print(ben.name) Ben Bill >>> print(ben.company_id) None We can see that they were not flushed yet. Even then, we can say that Bill works on Circus. .. doctest:: >>> ben.company = circus >>> print(ben.company_id) None >>> print(ben.company.name) Circus Inc. Of course, we still don't know the company id since it was not flushed to the database yet, and we didn't assign an id explicitly. Storm is keeping the relationship even then. If whatever is pending is flushed to the database (implicitly or explicitly), objects will get their ids, and any references are updated as well (before being flushed!). .. doctest:: >>> store.flush() >>> print(ben.company_id) 1 >>> print(ben.company.name) Circus Inc. They're both flushed to the database. Now, notice that the Circus company wasn't added to the store explicitly in any moment. Storm will do that automatically for referenced objects, for both objects (the referenced and the referencing one). Let's create another company to check something. This time we'll flush the store just after adding it. .. doctest:: >>> sweets = store.add(Company(u"Sweets Inc.")) >>> store.flush() >>> sweets.id 2 Nice, we've already got the id of the new company. So, what would happen if we changed **just the id** for Ben's company? .. doctest:: >>> ben.company_id = 2 >>> print(ben.company.name) Sweets Inc. >>> ben.company is sweets True Hah! **That** wasn't expected, was it? ;-) Let's commit everything. .. doctest:: >>> store.commit() Many-to-one reference sets -------------------------- So, while our model says that employees work for a single company (we only design normal people here), companies may of course have multiple employees. We represent that in Storm using reference sets. We won't define the company again. Instead, we'll add a new attribute to the class. .. doctest:: >>> Company.employees = ReferenceSet(Company.id, Employee.company_id) Without any further work, we can already see which employees are working for a given company. .. doctest:: >>> sweets.employees.count() 1 >>> for employee in sweets.employees: ... print(employee.id) ... print(employee.name) ... print(employee is ben) ... 1 Ben Bill True Let's create another employee, and add him to the company, rather than setting the company in the employee (it sounds better, at least). .. doctest:: >>> mike = store.add(Employee(u"Mike Mayer")) >>> sweets.employees.add(mike) That, of course, means that Mike's working for a company, and so it should be reflected elsewhere. .. doctest:: >>> mike.company_id 2 >>> mike.company is sweets True Many-to-many reference sets and composed keys --------------------------------------------- We want to represent accountants in our model as well. Companies have accountants, but accountants may also attend several companies, so we'll represent that using a many-to-many relationship. Let's create a simple class to use with accountants, and the relationship class. .. doctest:: >>> class Accountant(Person): ... __storm_table__ = "accountant" ... def __init__(self, name): ... self.name = name >>> class CompanyAccountant(object): ... __storm_table__ = "company_accountant" ... __storm_primary__ = "company_id", "accountant_id" ... company_id = Int() ... accountant_id = Int() Hey, we've just declared a class with a composed key! Now, let's use it to declare the many-to-many relationship in the company. Once more, we'll just stick the new attribute in the existent object. It may easily be defined at class definition time. Later we'll see another way to do that as well. .. doctest:: >>> Company.accountants = ReferenceSet(Company.id, ... CompanyAccountant.company_id, ... CompanyAccountant.accountant_id, ... Accountant.id) Done! The order in which attributes were defined is important, but the logic should be pretty obvious. We're missing some tables, at this point. .. doctest:: >>> store.execute( ... "CREATE TABLE accountant (id INTEGER PRIMARY KEY, name VARCHAR)", ... noresult=True) >>> store.execute( ... "CREATE TABLE company_accountant " ... "(company_id INTEGER, accountant_id INTEGER," ... " PRIMARY KEY (company_id, accountant_id))", ... noresult=True) Let's give life to a couple of accountants, and register them in both companies. .. doctest:: >>> karl = Accountant(u"Karl Kent") >>> frank = Accountant(u"Frank Fourt") >>> sweets.accountants.add(karl) >>> sweets.accountants.add(frank) >>> circus.accountants.add(frank) That's it! Really! Notice that we didn't even have to add them to the store, since it happens implicitly by linking to the other object which is already in the store, and that we didn't have to declare the relationship object, since that's known to the reference set. We can now check them. .. doctest:: >>> sweets.accountants.count() 2 >>> circus.accountants.count() 1 Even though we didn't use the ``CompanyAccountant`` object explicitly, we can check it if we're really curious. .. doctest:: >>> store.get(CompanyAccountant, (sweets.id, frank.id)) <...CompanyAccountant object at 0x...> Notice that we pass a tuple for the :py:meth:`~storm.store.Store.get` method, due to the composed key. If we wanted to know for which companies accountants are working, we could easily define a reversed relationship: .. doctest:: >>> Accountant.companies = ReferenceSet(Accountant.id, ... CompanyAccountant.accountant_id, ... CompanyAccountant.company_id, ... Company.id) >>> for name in sorted(company.name for company in frank.companies): ... print(name) Circus Inc. Sweets Inc. >>> for company in karl.companies: ... print(company.name) Sweets Inc. Joins ----- Since we've got some nice data to play with, let's try to make a few interesting queries. Let's start by checking which companies have at least one employee named Ben. We have at least two ways to do it. First, with an implicit join. .. doctest:: >>> result = store.find(Company, ... Employee.company_id == Company.id, ... Employee.name.like(u"Ben %")) >>> for company in result: ... print(company.name) Sweets Inc. Then, we can also do an explicit join. This is interesting for mapping complex SQL joins to Storm queries. .. doctest:: >>> origin = [Company, Join(Employee, Employee.company_id == Company.id)] >>> result = store.using(*origin).find( ... Company, Employee.name.like(u"Ben %")) >>> for company in result: ... print(company.name) Sweets Inc. If we already had the company, and wanted to know which of his employees were named Ben, that'd have been easier. .. doctest:: >>> result = sweets.employees.find(Employee.name.like(u"Ben %")) >>> for employee in result: ... print(employee.name) Ben Bill Sub-selects ----------- Suppose we want to find all accountants that aren't associated with a company. We can use a sub-select to get the data we want. .. doctest:: >>> laura = Accountant(u"Laura Montgomery") >>> store.add(laura) <...Accountant ...> >>> subselect = Select(CompanyAccountant.accountant_id, distinct=True) >>> result = store.find(Accountant, Not(Accountant.id.is_in(subselect))) >>> result.one() is laura True Ordering and limiting results ----------------------------- Ordering and limiting results obtained are certainly among the simplest and yet most wanted features for such tools, so we want to make them very easy to understand and use, of course. A line of code is worth a thousand words, so here are a few examples that demonstrate how it works: .. doctest:: >>> garry = store.add(Employee(u"Garry Glare")) >>> result = store.find(Employee) >>> for employee in result.order_by(Employee.name): ... print(employee.name) Ben Bill Garry Glare Mike Mayer >>> for employee in result.order_by(Desc(Employee.name)): ... print(employee.name) Mike Mayer Garry Glare Ben Bill >>> for employee in result.order_by(Employee.name)[:2]: ... print(employee.name) Ben Bill Garry Glare Multiple types with one query ----------------------------- Sometimes, it may be interesting to retrieve more than one object involved in a given query. Imagine, for instance, that besides knowing which companies have an employee named Ben, we also want to know who is the employee. This may be achieved with a query like follows: .. doctest:: >>> result = store.find((Company, Employee), ... Employee.company_id == Company.id, ... Employee.name.like(u"Ben %")) >>> for company, employee in result: ... print(company.name) ... print(employee.name) Sweets Inc. Ben Bill The Storm base class -------------------- So far we've been defining our references and reference sets using classes and their properties. This has some advantages, like being easier to debug, but also has some disadvantages, such as requiring classes to be present in the local scope, which potentially leads to circular import issues. To prevent that kind of situation, Storm supports defining these references using the stringified version of the class and property names. The only inconvenience of doing so is that all involved classes must inherit from the :py:class:`~storm.base.Storm` base class. Let's define some new classes to show that. To expose the point, we'll refer to a class before it's actually defined. .. doctest:: >>> class Country(Storm): ... __storm_table__ = "country" ... id = Int(primary=True) ... name = Unicode() ... currency_id = Int() ... currency = Reference(currency_id, "Currency.id") >>> class Currency(Storm): ... __storm_table__ = "currency" ... id = Int(primary=True) ... symbol = Unicode() >>> store.execute( ... "CREATE TABLE country " ... "(id INTEGER PRIMARY KEY, name VARCHAR, currency_id INTEGER)", ... noresult=True) >>> store.execute( ... "CREATE TABLE currency (id INTEGER PRIMARY KEY, symbol VARCHAR)", ... noresult=True) Now, let's see if it works. .. doctest:: >>> real = store.add(Currency()) >>> real.id = 1 >>> real.symbol = u"BRL" >>> brazil = store.add(Country()) >>> brazil.name = u"Brazil" >>> brazil.currency_id = 1 >>> print(brazil.currency.symbol) BRL Questions!? ;-) Loading hook ------------ Storm allows classes to define a few different hooks are called to act when certain things happen. One of the interesting hooks available is the ``__storm_loaded__`` one. Let's play with it. We'll define a temporary subclass of Person for that. .. doctest:: >>> class PersonWithHook(Person): ... def __init__(self, name): ... print("Creating %s" % name) ... self.name = name ... ... def __storm_loaded__(self): ... print("Loaded %s" % self.name) >>> earl = store.add(PersonWithHook(u"Earl Easton")) Creating Earl Easton >>> earl = store.find(PersonWithHook, name=u"Earl Easton").one() >>> store.invalidate(earl) >>> del earl >>> import gc >>> collected = gc.collect() >>> earl = store.find(PersonWithHook, name=u"Earl Easton").one() Loaded Earl Easton Note that in the first find, nothing was called, since the object was still in memory and cached. Then, we invalidated the object from Storm's internal cache and ensured that it was out-of-memory by triggering a garbage collection. After that, the object had to be retrieved from the database again, and thus the hook was called (and not the constructor!). Executing expressions --------------------- Storm also offers a way to execute expressions in a database-agnostic way, when that's necessary. For instance: .. doctest:: >>> result = store.execute(Select(Person.name, Person.id == 1)) >>> (name,) = result.get_one() >>> print(name) Joe Johnes This mechanism is used internally by Storm itself to implement the higher level features. Auto-reloading values --------------------- Storm offers some special values that may be assigned to attributes under its control. One of these values is :py:data:`~storm.store.AutoReload`. When used, it will make the object automatically reload the value from the database when touched. Even primary keys may benefit from its use, as shown below. .. doctest:: >>> from storm.locals import AutoReload >>> ruy = store.add(Person()) >>> ruy.name = u"Ruy" >>> print(ruy.id) None >>> ruy.id = AutoReload >>> print(ruy.id) 4 This may be set as the default value for any attribute, making the object be automatically flushed if necessary. Expression values ----------------- Besides auto-reloading, it's also possible to assign what we call a "lazy expression" to an attribute. Such expressions are flushed to the database when the attribute is accessed, or when the object is flushed to the database (INSERT/UPDATE time). For instance: .. doctest:: >>> ruy.name = SQL( ... "(SELECT name || ? FROM person WHERE id=4)", (" Ritcher",)) >>> print(ruy.name) Ruy Ritcher Notice that this is just an example of what **may** be done. There's no need to write SQL statements this way, if you don't want to. You may also use class-based SQL expressions provided in Storm, or even not use lazy expressions at all. Aliases ------- So now let's say that we want to find every pair of people that work for the same company. I have no idea about why one would *want* to do that, but that's a good case for us to exercise aliases. First, we create an alias for the `Employee` class. .. doctest:: >>> from storm.info import ClassAlias >>> AnotherEmployee = ClassAlias(Employee) Nice, isn't it? Now we can easily make the query we want, in a straightforward way: .. doctest:: >>> result = store.find((Employee, AnotherEmployee), ... Employee.company_id == AnotherEmployee.company_id, ... Employee.id > AnotherEmployee.id) >>> for employee1, employee2 in result: ... print(employee1.name) ... print(employee2.name) Mike Mayer Ben Bill Woah! Mike and Ben work for the same company! (Quiz for the attentive reader: why is *greater than* being used in the query above?) Debugging --------- Sometimes you just need to see which statements Storm is executing. A debug tracer built on top of Storm's tracing system can be used to see what's going on under the hood. A tracer is an object that gets notified when interesting events occur, such as when Storm executes a statement. A function to enable and disable statement tracing is provided. Statements are logged to sys.stderr by default, but a custom stream may also be used. .. doctest:: >>> import sys >>> from storm.tracer import debug >>> debug(True, stream=sys.stdout) >>> result = store.find((Employee, AnotherEmployee), ... Employee.company_id == AnotherEmployee.company_id, ... Employee.id > AnotherEmployee.id) >>> list(result) [...] EXECUTE: ...'SELECT employee.company_id, employee.id, employee.name, "...".company_id, "...".id, "...".name FROM employee, employee AS "..." WHERE employee.company_id = "...".company_id AND employee.id > "...".id', () [...] DONE [(<...Employee object at ...>, <...Employee object at ...>)] >>> debug(False) >>> list(result) [(<...Employee object at ...>, <...Employee object at ...>)] Much more! ---------- There's a lot more about Storm to be shown. This tutorial is just a way to get initiated on some of the concepts. If your questions are not answered somewhere else, feel free to ask them in the mailing list. .. >>> Currency._storm_property_registry.clear() >>> Country._storm_property_registry.clear() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1590495941.0 storm-1.0/storm/docs/zope.rst0000644000175000017500000002337413663205305016637 0ustar00cjwatsoncjwatson.. Copyright (c) 2006, 2007 Canonical Written by Jamshed Kakar This file is part of Storm Object Relational Mapper. Storm is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. Storm is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this program. If not, see . Zope integration ================ The ``storm.zope`` package contains the ZStorm utility which provides seamless integration between Storm and Zope 3's transaction system. Setting up ZStorm is quite easy. In most cases, you want to include ``storm/zope/configure.zcml`` in your application, which you would normally do in ZCML as follows: .. code-block:: xml For the purposes of this doctest we'll register ZStorm manually. >>> from zope.component import provideUtility, getUtility >>> import transaction >>> from storm.zope.interfaces import IZStorm >>> from storm.zope.zstorm import global_zstorm >>> provideUtility(global_zstorm, IZStorm) >>> zstorm = getUtility(IZStorm) >>> zstorm Awesome, now that the utility is in place we can start to use it! Getting stores -------------- The ZStorm utility allows us work with named stores. >>> zstorm.set_default_uri("test", "sqlite:") Setting a default URI for stores isn't strictly required. We could pass it as the second argument to ``zstorm.get``. Providing a default URI makes it possible to use ``zstorm.get`` more easily; this is especially handy when multiple threads are used as we'll see further on. >>> store = zstorm.get("test") >>> store ZStorm has automatically created a store instance for us. If we ask for a store by name again, we should get the same instance. >>> same_store = zstorm.get("test") >>> same_store is store True The stores provided by ZStorm are per-thread. If we ask for the named store in a different thread we should get a different instance. >>> import threading >>> thread_store = [] >>> def get_thread_store(): ... thread_store.append(zstorm.get("test")) >>> thread = threading.Thread(target=get_thread_store) >>> thread.start() >>> thread.join() >>> thread_store != [store] True Great! ZStorm abstracts away the process of creating and managing named stores. Let's move on and use the stores with Zope's transaction system. Committing transactions ----------------------- The primary purpose of ZStorm is to integrate with Zope's transaction system. Let's create a schema so we can play with some real data and see how it works. >>> result = store.execute(""" ... CREATE TABLE person ( ... id INTEGER PRIMARY KEY, ... name TEXT) ... """) >>> store.commit() We'll need a ``Person`` class to use with this database. >>> from storm.locals import Storm, Int, Unicode >>> class Person(Storm): ... ... __storm_table__ = "person" ... ... id = Int(primary=True) ... name = Unicode() ... ... def __init__(self, name): ... self.name = name Great! Let's try it out. >>> person = Person(u"John Doe") >>> store.add(person) <...Person object at ...> >>> transaction.commit() Notice that we're not using ``store.commit`` directly; we're using Zope's transaction system. Let's make sure it worked. >>> store.rollback() >>> same_person = store.find(Person).one() >>> same_person is person True Awesome! Aborting transactions --------------------- Let's make sure aborting transactions works, too. >>> store.add(Person(u"Imposter!")) <...Person object at ...> At this point a ``store.find`` should return the new object. >>> for name in sorted(person.name for person in store.find(Person)): ... print(name) Imposter! John Doe All this means is that the data has been flushed to the database; it's still not committed. If we abort the transaction the new ``Person`` object should disappear. >>> transaction.abort() >>> for person in store.find(Person): ... print(person.name) John Doe Excellent! As you can see, ZStorm makes working with SQL databases and Zope 3 very natural. ZCML ---- In the examples above we setup our stores manually. In many cases, setting up named stores via ZCML directives is more desirable. Add a stanza similar to the following to your ZCML configuration to setup a named store. .. code-block:: xml With that in place ``getUtility(IZStorm).get("test")`` will return the store named "test". Security Wrappers ----------------- Storm knows how to deal with "wrapped" objects -- the identity of any Storm-managed object does not need to be the same as the original object, by way of the "object info" system. As long as the object info can be retrieved from the wrapped objects, things work fine. To interoperate with the Zope security wrapper system, storm.zope tells Zope to exposes certain Storm-internal attributes which appear on Storm-managed objects. >>> from storm.info import get_obj_info, ObjectInfo >>> from zope.security.checker import ProxyFactory >>> from pprint import pprint >>> person = store.find(Person).one() >>> type(get_obj_info(person)) is ObjectInfo True >>> type(get_obj_info(ProxyFactory(person))) is ObjectInfo True Security-wrapped result sets can be used in the same way as unwrapped ones. >>> from zope.component.testing import ( ... setUp, ... tearDown, ... ) >>> from zope.configuration import xmlconfig >>> from zope.security.protectclass import protectName >>> import storm.zope >>> setUp() >>> _ = xmlconfig.file("configure.zcml", package=storm.zope) >>> protectName(Person, "name", "zope.Public") >>> another_person = Person(u"Jane Doe") >>> store.add(another_person) <...Person object at ...> >>> result = ProxyFactory(store.find(Person).order_by(Person.name)) >>> for person in result: ... print(person.name) Jane Doe John Doe >>> print(result[0].name) Jane Doe >>> for person in result[:1]: ... print(person.name) Jane Doe >>> another_person in result True >>> result.is_empty() False >>> result.any() <...Person object at ...> >>> print(result.first().name) Jane Doe >>> print(result.last().name) John Doe >>> print(result.count()) 2 Check ``list()`` as well as ordinary iteration: on Python 3, this tries to call ``__len__`` first (which doesn't exist, but is nevertheless allowed by the security wrapper). >>> for person in list(result): ... print(person.name) Jane Doe John Doe >>> result = ProxyFactory( ... store.find(Person, Person.name.startswith(u"John"))) >>> print(result.one().name) John Doe Security-wrapped reference sets work too. >>> _ = store.execute(""" ... CREATE TABLE team ( ... id INTEGER PRIMARY KEY, ... name TEXT) ... """) >>> _ = store.execute(""" ... CREATE TABLE teammembership ( ... id INTEGER PRIMARY KEY, ... person INTEGER NOT NULL REFERENCES person, ... team INTEGER NOT NULL REFERENCES team) ... """) >>> store.commit() >>> from storm.locals import Reference, ReferenceSet, Store >>> class TeamMembership(Storm): ... ... __storm_table__ = "teammembership" ... ... id = Int(primary=True) ... ... person_id = Int(name="person", allow_none=False) ... person = Reference(person_id, "Person.id") ... ... team_id = Int(name="team", allow_none=False) ... team = Reference(team_id, "Team.id") ... ... def __init__(self, person, team): ... self.person = person ... self.team = team >>> class Team(Storm): ... ... __storm_table__ = "team" ... ... id = Int(primary=True) ... name = Unicode() ... ... def __init__(self, name): ... self.name = name ... ... members = ReferenceSet( ... "id", "TeamMembership.team_id", ... "TeamMembership.person_id", "Person.id", ... order_by="Person.name") ... ... def addMember(self, person): ... Store.of(self).add(TeamMembership(person, self)) >>> protectName(Team, "members", "zope.Public") >>> protectName(Team, "addMember", "zope.Public") >>> doe_family = Team(U"does") >>> store.add(doe_family) <...Team object at ...> >>> doe_family = ProxyFactory(doe_family) >>> doe_family.addMember(person) >>> doe_family.addMember(another_person) >>> for member in doe_family.members: ... print(member.name) Jane Doe John Doe >>> for person in doe_family.members[:1]: ... print(person.name) Jane Doe >>> print(doe_family.members[0].name) Jane Doe >>> tearDown() ResultSet interfaces -------------------- Query results provide ``IResultSet`` (or ``ISQLObjectResultSet`` if SQLObject's compatibility layer is used). >>> from storm.zope.interfaces import IResultSet, ISQLObjectResultSet >>> from storm.store import EmptyResultSet, ResultSet >>> from storm.sqlobject import SQLObjectResultSet >>> IResultSet.implementedBy(ResultSet) True >>> IResultSet.implementedBy(EmptyResultSet) True >>> ISQLObjectResultSet.implementedBy(SQLObjectResultSet) True .. >>> Team._storm_property_registry.clear() >>> TeamMembership._storm_property_registry.clear() >>> Person._storm_property_registry.clear() >>> transaction.abort() >>> zstorm._reset() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/event.py0000644000175000017500000000650714645174376015710 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import weakref from storm import has_cextensions __all__ = ["EventSystem"] class EventSystem: """A system for managing hooks that are called when events are emitted. Hooks are callables that take the event system C{owner} as their first argument, followed by the arguments passed when emitting the event, followed by any additional C{data} arguments given when registering the hook. Hooks registered for a given event C{name} are stored without ordering: no particular call order may be assumed when an event is emitted. """ def __init__(self, owner): """ @param owner: The object that owns this event system. It is passed as the first argument to each hook function. """ self._owner_ref = weakref.ref(owner) self._hooks = {} def hook(self, name, callback, *data): """Register a hook. @param name: The name of the event for which this hook should be called. @param callback: A callable which should be called when the event is emitted. @param data: Additional arguments to pass to the callable, after the C{owner} and any arguments passed when emitting the event. """ callbacks = self._hooks.get(name) if callbacks is None: self._hooks.setdefault(name, set()).add((callback, data)) else: callbacks.add((callback, data)) def unhook(self, name, callback, *data): """Unregister a hook. This ignores attempts to unregister hooks that were not already registered. @param name: The name of the event for which this hook should no longer be called. @param callback: The callable to unregister. @param data: Additional arguments that were passed when registering the callable. """ callbacks = self._hooks.get(name) if callbacks is not None: callbacks.discard((callback, data)) def emit(self, name, *args): """Emit an event, calling any registered hooks. @param name: The name of the event. @param args: Additional arguments to pass to hooks. """ owner = self._owner_ref() if owner is not None: callbacks = self._hooks.get(name) if callbacks: for callback, data in tuple(callbacks): if callback(owner, *(args+data)) is False: callbacks.discard((callback, data)) if has_cextensions: from storm.cextensions import EventSystem ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/exceptions.py0000644000175000017500000000762514571373456016750 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from contextlib import contextmanager import sys class StormError(Exception): pass class CompileError(StormError): pass class NoTableError(CompileError): pass class ExprError(StormError): pass class NoneError(StormError): pass class PropertyPathError(StormError): pass class ClassInfoError(StormError): pass class URIError(StormError): pass class ClosedError(StormError): pass class FeatureError(StormError): pass class DatabaseModuleError(StormError): pass class StoreError(StormError): pass class NoStoreError(StormError): pass class WrongStoreError(StoreError): pass class NotFlushedError(StoreError): pass class OrderLoopError(StoreError): pass class NotOneError(StoreError): pass class UnorderedError(StoreError): pass class LostObjectError(StoreError): pass class Error(StormError): pass class Warning(StormError): pass class InterfaceError(Error): pass class DatabaseError(Error): pass class InternalError(DatabaseError): pass class OperationalError(DatabaseError): pass class ProgrammingError(DatabaseError): pass class IntegrityError(DatabaseError): pass class DataError(DatabaseError): pass class NotSupportedError(DatabaseError): pass class DisconnectionError(OperationalError): pass class TimeoutError(StormError): """Raised by timeout tracers when remining time is over.""" def __init__(self, statement, params, message=None): self.statement = statement self.params = params self.message = message def __str__(self): return ', '.join( [repr(element) for element in (self.message, self.statement, self.params) if element is not None]) class ConnectionBlockedError(StormError): """Raised when an attempt is made to use a blocked connection.""" # More generic exceptions must come later. For convenience, use the order # of definition above and then reverse it. _wrapped_exception_types = tuple(reversed(( Error, Warning, InterfaceError, DatabaseError, InternalError, OperationalError, ProgrammingError, IntegrityError, DataError, NotSupportedError, ))) @contextmanager def wrap_exceptions(database): """Context manager that re-raises DB exceptions as StormError instances.""" try: yield except Exception as e: module = database._exception_module if module is None: # This backend does not support re-raising database exceptions. raise for wrapper_type in _wrapped_exception_types: dbapi_type = getattr(module, wrapper_type.__name__, None) if (dbapi_type is not None and isinstance(dbapi_type, type) and isinstance(e, dbapi_type)): wrapped = database._wrap_exception(wrapper_type, e) tb = sys.exc_info()[2] try: raise wrapped.with_traceback(tb) from e finally: # Avoid traceback reference cycles. del wrapped, tb raise ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/expr.py0000644000175000017500000014712214645174376015544 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from decimal import Decimal from datetime import datetime, date, time, timedelta from weakref import WeakKeyDictionary from copy import copy import re from storm.exceptions import CompileError, NoTableError, ExprError from storm.variables import ( Variable, BytesVariable, UnicodeVariable, LazyValue, DateTimeVariable, DateVariable, TimeVariable, TimeDeltaVariable, BoolVariable, IntVariable, FloatVariable, DecimalVariable) from storm import Undef, has_cextensions # -------------------------------------------------------------------- # Basic compiler infrastructure def _when(self, types): """Check Compile.when. Defined here to ease the work of cextensions.""" def decorator(method): for type in types: self._local_dispatch_table[type] = method self._update_cache() return method return decorator class Compile: """Compiler based on the concept of generic functions.""" def __init__(self, parent=None): self._local_dispatch_table = {} self._local_precedence = {} self._local_reserved_words = {} self._dispatch_table = {} self._precedence = {} self._reserved_words = {} self._children = WeakKeyDictionary() self._parents = [] if parent: self._parents.extend(parent._parents) self._parents.append(parent) parent._children[self] = True self._update_cache() def _update_cache(self): for parent in self._parents: self._dispatch_table.update(parent._local_dispatch_table) self._precedence.update(parent._local_precedence) self._reserved_words.update(parent._local_reserved_words) self._dispatch_table.update(self._local_dispatch_table) self._precedence.update(self._local_precedence) self._reserved_words.update(self._local_reserved_words) for child in self._children: child._update_cache() def when(self, *types): """Decorator to include a type handler in this compiler. Use this as: >>> @compile.when(TypeA, TypeB) >>> def compile_type_a_or_b(compile, expr, state): >>> ... >>> return "THE COMPILED SQL STATEMENT" """ return _when(self, types) def add_reserved_words(self, words): """Include words to be considered reserved and thus escaped. Reserved words are escaped during compilation when they're seen in a SQLToken expression. """ self._local_reserved_words.update((word.lower(), True) for word in words) self._update_cache() def remove_reserved_words(self, words): self._local_reserved_words.update((word.lower(), None) for word in words) self._update_cache() def is_reserved_word(self, word): return self._reserved_words.get(word.lower()) is not None def create_child(self): """Create a new instance of L{Compile} which inherits from this one. This is most commonly used to customize a compiler for database-specific compilation strategies. """ return self.__class__(self) def get_precedence(self, type): return self._precedence.get(type, MAX_PRECEDENCE) def set_precedence(self, precedence, *types): for type in types: self._local_precedence[type] = precedence self._update_cache() def _compile_single(self, expr, state, outer_precedence): # FASTPATH This method is part of the fast path. Be careful when # changing it (try to profile any changes). cls = expr.__class__ dispatch_table = self._dispatch_table if cls in dispatch_table: handler = dispatch_table[cls] else: for mro_cls in cls.__mro__: # First iteration will always fail because we've already # tested that the class itself isn't in the dispatch table. if mro_cls in dispatch_table: handler = dispatch_table[mro_cls] break else: raise CompileError("Don't know how to compile type %r of %r" % (expr.__class__, expr)) inner_precedence = state.precedence = \ self._precedence.get(cls, MAX_PRECEDENCE) statement = handler(self, expr, state) if inner_precedence < outer_precedence: return "(%s)" % statement return statement def __call__(self, expr, state=None, join=", ", raw=False, token=False): """Compile the given expression into a SQL statement. @param expr: The expression to compile. @param state: An instance of State, or None, in which case it's created internally (and thus can't be accessed). @param join: The string token to use to put between subexpressions. Defaults to ", ". @param raw: If true, any string expression or subexpression will not be further compiled. @param token: If true, any string expression will be considered as a SQLToken, and quoted properly. """ # FASTPATH This method is part of the fast path. Be careful when # changing it (try to profile any changes). expr_type = type(expr) if expr_type is SQLRaw or (raw and expr_type is str): return expr if token and expr_type is str: expr = SQLToken(expr) if state is None: state = State() outer_precedence = state.precedence if expr_type is tuple or expr_type is list: compiled = [] for subexpr in expr: subexpr_type = type(subexpr) if subexpr_type is SQLRaw or (raw and subexpr_type is str): statement = subexpr elif subexpr_type is tuple or subexpr_type is list: state.precedence = outer_precedence statement = self(subexpr, state, join, raw, token) else: if token and subexpr_type is str: subexpr = SQLToken(subexpr) statement = self._compile_single(subexpr, state, outer_precedence) compiled.append(statement) statement = join.join(compiled) else: statement = self._compile_single(expr, state, outer_precedence) state.precedence = outer_precedence return statement if has_cextensions: from storm.cextensions import Compile class CompilePython(Compile): def get_matcher(self, expr): state = State() source = self(expr, state) namespace = {} code = ("def closure(parameters, bool):\n" " [%s] = parameters\n" " def match(get_column):\n" " return bool(%s)\n" " return match" % (",".join("_%d" % i for i in range(len(state.parameters))), source)) exec(code, namespace) return namespace['closure'](state.parameters, bool) class State: """All the data necessary during compilation of an expression. @ivar aliases: Dict of L{Column} instances to L{Alias} instances, specifying how columns should be compiled as aliases in very specific situations. This is typically used to work around strange deficiencies in various databases. @ivar auto_tables: The list of all implicitly-used tables. e.g., in store.find(Foo, Foo.attr==Bar.id), the tables of Bar and Foo are implicitly used because columns in them are referenced. This is used when building tables. @ivar join_tables: If not None, when Join expressions are compiled, tables seen will be added to this set. This acts as a blacklist against auto_tables when compiling Joins, because the generated statements should not refer to the table twice. @ivar context: an instance of L{Context}, specifying the context of the expression currently being compiled. @ivar precedence: Current precedence, automatically set and restored by the compiler. If an inner precedence is lower than an outer precedence, parenthesis around the inner expression are automatically emitted. """ def __init__(self): self._stack = [] self.precedence = 0 self.parameters = [] self.auto_tables = [] self.join_tables = None self.context = None self.aliases = None def push(self, attr, new_value=Undef): """Set an attribute in a way that can later be reverted with L{pop}. """ old_value = getattr(self, attr, None) self._stack.append((attr, old_value)) if new_value is Undef: new_value = copy(old_value) setattr(self, attr, new_value) return old_value def pop(self): """Revert the topmost L{push}. """ setattr(self, *self._stack.pop(-1)) compile = Compile() compile_python = CompilePython() # -------------------------------------------------------------------- # Expression contexts class Context: """ An object used to specify the nature of expected SQL expressions being compiled in a given context. """ def __init__(self, name): self._name = name def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._name) TABLE = Context("TABLE") EXPR = Context("EXPR") COLUMN = Context("COLUMN") COLUMN_PREFIX = Context("COLUMN_PREFIX") COLUMN_NAME = Context("COLUMN_NAME") SELECT = Context("SELECT") # -------------------------------------------------------------------- # Builtin type support @compile.when(bytes) def compile_bytes(compile, expr, state): state.parameters.append(BytesVariable(expr)) return "?" @compile.when(str) def compile_text(compile, expr, state): state.parameters.append(UnicodeVariable(expr)) return "?" @compile.when(int) def compile_int(compile, expr, state): state.parameters.append(IntVariable(expr)) return "?" @compile.when(float) def compile_float(compile, expr, state): state.parameters.append(FloatVariable(expr)) return "?" @compile.when(Decimal) def compile_decimal(compile, expr, state): state.parameters.append(DecimalVariable(expr)) return "?" @compile.when(bool) def compile_bool(compile, expr, state): state.parameters.append(BoolVariable(expr)) return "?" @compile.when(datetime) def compile_datetime(compile, expr, state): state.parameters.append(DateTimeVariable(expr)) return "?" @compile.when(date) def compile_date(compile, expr, state): state.parameters.append(DateVariable(expr)) return "?" @compile.when(time) def compile_time(compile, expr, state): state.parameters.append(TimeVariable(expr)) return "?" @compile.when(timedelta) def compile_timedelta(compile, expr, state): state.parameters.append(TimeDeltaVariable(expr)) return "?" @compile.when(type(None)) def compile_none(compile, expr, state): return "NULL" @compile_python.when(bytes, str, int, float, type(None)) def compile_python_builtin(compile, expr, state): return repr(expr) @compile_python.when(bool, datetime, date, time, timedelta) def compile_python_bool_and_dates(compile, expr, state): index = len(state.parameters) state.parameters.append(expr) return "_%d" % index @compile.when(Variable) def compile_variable(compile, variable, state): state.parameters.append(variable) return "?" @compile_python.when(Variable) def compile_python_variable(compile, variable, state): index = len(state.parameters) state.parameters.append(variable.get()) return "_%d" % index # -------------------------------------------------------------------- # Base classes for expressions MAX_PRECEDENCE = 1000 class Expr(LazyValue): __slots__ = () @compile_python.when(Expr) def compile_python_unsupported(compile, expr, state): raise CompileError("Can't compile python expressions with %r" % type(expr)) # A translation table that can escape a unicode string for use in a # Like() expression that uses "!" as the escape character. like_escape = { ord("!"): "!!", ord("_"): "!_", ord("%"): "!%" } class Comparable: __slots__ = () __hash__ = object.__hash__ def __eq__(self, other): if other is not None and not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Eq(self, other) def __ne__(self, other): if other is not None and not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Ne(self, other) def __gt__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Gt(self, other) def __ge__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Ge(self, other) def __lt__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Lt(self, other) def __le__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Le(self, other) def __rshift__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return RShift(self, other) def __lshift__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return LShift(self, other) def __and__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return And(self, other) def __or__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Or(self, other) def __add__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Add(self, other) def __sub__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Sub(self, other) def __mul__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Mul(self, other) def __div__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Div(self, other) __floordiv__ = __div__ __truediv__ = __div__ def __mod__(self, other): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Mod(self, other) def __neg__(self): return Neg(self) def is_in(self, others): if not isinstance(others, Expr): others = list(others) if not others: return False variable_factory = getattr(self, "variable_factory", Variable) for i, other in enumerate(others): if not isinstance(other, (Expr, Variable)): others[i] = variable_factory(value=other) return In(self, others) def like(self, other, escape=Undef, case_sensitive=None): if not isinstance(other, (Expr, Variable)): other = getattr(self, "variable_factory", Variable)(value=other) return Like(self, other, escape, case_sensitive) def lower(self): return Lower(self) def upper(self): return Upper(self) def startswith(self, prefix, case_sensitive=None): if not isinstance(prefix, str): raise ExprError("Expected text argument, got %r" % type(prefix)) pattern = prefix.translate(like_escape) + "%" return Like(self, pattern, "!", case_sensitive) def endswith(self, suffix, case_sensitive=None): if not isinstance(suffix, str): raise ExprError("Expected text argument, got %r" % type(suffix)) pattern = "%" + suffix.translate(like_escape) return Like(self, pattern, "!", case_sensitive) def contains_string(self, substring, case_sensitive=None): if not isinstance(substring, str): raise ExprError("Expected text argument, got %r" % type(substring)) pattern = "%" + substring.translate(like_escape) + "%" return Like(self, pattern, "!", case_sensitive) class ComparableExpr(Expr, Comparable): __slots__ = () class BinaryExpr(ComparableExpr): __slots__ = ("expr1", "expr2") def __init__(self, expr1, expr2): self.expr1 = expr1 self.expr2 = expr2 class CompoundExpr(ComparableExpr): __slots__ = ("exprs",) def __init__(self, *exprs): self.exprs = exprs # -------------------------------------------------------------------- # Statement expressions def has_tables(state, expr): return (expr.tables is not Undef or expr.default_tables is not Undef or state.auto_tables) def build_tables(compile, tables, default_tables, state): """Compile provided tables. Tables will be built from either C{tables}, C{state.auto_tables}, or C{default_tables}. If C{tables} is not C{Undef}, it will be used. If C{tables} is C{Undef} and C{state.auto_tables} is available, that's used instead. If neither C{tables} nor C{state.auto_tables} are available, C{default_tables} is tried as a last resort. If none of them are available, C{NoTableError} is raised. """ if tables is Undef: if state.auto_tables: tables = state.auto_tables elif default_tables is not Undef: tables = default_tables else: tables = None # If we have no elements, it's an error. if not tables: raise NoTableError("Couldn't find any tables") # If it's a single element, it's trivial. if type(tables) not in (list, tuple) or len(tables) == 1: return compile(tables, state, token=True) # If we have no joins, it's trivial as well. for elem in tables: if isinstance(elem, JoinExpr): break else: if tables is state.auto_tables: tables = {compile(table, state, token=True) for table in tables} return ", ".join(sorted(tables)) else: return compile(tables, state, token=True) # Ok, now we have to be careful. # If we're dealing with auto_tables, we have to take care of # duplicated tables, join ordering, and so on. if tables is state.auto_tables: table_stmts = set() join_stmts = set() half_join_stmts = set() # push a join_tables onto the state: compile calls below will # populate this set so that we know what tables not to include. state.push("join_tables", set()) for elem in tables: statement = compile(elem, state, token=True) if isinstance(elem, JoinExpr): if elem.left is Undef: half_join_stmts.add(statement) else: join_stmts.add(statement) else: table_stmts.add(statement) # Remove tables that were seen in join statements. table_stmts -= state.join_tables state.pop() result = ", ".join(sorted(table_stmts)+sorted(join_stmts)) if half_join_stmts: result += " " + " ".join(sorted(half_join_stmts)) return "".join(result) # Otherwise, it's just a matter of putting it together. result = [] for elem in tables: if result: if isinstance(elem, JoinExpr) and elem.left is Undef: #half-join result.append(" ") else: result.append(", ") result.append(compile(elem, state, token=True)) return "".join(result) class Select(Expr): __slots__ = ("columns", "where", "tables", "default_tables", "order_by", "group_by", "limit", "offset", "distinct", "having") def __init__(self, columns, where=Undef, tables=Undef, default_tables=Undef, order_by=Undef, group_by=Undef, limit=Undef, offset=Undef, distinct=False, having=Undef): self.columns = columns self.where = where self.tables = tables self.default_tables = default_tables self.order_by = order_by self.group_by = group_by self.limit = limit self.offset = offset self.distinct = distinct self.having = having @compile.when(Select) def compile_select(compile, select, state): tokens = ["SELECT "] state.push("auto_tables", []) state.push("context", COLUMN) if select.distinct: tokens.append("DISTINCT ") if isinstance(select.distinct, (tuple, list)): tokens.append( "ON (%s) " % compile(select.distinct, state, raw=True)) tokens.append(compile(select.columns, state)) tables_pos = len(tokens) parameters_pos = len(state.parameters) state.context = EXPR if select.where is not Undef: tokens.append(" WHERE ") tokens.append(compile(select.where, state, raw=True)) if select.group_by is not Undef: tokens.append(" GROUP BY ") tokens.append(compile(select.group_by, state, raw=True)) if select.having is not Undef: tokens.append(" HAVING ") tokens.append(compile(select.having, state, raw=True)) if select.order_by is not Undef: tokens.append(" ORDER BY ") tokens.append(compile(select.order_by, state, raw=True)) if select.limit is not Undef: tokens.append(" LIMIT %d" % select.limit) if select.offset is not Undef: tokens.append(" OFFSET %d" % select.offset) if has_tables(state, select): state.context = TABLE state.push("parameters", []) tokens.insert(tables_pos, " FROM ") tokens.insert(tables_pos+1, build_tables(compile, select.tables, select.default_tables, state)) parameters = state.parameters state.pop() state.parameters[parameters_pos:parameters_pos] = parameters state.pop() state.pop() return "".join(tokens) class Insert(Expr): """Expression representing an insert statement. @ivar map: Dictionary mapping columns to values, or a sequence of columns for a bulk insert. @ivar table: Table where the row should be inserted. @ivar default_table: Table to use if no table is explicitly provided, and no tables may be inferred from provided columns. @ivar primary_columns: Tuple of columns forming the primary key of the table where the row will be inserted. This is a hint used by backends to process the insertion of rows. @ivar primary_variables: Tuple of variables with values for the primary key of the table where the row will be inserted. This is a hint used by backends to process the insertion of rows. @ivar values: Expression or sequence of tuples of values for bulk insertion. """ __slots__ = ("map", "table", "default_table", "primary_columns", "primary_variables", "values") def __init__(self, map, table=Undef, default_table=Undef, primary_columns=Undef, primary_variables=Undef, values=Undef): self.map = map self.table = table self.default_table = default_table self.primary_columns = primary_columns self.primary_variables = primary_variables self.values = values @compile.when(Insert) def compile_insert(compile, insert, state): state.push("context", COLUMN_NAME) columns = compile(tuple(insert.map), state, token=True) state.context = TABLE table = build_tables(compile, insert.table, insert.default_table, state) state.context = EXPR values = insert.values if values is Undef: values = [tuple(insert.map.values())] if isinstance(values, Expr): compiled_values = compile(values, state) else: compiled_values = ( "VALUES (%s)" % "), (".join(compile(value, state) for value in values)) state.pop() return "".join( ["INSERT INTO ", table, " (", columns, ") ", compiled_values]) class Update(Expr): __slots__ = ("map", "where", "table", "default_table", "primary_columns") def __init__(self, map, where=Undef, table=Undef, default_table=Undef, primary_columns=Undef): self.map = map self.where = where self.table = table self.default_table = default_table self.primary_columns = primary_columns @compile.when(Update) def compile_update(compile, update, state): map = update.map state.push("context", COLUMN_NAME) sets = ["%s=%s" % (compile(col, state, token=True), compile(map[col], state)) for col in map] state.context = TABLE tokens = ["UPDATE ", build_tables(compile, update.table, update.default_table, state), " SET ", ", ".join(sets)] if update.where is not Undef: state.context = EXPR tokens.append(" WHERE ") tokens.append(compile(update.where, state, raw=True)) state.pop() return "".join(tokens) class Delete(Expr): __slots__ = ("where", "table", "default_table") def __init__(self, where=Undef, table=Undef, default_table=Undef): self.where = where self.table = table self.default_table = default_table @compile.when(Delete) def compile_delete(compile, delete, state): tokens = ["DELETE FROM ", None] state.push("context", EXPR) if delete.where is not Undef: tokens.append(" WHERE ") tokens.append(compile(delete.where, state, raw=True)) # Compile later for auto_tables support. state.context = TABLE tokens[1] = build_tables(compile, delete.table, delete.default_table, state) state.pop() return "".join(tokens) # -------------------------------------------------------------------- # Columns class Column(ComparableExpr): """Representation of a column in some table. @ivar name: Column name. @ivar table: Column table (maybe another expression). @ivar primary: Integer representing the primary key position of this column, or 0 if it's not a primary key. May be provided as a bool. @ivar variable_factory: Factory producing C{Variable} instances typed according to this column. """ __slots__ = ("name", "table", "primary", "variable_factory", "compile_cache", "compile_id") def __init__(self, name=Undef, table=Undef, primary=False, variable_factory=None): self.name = name self.table = table self.primary = int(primary) self.variable_factory = variable_factory or Variable self.compile_cache = None self.compile_id = None @compile.when(Column) def compile_column(compile, column, state): if column.table is not Undef: state.auto_tables.append(column.table) if column.table is Undef or state.context is COLUMN_NAME: if state.aliases is not None: # See compile_set_expr(). alias = state.aliases.get(column) if alias is not None: return compile(alias.name, state, token=True) if column.compile_id != id(compile): column.compile_cache = compile(column.name, state, token=True) column.compile_id = id(compile) return column.compile_cache state.push("context", COLUMN_PREFIX) table = compile(column.table, state, token=True) state.pop() if column.compile_id != id(compile): column.compile_cache = compile(column.name, state, token=True) column.compile_id = id(compile) return "%s.%s" % (table, column.compile_cache) @compile_python.when(Column) def compile_python_column(compile, column, state): index = len(state.parameters) state.parameters.append(column) return "get_column(_%d)" % index # -------------------------------------------------------------------- # Alias expressions class Alias(ComparableExpr): """A representation of "AS" alias clauses. e.g., SELECT foo AS bar. """ __slots__ = ("expr", "name") auto_counter = 0 def __init__(self, expr, name=Undef): """Create alias of C{expr} AS C{name}. If C{name} is not given, then a name will automatically be generated. """ self.expr = expr if name is Undef: Alias.auto_counter += 1 name = "_%x" % Alias.auto_counter self.name = name @compile.when(Alias) def compile_alias(compile, alias, state): name = compile(alias.name, state, token=True) if state.context is COLUMN or state.context is TABLE: return "%s AS %s" % (compile(alias.expr, state), name) return name # -------------------------------------------------------------------- # From expressions class FromExpr(Expr): __slots__ = () class Table(FromExpr): __slots__ = ("name", "compile_cache", "compile_id") def __init__(self, name): self.name = name self.compile_cache = None self.compile_id = None @compile.when(Table) def compile_table(compile, table, state): if table.compile_id != id(compile): table.compile_cache = compile(table.name, state, token=True) table.compile_id = id(compile) return table.compile_cache class JoinExpr(FromExpr): __slots__ = ("left", "right", "on") oper = "(unknown)" def __init__(self, arg1, arg2=Undef, on=Undef): # http://www.postgresql.org/docs/8.1/interactive/explicit-joins.html if arg2 is Undef: self.left = Undef self.right = arg1 self.on = on elif not isinstance(arg2, Expr) or isinstance(arg2, (FromExpr, Alias)): self.left = arg1 self.right = arg2 self.on = on else: self.left = Undef self.right = arg1 self.on = arg2 if on is not Undef: raise ExprError("Improper join arguments: (%r, %r, %r)" % (arg1, arg2, on)) @compile.when(JoinExpr) def compile_join(compile, join, state): result = [] if join.left is not Undef: statement = compile(join.left, state, token=True) result.append(statement) if state.join_tables is not None: state.join_tables.add(statement) result.append(join.oper) # Joins are left associative, so ensure joins in the right hand # argument get parentheses. state.precedence += 0.5 statement = compile(join.right, state, token=True) result.append(statement) if state.join_tables is not None: state.join_tables.add(statement) if join.on is not Undef: state.push("context", EXPR) result.append("ON") result.append(compile(join.on, state, raw=True)) state.pop() return " ".join(result) class Join(JoinExpr): __slots__ = () oper = "JOIN" class LeftJoin(JoinExpr): __slots__ = () oper = "LEFT JOIN" class RightJoin(JoinExpr): __slots__ = () oper = "RIGHT JOIN" class NaturalJoin(JoinExpr): __slots__ = () oper = "NATURAL JOIN" class NaturalLeftJoin(JoinExpr): __slots__ = () oper = "NATURAL LEFT JOIN" class NaturalRightJoin(JoinExpr): __slots__ = () oper = "NATURAL RIGHT JOIN" # -------------------------------------------------------------------- # Distinct expressions class Distinct(Expr): """Add the 'DISTINCT' prefix to an expression.""" __slots__ = ("expr") def __init__(self, expr): self.expr = expr @compile.when(Distinct) def compile_distinct(compile, distinct, state): return "DISTINCT %s" % compile(distinct.expr, state) # -------------------------------------------------------------------- # Operators class BinaryOper(BinaryExpr): __slots__ = () oper = " (unknown) " @compile.when(BinaryOper) @compile_python.when(BinaryOper) def compile_binary_oper(compile, expr, state): return "%s%s%s" % (compile(expr.expr1, state), expr.oper, compile(expr.expr2, state)) class NonAssocBinaryOper(BinaryOper): __slots__ = () oper = " (unknown) " @compile.when(NonAssocBinaryOper) @compile_python.when(NonAssocBinaryOper) def compile_non_assoc_binary_oper(compile, expr, state): expr1 = compile(expr.expr1, state) state.precedence += 0.5 # Enforce parentheses. expr2 = compile(expr.expr2, state) return "%s%s%s" % (expr1, expr.oper, expr2) class CompoundOper(CompoundExpr): __slots__ = () oper = " (unknown) " @compile.when(CompoundOper) def compile_compound_oper(compile, expr, state): return compile(expr.exprs, state, join=expr.oper) @compile_python.when(CompoundOper) def compile_compound_oper(compile, expr, state): return compile(expr.exprs, state, join=expr.oper.lower()) class Is(BinaryOper): """The SQL C{IS ...} operators, e.g. C{IS NULL}. C{Is(expr, None)} is synonymous with C{expr == None}, but is less likely to trip up linters. Unlike C{expr} or C{expr == True}, C{Is(expr, True)} returns C{FALSE} when C{expr} is C{NULL}. Unlike C{Not(expr)} or C{expr == False}, C{Is(expr, False)} returns C{FALSE} when C{expr} is C{NULL}. """ __slots__ = () oper = " IS " @compile.when(Is) def compile_is(compile, is_, state): tokens = [compile(is_.expr1, state), "IS"] if is_.expr2 is None: tokens.append("NULL") elif is_.expr2 is True: tokens.append("TRUE") elif is_.expr2 is False: tokens.append("FALSE") else: raise CompileError("expr2 must be None, True, or False") return " ".join(tokens) @compile_python.when(Is) def compile_is(compile, is_, state): return "%s is %s" % (compile(is_.expr1, state), compile(is_.expr2, state)) class IsNot(BinaryOper): """The SQL C{IS NOT ...} operators, e.g. C{IS NOT NULL}. C{IsNot(expr, None)} is synonymous with C{expr != None}, but is less likely to trip up linters. Unlike C{Not(expr)} or C{expr != True}, C{IsNot(expr, True)} returns C{TRUE} when C{expr} is C{NULL}. Unlike C{expr} or C{expr != False}, C{IsNot(expr, False)} returns C{TRUE} when C{expr} is C{NULL}. """ __slots__ = () oper = " IS NOT " @compile.when(IsNot) def compile_is_not(compile, is_not, state): tokens = [compile(is_not.expr1, state), "IS NOT"] if is_not.expr2 is None: tokens.append("NULL") elif is_not.expr2 is True: tokens.append("TRUE") elif is_not.expr2 is False: tokens.append("FALSE") else: raise CompileError("expr2 must be None, True, or False") return " ".join(tokens) @compile_python.when(IsNot) def compile_is_not(compile, is_not, state): return "%s is not %s" % ( compile(is_not.expr1, state), compile(is_not.expr2, state) ) class Eq(BinaryOper): __slots__ = () oper = " = " @compile.when(Eq) def compile_eq(compile, eq, state): if eq.expr2 is None: return "%s IS NULL" % compile(eq.expr1, state) return "%s = %s" % (compile(eq.expr1, state), compile(eq.expr2, state)) @compile_python.when(Eq) def compile_eq(compile, eq, state): return "%s == %s" % (compile(eq.expr1, state), compile(eq.expr2, state)) class Ne(BinaryOper): __slots__ = () oper = " != " @compile.when(Ne) def compile_ne(compile, ne, state): if ne.expr2 is None: return "%s IS NOT NULL" % compile(ne.expr1, state) return "%s != %s" % (compile(ne.expr1, state), compile(ne.expr2, state)) class Gt(BinaryOper): __slots__ = () oper = " > " class Ge(BinaryOper): __slots__ = () oper = " >= " class Lt(BinaryOper): __slots__ = () oper = " < " class Le(BinaryOper): __slots__ = () oper = " <= " class RShift(BinaryOper): __slots__ = () oper = " >> " class LShift(BinaryOper): __slots__ = () oper = " << " class Like(BinaryOper): __slots__ = ("escape", "case_sensitive") oper = " LIKE " def __init__(self, expr1, expr2, escape=Undef, case_sensitive=None): self.expr1 = expr1 self.expr2 = expr2 self.escape = escape self.case_sensitive = case_sensitive @compile.when(Like) def compile_like(compile, like, state, oper=None): statement = "%s%s%s" % (compile(like.expr1, state), oper or like.oper, compile(like.expr2, state)) if like.escape is not Undef: statement = "%s ESCAPE %s" % (statement, compile(like.escape, state)) return statement # It's easy to support it. Later. compile_python.when(Like)(compile_python_unsupported) class In(BinaryOper): __slots__ = () oper = " IN " @compile.when(In) def compile_in(compile, expr, state): expr1 = compile(expr.expr1, state) state.precedence = 0 # We're forcing parenthesis here. return "%s IN (%s)" % (expr1, compile(expr.expr2, state)) @compile_python.when(In) def compile_in(compile, expr, state): expr1 = compile(expr.expr1, state) state.precedence = 0 # We're forcing parenthesis here. return "%s in (%s,)" % (expr1, compile(expr.expr2, state)) class Add(CompoundOper): __slots__ = () oper = " + " class Sub(NonAssocBinaryOper): __slots__ = () oper = " - " class Mul(CompoundOper): __slots__ = () oper = " * " class Div(NonAssocBinaryOper): __slots__ = () oper = " / " class Mod(NonAssocBinaryOper): __slots__ = () oper = " % " class And(CompoundOper): __slots__ = () oper = " AND " class Or(CompoundOper): __slots__ = () oper = " OR " @compile.when(And, Or) def compile_compound_oper(compile, expr, state): return compile(expr.exprs, state, join=expr.oper, raw=True) # -------------------------------------------------------------------- # Set expressions. class SetExpr(Expr): __slots__ = ("exprs", "all", "order_by", "limit", "offset") oper = " (unknown) " def __init__(self, *exprs, **kwargs): self.exprs = exprs self.all = kwargs.get("all", False) self.order_by = kwargs.get("order_by", Undef) self.limit = kwargs.get("limit", Undef) self.offset = kwargs.get("offset", Undef) # If the first expression is of a compatible type, directly # include its sub expressions. if len(self.exprs) > 0: first = self.exprs[0] if (isinstance(first, self.__class__) and first.all == self.all and first.limit is Undef and first.offset is Undef): self.exprs = first.exprs + self.exprs[1:] @compile.when(SetExpr) def compile_set_expr(compile, expr, state): if expr.order_by is not Undef: # When ORDER BY is present, databases usually have trouble using # fully qualified column names. Because of that, we transform # pure column names into aliases, and use them in the ORDER BY. aliases = {} for subexpr in expr.exprs: if isinstance(subexpr, Select): columns = subexpr.columns if not isinstance(columns, (tuple, list)): columns = [columns] else: columns = list(columns) for i, column in enumerate(columns): if column not in aliases: if isinstance(column, Column): aliases[column] = columns[i] = Alias(column) elif isinstance(column, Alias): aliases[column.expr] = column subexpr.columns = columns state.push("context", SELECT) # In the statement: # SELECT foo UNION SELECT bar LIMIT 1 # The LIMIT 1 applies to the union results, not the SELECT bar # This ensures that parentheses will be placed around the # sub-selects in the expression. state.precedence += 0.5 oper = expr.oper if expr.all: oper += "ALL " statement = compile(expr.exprs, state, join=oper) state.precedence -= 0.5 if expr.order_by is not Undef: state.context = COLUMN_NAME if state.aliases is None: state.push("aliases", aliases) else: # Previously defined aliases have precedence. aliases.update(state.aliases) state.aliases = aliases aliases = None statement += " ORDER BY " + compile(expr.order_by, state) if aliases is not None: state.pop() if expr.limit is not Undef: statement += " LIMIT %d" % expr.limit if expr.offset is not Undef: statement += " OFFSET %d" % expr.offset state.pop() return statement class Union(SetExpr): __slots__ = () oper = " UNION " class Except(SetExpr): __slots__ = () oper = " EXCEPT " class Intersect(SetExpr): __slots__ = () oper = " INTERSECT " # -------------------------------------------------------------------- # Functions class FuncExpr(ComparableExpr): __slots__ = () name = "(unknown)" class Count(FuncExpr): __slots__ = ("column", "distinct") name = "COUNT" def __init__(self, column=Undef, distinct=False): if distinct and column is Undef: raise ValueError("Must specify column when using distinct count") self.column = column self.distinct = distinct @compile.when(Count) def compile_count(compile, count, state): if count.column is not Undef: state.push("context", EXPR) column = compile(count.column, state) state.pop() if count.distinct: return "COUNT(DISTINCT %s)" % column return "COUNT(%s)" % column return "COUNT(*)" class Func(FuncExpr): __slots__ = ("name", "args") def __init__(self, name, *args): self.name = name self.args = args class NamedFunc(FuncExpr): __slots__ = ("args",) def __init__(self, *args): self.args = args @compile.when(Func, NamedFunc) def compile_func(compile, func, state): state.push("context", EXPR) args = compile(func.args, state) state.pop() return "%s(%s)" % (func.name, args) class Max(NamedFunc): __slots__ = () name = "MAX" class Min(NamedFunc): __slots__ = () name = "MIN" class Avg(NamedFunc): __slots__ = () name = "AVG" class Sum(NamedFunc): __slots__ = () name = "SUM" class Lower(NamedFunc): __slots__ = () name = "LOWER" class Upper(NamedFunc): __slots__ = () name = "UPPER" class Coalesce(NamedFunc): __slots__ = () name = "COALESCE" class Row(NamedFunc): __slots__ = () name = "ROW" class Cast(FuncExpr): """A representation of C{CAST} clauses. e.g., C{CAST(bar AS TEXT)}.""" __slots__ = ("column", "type") name = "CAST" def __init__(self, column, type): """Create a cast of C{column} as C{type}.""" self.column = column self.type = type @compile.when(Cast) def compile_cast(compile, cast, state): """Compile L{Cast} expressions.""" state.push("context", EXPR) column = compile(cast.column, state) state.pop() return "CAST(%s AS %s)" % (column, cast.type) # -------------------------------------------------------------------- # Prefix and suffix expressions class PrefixExpr(Expr): __slots__ = ("expr",) prefix = "(unknown)" def __init__(self, expr): self.expr = expr @compile.when(PrefixExpr) def compile_prefix_expr(compile, expr, state): return "%s %s" % (expr.prefix, compile(expr.expr, state)) class SuffixExpr(Expr): __slots__ = ("expr",) suffix = "(unknown)" def __init__(self, expr): self.expr = expr @compile.when(SuffixExpr) def compile_suffix_expr(compile, expr, state): return "%s %s" % (compile(expr.expr, state, raw=True), expr.suffix) class Not(PrefixExpr): __slots__ = () prefix = "NOT" class Exists(PrefixExpr): __slots__ = () prefix = "EXISTS" class Neg(PrefixExpr): __slots__ = () prefix = "-" @compile_python.when(Neg) def compile_neg_expr(compile, expr, state): return "-%s" % compile(expr.expr, state, raw=True) class Asc(SuffixExpr): __slots__ = () suffix = "ASC" class Desc(SuffixExpr): __slots__ = () suffix = "DESC" # -------------------------------------------------------------------- # Plain SQL expressions. class SQLRaw(str): """Subtype to mark a string as something that shouldn't be compiled. This is handled internally by the compiler. """ __slots__ = () class SQLToken(str): """Marker for strings that should be considered as a single SQL token. These strings will be quoted, when needed. """ __slots__ = () is_safe_token = re.compile("^[a-zA-Z][a-zA-Z0-9_]*$").match @compile.when(SQLToken) def compile_sql_token(compile, expr, state): if is_safe_token(expr) and not compile.is_reserved_word(expr): return expr return '"%s"' % expr.replace('"', '""') @compile_python.when(SQLToken) def compile_python_sql_token(compile, expr, state): return expr class SQL(ComparableExpr): __slots__ = ("expr", "params", "tables") def __init__(self, expr, params=Undef, tables=Undef): self.expr = expr self.params = params self.tables = tables @compile.when(SQL) def compile_sql(compile, expr, state): if expr.params is not Undef: if type(expr.params) not in (tuple, list): raise CompileError("Parameters should be a list or a tuple, " "not %r" % type(expr.params)) for param in expr.params: state.parameters.append(param) if expr.tables is not Undef: state.auto_tables.append(expr.tables) return expr.expr # -------------------------------------------------------------------- # Sequences. class Sequence(Expr): """Expression representing auto-incrementing support from the database. This should be translated into the *next* value of the named auto-incrementing sequence. There's no standard way to compile a sequence, since it's very database-dependent. This may be used as follows:: class Class(object): (...) id = Int(default=Sequence("my_sequence_name")) """ __slots__ = ("name",) def __init__(self, name): self.name = name # -------------------------------------------------------------------- # Utility functions. def compare_columns(columns, values): if not columns: return Undef equals = [] if len(columns) == 1: value = values[0] if not isinstance(value, (Expr, Variable)) and value is not None: value = columns[0].variable_factory(value=value) return Eq(columns[0], value) else: for column, value in zip(columns, values): if not isinstance(value, (Expr, Variable)) and value is not None: value = column.variable_factory(value=value) equals.append(Eq(column, value)) return And(*equals) # -------------------------------------------------------------------- # Auto table class AutoTables(Expr): """This class will inject one or more entries in state.auto_tables. If the constructor is passed replace=True, it will also discard any auto_table entries injected by compiling the given expression. """ __slots__ = ("expr", "tables", "replace") def __init__(self, expr, tables, replace=False): assert type(tables) in (list, tuple) self.expr = expr self.tables = tables self.replace = replace @compile.when(AutoTables) def compile_auto_tables(compile, expr, state): if expr.replace: state.push("auto_tables", []) statement = compile(expr.expr, state) if expr.replace: state.pop() state.auto_tables.extend(expr.tables) return statement # -------------------------------------------------------------------- # Set operator precedences. compile.set_precedence(10, Select, Insert, Update, Delete) compile.set_precedence(10, Join, LeftJoin, RightJoin) compile.set_precedence(10, NaturalJoin, NaturalLeftJoin, NaturalRightJoin) compile.set_precedence(10, Union, Except, Intersect) compile.set_precedence(20, SQL) compile.set_precedence(30, Or) compile.set_precedence(40, And) compile.set_precedence(45, Is, IsNot) compile.set_precedence(50, Eq, Ne, Gt, Ge, Lt, Le, Like, In) compile.set_precedence(60, LShift, RShift) compile.set_precedence(70, Add, Sub) compile.set_precedence(80, Mul, Div, Mod) compile_python.set_precedence(10, Or) compile_python.set_precedence(20, And) compile_python.set_precedence(25, Is, IsNot) compile_python.set_precedence(30, Eq, Ne, Gt, Ge, Lt, Le, Like, In) compile_python.set_precedence(40, LShift, RShift) compile_python.set_precedence(50, Add, Sub) compile_python.set_precedence(60, Mul, Div, Mod) # -------------------------------------------------------------------- # Reserved words, from SQL1992 compile.add_reserved_words( """ absolute action add all allocate alter and any are as asc assertion at authorization avg begin between bit bit_length both by cascade cascaded case cast catalog char character char_ length character_length check close coalesce collate collation column commit connect connection constraint constraints continue convert corresponding count create cross current current_date current_time current_timestamp current_ user cursor date day deallocate dec decimal declare default deferrable deferred delete desc describe descriptor diagnostics disconnect distinct domain double drop else end end-exec escape except exception exec execute exists external extract false fetch first float for foreign found from full get global go goto grant group having hour identity immediate in indicator initially inner input insensitive insert int integer intersect interval into is isolation join key language last leading left level like local lower match max min minute module month names national natural nchar next no not null nullif numeric octet_length of on only open option or order outer output overlaps pad partial position precision prepare preserve primary prior privileges procedure public read real references relative restrict revoke right rollback rows schema scroll second section select session session_ user set size smallint some space sql sqlcode sqlerror sqlstate substring sum system_user table temporary then time timestamp timezone_ hour timezone_minute to trailing transaction translate translation trim true union unique unknown update upper usage user using value values varchar varying view when whenever where with work write year zone """.split()) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/info.py0000644000175000017500000002060214645174376015512 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from weakref import ref from storm.exceptions import ClassInfoError from storm.expr import Column, Desc, TABLE from storm.expr import compile, Table from storm.event import EventSystem from storm import Undef, has_cextensions __all__ = ["get_obj_info", "set_obj_info", "get_cls_info", "ClassInfo", "ObjectInfo", "ClassAlias"] def get_obj_info(obj): try: return obj.__storm_object_info__ except AttributeError: # Instantiate ObjectInfo first, so that it breaks gracefully, # in case the object isn't a storm object. obj_info = ObjectInfo(obj) return obj.__dict__.setdefault("__storm_object_info__", obj_info) def set_obj_info(obj, obj_info): obj.__dict__["__storm_object_info__"] = obj_info def get_cls_info(cls): if "__storm_class_info__" in cls.__dict__: # Can't use attribute access here, otherwise subclassing won't work. return cls.__dict__["__storm_class_info__"] else: cls.__storm_class_info__ = ClassInfo(cls) return cls.__storm_class_info__ class ClassInfo(dict): """Persistent Storm-related information of a class. The following attributes are defined: @ivar table: Expression from where columns will be looked up. @ivar cls: Class which should be used to build objects. @ivar columns: Tuple of column properties found in the class. @ivar primary_key: Tuple of column properties used to form the primary key @ivar primary_key_pos: Position of primary_key items in the columns tuple. """ def __init__(self, cls): self.table = getattr(cls, "__storm_table__", None) if self.table is None: raise ClassInfoError("%s.__storm_table__ missing" % repr(cls)) self.cls = cls if isinstance(self.table, str): self.table = Table(self.table) pairs = [] for attr in dir(cls): column = getattr(cls, attr, None) if isinstance(column, Column): pairs.append((attr, column)) pairs.sort() self.columns = tuple(pair[1] for pair in pairs) self.attributes = dict(pairs) storm_primary = getattr(cls, "__storm_primary__", None) if storm_primary is not None: if type(storm_primary) is not tuple: storm_primary = (storm_primary,) self.primary_key = tuple(self.attributes[attr] for attr in storm_primary) else: primary = [] primary_attrs = {} for attr, column in pairs: if column.primary != 0: if column.primary in primary_attrs: raise ClassInfoError( "%s has two columns with the same primary id: " "%s and %s" % (repr(cls), attr, primary_attrs[column.primary])) primary.append((column.primary, column)) primary_attrs[column.primary] = attr primary.sort() self.primary_key = tuple(column for i, column in primary) if not self.primary_key: raise ClassInfoError("%s has no primary key information" % repr(cls)) # columns have __eq__ implementations that do things we don't want - we # want to look these up in a dict and use identity semantics id_positions = {id(column): i for i, column in enumerate(self.columns)} self.primary_key_idx = {id(column): i for i, column in enumerate(self.primary_key)} self.primary_key_pos = tuple(id_positions[id(column)] for column in self.primary_key) __order__ = getattr(cls, "__storm_order__", None) if __order__ is None: self.default_order = Undef else: if type(__order__) is not tuple: __order__ = (__order__,) self.default_order = [] for item in __order__: if isinstance(item, str): if item.startswith("-"): prop = Desc(getattr(cls, item[1:])) else: prop = getattr(cls, item) else: prop = item self.default_order.append(prop) def __eq__(self, other): return self is other def __ne__(self, other): return self is not other __hash__ = object.__hash__ class ObjectInfo(dict): __hash__ = object.__hash__ # For get_obj_info(), an ObjectInfo is its own obj_info. __storm_object_info__ = property(lambda self: self) def __init__(self, obj): # FASTPATH This method is part of the fast path. Be careful when # changing it (try to profile any changes). # First thing, try to create a ClassInfo for the object's class. # This ensures that obj is the kind of object we expect. self.cls_info = get_cls_info(type(obj)) self.set_obj(obj) self.event = event = EventSystem(self) self.variables = variables = {} for column in self.cls_info.columns: variables[column] = \ column.variable_factory(column=column, event=event, validator_object_factory=self.get_obj) self.primary_vars = tuple(variables[column] for column in self.cls_info.primary_key) def __eq__(self, other): return self is other def __ne__(self, other): return self is not other def set_obj(self, obj): self._ref = ref(obj, self._emit_object_deleted) def get_obj(self): return self._ref() def _emit_object_deleted(self, obj_ref): self.event.emit("object-deleted") def checkpoint(self): for variable in self.variables.values(): variable.checkpoint() if has_cextensions: from storm.cextensions import ObjectInfo, get_obj_info class ClassAlias: """Create a named alias for a Storm class for use in queries. This is useful basically when the SQL 'AS' feature is desired in code using Storm queries. ClassAliases which are explicitly named (i.e., when 'name' is passed) are cached for as long as the class exists, such that the alias returned from C{ClassAlias(Foo, 'foo_alias')} will be the same object no matter how many times it's called. @param cls: The class to create the alias of. @param name: If provided, specify the name of the alias to create. """ alias_count = 0 def __new__(self_cls, cls, name=Undef): if name is Undef: use_cache = False ClassAlias.alias_count += 1 name = "_%x" % ClassAlias.alias_count else: use_cache = True cache = cls.__dict__.get("_storm_alias_cache") if cache is None: cls._storm_alias_cache = {} elif name in cache: return cache[name] alias_cls = type(cls.__name__ + "Alias", (self_cls,), {"__storm_table__": name}) alias_cls.__bases__ = (cls, self_cls) alias_cls_info = get_cls_info(alias_cls) alias_cls_info.cls = cls if use_cache: cls._storm_alias_cache[name] = alias_cls return alias_cls @compile.when(type) def compile_type(compile, expr, state): cls_info = get_cls_info(expr) table = compile(cls_info.table, state) if state.context is TABLE and issubclass(expr, ClassAlias): return "%s AS %s" % (compile(cls_info.cls, state), table) return table ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/locals.py0000644000175000017500000000261214571373456016033 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.properties import Bool, Int, Float, Bytes, RawStr, Chars, Unicode from storm.properties import List, Decimal, DateTime, Date, Time, Enum, UUID from storm.properties import TimeDelta, Pickle, JSON from storm.references import Reference, ReferenceSet, Proxy from storm.database import create_database from storm.exceptions import StormError from storm.store import Store, AutoReload from storm.expr import Select, Insert, Update, Delete, Join, SQL from storm.expr import Like, In, Asc, Desc, And, Or, Min, Max, Count, Not from storm.info import ClassAlias from storm.base import Storm from storm.xid import Xid ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/properties.py0000644000175000017500000004167114645174376016764 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from bisect import insort_left, bisect_left import weakref import sys from storm.exceptions import PropertyPathError from storm.info import get_obj_info, get_cls_info from storm.expr import Column, Undef from storm.variables import ( Variable, VariableFactory, BoolVariable, IntVariable, FloatVariable, DecimalVariable, BytesVariable, UnicodeVariable, DateTimeVariable, DateVariable, TimeVariable, TimeDeltaVariable, UUIDVariable, PickleVariable, JSONVariable, ListVariable, EnumVariable) __all__ = ["Property", "SimpleProperty", "Bool", "Int", "Float", "Decimal", "Bytes", "RawStr", "Unicode", "DateTime", "Date", "Time", "TimeDelta", "UUID", "Enum", "Pickle", "JSON", "List", "PropertyRegistry"] class Property: """A property representing a database column. Properties can be set as attributes of classes that have a C{__storm_table__}, and can then be used like ordinary Python properties on instances of the class, corresponding to database columns. """ def __init__(self, name=None, primary=False, variable_class=Variable, variable_kwargs={}): """ @param name: The name of this property. @param primary: A boolean indicating whether this property is a primary key. @param variable_class: The type of L{storm.variables.Variable} corresponding to this property. @param variable_kwargs: Dictionary of keyword arguments to be passed when constructing the underlying variable. """ self._name = name self._primary = primary self._variable_class = variable_class self._variable_kwargs = variable_kwargs def __get__(self, obj, cls=None): if obj is None: return self._get_column(cls) obj_info = get_obj_info(obj) if cls is None: # Don't get obj.__class__ because we don't trust it # (might be proxied or whatever). cls = obj_info.cls_info.cls column = self._get_column(cls) return obj_info.variables[column].get() def __set__(self, obj, value): obj_info = get_obj_info(obj) # Don't get obj.__class__ because we don't trust it # (might be proxied or whatever). column = self._get_column(obj_info.cls_info.cls) obj_info.variables[column].set(value) def __delete__(self, obj): obj_info = get_obj_info(obj) # Don't get obj.__class__ because we don't trust it # (might be proxied or whatever). column = self._get_column(obj_info.cls_info.cls) obj_info.variables[column].delete() def _detect_attr_name(self, used_cls): self_id = id(self) for cls in used_cls.__mro__: for attr, prop in cls.__dict__.items(): if id(prop) == self_id: return attr raise RuntimeError("Property used in an unknown class") def _get_column(self, cls): # Cache per-class column values in the class itself, to avoid # holding a strong reference to it here, and thus rendering # classes uncollectable in certain situations (e.g. subclasses # where the property is stored in the base). try: # Use class dictionary explicitly to get sensible # results on subclasses. column = cls.__dict__["_storm_columns"].get(self) except KeyError: cls._storm_columns = {} column = None if column is None: attr = self._detect_attr_name(cls) if self._name is None: name = attr else: name = self._name column = PropertyColumn(self, cls, attr, name, self._primary, self._variable_class, self._variable_kwargs) cls._storm_columns[self] = column return column class PropertyColumn(Column): def __init__(self, prop, cls, attr, name, primary, variable_class, variable_kwargs): Column.__init__(self, name, cls, primary, VariableFactory(variable_class, column=self, validator_attribute=attr, **variable_kwargs)) self.cls = cls # Used by references # Copy attributes from the property to avoid one additional # function call on each access. for attr in ["__get__", "__set__", "__delete__"]: setattr(self, attr, getattr(prop, attr)) class SimpleProperty(Property): variable_class = None def __init__(self, name=None, primary=False, **kwargs): """ @param name: The name of this property. @param primary: A boolean indicating whether this property is a primary key. @param default: The initial value of this variable. The default behavior is for the value to stay undefined until it is set with L{set}. @param default_factory: If specified, this will immediately be called to get the initial value. @param allow_none: A boolean indicating whether None should be allowed to be set as the value of this variable. @param validator: Validation function called whenever trying to set the variable to a non-db value. The function should look like validator(object, attr, value), where the first and second arguments are the result of validator_object_factory() (or None, if this parameter isn't provided) and the value of validator_attribute, respectively. When called, the function should raise an error if the value is unacceptable, or return the value to be used in place of the original value otherwise. @param kwargs: Other keyword arguments passed through when constructing the underlying variable. """ kwargs["value"] = kwargs.pop("default", Undef) kwargs["value_factory"] = kwargs.pop("default_factory", Undef) Property.__init__(self, name, primary, self.variable_class, kwargs) class Bool(SimpleProperty): """Boolean property. This accepts integer, L{float}, or L{decimal.Decimal} values, and stores them as booleans. """ variable_class = BoolVariable class Int(SimpleProperty): """Integer property. This accepts integer, L{float}, or L{decimal.Decimal} values, and stores them as integers. """ variable_class = IntVariable class Float(SimpleProperty): """Float property. This accepts integer, L{float}, or L{decimal.Decimal} values, and stores them as floating-point values. """ variable_class = FloatVariable class Decimal(SimpleProperty): """Decimal property. This accepts integer or L{decimal.Decimal} values, and stores them as text strings containing their decimal representation. """ variable_class = DecimalVariable class Bytes(SimpleProperty): """Bytes property. This accepts L{bytes} or L{memoryview} objects, and stores them as byte strings. Deprecated aliases: L{Chars}, L{RawStr}. """ variable_class = BytesVariable # OBSOLETE: Bytes was Chars in 0.9. This will die soon. Chars = Bytes # DEPRECATED: Bytes was RawStr until 0.22. RawStr = Bytes class Unicode(SimpleProperty): """Unicode property. This accepts L{str} objects, and stores them as text strings. """ variable_class = UnicodeVariable class DateTime(SimpleProperty): """Date and time property. This accepts aware L{datetime.datetime} objects and stores them as timestamps; it also accepts integer or L{float} objects, converting them using L{datetime.utcfromtimestamp}. Note that it does not accept naive L{datetime.datetime} objects (those that do not have timezone information). """ variable_class = DateTimeVariable class Date(SimpleProperty): """Date property. This accepts L{datetime.date} objects and stores them as datestamps; it also accepts L{datetime.datetime} objects, converting them using L{datetime.datetime.date}. """ variable_class = DateVariable class Time(SimpleProperty): """Time property. This accepts L{datetime.time} objects and stores them as datestamps; it also accepts L{datetime.datetime} objects, converting them using L{datetime.datetime.time}. """ variable_class = TimeVariable class TimeDelta(SimpleProperty): """Time delta property. This accepts L{datetime.timedelta} objects and stores them as time intervals. """ variable_class = TimeDeltaVariable class UUID(SimpleProperty): """UUID property. This accepts L{uuid.UUID} objects and stores them as their text representation. """ variable_class = UUIDVariable class Pickle(SimpleProperty): """Pickle property. This accepts any object that can be serialized using L{pickle}, and stores it as a byte string containing its pickled representation. """ variable_class = PickleVariable class JSON(SimpleProperty): """JSON property. This accepts any object that can be serialized using L{json}, and stores it as a text string containing its JSON representation. """ variable_class = JSONVariable class List(SimpleProperty): """List property. This accepts iterable objects and stores them as a list where each element is an object of the given value type. """ variable_class = ListVariable def __init__(self, name=None, **kwargs): """ @param name: The name of this property. @param type: An instance of L{Property} defining the type of each element of this list. @param default_factory: If specified, this will immediately be called to get the initial value. @param validator: Validation function called whenever trying to set the variable to a non-db value. The function should look like validator(object, attr, value), where the first and second arguments are the result of validator_object_factory() (or None, if this parameter isn't provided) and the value of validator_attribute, respectively. When called, the function should raise an error if the value is unacceptable, or return the value to be used in place of the original value otherwise. @param kwargs: Other keyword arguments passed through when constructing the underlying variable. """ if "default" in kwargs: raise ValueError("'default' not allowed for List. " "Use 'default_factory' instead.") type = kwargs.pop("type", None) if type is None: type = Property() kwargs["item_factory"] = VariableFactory(type._variable_class, **type._variable_kwargs) SimpleProperty.__init__(self, name, **kwargs) class Enum(SimpleProperty): """Enumeration property, allowing used values to differ from stored ones. For instance:: class Class(Storm): prop = Enum(map={"one": 1, "two": 2}) obj.prop = "one" assert obj.prop == "one" obj.prop = 1 # Raises error. Another example:: class Class(Storm): prop = Enum(map={"one": 1, "two": 2}, set_map={"um": 1}) obj.prop = "um" assert obj.prop is "one" obj.prop = "one" # Raises error. """ variable_class = EnumVariable def __init__(self, name=None, primary=False, **kwargs): set_map = dict(kwargs.pop("map")) get_map = {value: key for key, value in set_map.items()} if "set_map" in kwargs: set_map = dict(kwargs.pop("set_map")) kwargs["get_map"] = get_map kwargs["set_map"] = set_map SimpleProperty.__init__(self, name, primary, **kwargs) class PropertyRegistry: """ An object which remembers the Storm properties specified on classes, and is able to translate names to these properties. """ def __init__(self): self._properties = [] def get(self, name, namespace=None): """Translate a property name path to the actual property. This method accepts a property name like C{"id"} or C{"Class.id"} or C{"module.path.Class.id"}, and tries to find a unique class/property with the given name. When the C{namespace} argument is given, the registry will be able to disambiguate names by choosing the one that is closer to the given namespace. For instance C{get("Class.id", "a.b.c")} will choose C{a.Class.id} rather than C{d.Class.id}. """ key = ".".join(reversed(name.split(".")))+"." i = bisect_left(self._properties, (key,)) l = len(self._properties) best_props = [] if namespace is None: while i < l and self._properties[i][0].startswith(key): path, prop_ref = self._properties[i] prop = prop_ref() if prop is not None: best_props.append((path, prop)) i += 1 else: namespace_parts = ("." + namespace).split(".") best_path_info = (0, sys.maxsize) while i < l and self._properties[i][0].startswith(key): path, prop_ref = self._properties[i] prop = prop_ref() if prop is None: i += 1 continue path_parts = path.split(".") path_parts.reverse() common_prefix = 0 for part, ns_part in zip(path_parts, namespace_parts): if part == ns_part: common_prefix += 1 else: break path_info = (-common_prefix, len(path_parts)-common_prefix) if path_info < best_path_info: best_path_info = path_info best_props = [(path, prop)] elif path_info == best_path_info: best_props.append((path, prop)) i += 1 if not best_props: raise PropertyPathError("Path '%s' matches no known property." % name) elif len(best_props) > 1: paths = [".".join(reversed(path.split(".")[:-1])) for path, prop in best_props] raise PropertyPathError("Path '%s' matches multiple " "properties: %s" % (name, ", ".join(paths))) return best_props[0][1] def add_class(self, cls): """Register properties of C{cls} so that they may be found by C{get()}. """ suffix = cls.__module__.split(".") suffix.append(cls.__name__) suffix.reverse() suffix = ".%s." % ".".join(suffix) cls_info = get_cls_info(cls) for attr in cls_info.attributes: prop = cls_info.attributes[attr] prop_ref = weakref.KeyedRef(prop, self._remove, None) pair = (attr+suffix, prop_ref) prop_ref.key = pair insort_left(self._properties, pair) def add_property(self, cls, prop, attr_name): """Register property of C{cls} so that it may be found by C{get()}. """ suffix = cls.__module__.split(".") suffix.append(cls.__name__) suffix.reverse() suffix = ".%s." % ".".join(suffix) prop_ref = weakref.KeyedRef(prop, self._remove, None) pair = (attr_name+suffix, prop_ref) prop_ref.key = pair insort_left(self._properties, pair) def clear(self): """Clean up all properties in the registry. Used by tests. """ del self._properties[:] def _remove(self, ref): self._properties.remove(ref.key) class PropertyPublisherMeta(type): """A metaclass that associates subclasses with Storm L{PropertyRegistry}s. """ def __init__(self, name, bases, dict): if not hasattr(self, "_storm_property_registry"): self._storm_property_registry = PropertyRegistry() elif hasattr(self, "__storm_table__"): self._storm_property_registry.add_class(self) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/references.py0000644000175000017500000011720014645174376016701 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import weakref from storm.exceptions import ( ClassInfoError, FeatureError, NoStoreError, WrongStoreError) from storm.store import Store, get_where_for_args, LostObjectError from storm.variables import LazyValue from storm.expr import ( Select, Column, Exists, ComparableExpr, SuffixExpr, LeftJoin, Not, SQLRaw, compare_columns, compile) from storm.info import get_cls_info, get_obj_info __all__ = ["Reference", "ReferenceSet", "Proxy"] class LazyAttribute: """ This descriptor will call the named attribute builder to initialize the given attribute on first access. It avoids having a test at every single place where the attribute is touched when lazy initialization is wanted, and prevents paying the price of a normal property when classes are seldomly instantiated (the case of references). """ def __init__(self, attr, attr_builder): self._attr = attr self._attr_builder = attr_builder def __get__(self, obj, cls=None): getattr(obj, self._attr_builder)() return getattr(obj, self._attr) class PendingReferenceValue(LazyValue): """Lazy value to be used as a marker for unflushed foreign keys. When a reference is set to an object which is still unflushed, the foreign key in the local object remains set to this value until the object is flushed. """ PendingReferenceValue = PendingReferenceValue() class Reference: """Descriptor for one-to-one relationships. This is typically used when the class that it is being defined on has a foreign key onto another table:: class OtherGuy(object): ... id = Int() class MyGuy(object): ... other_guy_id = Int() other_guy = Reference(other_guy_id, OtherGuy.id) but can also be used for backwards references, where OtherGuy's table has a foreign key onto the class that you want this property on:: class OtherGuy(object): ... my_guy_id = Int() # in the database, a foreign key to my_guy.id class MyGuy(object): ... id = Int() other_guy = Reference(id, OtherGuy.my_guy_id, on_remote=True) In both cases, C{MyGuy().other_guy} will resolve to the C{OtherGuy} instance which is linked to it. In the first case, it will be the C{OtherGuy} instance whose C{id} is equivalent to the C{MyGuy}'s C{other_guy_id}; in the second, it'll be the C{OtherGuy} instance whose C{my_guy_id} is equivalent to the C{MyGuy}'s C{id}. Assigning to the property, for example with C{MyGuy().other_guy = OtherGuy()}, will link the objects and update either C{MyGuy.other_guy_id} or C{OtherGuy.my_guy_id} accordingly. String references may be used in place of L{storm.expr.Column} objects throughout, and are resolved to columns using L{PropertyResolver}. """ # Must initialize _relation later because we don't want to resolve # string references at definition time, since classes refered to might # not be available yet. Notice that this attribute is "public" to the # Proxy class and the SQLObject wrapper. It's still underlined because # it's *NOT* part of the public API of Storm (we'll modify it without # warnings!). _relation = LazyAttribute("_relation", "_build_relation") def __init__(self, local_key, remote_key, on_remote=False): """ Create a Reference property. @param local_key: The sibling column which is the foreign key onto C{remote_key}. (unless C{on_remote} is passed; see below). @param remote_key: The column on the referred-to object which will have the same value as that for C{local_key} when resolved on an instance. @param on_remote: If specified, then the reference is backwards: It is the C{remote_key} which is a foreign key onto C{local_key}. """ self._local_key = local_key self._remote_key = remote_key self._on_remote = on_remote self._cls = None def __get__(self, local, cls=None): if local is not None: # Don't use local here, as it might be security proxied. local = get_obj_info(local).get_obj() if self._cls is None: self._cls = _find_descriptor_class(cls or local.__class__, self) if local is None: return self remote = self._relation.get_remote(local) if remote is not None: return remote if self._relation.local_variables_are_none(local): return None store = Store.of(local) if store is None: return None if self._relation.remote_key_is_primary: remote = store.get(self._relation.remote_cls, self._relation.get_local_variables(local)) else: where = self._relation.get_where_for_remote(local) result = store.find(self._relation.remote_cls, where) remote = result.one() if remote is not None: self._relation.link(local, remote) return remote def __set__(self, local, remote): # Don't use local here, as it might be security proxied or something. local = get_obj_info(local).get_obj() if self._cls is None: self._cls = _find_descriptor_class(local.__class__, self) if remote is None: if self._on_remote: remote = self.__get__(local) if remote is None: return else: remote = self._relation.get_remote(local) if remote is None: remote_info = None else: remote_info = get_obj_info(remote) self._relation.unlink(get_obj_info(local), remote_info, True) else: # Don't use remote here, as it might be # security proxied or something. try: remote = get_obj_info(remote).get_obj() except ClassInfoError: pass # It might fail when remote is a tuple or a raw value. self._relation.link(local, remote, True) def _build_relation(self): resolver = PropertyResolver(self, self._cls) self._local_key = resolver.resolve(self._local_key) self._remote_key = resolver.resolve(self._remote_key) self._relation = Relation(self._local_key, self._remote_key, False, self._on_remote) def __eq__(self, other): return self._relation.get_where_for_local(other) def __ne__(self, other): return Not(self == other) __hash__ = object.__hash__ class ReferenceSet: """Descriptor for many-to-one and many-to-many reference sets. This is typically used when another class has a foreign key onto the class being defined, either directly (the many-to-one case) or via an intermediate table (the many-to-many case). For instance:: class Person(Storm): ... id = Int(primary=True) email_addresses = ReferenceSet("id", "EmailAddress.owner_id") class EmailAddress(Storm): ... owner_id = Int(name="owner", allow_none=False) owner = Reference(owner_id, "Person.id") class TeamMembership(Storm): ... person_id = Int(name="person", allow_none=False) person = Reference(person_id, "Person.id") team_id = Int(name="team", allow_none=False) team = Reference(team_id, "Team.id") class Team(Storm): ... id = Int(primary=True) members = ReferenceSet( "id", "TeamMembership.team_id", "TeamMembership.person_id", "Person.id", order_by="Person.name") In this case, C{Person().email_addresses} resolves to a L{BoundReferenceSet} of all the email addresses linked to that person (a many-to-one relationship), while C{Team().members} resolves to a L{BoundIndirectReferenceSet} of all the members of that team (a many-to-many relationship). These can be used in a somewhat similar way to L{ResultSet } objects. String references may be used in place of L{storm.expr.Column} objects throughout, and are resolved to columns using L{PropertyResolver}. """ # Must initialize later because we don't want to resolve string # references at definition time, since classes refered to might # not be available yet. _relation1 = LazyAttribute("_relation1", "_build_relations") _relation2 = LazyAttribute("_relation2", "_build_relations") _order_by = LazyAttribute("_order_by", "_build_relations") def __init__(self, local_key1, remote_key1, remote_key2=None, local_key2=None, order_by=None): """ @param local_key1: The sibling column which has the same value as that for C{remote_key1} when resolved on an instance. @param remote_key1: The column on the referring object (in the case of a many-to-one relation) or on the intermediate table (in the case of a many-to-many relation) which is the foreign key onto C{local_key1}. @param remote_key2: In the case of a many-to-many relation, the column on the intermediate table which is the foreign key onto C{local_key2}. @param local_key2: In the case of a many-to-many relation, the column on the referred-to object which has the same value as C{remote_key2} when resolved on an instance. @param order_by: If not C{None}, order the resolved L{BoundReferenceSet} or L{BoundIndirectReferenceSet} by these columns, as in L{storm.store.ResultSet.order_by}. """ self._local_key1 = local_key1 self._remote_key1 = remote_key1 self._remote_key2 = remote_key2 self._local_key2 = local_key2 self._default_order_by = order_by self._cls = None def __get__(self, local, cls=None): if local is not None: # Don't use local here, as it might be security proxied. local = get_obj_info(local).get_obj() if self._cls is None: self._cls = _find_descriptor_class(cls or local.__class__, self) if local is None: return self #store = Store.of(local) #if store is None: # return None if self._relation2 is None: return BoundReferenceSet(self._relation1, local, self._order_by) else: return BoundIndirectReferenceSet(self._relation1, self._relation2, local, self._order_by) def __set__(self, local, value): raise FeatureError("Assigning to ReferenceSets not supported") def _build_relations(self): resolver = PropertyResolver(self, self._cls) if self._default_order_by is not None: self._order_by = resolver.resolve(self._default_order_by) else: self._order_by = None self._local_key1 = resolver.resolve(self._local_key1) self._remote_key1 = resolver.resolve(self._remote_key1) self._relation1 = Relation(self._local_key1, self._remote_key1, True, True) if self._local_key2 and self._remote_key2: self._local_key2 = resolver.resolve(self._local_key2) self._remote_key2 = resolver.resolve(self._remote_key2) self._relation2 = Relation(self._local_key2, self._remote_key2, True, True) else: self._relation2 = None class BoundReferenceSetBase: def find(self, *args, **kwargs): store = Store.of(self._local) if store is None: raise NoStoreError("Can't perform operation without a store") where = self._get_where_clause() result = store.find(self._target_cls, where, *args, **kwargs) if self._order_by is not None: result.order_by(*self._order_by) return result def __iter__(self): return self.find().__iter__() def __getitem__(self, index): return self.find().__getitem__(index) def __contains__(self, item): return item in self.find() def is_empty(self): return self.find().is_empty() def first(self, *args, **kwargs): return self.find(*args, **kwargs).first() def last(self, *args, **kwargs): return self.find(*args, **kwargs).last() def any(self, *args, **kwargs): return self.find(*args, **kwargs).any() def one(self, *args, **kwargs): return self.find(*args, **kwargs).one() def values(self, *columns): return self.find().values(*columns) def order_by(self, *args): return self.find().order_by(*args) def count(self): return self.find().count() class BoundReferenceSet(BoundReferenceSetBase): """An instance of a many-to-one relation.""" def __init__(self, relation, local, order_by): self._relation = relation self._local = local self._target_cls = self._relation.remote_cls self._order_by = order_by def _get_where_clause(self): return self._relation.get_where_for_remote(self._local) def clear(self, *args, **kwargs): set_kwargs = {} for remote_column in self._relation.remote_key: set_kwargs[remote_column.name] = None store = Store.of(self._local) if store is None: raise NoStoreError("Can't perform operation without a store") where = self._relation.get_where_for_remote(self._local) store.find(self._target_cls, where, *args, **kwargs).set(**set_kwargs) def add(self, remote): self._relation.link(self._local, remote, True) def remove(self, remote): self._relation.unlink(get_obj_info(self._local), get_obj_info(remote), True) class BoundIndirectReferenceSet(BoundReferenceSetBase): """An instance of a many-to-many relation.""" def __init__(self, relation1, relation2, local, order_by): self._relation1 = relation1 self._relation2 = relation2 self._local = local self._order_by = order_by self._target_cls = relation2.local_cls self._link_cls = relation1.remote_cls def _get_where_clause(self): return (self._relation1.get_where_for_remote(self._local) & self._relation2.get_where_for_join()) def clear(self, *args, **kwargs): store = Store.of(self._local) if store is None: raise NoStoreError("Can't perform operation without a store") where = self._relation1.get_where_for_remote(self._local) if args or kwargs: filter = get_where_for_args(args, kwargs, self._target_cls) join = self._relation2.get_where_for_join() table = get_cls_info(self._target_cls).table where &= Exists(Select(SQLRaw("*"), join & filter, tables=table)) store.find(self._link_cls, where).remove() def add(self, remote): link = self._link_cls() self._relation1.link(self._local, link, True) # Don't use remote here, as it might be security proxied or something. remote = get_obj_info(remote).get_obj() self._relation2.link(remote, link, True) def remove(self, remote): store = Store.of(self._local) if store is None: raise NoStoreError("Can't perform operation without a store") # Don't use remote here, as it might be security proxied or something. remote = get_obj_info(remote).get_obj() where = (self._relation1.get_where_for_remote(self._local) & self._relation2.get_where_for_remote(remote)) store.find(self._link_cls, where).remove() class Proxy(ComparableExpr): """Proxy exposes a referred object's column as a local column. For example:: class Foo(object): bar_id = Int() bar = Reference(bar_id, Bar.id) bar_title = Proxy(bar, Bar.title) For most uses, C{Foo.bar_title} should behave as if it were a native property of C{Foo}. """ class RemoteProp: """ This descriptor will resolve and set the _remote_prop attribute when it's first used. It avoids having a test at every single place where the attribute is touched. """ def __get__(self, obj, cls=None): resolver = PropertyResolver(obj, obj._cls) obj._remote_prop = resolver.resolve_one(obj._unresolved_prop) return obj._remote_prop _remote_prop = RemoteProp() def __init__(self, reference, remote_prop): self._reference = reference self._unresolved_prop = remote_prop self._cls = None def __get__(self, obj, cls=None): if self._cls is None: self._cls = _find_descriptor_class(cls, self) if obj is None: return self # Have you counted how many descriptors we're dealing with here? ;-) return self._remote_prop.__get__(self._reference.__get__(obj)) def __set__(self, obj, value): return self._remote_prop.__set__(self._reference.__get__(obj), value) @property def variable_factory(self): return self._remote_prop.variable_factory @compile.when(Proxy) def compile_proxy(compile, proxy, state): # Inject the join between the table of the class holding the proxy # and the table of the class which is the target of the reference. left_join = LeftJoin(proxy._reference._relation.local_cls, proxy._remote_prop.table, proxy._reference._relation.get_where_for_join()) state.auto_tables.append(left_join) # And compile the remote property normally. return compile(proxy._remote_prop, state) class Relation: def __init__(self, local_key, remote_key, many, on_remote): assert type(local_key) is tuple and type(remote_key) is tuple self.local_key = local_key self.remote_key = remote_key self.local_cls = getattr(self.local_key[0], "cls", None) self.remote_cls = self.remote_key[0].cls self.remote_key_is_primary = False primary_key = get_cls_info(self.remote_cls).primary_key if len(primary_key) == len(self.remote_key): for column1, column2 in zip(self.remote_key, primary_key): if column1.name != column2.name: break else: self.remote_key_is_primary = True self.many = many self.on_remote = on_remote # XXX These should probably be weak dictionaries. self._local_columns = {} self._remote_columns = {} self._l_to_r = {} self._r_to_l = {} def get_remote(self, local): """Return the remote object for this relation, using the local cache. If the object in the cache is invalidated, we validate it again to check if it's still in the database. """ local_info = get_obj_info(local) try: obj = local_info[self]["remote"] except KeyError: return None remote_info = get_obj_info(obj) if remote_info.get("invalidated"): try: Store.of(obj)._validate_alive(remote_info) except LostObjectError: return None return obj def get_where_for_remote(self, local): """Generate a column comparison expression for reference properties. The returned expression may be used to find objects of the I{remote} type referring to C{local}. """ local_variables = self.get_local_variables(local) for variable in local_variables: if not variable.is_defined(): Store.of(local).flush() break return compare_columns(self.remote_key, local_variables) def get_where_for_local(self, other): """Generate a column comparison expression for reference properties. The returned expression may be used to find objects of the I{local} type referring to C{other}. It handles the following cases:: Class.reference == obj Class.reference == obj.id Class.reference == (obj.id1, obj.id2) Where the right-hand side is the C{other} object given. """ try: obj_info = get_obj_info(other) except ClassInfoError: if type(other) is not tuple: remote_variables = (other,) else: remote_variables = other else: # Don't use other here, as it might be # security proxied or something. other = get_obj_info(other).get_obj() remote_variables = self.get_remote_variables(other) return compare_columns(self.local_key, remote_variables) def get_where_for_join(self): return compare_columns(self.local_key, self.remote_key) def get_local_variables(self, local): local_info = get_obj_info(local) return tuple(local_info.variables[column] for column in self._get_local_columns(local.__class__)) def local_variables_are_none(self, local): """Return true if all variables of the local key have None values.""" local_info = get_obj_info(local) for column in self._get_local_columns(local.__class__): if local_info.variables[column].get() is not None: return False return True def get_remote_variables(self, remote): remote_info = get_obj_info(remote) return tuple(remote_info.variables[column] for column in self._get_remote_columns(remote.__class__)) def link(self, local, remote, setting=False): """Link objects to represent their relation. @param local: Object representing the I{local} side of the reference. @param remote: Object representing the I{remote} side of the reference, or the actual value to be set as the local key. @param setting: Pass true when the relationship is being newly created. """ local_info = get_obj_info(local) try: remote_info = get_obj_info(remote) except ClassInfoError: # Must be a plain key. Just set it. # XXX I guess this is broken if self.on_remote is True. local_variables = self.get_local_variables(local) if type(remote) is not tuple: remote = (remote,) assert len(remote) == len(local_variables) for variable, value in zip(local_variables, remote): variable.set(value) return local_store = Store.of(local) remote_store = Store.of(remote) if setting: if local_store is None: if remote_store is None: local_info.event.hook("added", self._add_all, local_info) remote_info.event.hook("added", self._add_all, local_info) else: remote_store.add(local) local_store = remote_store elif remote_store is None: local_store.add(remote) elif local_store is not remote_store: raise WrongStoreError("%r and %r cannot be linked because they " "are in different stores." % (local, remote)) # In cases below, we maintain a reference to the remote object # to make sure it won't get deallocated while the link is active. relation_data = local_info.get(self) if self.many: if relation_data is None: relation_data = local_info[self] = {"remote": {remote_info: remote}} else: relation_data["remote"][remote_info] = remote else: if relation_data is None: relation_data = local_info[self] = {"remote": remote} else: old_remote = relation_data.get("remote") if old_remote is not None: self.unlink(local_info, get_obj_info(old_remote)) relation_data["remote"] = remote if setting: local_vars = local_info.variables remote_vars = remote_info.variables pairs = zip(self._get_local_columns(local.__class__), self.remote_key) if self.on_remote: local_has_changed = False for local_column, remote_column in pairs: local_var = local_vars[local_column] if not local_var.is_defined(): remote_vars[remote_column].set(PendingReferenceValue) else: remote_vars[remote_column].set(local_var.get()) if local_var.has_changed(): local_has_changed = True if local_has_changed: self._add_flush_order(local_info, remote_info) local_info.event.hook("changed", self._track_local_changes, remote_info) local_info.event.hook("flushed", self._break_on_local_flushed, remote_info) #local_info.event.hook("removed", self._break_on_local_removed, # remote_info) remote_info.event.hook("removed", self._break_on_remote_removed, weakref.ref(local_info)) else: remote_has_changed = False for local_column, remote_column in pairs: remote_var = remote_vars[remote_column] if not remote_var.is_defined(): local_vars[local_column].set(PendingReferenceValue) else: local_vars[local_column].set(remote_var.get()) if remote_var.has_changed(): remote_has_changed = True if remote_has_changed: self._add_flush_order(local_info, remote_info, remote_first=True) remote_info.event.hook("changed", self._track_remote_changes, local_info) remote_info.event.hook("flushed", self._break_on_remote_flushed, local_info) #local_info.event.hook("removed", self._break_on_remote_removed, # local_info) local_info.event.hook("changed", self._break_on_local_diverged, remote_info) else: local_info.event.hook("changed", self._break_on_local_diverged, remote_info) remote_info.event.hook("changed", self._break_on_remote_diverged, weakref.ref(local_info)) if self.on_remote: remote_info.event.hook("removed", self._break_on_remote_removed, weakref.ref(local_info)) def unlink(self, local_info, remote_info, setting=False): """Break the relation between the local and remote objects. @param setting: If true objects will be changed to persist breakage. """ unhook = False relation_data = local_info.get(self) if relation_data is not None: if self.many: remote_infos = relation_data["remote"] if remote_info in remote_infos: remote_infos.pop(remote_info, None) unhook = True else: if relation_data.pop("remote", None) is not None: unhook = True if unhook: local_store = Store.of(local_info) local_info.event.unhook("changed", self._track_local_changes, remote_info) local_info.event.unhook("changed", self._break_on_local_diverged, remote_info) local_info.event.unhook("flushed", self._break_on_local_flushed, remote_info) remote_info.event.unhook("changed", self._track_remote_changes, local_info) remote_info.event.unhook("changed", self._break_on_remote_diverged, weakref.ref(local_info)) remote_info.event.unhook("flushed", self._break_on_remote_flushed, local_info) remote_info.event.unhook("removed", self._break_on_remote_removed, weakref.ref(local_info)) if local_store is None: if not self.many or not remote_infos: local_info.event.unhook("added", self._add_all, local_info) remote_info.event.unhook("added", self._add_all, local_info) else: flush_order = relation_data.get("flush_order") if flush_order is not None and remote_info in flush_order: if self.on_remote: local_store.remove_flush_order(local_info, remote_info) else: local_store.remove_flush_order(remote_info, local_info) flush_order.remove(remote_info) if setting: if self.on_remote: remote_vars = remote_info.variables for remote_column in self.remote_key: remote_vars[remote_column].set(None) else: local_vars = local_info.variables local_cols = self._get_local_columns(local_info.cls_info.cls) for local_column in local_cols: local_vars[local_column].set(None) def _add_flush_order(self, local_info, remote_info, remote_first=False): """Tell the Store to flush objects in the specified order. We need to conditionally remove the flush order in unlink() only if we added it here. Note that we can't just check if the Store has ordering on the (local, remote) pair, since it may have more than one request for ordering it, from different relations. @param local_info: The object info for the local object. @param remote_info: The object info for the remote object. @param remote_first: If True, remote_info will be flushed before local_info. """ local_store = Store.of(local_info) if local_store is not None: flush_order = local_info[self].setdefault("flush_order", set()) if remote_info not in flush_order: flush_order.add(remote_info) if remote_first: local_store.add_flush_order(remote_info, local_info) else: local_store.add_flush_order(local_info, remote_info) def _track_local_changes(self, local_info, local_variable, old_value, new_value, fromdb, remote_info): """Deliver changes in local to remote. This hook ensures that the remote object will keep track of changes done in the local object, either manually or at flushing time. """ remote_column = self._get_remote_column(local_info.cls_info.cls, local_variable.column) if remote_column is not None: remote_info.variables[remote_column].set(new_value) self._add_flush_order(local_info, remote_info) def _track_remote_changes(self, remote_info, remote_variable, old_value, new_value, fromdb, local_info): """Deliver changes in remote to local. This hook ensures that the local object will keep track of changes done in the remote object, either manually or at flushing time. """ local_column = self._get_local_column(local_info.cls_info.cls, remote_variable.column) if local_column is not None: local_info.variables[local_column].set(new_value) self._add_flush_order(local_info, remote_info, remote_first=True) def _break_on_local_diverged(self, local_info, local_variable, old_value, new_value, fromdb, remote_info): """Break the remote/local relationship on diverging changes. This hook ensures that if the local object has an attribute changed by hand in a way that diverges from the remote object, it stops tracking changes. """ remote_column = self._get_remote_column(local_info.cls_info.cls, local_variable.column) if remote_column is not None: variable = remote_info.variables[remote_column] if variable.get_lazy() is None and variable.get() != new_value: self.unlink(local_info, remote_info) def _break_on_remote_diverged(self, remote_info, remote_variable, old_value, new_value, fromdb, local_info_ref): """Break the remote/local relationship on diverging changes. This hook ensures that if the remote object has an attribute changed by hand in a way that diverges from the local object, the relationship is undone. """ local_info = local_info_ref() if local_info is None: return local_column = self._get_local_column(local_info.cls_info.cls, remote_variable.column) if local_column is not None: local_value = local_info.variables[local_column].get() if local_value != new_value: self.unlink(local_info, remote_info) def _break_on_local_flushed(self, local_info, remote_info): """Break the remote/local relationship on flush.""" self.unlink(local_info, remote_info) def _break_on_remote_flushed(self, remote_info, local_info): """Break the remote/local relationship on flush.""" self.unlink(local_info, remote_info) def _break_on_remote_removed(self, remote_info, local_info_ref): """Break the remote relationship when the remote object is removed.""" local_info = local_info_ref() if local_info is not None: self.unlink(local_info, remote_info) def _add_all(self, obj_info, local_info): store = Store.of(obj_info) store.add(local_info) local_info.event.unhook("added", self._add_all, local_info) def add(remote_info): remote_info.event.unhook("added", self._add_all, local_info) store.add(remote_info) self._add_flush_order(local_info, remote_info, remote_first=(not self.on_remote)) if self.many: for remote_info in local_info[self]["remote"]: add(remote_info) else: add(get_obj_info(local_info[self]["remote"])) def _get_remote_columns(self, remote_cls): try: return self._remote_columns[remote_cls] except KeyError: columns = tuple(prop.__get__(None, remote_cls) for prop in self.remote_key) self._remote_columns[remote_cls] = columns return columns def _get_local_columns(self, local_cls): try: return self._local_columns[local_cls] except KeyError: columns = tuple(prop.__get__(None, local_cls) for prop in self.local_key) self._local_columns[local_cls] = columns return columns def _get_remote_column(self, local_cls, local_column): try: return self._l_to_r[local_cls].get(local_column) except KeyError: map = {} for local_prop, _remote_column in zip(self.local_key, self.remote_key): map[local_prop.__get__(None, local_cls)] = _remote_column return self._l_to_r.setdefault(local_cls, map).get(local_column) def _get_local_column(self, local_cls, remote_column): try: return self._r_to_l[local_cls].get(remote_column) except KeyError: map = {} for local_prop, _remote_column in zip(self.local_key, self.remote_key): map[_remote_column] = local_prop.__get__(None, local_cls) return self._r_to_l.setdefault(local_cls, map).get(remote_column) class PropertyResolver: """Transform strings and pure properties (non-columns) into columns.""" def __init__(self, reference, used_cls): self._reference = reference self._used_cls = used_cls self._registry = None self._namespace = None def resolve(self, properties): if not type(properties) is tuple: return (self.resolve_one(properties),) return tuple(self.resolve_one(property) for property in properties) def resolve_one(self, property): if type(property) is tuple: return self.resolve(property) elif isinstance(property, str): return self._resolve_string(property) elif isinstance(property, SuffixExpr): # XXX This covers cases like order_by=Desc("Bar.id"), see #620369. # Eventually we might want to add support for other types of # expressions property.expr = self.resolve(property.expr) return property elif not isinstance(property, Column): return _find_descriptor_obj(self._used_cls, property) return property def _resolve_string(self, property_path): if self._registry is None: try: self._registry = self._used_cls._storm_property_registry except AttributeError: raise RuntimeError("When using strings on references, " "classes involved must be subclasses " "of 'Storm'") cls = _find_descriptor_class(self._used_cls, self._reference) self._namespace = "%s.%s" % (cls.__module__, cls.__name__) return self._registry.get(property_path, self._namespace) def _find_descriptor_class(used_cls, descr): for cls in used_cls.__mro__: for attr, _descr in cls.__dict__.items(): if _descr is descr: return cls raise RuntimeError("Reference used in an unknown class") def _find_descriptor_obj(used_cls, descr): for cls in used_cls.__mro__: for attr, _descr in cls.__dict__.items(): if _descr is descr: return getattr(cls, attr) raise RuntimeError("Reference used in an unknown class") ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4171247 storm-1.0/storm/schema/0000755000175000017500000000000014645532536015441 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/schema/__init__.py0000644000175000017500000000150514571373456017555 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.schema.schema import Schema ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/schema/patch.py0000644000175000017500000002227514645174376017126 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """Apply database patches. The L{PatchApplier} class can be used to apply and keep track of a series of database patches. To create a patch series all is needed is to add Python files under a module of choice, an name them as 'patch_N.py' where 'N' is the version of the patch in the series. Each patch file must define an C{apply} callable taking a L{Store} instance has its only argument. This function will be called when the patch gets applied. The L{PatchApplier} can be then used to apply to a L{Store} all the available patches. After a patch has been applied, its version is recorded in a special 'patch' table in the given L{Store}, and it won't be applied again. """ import sys import os import re import types from storm.locals import StormError, Int class UnknownPatchError(Exception): """ Raised if a patch is found in the database that doesn't exist in the local patch directory. """ def __init__(self, store, patches): self._store = store self._patches = patches def __str__(self): return "store has patches the code doesn't know about: %s" % ( ", ".join([str(version) for version in self._patches])) class BadPatchError(Exception): """Raised when a patch failing with a random exception is found.""" class Patch: """Database object representing an applied patch. @version: The version of the patch associated with this object. """ __storm_table__ = "patch" version = Int(primary=True, allow_none=False) def __init__(self, version): self.version = version class PatchApplier: """Apply to a L{Store} the database patches from a given Python package. @param store: The L{Store} to apply the patches to. @param patch_set: The L{PatchSet} containing the patches to apply. @param committer: Optionally an object implementing 'commit()' and 'rollback()' methods, to be used to commit or rollback the changes after applying a patch. If C{None} is given, the C{store} itself is used. """ def __init__(self, store, patch_set, committer=None): self._store = store if isinstance(patch_set, types.ModuleType): # Up to version 0.20.0 the second positional parameter used to # be a raw module containing the patches. We wrap it with PatchSet # for keeping backward-compatibility. patch_set = PatchSet(patch_set) self._patch_set = patch_set if committer is None: committer = store self._committer = committer def apply(self, version): """Execute the patch with the given version. This will call the 'apply' function defined in the patch file with the given version, passing it our L{Store}. @param version: The version of the patch to execute. """ patch = Patch(version) self._store.add(patch) module = None try: module = self._patch_set.get_patch_module(version) module.apply(self._store) except StormError: self._committer.rollback() raise except: type, value, traceback = sys.exc_info() patch_repr = getattr(module, "__file__", version) raise BadPatchError( "Patch %s failed: %s: %s" % (patch_repr, type.__name__, str(value)) ).with_traceback(traceback) self._committer.commit() def apply_all(self): """Execute all unapplied patches. @raises UnknownPatchError: If the patch table has versions for which no patch file actually exists. """ self.check_unknown() for version in self.get_unapplied_versions(): self.apply(version) def mark_applied(self, version): """Mark the patch with the given version as applied.""" self._store.add(Patch(version)) self._committer.commit() def mark_applied_all(self): """Mark all unapplied patches as applied.""" for version in self.get_unapplied_versions(): self.mark_applied(version) def has_pending_patches(self): """Return C{True} if there are unapplied patches, C{False} if not.""" for version in self.get_unapplied_versions(): return True return False def get_unknown_patch_versions(self): """ Return the list of Patch versions that have been applied to the database, but don't appear in the schema's patches module. """ applied = self._get_applied_patches() known_patches = self._patch_set.get_patch_versions() unknown_patches = set() for patch in applied: if not patch in known_patches: unknown_patches.add(patch) return unknown_patches def check_unknown(self): """Look for patches that we don't know about. @raises UnknownPatchError: If the store has applied patch versions this schema doesn't know about. """ unknown_patches = self.get_unknown_patch_versions() if unknown_patches: raise UnknownPatchError(self._store, unknown_patches) def get_unapplied_versions(self): """Return the versions of all unapplied patches.""" applied = self._get_applied_patches() for version in self._patch_set.get_patch_versions(): if version not in applied: yield version def _get_applied_patches(self): """Return the versions of all applied patches.""" applied = set() for patch in self._store.find(Patch): applied.add(patch.version) return applied class PatchSet: """A collection of patch modules. Each patch module lives in a regular Python module file, contained in a sub-directory named against the patch version. For example, given a directory tree like: mypackage/ __init__.py patch_1/ __init__.py foo.py the following code will return a patch module object for foo.py: >>> import mypackage >>> patch_set = PackagePackage(mypackage, sub_level="foo") >>> patch_module = patch_set.get_patch_module(1) >>> print(patch_module.__name__) 'mypackage.patch_1.foo' Different sub-levels can be used to apply different patches to different stores (see L{Sharding}). Alternatively if no sub-level is provided, the structure will be flat: mypackage/ __init__.py patch_1.py >>> import mypackage >>> patch_set = PackagePackage(mypackage) >>> patch_module = patch_set.get_patch_module(1) >>> print(patch_module.__name__) 'mypackage.patch_1' This simpler structure can be used if you have just one store to patch or you don't care to co-ordinate the patches across your stores. """ def __init__(self, package, sub_level=None): self._package = package self._sub_level = sub_level def get_patch_versions(self): """Return the versions of all available patches.""" pattern = r"^patch_(\d+)" if not self._sub_level: pattern += ".py" pattern += "$" format = re.compile(pattern) patch_directory = self._get_patch_directory() filenames = os.listdir(patch_directory) matches = [(format.match(fn), fn) for fn in filenames] matches = sorted(filter(lambda x: x[0], matches), key=lambda x: int(x[0].group(1))) return [int(match.group(1)) for match, filename in matches] def get_patch_module(self, version): """Import the Python module of the patch file with the given version. @param: The version of the module patch to import. @return: The imported module. """ name = "patch_%d" % version levels = [self._package.__name__, name] if self._sub_level: directory = self._get_patch_directory() path = os.path.join(directory, name, self._sub_level + ".py") if not os.path.exists(path): return _EmptyPatchModule() levels.append(self._sub_level) return __import__(".".join(levels), None, None, ['']) def _get_patch_directory(self): """Get the path to the directory of the patch package.""" return os.path.dirname(self._package.__file__) class _EmptyPatchModule: """Fake module object with a no-op C{apply} function.""" def apply(self, store): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/schema/schema.py0000644000175000017500000001632514645174376017266 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """Manage database shemas. The L{Schema} class can be used to create, drop, clean and upgrade database schemas. A database L{Schema} is defined by the series of SQL statements that should be used to create, drop and clear the schema, respectively and by a patch module used to upgrade it (see also L{PatchApplier}). For example: >>> store = Store(create_database('sqlite:')) >>> creates = ['CREATE TABLE person (id INTEGER, name TEXT)'] >>> drops = ['DROP TABLE person'] >>> deletes = ['DELETE FROM person'] >>> import patch_module >>> patch_set = PatchSet(patch_module) >>> schema = Schema(creates, drops, deletes, patch_set) >>> schema.create(store) where patch_module is a Python module containing database patches used to upgrade the schema over time. """ import types from storm.locals import StormError from storm.schema.patch import PatchApplier, PatchSet class SchemaMissingError(Exception): """Raised when a L{Store} has no schema at all.""" class UnappliedPatchesError(Exception): """Raised when a L{Store} has unapplied schema patches. @ivar unapplied_versions: A list containing all unapplied patch versions. """ def __init__(self, unapplied_versions): self.unapplied_versions = unapplied_versions class Schema: """Create, drop, clean and patch table schemas. @param creates: A list of C{CREATE TABLE} statements. @param drops: A list of C{DROP TABLE} statements. @param deletes: A list of C{DELETE FROM} statements. @param patch_set: The L{PatchSet} containing patch modules to apply. @param committer: Optionally a committer to pass to the L{PatchApplier}. @see: L{PatchApplier}. """ _create_patch = "CREATE TABLE patch (version INTEGER NOT NULL PRIMARY KEY)" _drop_patch = "DROP TABLE IF EXISTS patch" _autocommit = True def __init__(self, creates, drops, deletes, patch_set, committer=None): self._creates = creates self._drops = drops self._deletes = deletes if isinstance(patch_set, types.ModuleType): # Up to version 0.20.0 the fourth positional parameter used to # be a raw module containing the patches. We wrap it with PatchSet # for keeping backward-compatibility. patch_set = PatchSet(patch_set) self._patch_set = patch_set self._committer = committer def _execute_statements(self, store, statements): """Execute the given statements in the given store.""" for statement in statements: try: store.execute(statement) except Exception: print("Error running %s" % statement) raise if self._autocommit: store.commit() def autocommit(self, flag): """Control whether to automatically commit/rollback schema changes. The default is C{True}, if set to C{False} it's up to the calling code to handle commits and rollbacks. @note: In case of rollback the exception will just be propagated, and no rollback on the store will be performed. """ self._autocommit = flag def check(self, store): """Check that the given L{Store} is compliant with this L{Schema}. @param store: The L{Store} to check. @raises SchemaMissingError: If there is no schema at all. @raises UnappliedPatchesError: If there are unapplied schema patches. @raises UnknownPatchError: If the store has patches the schema doesn't. """ # Let's create a savepoint here: the select statement below is just # used to test if the patch exists and we don't want to rollback # the whole transaction in case it fails. store.execute("SAVEPOINT schema") try: store.execute("SELECT * FROM patch WHERE version=0") except StormError: # No schema at all. Create it from the ground. store.execute("ROLLBACK TO SAVEPOINT schema") raise SchemaMissingError() else: store.execute("RELEASE SAVEPOINT schema") patch_applier = self._build_patch_applier(store) patch_applier.check_unknown() unapplied_versions = list(patch_applier.get_unapplied_versions()) if unapplied_versions: raise UnappliedPatchesError(unapplied_versions) def create(self, store): """Run C{CREATE TABLE} SQL statements with C{store}. @raises SchemaAlreadyCreatedError: If the schema for this store was already created. """ self._execute_statements(store, [self._create_patch]) self._execute_statements(store, self._creates) patch_applier = self._build_patch_applier(store) patch_applier.mark_applied_all() def drop(self, store): """Run C{DROP TABLE} SQL statements with C{store}.""" self._execute_statements(store, self._drops) self._execute_statements(store, [self._drop_patch]) def delete(self, store): """Run C{DELETE FROM} SQL statements with C{store}.""" self._execute_statements(store, self._deletes) def upgrade(self, store): """Upgrade C{store} to have the latest schema. If a schema isn't present a new one will be created. Unapplied patches will be applied to an existing schema. """ patch_applier = self._build_patch_applier(store) try: self.check(store) except SchemaMissingError: # No schema at all. Create it from the ground. self.create(store) except UnappliedPatchesError as error: patch_applier.check_unknown() for version in error.unapplied_versions: self.advance(store, version) def advance(self, store, version): """Advance the schema of C{store} by applying the next unapplied patch. @return: The version of patch that has been applied or C{None} if no patch was applied (i.e. the schema is fully upgraded). """ patch_applier = self._build_patch_applier(store) patch_applier.apply(version) def _build_patch_applier(self, store): """Build a L{PatchApplier} to use for the given C{store}.""" committer = self._committer if not self._autocommit: committer = _NoopCommitter() return PatchApplier(store, self._patch_set, committer) class _NoopCommitter: """Dummy committer that does nothing.""" def commit(self): pass def rollback(self): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/schema/sharding.py0000644000175000017500000001015014645174376017613 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2014 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """Manage L{Schema}s across a set of L{Store} shards. The L{Sharding} class can be used to perform schema operations (create, upgrade, delete) against a set of L{Store}s. For example, let's say we have two L{Schema}s and two L{Store}s we want to apply them to. We can setup our L{Sharding} instance like this: >>> schema1 = Schema(...) >>> schema1 = Schema(...) >>> store1 = Store(...) >>> store2 = Store(...) >>> sharding = Sharding() >>> sharding.add(store1, schema1) >>> sharding.add(store2, schema2) And then perform schema maintenance operations across all shards: >>> sharding.upgrade() Patches will be applied "horizontally", meaning that the stores will always be at the same patch level. See L{storm.schema.patch.PatchSet}. """ from storm.schema.schema import SchemaMissingError, UnappliedPatchesError class PatchLevelMismatchError(Exception): """Raised when stores don't have all the same patch level.""" class Sharding: """Manage L{Shema}s over a collection of L{Store}s.""" def __init__(self): self._stores = [] # Sequence of Stores with their Schemas def add(self, store, schema): """Add a new L{Store} shard. @param store: The L{Store} to add. @param schema: The L{Schema} the L{Store} is meant to have. """ self._stores.append((store, schema)) def create(self): """Create all schemas from scratch across all L{Store} shards.""" for store, schema in self._stores: schema.create(store) def drop(self): """Drop all tables across all L{Store} shards.""" for store, schema in self._stores: schema.drop(store) def delete(self): """Delete all table rows across all L{Store} shards.""" for store, schema in self._stores: schema.delete(store) def upgrade(self): """Perform a schema upgrade. Pristine L{Store}s without any schema applied yet, will be initialized using L{Schema.create}. All other L{Store}s will be upgraded to the latest version of their L{Schema}s by applying all pending patches. The patching strategy is "horizontal", meaning that patch numbering must be the same across all L{Schema}s, and all L{Store}s must be at the same patch number. For example if the common patch number is N, the upgrade will start by applying patch number N + 1 to all non-pristine stores, following the order in which the stores were added to the L{Sharding}. Then the upgrade will apply all patches with number N + 2, etc. """ stores_to_upgrade = [] unapplied_versions = [] for store, schema in self._stores: try: schema.check(store) except SchemaMissingError: schema.create(store) except UnappliedPatchesError as error: if not unapplied_versions: unapplied_versions = error.unapplied_versions elif unapplied_versions != error.unapplied_versions: raise PatchLevelMismatchError( "Some stores have different patch levels") stores_to_upgrade.append((store, schema)) for version in unapplied_versions: for store, schema in stores_to_upgrade: schema.advance(store, version) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/sqlobject.py0000644000175000017500000006377214645174376016564 0ustar00cjwatsoncjwatson# # Copyright (c) 2006-2010 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """A SQLObject emulation layer for Storm. L{SQLObjectBase} is the central point of compatibility. """ import re import warnings from storm.properties import ( Bytes, Int, Bool, Float, DateTime, Date, TimeDelta) from storm.references import Reference, ReferenceSet from storm.properties import SimpleProperty, PropertyPublisherMeta from storm.variables import UnicodeVariable from storm.exceptions import StormError, NotOneError from storm.info import get_cls_info, ClassAlias from storm.store import AutoReload, Store from storm.base import Storm from storm.expr import ( SQL, SQLRaw, Desc, And, Or, Not, In, Like, AutoTables, LeftJoin, Column, compare_columns) from storm.tz import tzutc from storm import Undef __all__ = [ "SQLObjectBase", "StringCol", "IntCol", "BoolCol", "FloatCol", "DateCol", "UtcDateTimeCol", "IntervalCol", "ForeignKey", "SQLMultipleJoin", "SQLRelatedJoin", "SingleJoin", "DESC", "AND", "OR", "NOT", "IN", "LIKE", "SQLConstant", "CONTAINSSTRING", "SQLObjectMoreThanOneResultError", "SQLObjectNotFound", "SQLObjectResultSet"] DESC, AND, OR, NOT, IN, LIKE, SQLConstant = Desc, And, Or, Not, In, Like, SQL SQLObjectMoreThanOneResultError = NotOneError _IGNORED = object() class SQLObjectNotFound(StormError): pass class SQLObjectStyle: longID = False def idForTable(self, table_name): if self.longID: return self.tableReference(table_name) else: return "id" def pythonClassToAttr(self, class_name): return self._lowerword(class_name) def instanceAttrToIDAttr(self, attr_name): return attr_name + "ID" def pythonAttrToDBColumn(self, attr_name): return self._mixed_to_under(attr_name) def dbColumnToPythonAttr(self, column_name): return self._under_to_mixed(column_name) def pythonClassToDBTable(self, class_name): return class_name[0].lower()+self._mixed_to_under(class_name[1:]) def dbTableToPythonClass(self, table_name): return table_name[0].upper()+self._under_to_mixed(table_name[1:]) def pythonClassToDBTableReference(self, class_name): return self.tableReference(self.pythonClassToDBTable(class_name)) def tableReference(self, table_name): return table_name+"_id" def _mixed_to_under(self, name, _re=re.compile("[A-Z]+")): if name.endswith("ID"): return self._mixed_to_under(name[:-2]+"_id") name = _re.sub(self._mixed_to_under_sub, name) if name.startswith("_"): return name[1:] return name def _mixed_to_under_sub(self, match): m = match.group(0).lower() if len(m) > 1: return "_%s_%s" % (m[:-1], m[-1]) else: return "_%s" % m def _under_to_mixed(self, name, _re=re.compile("_.")): if name.endswith("_id"): return self._under_to_mixed(name[:-3] + "ID") return _re.sub(self._under_to_mixed_sub, name) def _under_to_mixed_sub(self, match): return match.group(0)[1].upper() @staticmethod def _capword(s): return s[0].upper() + s[1:] @staticmethod def _lowerword(s): return s[0].lower() + s[1:] class SQLObjectMeta(PropertyPublisherMeta): @staticmethod def _get_attr(attr, bases, dict): value = dict.get(attr) if value is None: for base in bases: value = getattr(base, attr, None) if value is not None: break return value def __new__(cls, name, bases, dict): if Storm in bases or SQLObjectBase in bases: # Do not parse abstract base classes. return type.__new__(cls, name, bases, dict) style = cls._get_attr("_style", bases, dict) if style is None: dict["_style"] = style = SQLObjectStyle() table_name = cls._get_attr("_table", bases, dict) if table_name is None: table_name = style.pythonClassToDBTable(name) id_name = cls._get_attr("_idName", bases, dict) if id_name is None: id_name = style.idForTable(table_name) # Handle this later to call _parse_orderBy() on the created class. default_order = cls._get_attr("_defaultOrder", bases, dict) dict["__storm_table__"] = table_name attr_to_prop = {} for attr, prop in list(dict.items()): if attr == "__classcell__": # Python >= 3.6 continue attr_to_prop[attr] = attr if isinstance(prop, ForeignKey): db_name = prop.kwargs.get("dbName", attr) local_prop_name = style.instanceAttrToIDAttr(attr) dict[local_prop_name] = local_prop = Int( db_name, allow_none=not prop.kwargs.get("notNull", False), validator=prop.kwargs.get("storm_validator", None)) dict[attr] = Reference(local_prop, "%s." % prop.foreignKey) attr_to_prop[attr] = local_prop_name elif isinstance(prop, PropertyAdapter): db_name = prop.dbName or attr method_name = prop.alternateMethodName if method_name is None and prop.alternateID: method_name = "by" + db_name[0].upper() + db_name[1:] if method_name is not None: def func(cls, key, attr=attr): store = cls._get_store() obj = store.find(cls, getattr(cls, attr) == key).one() if obj is None: raise SQLObjectNotFound return obj func.func_name = method_name dict[method_name] = classmethod(func) elif isinstance(prop, SQLMultipleJoin): # Generate addFoo/removeFoo names. def define_add_remove(dict, prop): capitalised_name = (prop._otherClass[0].capitalize() + prop._otherClass[1:]) def add(self, obj): prop._get_bound_reference_set(self).add(obj) add.__name__ = "add" + capitalised_name dict.setdefault(add.__name__, add) def remove(self, obj): prop._get_bound_reference_set(self).remove(obj) remove.__name__ = "remove" + capitalised_name dict.setdefault(remove.__name__, remove) define_add_remove(dict, prop) id_type = dict.setdefault("_idType", int) id_cls = {int: Int, bytes: Bytes, str: AutoUnicode}[id_type] dict["id"] = id_cls(id_name, primary=True, default=AutoReload) attr_to_prop[id_name] = "id" # Notice that obj is the class since this is the metaclass. obj = super().__new__(cls, name, bases, dict) property_registry = obj._storm_property_registry property_registry.add_property(obj, getattr(obj, "id"), "") # Let's explore this same mechanism to register table names, # so that we can find them to handle prejoinClauseTables. property_registry.add_property(obj, getattr(obj, "id"), "" % table_name) for fake_name, real_name in list(attr_to_prop.items()): prop = getattr(obj, real_name) if fake_name != real_name: property_registry.add_property(obj, prop, fake_name) attr_to_prop[fake_name] = prop obj._attr_to_prop = attr_to_prop if default_order is not None: cls_info = get_cls_info(obj) cls_info.default_order = obj._parse_orderBy(default_order) return obj class DotQ: """A descriptor that mimics the SQLObject 'Table.q' syntax""" def __get__(self, obj, cls=None): return BoundDotQ(cls) class BoundDotQ: def __init__(self, cls): self._cls = cls def __getattr__(self, attr): if attr.startswith("__"): raise AttributeError(attr) elif attr == "id": cls_info = get_cls_info(self._cls) return cls_info.primary_key[0] else: return getattr(self._cls, attr) class SQLObjectBase(Storm, metaclass=SQLObjectMeta): """The root class of all SQLObject-emulating classes in your application. The general strategy for using Storm's SQLObject emulation layer is to create an application-specific subclass of SQLObjectBase (probably named "SQLObject") that provides an implementation of _get_store to return an instance of L{storm.store.Store}. It may even be implemented as returning a global L{Store} instance. Then all database classes should subclass that class. """ q = DotQ() _SO_creating = False def __init__(self, *args, **kwargs): store = self._get_store() store.add(self) try: self._create(None, **kwargs) except: store.remove(self) raise def __storm_loaded__(self): self._init(None) def _init(self, id, *args, **kwargs): pass def _create(self, _id_, **kwargs): self._SO_creating = True self.set(**kwargs) del self._SO_creating self._init(None) def set(self, **kwargs): for attr, value in kwargs.items(): setattr(self, attr, value) def destroySelf(self): Store.of(self).remove(self) @staticmethod def _get_store(): raise NotImplementedError("SQLObjectBase._get_store() " "must be implemented") @classmethod def delete(cls, id): # destroySelf() should be extended to support cascading, so # we'll mimic what SQLObject does here, even if more expensive. obj = cls.get(id) obj.destroySelf() @classmethod def get(cls, id): id = cls._idType(id) store = cls._get_store() obj = store.get(cls, id) if obj is None: raise SQLObjectNotFound("Object not found") return obj @classmethod def _parse_orderBy(cls, orderBy): result = [] if not isinstance(orderBy, (tuple, list)): orderBy = (orderBy,) for item in orderBy: if isinstance(item, str): desc = item.startswith("-") if desc: item = item[1:] item = cls._attr_to_prop.get(item, item) if desc: item = Desc(item) result.append(item) return tuple(result) @classmethod def select(cls, *args, **kwargs): return SQLObjectResultSet(cls, *args, **kwargs) @classmethod def selectBy(cls, orderBy=None, **kwargs): return SQLObjectResultSet(cls, orderBy=orderBy, by=kwargs) @classmethod def selectOne(cls, *args, **kwargs): return SQLObjectResultSet(cls, *args, **kwargs)._one() @classmethod def selectOneBy(cls, **kwargs): return SQLObjectResultSet(cls, by=kwargs)._one() @classmethod def selectFirst(cls, *args, **kwargs): return SQLObjectResultSet(cls, *args, **kwargs)._first() @classmethod def selectFirstBy(cls, orderBy=None, **kwargs): result = SQLObjectResultSet(cls, orderBy=orderBy, by=kwargs) return result._first() def syncUpdate(self): self._get_store().flush() def sync(self): store = self._get_store() store.flush() store.autoreload(self) class SQLObjectResultSet: """SQLObject-equivalent of the ResultSet class in Storm. Storm handles joins in the Store interface, while SQLObject does that in the result one. To offer support for prejoins, we can't simply wrap our ResultSet instance, and instead have to postpone the actual find until the very last moment. """ def __init__(self, cls, clause=None, clauseTables=None, orderBy=None, limit=None, distinct=None, prejoins=None, prejoinClauseTables=None, selectAlso=None, by={}, prepared_result_set=None, slice=None): self._cls = cls self._clause = clause self._clauseTables = clauseTables self._orderBy = orderBy self._limit = limit self._distinct = distinct self._prejoins = prejoins self._prejoinClauseTables = prejoinClauseTables self._selectAlso = selectAlso # Parameters not mapping SQLObject: self._by = by self._slice = slice self._prepared_result_set = prepared_result_set self._finished_result_set = None def _copy(self, **kwargs): copy = self.__class__(self._cls, **kwargs) for name, value in self.__dict__.items(): if name[1:] not in kwargs and name != "_finished_result_set": setattr(copy, name, value) return copy def _prepare_result_set(self): store = self._cls._get_store() args = [] if self._clause: args.append(self._clause) for key, value in self._by.items(): args.append(getattr(self._cls, key) == value) tables = [] if self._clauseTables is not None: tables.extend(self._clauseTables) if not (self._prejoins or self._prejoinClauseTables): find_spec = self._cls else: find_spec = [self._cls] if self._prejoins: already_prejoined = {} last_prejoin = 0 join = self._cls for prejoin_path in self._prejoins: local_cls = self._cls path = () for prejoin_attr in prejoin_path.split("."): path += (prejoin_attr,) # If we've already prejoined this column, we're done. if path in already_prejoined: local_cls = already_prejoined[path] continue # Otherwise, join the table relation = getattr(local_cls, prejoin_attr)._relation last_prejoin += 1 remote_cls = ClassAlias(relation.remote_cls, '_prejoin%d' % last_prejoin) join_expr = join_aliased_relation( local_cls, remote_cls, relation) join = LeftJoin(join, remote_cls, join_expr) find_spec.append(remote_cls) already_prejoined[path] = remote_cls local_cls = remote_cls if join is not self._cls: tables.append(join) if self._prejoinClauseTables: property_registry = self._cls._storm_property_registry for table in self._prejoinClauseTables: cls = property_registry.get("
" % table).cls find_spec.append(cls) find_spec = tuple(find_spec) if tables: # If we are adding extra tables, make sure the main table # is included. tables.insert(0, self._cls.__storm_table__) # Inject an AutoTables expression with a dummy true value to # be ANDed in the WHERE clause, so that we can introduce our # tables into the dynamic table handling of Storm without # disrupting anything else. args.append(AutoTables(SQL("1=1"), tables)) if self._selectAlso is not None: if type(find_spec) is not tuple: find_spec = (find_spec, SQL(self._selectAlso)) else: find_spec += (SQL(self._selectAlso),) return store.find(find_spec, *args) def _finish_result_set(self): if self._prepared_result_set is not None: result = self._prepared_result_set else: result = self._prepare_result_set() if self._orderBy is not None: result.order_by(*self._cls._parse_orderBy(self._orderBy)) if self._limit is not None or self._distinct is not None: result.config(limit=self._limit, distinct=self._distinct) if self._slice is not None: result = result[self._slice] return result @property def _result_set(self): if self._finished_result_set is None: self._finished_result_set = self._finish_result_set() return self._finished_result_set def _without_prejoins(self, always_copy=False): if always_copy or self._prejoins or self._prejoinClauseTables: return self._copy(prejoins=None, prejoinClauseTables=None) else: return self def _one(self): """Internal API for the base class.""" return detuplelize(self._result_set.one()) def _first(self): """Internal API for the base class.""" return detuplelize(self._result_set.first()) def __iter__(self): for item in self._result_set: yield detuplelize(item) def __getitem__(self, index): if isinstance(index, slice): if index.start and index.start < 0 or ( index.stop and index.stop < 0): L = list(self) if len(L) > 100: warnings.warn('Negative indices when slicing are slow: ' 'fetched %d rows.' % (len(L),)) start, stop, step = index.indices(len(L)) assert step == 1, "slice step must be 1" index = slice(start, stop) return self._copy(slice=index) else: if index < 0: L = list(self) if len(L) > 100: warnings.warn('Negative indices are slow: ' 'fetched %d rows.' % (len(L),)) return detuplelize(L[index]) return detuplelize(self._result_set[index]) def __contains__(self, item): result_set = self._without_prejoins()._result_set return item in result_set def __bool__(self): """Return C{True} if this result set contains any results. @note: This method is provided for compatibility with SQL Object. For new code, prefer L{is_empty}. It's compatible with L{ResultSet} which doesn't have a C{__bool__} implementation. """ return not self.is_empty() def is_empty(self): """Return C{True} if this result set doesn't contain any results.""" result_set = self._without_prejoins()._result_set return result_set.is_empty() def count(self): result_set = self._without_prejoins()._result_set return result_set.count() def orderBy(self, orderBy): return self._copy(orderBy=orderBy) def limit(self, limit): return self._copy(limit=limit) def distinct(self): return self._copy(distinct=True, orderBy=None) def union(self, otherSelect, unionAll=False, orderBy=()): result1 = self._without_prejoins(True)._result_set.order_by() result2 = otherSelect._without_prejoins(True)._result_set.order_by() result_set = result1.union(result2, all=unionAll) return self._copy( prepared_result_set=result_set, distinct=False, orderBy=orderBy) def except_(self, otherSelect, exceptAll=False, orderBy=()): result1 = self._without_prejoins(True)._result_set.order_by() result2 = otherSelect._without_prejoins(True)._result_set.order_by() result_set = result1.difference(result2, all=exceptAll) return self._copy( prepared_result_set=result_set, distinct=False, orderBy=orderBy) def intersect(self, otherSelect, intersectAll=False, orderBy=()): result1 = self._without_prejoins(True)._result_set.order_by() result2 = otherSelect._without_prejoins(True)._result_set.order_by() result_set = result1.intersection(result2, all=intersectAll) return self._copy( prepared_result_set=result_set, distinct=False, orderBy=orderBy) def prejoin(self, prejoins): return self._copy(prejoins=prejoins) def prejoinClauseTables(self, prejoinClauseTables): return self._copy(prejoinClauseTables=prejoinClauseTables) def sum(self, attribute): if isinstance(attribute, str): attribute = SQL(attribute) result_set = self._without_prejoins()._result_set return result_set.sum(attribute) def detuplelize(item): """If item is a tuple, return first element, otherwise the item itself. The tuple syntax is used to implement prejoins, so we have to hide from the user the fact that more than a single object are being selected at once. """ if type(item) is tuple: return item[0] return item def join_aliased_relation(local_cls, remote_cls, relation): """Build a join expression between local_cls and remote_cls. This is equivalent to relation.get_where_for_join(), except that the join expression is changed to be relative to the given local_cls and remote_cls (which may be aliases). The result is the join expression. """ remote_key = tuple(Column(column.name, remote_cls) for column in relation.remote_key) local_key = tuple(Column(column.name, local_cls) for column in relation.local_key) return compare_columns(local_key, remote_key) class PropertyAdapter: _kwargs = {} def __init__(self, dbName=None, notNull=False, default=Undef, alternateID=None, unique=_IGNORED, name=_IGNORED, alternateMethodName=None, length=_IGNORED, immutable=None, storm_validator=None): if default is None and notNull: raise RuntimeError("Can't use default=None and notNull=True") self.dbName = dbName self.alternateID = alternateID self.alternateMethodName = alternateMethodName # XXX Implement handler for: # # - immutable (causes setting the attribute to fail) # # XXX Implement tests for ignored parameters: # # - unique (for tablebuilder) # - length (for tablebuilder for StringCol) # - name (for _columns stuff) if callable(default): default_factory = default default = Undef else: default_factory = Undef super().__init__(dbName, allow_none=not notNull, default_factory=default_factory, default=default, validator=storm_validator, **self._kwargs) # DEPRECATED: On Python 2, this used to be a more relaxed version of # UnicodeVariable that accepted both bytes and text. On Python 3, it # accepts only text and is thus the same as UnicodeVariable. It exists only # for compatibility. AutoUnicodeVariable = UnicodeVariable # DEPRECATED: Use storm.properties.Unicode instead. class AutoUnicode(SimpleProperty): variable_class = AutoUnicodeVariable class StringCol(PropertyAdapter, AutoUnicode): pass class IntCol(PropertyAdapter, Int): pass class BoolCol(PropertyAdapter, Bool): pass class FloatCol(PropertyAdapter, Float): pass class UtcDateTimeCol(PropertyAdapter, DateTime): _kwargs = {"tzinfo": tzutc()} class DateCol(PropertyAdapter, Date): pass class IntervalCol(PropertyAdapter, TimeDelta): pass class ForeignKey: def __init__(self, foreignKey, **kwargs): self.foreignKey = foreignKey self.kwargs = kwargs class SQLMultipleJoin(ReferenceSet): def __init__(self, otherClass=None, joinColumn=None, intermediateTable=None, otherColumn=None, orderBy=None, prejoins=None): if intermediateTable: args = ("", "%s.%s" % (intermediateTable, joinColumn), "%s.%s" % (intermediateTable, otherColumn), "%s." % otherClass) else: args = ("", "%s.%s" % (otherClass, joinColumn)) ReferenceSet.__init__(self, *args) self._orderBy = orderBy self._otherClass = otherClass self._prejoins = prejoins def __get__(self, obj, cls=None): if obj is None: return self bound_reference_set = ReferenceSet.__get__(self, obj) target_cls = bound_reference_set._target_cls where_clause = bound_reference_set._get_where_clause() return SQLObjectResultSet(target_cls, where_clause, orderBy=self._orderBy, prejoins=self._prejoins) def _get_bound_reference_set(self, obj): assert obj is not None return ReferenceSet.__get__(self, obj) SQLRelatedJoin = SQLMultipleJoin class SingleJoin(Reference): def __init__(self, otherClass, joinColumn, prejoins=_IGNORED): super().__init__( "", "%s.%s" % (otherClass, joinColumn), on_remote=True) class CONTAINSSTRING(Like): def __init__(self, expr, string): string = string.replace("!", "!!") \ .replace("_", "!_") \ .replace("%", "!%") Like.__init__(self, expr, "%"+string+"%", SQLRaw("'!'")) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/store.py0000644000175000017500000021471314645174376015723 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """The Store interface to a database. This module contains the highest-level ORM interface in Storm. """ from copy import copy from weakref import WeakValueDictionary from operator import itemgetter from storm.info import get_cls_info, get_obj_info, set_obj_info from storm.variables import Variable, LazyValue from storm.expr import ( Expr, Select, Insert, Update, Delete, Column, Count, Max, Min, Avg, Sum, Eq, And, Asc, Desc, compile_python, compare_columns, SQLRaw, Union, Except, Intersect, Alias, SetExpr) from storm.exceptions import ( WrongStoreError, NotFlushedError, OrderLoopError, UnorderedError, NotOneError, FeatureError, CompileError, LostObjectError, ClassInfoError) from storm.properties import PropertyColumn from storm import Undef from storm.cache import Cache from storm.event import EventSystem __all__ = [ "AutoReload", "block_access", "EmptyResultSet", "Store", ] PENDING_ADD = 1 PENDING_REMOVE = 2 class Store: """The Storm Store. This is the highest-level interface to a database. It manages transactions with L{commit} and L{rollback}, caching, high-level querying with L{find}, and more. Note that Store objects are not threadsafe. You should create one Store per thread in your application, passing them the same backend L{Database} object. """ _result_set_factory = None def __init__(self, database, cache=None): """ @param database: The L{storm.database.Database} instance to use. @param cache: The cache to use. Defaults to a L{Cache} instance. """ self._database = database self._event = EventSystem(self) self._connection = database.connect(self._event) self._alive = WeakValueDictionary() self._dirty = {} self._order = {} # (info, info) = count if cache is None: self._cache = Cache() else: self._cache = cache self._implicit_flush_block_count = 0 self._sequence = 0 # Advisory ordering. def get_database(self): """Return this Store's Database object.""" return self._database @staticmethod def of(obj): """Get the Store that the object is associated with. If the given object has not yet been associated with a store, return None. """ try: return get_obj_info(obj).get("store") except (AttributeError, ClassInfoError): return None def execute(self, statement, params=None, noresult=False): """Execute a basic query. This is just like L{storm.database.Connection.execute}, except that a flush is performed first. """ if self._implicit_flush_block_count == 0: self.flush() return self._connection.execute(statement, params, noresult) def close(self): """Close the connection.""" self._connection.close() def begin(self, xid): """Start a new two-phase transaction. @param xid: A L{Xid } instance holding identification data for the new transaction. """ self._connection.begin(xid) def prepare(self): """Prepare a two-phase transaction for the final commit. @note: It must be called inside a two-phase transaction started with L{begin}. """ self._connection.prepare() def commit(self): """Commit all changes to the database. This invalidates the cache, so all live objects will have data reloaded next time they are touched. """ self.flush() self.invalidate() self._connection.commit() def rollback(self): """Roll back all outstanding changes, reverting to database state.""" for obj_info in self._dirty: pending = obj_info.pop("pending", None) if pending is PENDING_ADD: # Object never got in the cache, so being "in the store" # has no actual meaning for it. del obj_info["store"] elif pending is PENDING_REMOVE: # Object never got removed, so it's still in the cache, # and thus should continue to resolve from now on. self._enable_lazy_resolving(obj_info) self._dirty.clear() self.invalidate() self._connection.rollback() def get(self, cls, key): """Get object of type cls with the given primary key from the database. If the object is alive the database won't be touched. @param cls: Class of the object to be retrieved. @param key: Primary key of object. May be a tuple for composed keys. @return: The object found with the given primary key, or None if no object is found. """ if self._implicit_flush_block_count == 0: self.flush() if type(key) != tuple: key = (key,) cls_info = get_cls_info(cls) assert len(key) == len(cls_info.primary_key) primary_vars = [] for column, variable in zip(cls_info.primary_key, key): if not isinstance(variable, Variable): variable = column.variable_factory(value=variable) primary_vars.append(variable) primary_values = tuple(var.get(to_db=True) for var in primary_vars) obj_info = self._alive.get((cls_info.cls, primary_values)) if obj_info is not None and not obj_info.get("invalidated"): return self._get_object(obj_info) where = compare_columns(cls_info.primary_key, primary_vars) select = Select(cls_info.columns, where, default_tables=cls_info.table, limit=1) result = self._connection.execute(select) values = result.get_one() if values is None: return None return self._load_object(cls_info, result, values) def find(self, cls_spec, *args, **kwargs): """Perform a query. Some examples:: store.find(Person, Person.name == u"Joe") --> all Persons named Joe store.find(Person, name=u"Joe") --> same store.find((Company, Person), Person.company_id == Company.id) --> iterator of tuples of Company and Person instances which are associated via the company_id -> Company relation. @param cls_spec: The class or tuple of classes whose associated tables will be queried. @param args: Instances of L{Expr}. @param kwargs: Mapping of simple column names to values or expressions to query for. @return: A L{ResultSet} of instances C{cls_spec}. If C{cls_spec} was a tuple, then an iterator of tuples of such instances. """ if self._implicit_flush_block_count == 0: self.flush() find_spec = FindSpec(cls_spec) where = get_where_for_args(args, kwargs, find_spec.default_cls) return self._result_set_factory(self, find_spec, where) def using(self, *tables): """Specify tables to use explicitly. The L{find} method generally does a good job at figuring out the tables to query by itself, but in some cases it's useful to specify them explicitly. This is most often necessary when an explicit SQL join is required. An example follows:: join = LeftJoin(Person, Person.id == Company.person_id) print(list(store.using(Company, join).find((Company, Person)))) The previous code snippet will produce an SQL statement somewhat similar to this, depending on your backend:: SELECT company.id, employee.company_id, employee.id FROM company LEFT JOIN employee ON employee.company_id = company.id; @return: A L{TableSet}, which has a C{find} method similar to L{Store.find}. """ return self._table_set(self, tables) def add(self, obj): """Add the given object to the store. The object will be inserted into the database if it has not yet been added. The C{added} event will be fired on the object info's event system. """ self._event.emit("register-transaction") obj_info = get_obj_info(obj) store = obj_info.get("store") if store is not None and store is not self: raise WrongStoreError("%s is part of another store" % repr(obj)) pending = obj_info.get("pending") if pending is PENDING_ADD: pass elif pending is PENDING_REMOVE: del obj_info["pending"] self._enable_lazy_resolving(obj_info) # obj_info.event.emit("added") elif store is None: obj_info["store"] = self obj_info["pending"] = PENDING_ADD self._set_dirty(obj_info) self._enable_lazy_resolving(obj_info) obj_info.event.emit("added") return obj def remove(self, obj): """Remove the given object from the store. The associated row will be deleted from the database. """ self._event.emit("register-transaction") obj_info = get_obj_info(obj) if obj_info.get("store") is not self: raise WrongStoreError("%s is not in this store" % repr(obj)) pending = obj_info.get("pending") if pending is PENDING_REMOVE: pass elif pending is PENDING_ADD: del obj_info["store"] del obj_info["pending"] self._set_clean(obj_info) self._disable_lazy_resolving(obj_info) obj_info.event.emit("removed") else: obj_info["pending"] = PENDING_REMOVE self._set_dirty(obj_info) self._disable_lazy_resolving(obj_info) obj_info.event.emit("removed") def reload(self, obj): """Reload the given object. The object will immediately have all of its data reset from the database. Any pending changes will be thrown away. """ obj_info = get_obj_info(obj) cls_info = obj_info.cls_info if obj_info.get("store") is not self: raise WrongStoreError("%s is not in this store" % repr(obj)) if "primary_vars" not in obj_info: raise NotFlushedError("Can't reload an object if it was " "never flushed") where = compare_columns(cls_info.primary_key, obj_info["primary_vars"]) select = Select(cls_info.columns, where, default_tables=cls_info.table, limit=1) result = self._connection.execute(select) values = result.get_one() self._set_values(obj_info, cls_info.columns, result, values, replace_unknown_lazy=True) self._set_clean(obj_info) def autoreload(self, obj=None): """Set an object or all objects to be reloaded automatically on access. When a database-backed attribute of one of the objects is accessed, the object will be reloaded entirely from the database. @param obj: If passed, only mark the given object for autoreload. Otherwise, all cached objects will be marked for autoreload. """ self._mark_autoreload(obj, False) def invalidate(self, obj=None): """Set an object or all objects to be invalidated. This prevents Storm from returning the cached object without first verifying that the object is still available in the database. This should almost never be called by application code; it is only necessary if it is possible that an object has disappeared through some mechanism that Storm was unable to detect, like direct SQL statements within the current transaction that bypassed the ORM layer. The Store automatically invalidates all cached objects on transaction boundaries. """ if obj is None: self._cache.clear() else: self._cache.remove(get_obj_info(obj)) self._mark_autoreload(obj, True) def reset(self): """Reset this store, causing all future queries to return new objects. Beware this method: it breaks the assumption that there will never be two objects in memory which represent the same database object. This is useful if you've got in-memory changes to an object that you want to "throw out"; next time they're fetched the objects will be recreated, so in-memory modifications will not be in effect for future queries. """ for obj_info in self._iter_alive(): if "store" in obj_info: del obj_info["store"] self._alive.clear() self._dirty.clear() self._cache.clear() # The following line is untested, but then, I can't really find a way # to test it without whitebox. self._order.clear() def _mark_autoreload(self, obj=None, invalidate=False): if obj is None: obj_infos = self._iter_alive() else: obj_infos = (get_obj_info(obj),) for obj_info in obj_infos: cls_info = obj_info.cls_info for column in cls_info.columns: if id(column) not in cls_info.primary_key_idx: obj_info.variables[column].set(AutoReload) if invalidate: # Marking an object with 'invalidated' means that we're # not sure if the object is actually in the database # anymore, so before the object is returned from the cache # (e.g. by a get()), the database should be queried to see # if the object's still there. obj_info["invalidated"] = True # We want to make sure we've marked all objects as invalidated and set # up their autoreloads before calling the invalidated hook on *any* of # them, because an invalidated hook might use other objects and we want # to prevent invalidation ordering issues. if invalidate: for obj_info in obj_infos: self._run_hook(obj_info, "__storm_invalidated__") def add_flush_order(self, before, after): """Explicitly specify the order of flushing two objects. When the next database flush occurs, the order of data modification statements will be ensured. @param before: The object to flush first. @param after: The object to flush after C{before}. """ pair = (get_obj_info(before), get_obj_info(after)) try: self._order[pair] += 1 except KeyError: self._order[pair] = 1 def remove_flush_order(self, before, after): """Cancel an explicit flush order specified with L{add_flush_order}. @param before: The C{before} object previously specified in a call to L{add_flush_order}. @param after: The C{after} object previously specified in a call to L{add_flush_order}. """ pair = (get_obj_info(before), get_obj_info(after)) self._order[pair] -= 1 def flush(self): """Flush all dirty objects in cache to database. This method will first call the __storm_pre_flush__ hook of all dirty objects. If more objects become dirty as a result of executing code in the hooks, the hook is also called on them. The hook is only called once for each object. It will then flush each dirty object to the database, that is, execute the SQL code to insert/delete/update them. After each object is flushed, the hook __storm_flushed__ is called on it, and if changes are made to the object it will get back to the dirty list, and be flushed again. Note that Storm will flush objects for you automatically, so you'll only need to call this method explicitly in very rare cases where normal flushing times are insufficient, such as when you want to make sure a database trigger gets run at a particular time. """ self._event.emit("flush") # The _dirty list may change under us while we're running # the flush hooks, so we cannot just simply loop over it # once. To prevent infinite looping we keep track of which # objects we've called the hook for using a `flushing` dict. flushing = {} while self._dirty: (obj_info, obj) = self._dirty.popitem() if obj_info not in flushing: flushing[obj_info] = obj self._run_hook(obj_info, "__storm_pre_flush__") self._dirty = flushing predecessors = {} for (before_info, after_info), n in self._order.items(): if n > 0: before_set = predecessors.get(after_info) if before_set is None: predecessors[after_info] = {before_info} else: before_set.add(before_info) key_func = itemgetter("sequence") # The external loop is important because items can get into the dirty # state while we're flushing objects, ... while self._dirty: # ... but we don't have to resort everytime an object is flushed, # so we have an internal loop too. If no objects become dirty # during flush, this will clean self._dirty and the external loop # will exit too. sorted_dirty = sorted(self._dirty, key=key_func) while sorted_dirty: for i, obj_info in enumerate(sorted_dirty): for before_info in predecessors.get(obj_info, ()): if before_info in self._dirty: break # A predecessor is still dirty. else: break # Found an item without dirty predecessors. else: raise OrderLoopError("Can't flush due to ordering loop") del sorted_dirty[i] self._dirty.pop(obj_info, None) self._flush_one(obj_info) self._order.clear() # That's not stricly necessary, but prevents getting into bigints. self._sequence = 0 def _flush_one(self, obj_info): cls_info = obj_info.cls_info pending = obj_info.pop("pending", None) if pending is PENDING_REMOVE: expr = Delete(compare_columns(cls_info.primary_key, obj_info["primary_vars"]), cls_info.table) self._connection.execute(expr, noresult=True) # We're sure the cache is valid at this point. obj_info.pop("invalidated", None) self._disable_change_notification(obj_info) self._remove_from_alive(obj_info) del obj_info["store"] elif pending is PENDING_ADD: # Give a chance to the backend to process primary variables. self._connection.preset_primary_key(cls_info.primary_key, obj_info.primary_vars) changes = self._get_changes_map(obj_info, True) expr = Insert(changes, cls_info.table, primary_columns=cls_info.primary_key, primary_variables=obj_info.primary_vars) result = self._connection.execute(expr) # We're sure the cache is valid at this point. We just added # the object. obj_info.pop("invalidated", None) self._fill_missing_values(obj_info, obj_info.primary_vars, result) self._enable_change_notification(obj_info) self._add_to_alive(obj_info) else: cached_primary_vars = obj_info["primary_vars"] changes = self._get_changes_map(obj_info) if changes: expr = Update(changes, compare_columns(cls_info.primary_key, cached_primary_vars), cls_info.table) self._connection.execute(expr, noresult=True) self._fill_missing_values(obj_info, obj_info.primary_vars) self._add_to_alive(obj_info) self._run_hook(obj_info, "__storm_flushed__") obj_info.event.emit("flushed") def block_implicit_flushes(self): """Block implicit flushes from operations like execute().""" self._implicit_flush_block_count += 1 def unblock_implicit_flushes(self): """Unblock implicit flushes from operations like execute().""" assert self._implicit_flush_block_count > 0 self._implicit_flush_block_count -= 1 def block_access(self): """Block access to the underlying database connection.""" self._connection.block_access() def unblock_access(self): """Unblock access to the underlying database connection.""" self._connection.unblock_access() def _get_changes_map(self, obj_info, adding=False): """Return a {column: variable} dictionary suitable for inserts/updates. @param obj_info: ObjectInfo to inspect for changes. @param adding: If true, any defined variables will be considered a change and included in the returned map. """ cls_info = obj_info.cls_info changes = {} select_variables = [] for column in cls_info.columns: variable = obj_info.variables[column] if adding or variable.has_changed(): if variable.is_defined(): changes[column] = variable else: lazy_value = variable.get_lazy() if isinstance(lazy_value, Expr): if id(column) in cls_info.primary_key_idx: select_variables.append(variable) # See below. changes[column] = variable else: changes[column] = lazy_value # If we have any expressions in the primary variables, we # have to resolve them now so that we have the identity of # the inserted object available later. if select_variables: resolve_expr = Select([variable.get_lazy() for variable in select_variables]) result = self._connection.execute(resolve_expr) for variable, value in zip(select_variables, result.get_one()): result.set_variable(variable, value) return changes def _fill_missing_values(self, obj_info, primary_vars, result=None): """Fill missing values in variables of the given obj_info. This method will verify which values are unset in obj_info, and set them to L{AutoReload}, or if it's part of the primary key, query the database for the actual values. @param obj_info: ObjectInfo to have its values filled. @param primary_vars: Variables composing the primary key with up-to-date values (cached variables may be out-of-date when this method is called). @param result: If some value in the set of primary variables isn't defined, it must be retrieved from the database using database-dependent logic, which is provided by the backend in the result of the query which inserted the object. """ cls_info = obj_info.cls_info cached_primary_vars = obj_info.get("primary_vars") primary_key_idx = cls_info.primary_key_idx missing_columns = [] for column in cls_info.columns: variable = obj_info.variables[column] if not variable.is_defined(): idx = primary_key_idx.get(id(column)) if idx is not None: if (cached_primary_vars is not None and variable.get_lazy() is AutoReload): # For auto-reloading a primary key, just # get the value out of the cache. variable.set(cached_primary_vars[idx].get()) else: missing_columns.append(column) else: # Any lazy values are overwritten here. This value # must have just been sent to the database, so this # was already set there. variable.set(AutoReload) else: variable.checkpoint() if missing_columns: where = result.get_insert_identity(cls_info.primary_key, primary_vars) result = self._connection.execute(Select(missing_columns, where)) self._set_values(obj_info, missing_columns, result, result.get_one()) def _validate_alive(self, obj_info): """Perform cache validation for the given obj_info.""" where = compare_columns(obj_info.cls_info.primary_key, obj_info["primary_vars"]) result = self._connection.execute(Select(SQLRaw("1"), where)) if not result.get_one(): raise LostObjectError("Object is not in the database anymore") obj_info.pop("invalidated", None) def _load_object(self, cls_info, result, values): # _set_values() need the cls_info columns for the class of the # actual object, not from a possible wrapper (e.g. an alias). cls = cls_info.cls cls_info = get_cls_info(cls) # Prepare cache key. primary_vars = [] columns = cls_info.columns for value in values: if value is not None: break else: # We've got a row full of NULLs, so consider that the object # wasn't found. This is useful for joins, where non-existent # rows are represented like that. return None for i in cls_info.primary_key_pos: value = values[i] variable = columns[i].variable_factory(value=value, from_db=True) primary_vars.append(variable) # Lookup cache. primary_values = tuple(var.get(to_db=True) for var in primary_vars) obj_info = self._alive.get((cls, primary_values)) if obj_info is not None: # Found object in cache, and it must be valid since the # primary key was extracted from result values. obj_info.pop("invalidated", None) # Take that chance and fill up any undefined variables # with fresh data, since we got it anyway. self._set_values(obj_info, cls_info.columns, result, values, keep_defined=True) # We're not sure if the obj is still in memory at this # point. This will rebuild it if needed. obj = self._get_object(obj_info) else: # Nothing found in the cache. Build everything from the ground. obj = cls.__new__(cls) obj_info = get_obj_info(obj) obj_info["store"] = self self._set_values(obj_info, cls_info.columns, result, values, replace_unknown_lazy=True) self._add_to_alive(obj_info) self._enable_change_notification(obj_info) self._enable_lazy_resolving(obj_info) self._run_hook(obj_info, "__storm_loaded__") return obj def _get_object(self, obj_info): """Return object for obj_info, rebuilding it if it's dead.""" obj = obj_info.get_obj() if obj is None: cls = obj_info.cls_info.cls obj = cls.__new__(cls) obj_info.set_obj(obj) set_obj_info(obj, obj_info) # Re-enable change notification, as it may have been implicitely # disabled when the previous object has been collected self._enable_change_notification(obj_info) self._run_hook(obj_info, "__storm_loaded__") # Renew the cache. self._cache.add(obj_info) return obj @staticmethod def _run_hook(obj_info, hook_name): func = getattr(obj_info.get_obj(), hook_name, None) if func is not None: func() def _set_values(self, obj_info, columns, result, values, keep_defined=False, replace_unknown_lazy=False): if values is None: raise LostObjectError("Can't obtain values from the database " "(object got removed?)") obj_info.pop("invalidated", None) for column, value in zip(columns, values): variable = obj_info.variables[column] lazy_value = variable.get_lazy() is_unknown_lazy = not (lazy_value is None or lazy_value is AutoReload) if keep_defined: if variable.is_defined() or is_unknown_lazy: continue elif is_unknown_lazy and not replace_unknown_lazy: # This should *never* happen, because whenever we get # to this point it should be after a flush() which # updated the database with lazy values and then replaced # them by AutoReload. Letting this go through means # we're blindly discarding an unknown lazy value and # replacing it by the value from the database. raise RuntimeError("Unexpected situation. " "Please contact the developers.") if value is None: variable.set(value, from_db=True) else: result.set_variable(variable, value) variable.checkpoint() def _is_dirty(self, obj_info): return obj_info in self._dirty def _set_dirty(self, obj_info): if obj_info not in self._dirty: self._dirty[obj_info] = obj_info.get_obj() obj_info["sequence"] = self._sequence = self._sequence + 1 def _set_clean(self, obj_info): self._dirty.pop(obj_info, None) def _iter_dirty(self): return self._dirty def _add_to_alive(self, obj_info): """Add an object to the set of known in-memory objects. When an object is added to the set of known in-memory objects, the key is built from a copy of the current variables that are part of the primary key. This means that, when an object is retrieved from the database, these values may be used to get the cached object which is already in memory, even if it requested the primary key value to be changed. For that reason, when changes to the primary key are flushed, the alive object key should also be updated to reflect these changes. In addition to tracking objects alive in memory, we have a strong reference cache which keeps a fixed number of last-used objects in-memory, to prevent further database access for recently fetched objects. """ cls_info = obj_info.cls_info old_primary_vars = obj_info.get("primary_vars") if old_primary_vars is not None: old_primary_values = tuple( var.get(to_db=True) for var in old_primary_vars) self._alive.pop((cls_info.cls, old_primary_values), None) new_primary_vars = tuple(variable.copy() for variable in obj_info.primary_vars) new_primary_values = tuple( var.get(to_db=True) for var in new_primary_vars) self._alive[cls_info.cls, new_primary_values] = obj_info obj_info["primary_vars"] = new_primary_vars self._cache.add(obj_info) def _remove_from_alive(self, obj_info): """Remove an object from the cache. This method is only called for objects that were explicitly deleted and flushed. Objects that are unused will get removed from the cache dictionary automatically by their weakref callbacks. """ primary_vars = obj_info.get("primary_vars") if primary_vars is not None: self._cache.remove(obj_info) primary_values = tuple(var.get(to_db=True) for var in primary_vars) del self._alive[obj_info.cls_info.cls, primary_values] del obj_info["primary_vars"] def _iter_alive(self): return list(self._alive.values()) def _enable_change_notification(self, obj_info): obj_info.event.emit("start-tracking-changes", self._event) obj_info.event.hook("changed", self._variable_changed) def _disable_change_notification(self, obj_info): obj_info.event.unhook("changed", self._variable_changed) obj_info.event.emit("stop-tracking-changes", self._event) def _variable_changed(self, obj_info, variable, old_value, new_value, fromdb): # The fromdb check makes sure that values coming from the # database don't mark the object as dirty again. # XXX The fromdb check is untested. How to test it? if not fromdb: if new_value is not Undef and new_value is not AutoReload: if obj_info.get("invalidated"): # This might be a previously alive object being # updated. Let's validate it now to improve debugging. # This will raise LostObjectError if the object is gone. self._validate_alive(obj_info) self._set_dirty(obj_info) def _enable_lazy_resolving(self, obj_info): obj_info.event.hook("resolve-lazy-value", self._resolve_lazy_value) def _disable_lazy_resolving(self, obj_info): obj_info.event.unhook("resolve-lazy-value", self._resolve_lazy_value) def _resolve_lazy_value(self, obj_info, variable, lazy_value): """Resolve a variable set to a lazy value when it's touched. This method is hooked into the obj_info to resolve variables set to lazy values when they're accessed. It will first flush the store, and then set all variables set to L{AutoReload} to their database values. """ if lazy_value is not AutoReload and not isinstance(lazy_value, Expr): # It's not something we handle. return # XXX This will do it for now, but it should really flush # just this single object and ones that it depends on. # _flush_one() doesn't consider dependencies, so it may # not be used directly. Maybe allow flush(obj)? if self._implicit_flush_block_count == 0: self.flush() autoreload_columns = [] for column in obj_info.cls_info.columns: if obj_info.variables[column].get_lazy() is AutoReload: autoreload_columns.append(column) if autoreload_columns: where = compare_columns(obj_info.cls_info.primary_key, obj_info["primary_vars"]) result = self._connection.execute( Select(autoreload_columns, where)) self._set_values(obj_info, autoreload_columns, result, result.get_one()) class ResultSet: """The representation of the results of a query. Note that having an instance of this class does not indicate that a database query has necessarily been made. Database queries are put off until absolutely necessary. Generally these should not be constructed directly, but instead retrieved from calls to L{Store.find}. """ def __init__(self, store, find_spec, where=Undef, tables=Undef, select=Undef): self._store = store self._find_spec = find_spec self._where = where self._tables = tables self._select = select self._order_by = find_spec.default_order self._offset = Undef self._limit = Undef self._distinct = False self._group_by = Undef self._having = Undef def copy(self): """Return a copy of this ResultSet object, with the same configuration. """ result_set = object.__new__(self.__class__) result_set.__dict__.update(self.__dict__) if self._select is not Undef: # This expression must be copied because we may have to change it # in-place inside _get_select(). result_set._select = copy(self._select) return result_set def config(self, distinct=None, offset=None, limit=None): """Configure this result object in-place. All parameters are optional. @param distinct: If True, enables usage of the DISTINCT keyword in the query. If a tuple or list of columns, inserts a DISTINCT ON (only supported by PostgreSQL). @param offset: Offset where results will start to be retrieved from the result set. @param limit: Limit the number of objects retrieved from the result set. @return: self (not a copy). """ if distinct is not None: self._distinct = distinct if offset is not None: self._offset = offset if limit is not None: self._limit = limit return self def _get_select(self): if self._select is not Undef: if self._order_by is not Undef: self._select.order_by = self._order_by if self._limit is not Undef: # XXX UNTESTED! self._select.limit = self._limit if self._offset is not Undef: # XXX UNTESTED! self._select.offset = self._offset return self._select columns, default_tables = self._find_spec.get_columns_and_tables() return Select(columns, self._where, self._tables, default_tables, self._order_by, offset=self._offset, limit=self._limit, distinct=self._distinct, group_by=self._group_by, having=self._having) def _load_objects(self, result, values): return self._find_spec.load_objects(self._store, result, values) def __iter__(self): """Iterate the results of the query. """ result = self._store._connection.execute(self._get_select()) for values in result: yield self._load_objects(result, values) def __getitem__(self, index): """Get an individual item by offset, or a range of items by slice. @return: The matching object or, if a slice is used, a new L{ResultSet} will be returned appropriately modified with C{OFFSET} and C{LIMIT} clauses. """ if isinstance(index, int): if index == 0: result_set = self else: if self._offset is not Undef: index += self._offset result_set = self.copy() result_set.config(offset=index, limit=1) obj = result_set._any() if obj is None: raise IndexError("Index out of range") return obj if not isinstance(index, slice): raise IndexError("Can't index ResultSets with %r" % (index,)) if index.step is not None: raise IndexError("Stepped slices not yet supported: %r" % (index.step,)) offset = self._offset limit = self._limit if index.start is not None: if offset is Undef: offset = index.start else: offset += index.start if limit is not Undef: limit = max(0, limit - index.start) if index.stop is not None: if index.start is None: new_limit = index.stop else: new_limit = index.stop - index.start if limit is Undef or limit > new_limit: limit = new_limit return self.copy().config(offset=offset, limit=limit) def __contains__(self, item): """Check if an item is contained within the result set.""" columns, values = self._find_spec.get_columns_and_values_for_item(item) if self._select is Undef and self._group_by is Undef: # No predefined select: adjust the where clause. dummy, default_tables = self._find_spec.get_columns_and_tables() where = [Eq(*pair) for pair in zip(columns, values)] if self._where is not Undef: where.append(self._where) select = Select(1, And(*where), self._tables, default_tables) else: # Rewrite the predefined query and use it as a subquery. aliased_columns = [Alias(column, "_key%d" % index) for (index, column) in enumerate(columns)] subquery = replace_columns(self._get_select(), aliased_columns) where = [Eq(*pair) for pair in zip(aliased_columns, values)] select = Select(1, And(*where), Alias(subquery, "_tmp")) result = self._store._connection.execute(select) return result.get_one() is not None def is_empty(self): """Return C{True} if this result set doesn't contain any results.""" subselect = self._get_select() subselect.limit = 1 subselect.order_by = Undef select = Select(1, tables=Alias(subselect, "_tmp"), limit=1) result = self._store._connection.execute(select) return (not result.get_one()) def any(self): """Return a single item from the result set. @return: An arbitrary object or C{None} if one isn't available. @see: L{one}, L{first}, and L{last}. """ select = self._get_select() select.limit = 1 select.order_by = Undef result = self._store._connection.execute(select) values = result.get_one() if values: return self._load_objects(result, values) return None def _any(self): """Return a single item from the result without changing sort order. @return: An arbitrary object or C{None} if one isn't available. """ select = self._get_select() select.limit = 1 result = self._store._connection.execute(select) values = result.get_one() if values: return self._load_objects(result, values) return None def first(self): """Return the first item from an ordered result set. @raises UnorderedError: Raised if the result set isn't ordered. @return: The first object or C{None} if one isn't available. @see: L{last}, L{one}, and L{any}. """ if self._order_by is Undef: raise UnorderedError("Can't use first() on unordered result set") return self._any() def last(self): """Return the last item from an ordered result set. @raises FeatureError: Raised if the result set has a C{LIMIT} set. @raises UnorderedError: Raised if the result set isn't ordered. @return: The last object or C{None} if one isn't available. @see: L{first}, L{one}, and L{any}. """ if self._order_by is Undef: raise UnorderedError("Can't use last() on unordered result set") if self._limit is not Undef: raise FeatureError("Can't use last() with a slice " "of defined stop index") select = self._get_select() select.offset = Undef select.limit = 1 select.order_by = [] for expr in self._order_by: if isinstance(expr, Desc): select.order_by.append(expr.expr) elif isinstance(expr, Asc): select.order_by.append(Desc(expr.expr)) else: select.order_by.append(Desc(expr)) result = self._store._connection.execute(select) values = result.get_one() if values: return self._load_objects(result, values) return None def one(self): """Return one item from a result set containing at most one item. @raises NotOneError: Raised if the result set contains more than one item. @return: The object or C{None} if one isn't available. @see: L{first}, L{last}, and L{any}. """ select = self._get_select() # limit could be 1 due to slicing, for instance. if select.limit is not Undef and select.limit > 2: select.limit = 2 result = self._store._connection.execute(select) values = result.get_one() if result.get_one(): raise NotOneError("one() used with more than one result available") if values: return self._load_objects(result, values) return None def order_by(self, *args): """Specify the ordering of the results. The query will be modified appropriately with an ORDER BY clause. Ascending and descending order can be specified by wrapping the columns in L{Asc} and L{Desc}. @param args: One or more L{storm.expr.Column} objects. """ if self._offset is not Undef or self._limit is not Undef: raise FeatureError("Can't reorder a sliced result set") self._order_by = args or Undef return self def remove(self): """Remove all rows represented by this ResultSet from the database. This is done efficiently with a DELETE statement, so objects are not actually loaded into Python. """ if self._group_by is not Undef: raise FeatureError("Removing isn't supported after a " " GROUP BY clause ") if self._offset is not Undef or self._limit is not Undef: raise FeatureError("Can't remove a sliced result set") if self._find_spec.default_cls_info is None: raise FeatureError("Removing not yet supported for tuple or " "expression finds") if self._select is not Undef: raise FeatureError("Removing isn't supported with " "set expressions (unions, etc)") result = self._store._connection.execute( Delete(self._where, self._find_spec.default_cls_info.table)) return result.rowcount def group_by(self, *expr): """Group this ResultSet by the given expressions. @param expr: The expressions used in the GROUP BY statement. @return: self (not a copy). """ if self._select is not Undef: raise FeatureError("Grouping isn't supported with " "set expressions (unions, etc)") find_spec = FindSpec(expr) columns, dummy = find_spec.get_columns_and_tables() self._group_by = columns return self def having(self, *expr): """Filter result previously grouped by. @param expr: Instances of L{Expr}. @return: self (not a copy). """ if self._group_by is Undef: raise FeatureError("having can only be called after group_by.") self._having = And(*expr) return self def _aggregate(self, aggregate_func, expr, column=None): if self._group_by is not Undef: raise FeatureError("Single aggregates aren't supported after a " " GROUP BY clause ") columns, default_tables = self._find_spec.get_columns_and_tables() if (self._select is Undef and not self._distinct and self._offset is Undef and self._limit is Undef): select = Select(aggregate_func(expr), self._where, self._tables, default_tables) else: if expr is Undef: aggregate = aggregate_func(expr) else: alias = Alias(expr, "_expr") columns.append(alias) aggregate = aggregate_func(alias) # Ordering probably doesn't matter for any aggregates, and since # replace_columns() blows up on an ordered query, we'll drop it. select = self._get_select() select.order_by = Undef subquery = replace_columns(select, columns) select = Select(aggregate, tables=Alias(subquery, "_tmp")) result = self._store._connection.execute(select) value = result.get_one()[0] variable_factory = getattr(column, "variable_factory", None) if variable_factory: variable = variable_factory(allow_none=True) result.set_variable(variable, value) return variable.get() return value def count(self, expr=Undef, distinct=False): """Get the number of objects represented by this ResultSet.""" return int(self._aggregate(lambda expr: Count(expr, distinct), expr)) def max(self, expr): """Get the highest value from an expression.""" return self._aggregate(Max, expr, expr) def min(self, expr): """Get the lowest value from an expression.""" return self._aggregate(Min, expr, expr) def avg(self, expr): """Get the average value from an expression.""" value = self._aggregate(Avg, expr) if value is None: return value return float(value) def sum(self, expr): """Get the sum of all values in an expression.""" return self._aggregate(Sum, expr, expr) def get_select_expr(self, *columns): """Get a L{Select} expression to retrieve only the specified columns. @param columns: One or more L{storm.expr.Column} objects whose values will be fetched. @raises FeatureError: Raised if no columns are specified or if this result is a set expression such as a union. @return: A L{Select} expression configured to use the query parameters specified for this result set, and also limited to only retrieving data for the specified columns. """ if not columns: raise FeatureError("select() takes at least one column " "as argument") if self._select is not Undef: raise FeatureError( "Can't generate subselect expression for set expressions") select = self._get_select() select.columns = columns return select def values(self, *columns): """Retrieve only the specified columns. This does not load full objects from the database into Python. @param columns: One or more L{storm.expr.Column} objects whose values will be fetched. @raises FeatureError: Raised if no columns are specified or if this result is a set expression such as a union. @return: An iterator of tuples of the values for each column from each matching row in the database. """ if not columns: raise FeatureError("values() takes at least one column " "as argument") if self._select is not Undef: raise FeatureError("values() can't be used with set expressions") select = self._get_select() select.columns = columns result = self._store._connection.execute(select) if len(columns) == 1: variable = columns[0].variable_factory() for values in result: result.set_variable(variable, values[0]) yield variable.get() else: variables = [column.variable_factory() for column in columns] for values in result: for variable, value in zip(variables, values): result.set_variable(variable, value) yield tuple(variable.get() for variable in variables) def set(self, *args, **kwargs): """Update objects in the result set with the given arguments. This method will update all objects in the current result set to match expressions given as equalities or keyword arguments. These objects may still be in the database (an UPDATE is issued) or may be cached. For instance, C{result.set(Class.attr1 == 1, attr2=2)} will set C{attr1} to 1 and C{attr2} to 2, on all matching objects. """ if self._group_by is not Undef: raise FeatureError("Setting isn't supported after a " " GROUP BY clause ") if self._find_spec.default_cls_info is None: raise FeatureError("Setting isn't supported with tuple or " "expression finds") if self._select is not Undef: raise FeatureError("Setting isn't supported with " "set expressions (unions, etc)") if not (args or kwargs): return changes = {} cls = self._find_spec.default_cls_info.cls # For now only "Class.attr == var" is supported in args. for expr in args: if not isinstance(expr, Eq): raise FeatureError("Unsupported set expression: %r" % repr(expr)) elif not isinstance(expr.expr1, Column): raise FeatureError("Unsupported left operand in set " "expression: %r" % repr(expr.expr1)) elif not isinstance(expr.expr2, (Expr, Variable)): raise FeatureError("Unsupported right operand in set " "expression: %r" % repr(expr.expr2)) changes[expr.expr1] = expr.expr2 for key, value in kwargs.items(): column = getattr(cls, key) if value is None: changes[column] = None elif isinstance(value, Expr): changes[column] = value else: changes[column] = column.variable_factory(value=value) expr = Update(changes, self._where, self._find_spec.default_cls_info.table) self._store.execute(expr, noresult=True) try: cached = self.cached() except CompileError: # We are iterating through all objects in memory here, so # check if the object type matches to avoid trying to # invalidate a column that does not exist, on an unrelated # object. for obj_info in self._store._iter_alive(): if obj_info.cls_info is self._find_spec.default_cls_info: for column in changes: obj_info.variables[column].set(AutoReload) else: changes = list(changes.items()) for obj in cached: for column, value in changes: variables = get_obj_info(obj).variables if value is None: pass elif isinstance(value, Variable): value = value.get() elif isinstance(value, Expr): # If the value is an Expression that means we # can't compute it by ourselves: we rely on # the database to compute it, so just set the # value to AutoReload. value = AutoReload else: value = variables[value].get() variables[column].set(value) variables[column].checkpoint() def cached(self): """Return matching objects from the cache for the current query.""" if self._find_spec.default_cls_info is None: raise FeatureError("Cache finds not supported with tuples " "or expressions") if self._tables is not Undef: raise FeatureError("Cache finds not supported with custom tables") if self._where is Undef: match = None else: match = compile_python.get_matcher(self._where) def get_column(column): return obj_info.variables[column].get() objects = [] for obj_info in self._store._iter_alive(): try: if (obj_info.cls_info is self._find_spec.default_cls_info and (match is None or match(get_column))): objects.append(self._store._get_object(obj_info)) except LostObjectError: pass # This may happen when resolving lazy values # in get_column(). return objects def find(self, *args, **kwargs): """Perform a query on objects within this result set. This is analogous to L{Store.find}, although it doesn't take a C{cls_spec} argument, instead using the same tables as the existing result set, and restricts the results to those in this set. @param args: Instances of L{Expr}. @param kwargs: Mapping of simple column names to values or expressions to query for. @return: A L{ResultSet} of matching instances. """ if self._select is not Undef: raise FeatureError("Can't query set expressions") if self._offset is not Undef or self._limit is not Undef: raise FeatureError("Can't query a sliced result set") if self._group_by is not Undef: raise FeatureError("Can't query grouped result sets") result_set = self.copy() extra_where = get_where_for_args( args, kwargs, self._find_spec.default_cls) if extra_where is not Undef: if result_set._where is Undef: result_set._where = extra_where else: result_set._where = And(result_set._where, extra_where) return result_set def _set_expr(self, expr_cls, other, all=False): if not self._find_spec.is_compatible(other._find_spec): raise FeatureError("Incompatible results for set operation") expr = expr_cls(self._get_select(), other._get_select(), all=all) return ResultSet(self._store, self._find_spec, select=expr) def union(self, other, all=False): """Get the L{Union} of this result set and another. @param all: If True, include duplicates. """ if isinstance(other, EmptyResultSet): return self return self._set_expr(Union, other, all) def difference(self, other, all=False): """Get the difference, using L{Except}, of this result set and another. @param all: If True, include duplicates. """ if isinstance(other, EmptyResultSet): return self return self._set_expr(Except, other, all) def intersection(self, other, all=False): """Get the L{Intersection} of this result set and another. @param all: If True, include duplicates. """ if isinstance(other, EmptyResultSet): return other return self._set_expr(Intersect, other, all) class EmptyResultSet: """An object that looks like a L{ResultSet} but represents no rows. This is convenient for application developers who want to provide a method which is guaranteed to return a L{ResultSet}-like object but which, in certain cases, knows there is no point in querying the database. For example:: def get_people(self, ids): if not ids: return EmptyResultSet() return store.find(People, People.id.is_in(ids)) The methods on EmptyResultSet (L{one}, L{config}, L{union}, etc) are meant to emulate a L{ResultSet} which has matched no rows. """ def __init__(self, ordered=False): self._order_by = ordered def copy(self): result = EmptyResultSet(self._order_by) return result def config(self, distinct=None, offset=None, limit=None): return self def __iter__(self): return yield None def __len__(self): return 0 def __getitem__(self, index): return self.copy() def __contains__(self, item): return False def is_empty(self): return True def any(self): return None def first(self): if self._order_by: return None raise UnorderedError("Can't use first() on unordered result set") def last(self): if self._order_by: return None raise UnorderedError("Can't use last() on unordered result set") def one(self): return None def order_by(self, *args): self._order_by = True return self def group_by(self, *expr): return self def remove(self): return 0 def count(self, expr=Undef, distinct=False): return 0 def max(self, column): return None def min(self, column): return None def avg(self, column): return None def sum(self, column): return None def get_select_expr(self, *columns): """Get a L{Select} expression to retrieve only the specified columns. @param columns: One or more L{storm.expr.Column} objects whose values will be fetched. @raises FeatureError: Raised if no columns are specified. @return: A L{Select} expression configured to use the query parameters specified for this result set. The result of the select will always be an empty set of rows. """ if not columns: raise FeatureError("select() takes at least one column " "as argument") return Select(columns, False) def values(self, *columns): if not columns: raise FeatureError("values() takes at least one column " "as argument") return yield None def set(self, *args, **kwargs): pass def cached(self): return [] def find(self, *args, **kwargs): return self def union(self, other): if isinstance(other, EmptyResultSet): return self return other.union(self) def difference(self, other): return self def intersection(self, other): return self class TableSet: """The representation of a set of tables which can be queried at once. This will typically be constructed by a call to L{Store.using}. """ def __init__(self, store, tables): self._store = store self._tables = tables def find(self, cls_spec, *args, **kwargs): """Perform a query on the previously specified tables. This is identical to L{Store.find} except that the tables are explicitly specified instead of relying on inference. @return: A L{ResultSet}. """ if self._store._implicit_flush_block_count == 0: self._store.flush() find_spec = FindSpec(cls_spec) where = get_where_for_args(args, kwargs, find_spec.default_cls) return self._store._result_set_factory(self._store, find_spec, where, self._tables) Store._result_set_factory = ResultSet Store._table_set = TableSet class FindSpec: """The set of tables or expressions in the result of L{Store.find}.""" def __init__(self, cls_spec): self.is_tuple = type(cls_spec) == tuple if not self.is_tuple: cls_spec = (cls_spec,) info = [] for item in cls_spec: if isinstance(item, Expr): info.append((True, item)) else: info.append((False, get_cls_info(item))) self._cls_spec_info = tuple(info) # Do we have a single non-expression item here? if not self.is_tuple and not info[0][0]: self.default_cls = cls_spec[0] self.default_cls_info = info[0][1] self.default_order = self.default_cls_info.default_order else: self.default_cls = None self.default_cls_info = None self.default_order = Undef def get_columns_and_tables(self): columns = [] default_tables = [] for is_expr, info in self._cls_spec_info: if is_expr: columns.append(info) if isinstance(info, Column): default_tables.append(info.table) else: columns.extend(info.columns) default_tables.append(info.table) return columns, default_tables def is_compatible(self, find_spec): """Return True if this FindSpec is compatible with a second one. Two FindSpecs are considered compatible if either the find specs are identical (i.e. specifies the same classes and columns) or the find spec columns are of the same type. """ if self.is_tuple != find_spec.is_tuple: return False if len(self._cls_spec_info) != len(find_spec._cls_spec_info): return False for (is_expr1, info1), (is_expr2, info2) in zip( self._cls_spec_info, find_spec._cls_spec_info): if is_expr1 != is_expr2: return False # If both infos are PropertyColumns, check whether they are # of the same type. Ideally we should check that the types as # defined in the database are the same, but checking the # variable class is easier and will work most of the time. if isinstance(info1, PropertyColumn): if not isinstance(info2, PropertyColumn): return False variable_class1 = info1.variable_factory().__class__ variable_class2 = info2.variable_factory().__class__ if variable_class1 is not variable_class2: return False elif info1 is not info2: return False return True def load_objects(self, store, result, values): objects = [] values_start = values_end = 0 for is_expr, info in self._cls_spec_info: if is_expr: values_end += 1 variable = getattr(info, "variable_factory", Variable)( value=values[values_start], from_db=True) objects.append(variable.get()) else: values_end += len(info.columns) obj = store._load_object(info, result, values[values_start:values_end]) objects.append(obj) values_start = values_end if self.is_tuple: return tuple(objects) else: return objects[0] def get_columns_and_values_for_item(self, item): """Generate a comparison expression with the given item.""" if isinstance(item, tuple): if not self.is_tuple: raise TypeError("Find spec does not expect tuples.") else: if self.is_tuple: raise TypeError("Find spec expects tuples.") item = (item,) columns = [] values = [] for (is_expr, info), value in zip(self._cls_spec_info, item): if is_expr: if not isinstance(value, (Expr, Variable)) and ( value is not None): value = getattr(info, "variable_factory", Variable)( value=value) columns.append(info) values.append(value) else: obj_info = get_obj_info(value) if obj_info.cls_info != info: raise TypeError("%r does not match %r" % (value, info)) columns.extend(info.primary_key) values.extend(obj_info.primary_vars) return columns, values def get_where_for_args(args, kwargs, cls=None): equals = list(args) if kwargs: if cls is None: raise FeatureError("Can't determine class that keyword " "arguments are associated with") for key, value in kwargs.items(): equals.append(getattr(cls, key) == value) if equals: return And(*equals) return Undef def replace_columns(expr, columns): if isinstance(expr, Select): select = copy(expr) select.columns = columns # Remove the ordering if it won't affect the result of the query. if select.limit is Undef and select.offset is Undef: select.order_by = Undef return select elif isinstance(expr, SetExpr): # The ORDER BY clause might refer to columns we have replaced. # Luckily we can ignore it if there is no limit/offset. if expr.order_by is not Undef and ( expr.limit is not Undef or expr.offset is not Undef): raise FeatureError( "__contains__() does not yet support set " "expressions that combine ORDER BY with " "LIMIT/OFFSET") subexprs = [replace_columns(subexpr, columns) for subexpr in expr.exprs] return expr.__class__( all=expr.all, limit=expr.limit, offset=expr.offset, *subexprs) else: raise FeatureError( "__contains__() does not yet support %r expressions" % (expr.__class__,)) class AutoReload(LazyValue): """A marker for reloading a single value. Often this will be used to specify that a specific attribute should be loaded from the database on the next access, like so:: storm_object.property = AutoReload On the next access to C{storm_object.property}, the value will be loaded from the database. It is also often used as a default value for a property:: class Person(object): __storm_table__ = "person" id = Int(allow_none=False, default=AutoReload) person = store.add(Person) person.id # gets the attribute from the database. """ pass AutoReload = AutoReload() class block_access: """ Context manager blocks database access by one or more L{Store}\\ s in the managed scope. """ def __init__(self, *args): self.stores = args def __enter__(self): for store in self.stores: store.block_access() def __exit__(self, exc_type, exc_val, exc_tb): for store in self.stores: store.unblock_access() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/testing.py0000644000175000017500000000140414645174376016233 0ustar00cjwatsoncjwatsonfrom fixtures import Fixture from storm.tracer import BaseStatementTracer, install_tracer, remove_tracer class CaptureTracer(BaseStatementTracer, Fixture): """Trace SQL statements appending them to a C{list}. Example:: with CaptureTracer() as tracer: # Run queries print(tracer.queries) # Print the queries that have been run @note: This class requires the fixtures package to be available. """ def __init__(self): super().__init__() self.queries = [] def _setUp(self): install_tracer(self) self.addCleanup(remove_tracer, self) def _expanded_raw_execute(self, conn, raw_cursor, statement): """Save the statement to the log.""" self.queries.append(statement) ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4211247 storm-1.0/storm/tests/0000755000175000017500000000000014645532536015343 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/__init__.py0000644000175000017500000000764214571373456017467 0ustar00cjwatsoncjwatson# # Copyright (c) 2011 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # __all__ = [ 'find_tests', 'has_fixtures', 'has_psycopg', 'has_subunit', ] import doctest from itertools import chain import os import unittest try: import fixtures fixtures # Silence lint. except ImportError: has_fixtures = False else: has_fixtures = True try: import psycopg2 psycopg2 # Silence lint. except ImportError: has_psycopg = False else: has_psycopg = True try: import subunit subunit # Silence lint. except ImportError: has_subunit = False else: has_subunit = True def find_tests(testpaths=()): """Find all test paths, or test paths contained in the provided sequence. @param testpaths: If provided, only tests in the given sequence will be considered. If not provided, all tests are considered. @return: a test suite containing the requested tests. """ suite = unittest.TestSuite() topdir = os.path.abspath( os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) testdir = os.path.dirname(__file__) docdir = os.path.join(os.path.dirname(testdir), "docs") testpaths = set(testpaths) for root, dirnames, filenames in chain(os.walk(testdir), os.walk(docdir)): for filename in filenames: filepath = os.path.join(root, filename) relpath = filepath[len(topdir)+1:] if (filename == "__init__.py" or filename.endswith(".pyc") or relpath == os.path.join("storm", "docs", "conf.py") or relpath == os.path.join("storm", "tests", "conftest.py")): # Skip non-tests. continue if testpaths: # Skip any tests not in testpaths. for testpath in testpaths: if relpath.startswith(testpath): break else: continue if filename.endswith(".py"): modpath = relpath.replace(os.path.sep, ".")[:-3] module = __import__(modpath, None, None, [""]) suite.addTest( unittest.defaultTestLoader.loadTestsFromModule(module)) elif filename.endswith(".rst"): load_test = True if relpath == os.path.join("storm", "docs", "zope.rst"): # Special case the inclusion of the Zope-dependent # ZStorm doctest. import storm.tests.zope as ztest load_test = ( ztest.has_transaction and ztest.has_zope_component and ztest.has_zope_security) if load_test: parent_path = os.path.dirname(relpath).replace( os.path.sep, ".") parent_module = __import__(parent_path, None, None, [""]) suite.addTest(doctest.DocFileSuite( os.path.basename(relpath), module_relative=True, package=parent_module, optionflags=doctest.ELLIPSIS)) return suite ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/base.py0000644000175000017500000000305114571373456016630 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import weakref import gc from storm.properties import Property, PropertyPublisherMeta from storm.info import get_obj_info from storm.base import * from storm.tests.helper import TestHelper class BaseTest(TestHelper): def test_metaclass(self): class Class(Storm): __storm_table__ = "table_name" prop = Property(primary=True) self.assertEqual(type(Class), PropertyPublisherMeta) def test_class_is_collectable(self): class Class(Storm): __storm_table__ = "table_name" prop = Property(primary=True) obj = Class() get_obj_info(obj) # Build all wanted meta-information. obj_ref = weakref.ref(obj) del obj gc.collect() self.assertEqual(obj_ref(), None) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/cache.py0000644000175000017500000003047514645174376016775 0ustar00cjwatsoncjwatsonfrom unittest import defaultTestLoader from storm.properties import Int from storm.info import get_obj_info from storm.cache import Cache, GenerationalCache from storm.tests.helper import TestHelper class StubObjectInfo: def __init__(self, id): self.id = id self.hashed = False def get_obj(self): return str(self.id) def __repr__(self): return "%s(%s)" % (self.__class__.__name__, self.id) def __hash__(self): self.hashed = True return self.id def __lt__(self, other): return self.id < other.id class StubClass: __storm_table__ = "stub_class" id = Int(primary=True) class BaseCacheTest(TestHelper): Cache = Cache def setUp(self): super().setUp() self.obj_infos = [StubObjectInfo(i) for i in range(10)] for i in range(len(self.obj_infos)): setattr(self, "obj%d" % (i+1), self.obj_infos[i]) def clear_hashed(self): for obj_info in self.obj_infos: obj_info.hashed = False def test_initially_empty(self): cache = self.Cache() self.assertEqual(cache.get_cached(), []) def test_add(self): cache = self.Cache(5) cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2, self.obj3]) def test_adding_similar_obj_infos(self): """If __eq__ is broken, this fails.""" obj_info1 = get_obj_info(StubClass()) obj_info2 = get_obj_info(StubClass()) cache = self.Cache(5) cache.add(obj_info1) cache.add(obj_info2) cache.add(obj_info2) cache.add(obj_info1) self.assertEqual(sorted([hash(obj_info) for obj_info in cache.get_cached()]), sorted([hash(obj_info1), hash(obj_info2)])) def test_remove(self): cache = self.Cache(5) cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) cache.remove(self.obj2) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj3]) def test_add_existing(self): cache = self.Cache(5) cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) cache.add(self.obj2) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2, self.obj3]) def test_add_with_size_zero(self): """Cache is disabled entirely on add() if size is 0.""" cache = self.Cache(0) cache.add(self.obj1) # Ensure that we don't even check if obj_info is in the # cache, by testing if it was hashed. Hopefully, that means # we got a faster path. self.assertEqual(self.obj1.hashed, False) def test_remove_with_size_zero(self): """Cache is disabled entirely on remove() if size is 0.""" cache = self.Cache(0) cache.remove(self.obj1) def test_clear(self): """The clear method empties the cache.""" cache = self.Cache(5) for obj_info in self.obj_infos: cache.add(obj_info) cache.clear() self.assertEqual(cache.get_cached(), []) # Just an additional check ensuring that any additional structures # which may be used were cleaned properly as well. for obj_info in self.obj_infos: self.assertEqual(cache.remove(obj_info), False) def test_set_zero_size(self): """ Setting a cache's size to zero clears the cache. """ cache = self.Cache() cache.add(self.obj1) cache.add(self.obj2) cache.set_size(0) self.assertEqual(cache.get_cached(), []) def test_fit_size(self): """ A cache of size n can hold at least n objects. """ size = 10 cache = self.Cache(size) for value in range(size): cache.add(StubObjectInfo(value)) self.assertEqual(len(cache.get_cached()), size) class CacheTest(BaseCacheTest): def test_size_and_fifo_behaviour(self): cache = Cache(5) for obj_info in self.obj_infos: cache.add(obj_info) self.assertEqual([obj_info.id for obj_info in cache.get_cached()], [9, 8, 7, 6, 5]) def test_reduce_max_size_to_zero(self): """When setting the size to zero, there's an optimization.""" cache = Cache(5) obj_info = self.obj_infos[0] cache.add(obj_info) obj_info.hashed = False cache.set_size(0) self.assertEqual(cache.get_cached(), []) # Ensure that we don't even check if obj_info is in the # cache, by testing if it was hashed. Hopefully, that means # we got a faster path. self.assertEqual(obj_info.hashed, False) def test_reduce_max_size(self): cache = Cache(5) for obj_info in self.obj_infos: cache.add(obj_info) cache.set_size(3) self.assertEqual([obj_info.id for obj_info in cache.get_cached()], [9, 8, 7]) # Adding items past the new maximum size should drop older ones. for obj_info in self.obj_infos[:2]: cache.add(obj_info) self.assertEqual([obj_info.id for obj_info in cache.get_cached()], [1, 0, 9]) def test_increase_max_size(self): cache = Cache(5) for obj_info in self.obj_infos: cache.add(obj_info) cache.set_size(10) self.assertEqual([obj_info.id for obj_info in cache.get_cached()], [9, 8, 7, 6, 5]) # Adding items past the new maximum size should drop older ones. for obj_info in self.obj_infos[:6]: cache.add(obj_info) self.assertEqual([obj_info.id for obj_info in cache.get_cached()], [5, 4, 3, 2, 1, 0, 9, 8, 7, 6]) class TestGenerationalCache(BaseCacheTest): Cache = GenerationalCache def setUp(self): super().setUp() self.obj1 = StubObjectInfo(1) self.obj2 = StubObjectInfo(2) self.obj3 = StubObjectInfo(3) self.obj4 = StubObjectInfo(4) def test_initially_empty(self): cache = GenerationalCache() self.assertEqual(cache.get_cached(), []) def test_cache_one_object(self): cache = GenerationalCache() cache.add(self.obj1) self.assertEqual(cache.get_cached(), [self.obj1]) def test_cache_multiple_objects(self): cache = GenerationalCache() cache.add(self.obj1) cache.add(self.obj2) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2]) def test_clear_cache(self): cache = GenerationalCache() cache.add(self.obj1) cache.clear() self.assertEqual(cache.get_cached(), []) def test_clear_cache_clears_the_second_generation(self): cache = GenerationalCache(1) cache.add(self.obj1) cache.add(self.obj2) cache.clear() self.assertEqual(cache.get_cached(), []) def test_remove_object(self): cache = GenerationalCache() cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) present = cache.remove(self.obj2) self.assertTrue(present) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj3]) def test_remove_nothing(self): cache = GenerationalCache() cache.add(self.obj1) present = cache.remove(self.obj2) self.assertFalse(present) self.assertEqual(cache.get_cached(), [self.obj1]) def test_size_limit(self): """ A cache will never hold more than twice its size in objects. The generational system is what prevents it from holding exactly the requested number of objects. """ size = 10 cache = GenerationalCache(size) for value in range(5 * size): cache.add(StubObjectInfo(value)) self.assertEqual(len(cache.get_cached()), size * 2) def test_set_size_smaller_than_current_size(self): """ Setting the size to a smaller size than the number of objects currently cached will drop some of the extra content. Note that because of the generation system, it can actually hold two times the size requested in edge cases. """ cache = GenerationalCache(150) for i in range(250): cache.add(StubObjectInfo(i)) cache.set_size(100) cached = cache.get_cached() self.assertEqual(len(cached), 100) for obj_info in cache.get_cached(): self.assertTrue(obj_info.id >= 100) def test_set_size_larger_than_current_size(self): """ Setting the cache size to something more than the number of objects in the cache does not affect its current contents, and will merge any elements from the second generation into the first one. """ cache = GenerationalCache(1) cache.add(self.obj1) # new=[1] old=[] cache.add(self.obj2) # new=[2] old=[1] cache.set_size(2) # new=[1, 2] old=[] cache.add(self.obj3) # new=[3] old=[1, 2] self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2, self.obj3]) def test_set_size_limit(self): """ Setting the size limits the cache's size just like passing an initial size would. """ size = 10 cache = GenerationalCache(size * 100) cache.set_size(size) for value in range(size * 10): cache.add(StubObjectInfo(value)) self.assertEqual(len(cache.get_cached()), size * 2) def test_two_generations(self): """ Inserting more objects than the cache's size causes the cache to contain two generations, each holding up to objects. """ cache = GenerationalCache(1) cache.add(self.obj1) cache.add(self.obj2) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2]) def test_three_generations(self): """ If more than 2* objects come along, only 2* objects are retained. """ cache = GenerationalCache(1) cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) self.assertEqual(sorted(cache.get_cached()), [self.obj2, self.obj3]) def test_generational_overlap(self): """ An object that is both in the primary and the secondary generation is listed only once in the cache's contents. """ cache = GenerationalCache(2) cache.add(self.obj1) # new=[1] old=[] cache.add(self.obj2) # new=[1, 2] old=[] cache.add(self.obj3) # new=[3] old=[1, 2] cache.add(self.obj1) # new=[3, 1] old=[1, 2] self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj2, self.obj3]) def test_remove_from_overlap(self): """ Removing an object from the cache removes it from both its primary and secondary generations. """ cache = GenerationalCache(2) cache.add(self.obj1) # new=[1] old=[] cache.add(self.obj2) # new=[1, 2] old=[] cache.add(self.obj3) # new=[3] old=[1, 2] cache.add(self.obj1) # new=[3, 1] old=[1, 2] present = cache.remove(self.obj1) self.assertTrue(present) self.assertEqual(sorted(cache.get_cached()), [self.obj2, self.obj3]) def test_evict_oldest(self): """The "oldest" object is the first to be evicted.""" cache = GenerationalCache(1) cache.add(self.obj1) cache.add(self.obj2) cache.add(self.obj3) self.assertEqual(sorted(cache.get_cached()), [self.obj2, self.obj3]) def test_evict_LRU(self): """ Actually, it's not the oldest but the LRU object that is first to be evicted. Re-adding the oldest object makes it not be the LRU. """ cache = GenerationalCache(1) cache.add(self.obj1) cache.add(self.obj2) # This "refreshes" the oldest object in the cache. cache.add(self.obj1) cache.add(self.obj3) self.assertEqual(sorted(cache.get_cached()), [self.obj1, self.obj3]) def test_suite(): return defaultTestLoader.loadTestsFromName(__name__) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/database.py0000644000175000017500000005205614645174376017475 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import sys import types import gc from storm.exceptions import ClosedError, DatabaseError, DisconnectionError from storm.variables import Variable import storm.database from storm.database import * from storm.tracer import install_tracer, remove_all_tracers, DebugTracer from storm.uri import URI from storm.expr import * from storm.tests.helper import TestHelper from storm.tests.mocker import ARGS marker = object() class RawConnection: closed = False def __init__(self, executed): self.executed = executed def cursor(self): return RawCursor(executed=self.executed) def commit(self): self.executed.append("COMMIT") def rollback(self): self.executed.append("ROLLBACK") def close(self): self.executed.append("CCLOSE") class RawCursor: def __init__(self, arraysize=1, executed=None): self.arraysize = arraysize if executed is None: self.executed = [] else: self.executed = executed self._fetchone_data = [("fetchone%d" % i,) for i in range(3)] self._fetchall_data = [("fetchall%d" % i,) for i in range(2)] self._fetchmany_data = [("fetchmany%d" % i,) for i in range(5)] def close(self): self.executed.append("RCLOSE") def execute(self, statement, params=marker): self.executed.append((statement, params)) def fetchone(self): if self._fetchone_data: return self._fetchone_data.pop(0) return None def fetchall(self): result = self._fetchall_data self._fetchall_data = [] return result def fetchmany(self): result = self._fetchmany_data[:self.arraysize] del self._fetchmany_data[:self.arraysize] return result class FakeConnection: def __init__(self): self._database = Database() def _check_disconnect(self, _function, *args, **kwargs): return _function(*args, **kwargs) class FakeTracer: def __init__(self, stream=None): self.seen = [] def connection_raw_execute(self, connection, raw_cursor, statement, params): self.seen.append(("EXECUTE", connection, type(raw_cursor), statement, params)) def connection_raw_execute_success(self, connection, raw_cursor, statement, params): self.seen.append(("SUCCESS", connection, type(raw_cursor), statement, params)) def connection_raw_execute_error(self, connection, raw_cursor, statement, params, error): self.seen.append(("ERROR", connection, type(raw_cursor), statement, params, error)) def connection_commit(self, connection, xid=None): self.seen.append(("COMMIT", connection, xid)) def connection_rollback(self, connection, xid=None): self.seen.append(("ROLLBACK", connection, xid)) class DatabaseTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.database = Database() def test_connect(self): self.assertRaises(NotImplementedError, self.database.connect) class ConnectionTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.executed = [] self.database = Database() self.raw_connection = RawConnection(self.executed) self.database.raw_connect = lambda: self.raw_connection self.connection = Connection(self.database) def tearDown(self): TestHelper.tearDown(self) remove_all_tracers() def test_execute(self): result = self.connection.execute("something") self.assertTrue(isinstance(result, Result)) self.assertEqual(self.executed, [("something", marker)]) def test_execute_params(self): result = self.connection.execute("something", (1,2,3)) self.assertTrue(isinstance(result, Result)) self.assertEqual(self.executed, [("something", (1,2,3))]) def test_execute_noresult(self): result = self.connection.execute("something", noresult=True) self.assertEqual(result, None) self.assertEqual(self.executed, [("something", marker), "RCLOSE"]) def test_execute_convert_param_style(self): class MyConnection(Connection): param_mark = "%s" connection = MyConnection(self.database) result = connection.execute("'?' ? '?' ? '?'") self.assertEqual(self.executed, [("'?' %s '?' %s '?'", marker)]) # TODO: Unsupported for now. #result = connection.execute("$$?$$ ? $asd$'?$asd$ ? '?'") #self.assertEqual(self.executed, # [("'?' %s '?' %s '?'", marker), # ("$$?$$ %s $asd'?$asd$ %s '?'", marker)]) def test_execute_select(self): select = Select([SQLToken("column1"), SQLToken("column2")], tables=[SQLToken("table1"), SQLToken("table2")]) result = self.connection.execute(select) self.assertTrue(isinstance(result, Result)) self.assertEqual(self.executed, [("SELECT column1, column2 FROM table1, table2", marker)]) def test_execute_select_and_params(self): select = Select(["column1", "column2"], tables=["table1", "table2"]) self.assertRaises(ValueError, self.connection.execute, select, ("something",)) def test_execute_closed(self): self.connection.close() self.assertRaises(ClosedError, self.connection.execute, "SELECT 1") def test_raw_execute_tracing(self): self.assertMethodsMatch(FakeTracer, DebugTracer) tracer = FakeTracer() install_tracer(tracer) self.connection.execute("something") self.assertEqual(tracer.seen, [("EXECUTE", self.connection, RawCursor, "something", ()), ("SUCCESS", self.connection, RawCursor, "something", ())]) del tracer.seen[:] self.connection.execute("something", (1, 2)) self.assertEqual(tracer.seen, [("EXECUTE", self.connection, RawCursor, "something", (1, 2)), ("SUCCESS", self.connection, RawCursor, "something", (1, 2))]) def test_raw_execute_error_tracing(self): cursor_mock = self.mocker.patch(RawCursor) cursor_mock.execute(ARGS) error = ZeroDivisionError() self.mocker.throw(error) self.mocker.replay() self.assertMethodsMatch(FakeTracer, DebugTracer) tracer = FakeTracer() install_tracer(tracer) self.assertRaises(ZeroDivisionError, self.connection.execute, "something") self.assertEqual(tracer.seen, [("EXECUTE", self.connection, RawCursor, "something", ()), ("ERROR", self.connection, RawCursor, "something", (), error)]) def test_raw_execute_setup_error_tracing(self): """ When an exception is raised in the connection_raw_execute hook of a tracer, the connection_raw_execute_error hook is called. """ cursor_mock = self.mocker.patch(FakeTracer) cursor_mock.connection_raw_execute(ARGS) error = ZeroDivisionError() self.mocker.throw(error) self.mocker.replay() tracer = FakeTracer() install_tracer(tracer) self.assertRaises(ZeroDivisionError, self.connection.execute, "something") self.assertEqual(tracer.seen, [("ERROR", self.connection, RawCursor, "something", (), error)]) def test_tracing_check_disconnect(self): tracer = FakeTracer() tracer_mock = self.mocker.patch(tracer) tracer_mock.connection_raw_execute(ARGS) self.mocker.throw(DatabaseError('connection closed')) self.mocker.replay() install_tracer(tracer_mock) self.connection.is_disconnection_error = ( lambda exc, extra_disconnection_errors=(): 'connection closed' in str(exc)) self.assertRaises(DisconnectionError, self.connection.execute, "something") def test_tracing_success_check_disconnect(self): tracer = FakeTracer() tracer_mock = self.mocker.patch(tracer) tracer_mock.connection_raw_execute(ARGS) tracer_mock.connection_raw_execute_success(ARGS) self.mocker.throw(DatabaseError('connection closed')) self.mocker.replay() install_tracer(tracer_mock) self.connection.is_disconnection_error = ( lambda exc, extra_disconnection_errors=(): 'connection closed' in str(exc)) self.assertRaises(DisconnectionError, self.connection.execute, "something") def test_tracing_error_check_disconnect(self): cursor_mock = self.mocker.patch(RawCursor) cursor_mock.execute(ARGS) error = ZeroDivisionError() self.mocker.throw(error) tracer = FakeTracer() tracer_mock = self.mocker.patch(tracer) tracer_mock.connection_raw_execute(ARGS) tracer_mock.connection_raw_execute_error(ARGS) self.mocker.throw(DatabaseError('connection closed')) self.mocker.replay() install_tracer(tracer_mock) self.connection.is_disconnection_error = ( lambda exc, extra_disconnection_errors=(): 'connection closed' in str(exc)) self.assertRaises(DisconnectionError, self.connection.execute, "something") def test_commit(self): self.connection.commit() self.assertEqual(self.executed, ["COMMIT"]) def test_commit_tracing(self): self.assertMethodsMatch(FakeTracer, DebugTracer) tracer = FakeTracer() install_tracer(tracer) self.connection.commit() self.assertEqual(tracer.seen, [("COMMIT", self.connection, None)]) def test_rollback(self): self.connection.rollback() self.assertEqual(self.executed, ["ROLLBACK"]) def test_rollback_tracing(self): self.assertMethodsMatch(FakeTracer, DebugTracer) tracer = FakeTracer() install_tracer(tracer) self.connection.rollback() self.assertEqual(tracer.seen, [("ROLLBACK", self.connection, None)]) def test_close(self): self.connection.close() self.assertEqual(self.executed, ["CCLOSE"]) def test_close_twice(self): self.connection.close() self.connection.close() self.assertEqual(self.executed, ["CCLOSE"]) def test_close_deallocates_raw_connection(self): refs_before = len(gc.get_referrers(self.raw_connection)) self.connection.close() refs_after = len(gc.get_referrers(self.raw_connection)) self.assertEqual(refs_after, refs_before-1) def test_del_deallocates_raw_connection(self): refs_before = len(gc.get_referrers(self.raw_connection)) self.connection.__del__() refs_after = len(gc.get_referrers(self.raw_connection)) self.assertEqual(refs_after, refs_before-1) def test_wb_del_with_previously_deallocated_connection(self): self.connection._raw_connection = None self.connection.__del__() def test_get_insert_identity(self): result = self.connection.execute("INSERT") self.assertRaises(NotImplementedError, result.get_insert_identity, None, None) def test_wb_ensure_connected_noop(self): """Check that _ensure_connected() is a no-op for STATE_CONNECTED.""" self.assertEqual(self.connection._state, storm.database.STATE_CONNECTED) def connect(): raise DatabaseError("_ensure_connected() tried to connect") self.database.raw_connect = connect self.connection._ensure_connected() def test_wb_ensure_connected_dead_connection(self): """Check that DisconnectionError is raised for STATE_DISCONNECTED.""" self.connection._state = storm.database.STATE_DISCONNECTED self.assertRaises(DisconnectionError, self.connection._ensure_connected) def test_wb_ensure_connected_reconnects(self): """Check that _ensure_connected() reconnects for STATE_RECONNECT.""" self.connection._state = storm.database.STATE_RECONNECT self.connection._raw_connection = None self.connection._ensure_connected() self.assertNotEqual(self.connection._raw_connection, None) self.assertEqual(self.connection._state, storm.database.STATE_CONNECTED) def test_wb_ensure_connected_connect_failure(self): """Check that the connection is flagged on reconnect failures.""" self.connection._state = storm.database.STATE_RECONNECT self.connection._raw_connection = None def _fail_to_connect(): raise DatabaseError("could not connect") self.database.raw_connect = _fail_to_connect self.assertRaises(DisconnectionError, self.connection._ensure_connected) self.assertEqual(self.connection._state, storm.database.STATE_DISCONNECTED) self.assertEqual(self.connection._raw_connection, None) def test_wb_check_disconnection(self): """Ensure that _check_disconnect() detects disconnections.""" class FakeException(DatabaseError): """A fake database exception that indicates a disconnection.""" self.connection.is_disconnection_error = ( lambda exc, extra_disconnection_errors=(): isinstance(exc, FakeException)) self.assertEqual(self.connection._state, storm.database.STATE_CONNECTED) # Error is converted to DisconnectionError: def raise_exception(): raise FakeException self.assertRaises(DisconnectionError, self.connection._check_disconnect, raise_exception) self.assertEqual(self.connection._state, storm.database.STATE_DISCONNECTED) self.assertEqual(self.connection._raw_connection, None) def test_wb_check_disconnection_extra_errors(self): """Ensure that _check_disconnect() can check for additional exceptions.""" class FakeException(DatabaseError): """A fake database exception that indicates a disconnection.""" self.connection.is_disconnection_error = ( lambda exc, extra_disconnection_errors=(): isinstance(exc, extra_disconnection_errors)) self.assertEqual(self.connection._state, storm.database.STATE_CONNECTED) # Error is converted to DisconnectionError: def raise_exception(): raise FakeException # Exception passes through as normal. self.assertRaises(FakeException, self.connection._check_disconnect, raise_exception) self.assertEqual(self.connection._state, storm.database.STATE_CONNECTED) # Exception treated as a disconnection when keyword argument passed. self.assertRaises(DisconnectionError, self.connection._check_disconnect, raise_exception, extra_disconnection_errors=FakeException) self.assertEqual(self.connection._state, storm.database.STATE_DISCONNECTED) def test_wb_rollback_clears_disconnected_connection(self): """Check that rollback clears the DISCONNECTED state.""" self.connection._state = storm.database.STATE_DISCONNECTED self.connection._raw_connection = None self.connection.rollback() self.assertEqual(self.executed, []) self.assertEqual(self.connection._state, storm.database.STATE_RECONNECT) class ResultTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.executed = [] self.raw_cursor = RawCursor(executed=self.executed) self.result = Result(FakeConnection(), self.raw_cursor) def test_get_one(self): self.assertEqual(self.result.get_one(), ("fetchone0",)) self.assertEqual(self.result.get_one(), ("fetchone1",)) self.assertEqual(self.result.get_one(), ("fetchone2",)) self.assertEqual(self.result.get_one(), None) def test_get_all(self): self.assertEqual(self.result.get_all(), [("fetchall0",), ("fetchall1",)]) self.assertEqual(self.result.get_all(), []) def test_iter(self): result = Result(FakeConnection(), RawCursor(2)) self.assertEqual([item for item in result], [("fetchmany0",), ("fetchmany1",), ("fetchmany2",), ("fetchmany3",), ("fetchmany4",)]) def test_set_variable(self): variable = Variable() self.result.set_variable(variable, marker) self.assertEqual(variable.get(), marker) def test_close(self): self.result.close() self.assertEqual(self.executed, ["RCLOSE"]) def test_close_twice(self): self.result.close() self.result.close() self.assertEqual(self.executed, ["RCLOSE"]) def test_close_deallocates_raw_cursor(self): refs_before = len(gc.get_referrers(self.raw_cursor)) self.result.close() refs_after = len(gc.get_referrers(self.raw_cursor)) self.assertEqual(refs_after, refs_before-1) def test_del_deallocates_raw_cursor(self): refs_before = len(gc.get_referrers(self.raw_cursor)) self.result.__del__() refs_after = len(gc.get_referrers(self.raw_cursor)) self.assertEqual(refs_after, refs_before-1) def test_wb_del_with_previously_deallocated_cursor(self): self.result._raw_cursor = None self.result.__del__() def test_set_arraysize(self): """When the arraysize is 1, change it to a better value.""" raw_cursor = RawCursor() self.assertEqual(raw_cursor.arraysize, 1) result = Result(FakeConnection(), raw_cursor) self.assertEqual(raw_cursor.arraysize, 10) def test_preserve_arraysize(self): """When the arraysize is not 1, preserve it.""" raw_cursor = RawCursor(arraysize=123) result = Result(FakeConnection(), raw_cursor) self.assertEqual(raw_cursor.arraysize, 123) class CreateDatabaseTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.db_module = types.ModuleType("db_module") self.uri = None def create_from_uri(uri): self.uri = uri return "RESULT" self.db_module.create_from_uri = create_from_uri sys.modules["storm.databases.db_module"] = self.db_module def tearDown(self): del sys.modules["storm.databases.db_module"] TestHelper.tearDown(self) def test_create_database_with_str(self): create_database("db_module:db") self.assertTrue(self.uri) self.assertEqual(self.uri.scheme, "db_module") self.assertEqual(self.uri.database, "db") def test_create_database_with_unicode(self): create_database("db_module:db") self.assertTrue(self.uri) self.assertEqual(self.uri.scheme, "db_module") self.assertEqual(self.uri.database, "db") def test_create_database_with_uri(self): uri = URI("db_module:db") create_database(uri) self.assertTrue(self.uri is uri) class RegisterSchemeTest(TestHelper): uri = None def tearDown(self): if 'factory' in storm.database._database_schemes: del storm.database._database_schemes['factory'] TestHelper.tearDown(self) def test_register_scheme(self): def factory(uri): self.uri = uri return "FACTORY RESULT" register_scheme('factory', factory) self.assertEqual(storm.database._database_schemes['factory'], factory) # Check that we can create databases that use this scheme ... result = create_database('factory:foobar') self.assertEqual(result, "FACTORY RESULT") self.assertTrue(self.uri) self.assertEqual(self.uri.scheme, 'factory') self.assertEqual(self.uri.database, 'foobar') ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4211247 storm-1.0/storm/tests/databases/0000755000175000017500000000000014645532536017272 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/tests/databases/__init__.py0000644000175000017500000000000011752263216021361 0ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/databases/base.py0000644000175000017500000010754314645174376020574 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta import pickle import shutil import sys import os from storm.uri import URI from storm.expr import Select, Column, SQLToken, SQLRaw, Count, Alias from storm.variables import (Variable, PickleVariable, BytesVariable, DecimalVariable, DateTimeVariable, DateVariable, TimeVariable, TimeDeltaVariable) from storm.database import * from storm.xid import Xid from storm.event import EventSystem from storm.exceptions import ( DatabaseError, DatabaseModuleError, ConnectionBlockedError, DisconnectionError, Error, OperationalError, ProgrammingError) from storm.tests.databases.proxy import ProxyTCPServer from storm.tests.helper import MakePath class Marker: pass marker = Marker() class DatabaseTest: supports_microseconds = True def setUp(self): super().setUp() self.create_database() self.create_connection() self.drop_tables() self.create_tables() self.create_sample_data() def tearDown(self): self.drop_sample_data() self.drop_tables() self.drop_connection() self.drop_database() super().tearDown() def create_database(self): raise NotImplementedError def create_connection(self): self.connection = self.database.connect() def create_tables(self): raise NotImplementedError def create_sample_data(self): self.connection.execute("INSERT INTO number VALUES (1, 2, 3)") self.connection.execute("INSERT INTO test VALUES (10, 'Title 10')") self.connection.execute("INSERT INTO test VALUES (20, 'Title 20')") self.connection.commit() def drop_sample_data(self): pass def drop_tables(self): for table in ["number", "test", "datetime_test", "bin_test"]: try: self.connection.execute("DROP TABLE " + table) self.connection.commit() except: self.connection.rollback() def drop_connection(self): self.connection.close() def drop_database(self): pass def test_create(self): self.assertTrue(isinstance(self.database, Database)) def test_get_uri(self): """ The get_uri() method returns the URI the database with created with. """ uri = self.database.get_uri() self.assertIsNotNone(uri.scheme) def test_connection(self): self.assertTrue(isinstance(self.connection, Connection)) def test_rollback(self): self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.rollback() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertFalse(result.get_one()) def test_rollback_twice(self): self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.rollback() self.connection.rollback() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertFalse(result.get_one()) def test_commit(self): self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.commit() self.connection.rollback() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertTrue(result.get_one()) def test_commit_twice(self): self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.commit() self.connection.commit() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertTrue(result.get_one()) def test_execute_result(self): result = self.connection.execute("SELECT 1") self.assertTrue(isinstance(result, Result)) self.assertTrue(result.get_one()) def test_execute_unicode_result(self): result = self.connection.execute("SELECT title FROM test") self.assertTrue(isinstance(result, Result)) row = result.get_one() self.assertEqual(row, ("Title 10",)) self.assertTrue(isinstance(row[0], str)) def test_execute_params(self): result = self.connection.execute("SELECT one FROM number " "WHERE 1=?", (1,)) self.assertTrue(result.get_one()) result = self.connection.execute("SELECT one FROM number " "WHERE 1=?", (2,)) self.assertFalse(result.get_one()) def test_execute_empty_params(self): result = self.connection.execute("SELECT one FROM number", ()) self.assertTrue(result.get_one()) def test_execute_expression(self): result = self.connection.execute(Select(1)) self.assertTrue(result.get_one(), (1,)) def test_execute_expression_empty_params(self): result = self.connection.execute(Select(SQLRaw("1"))) self.assertTrue(result.get_one(), (1,)) def test_get_one(self): result = self.connection.execute("SELECT * FROM test ORDER BY id") self.assertEqual(result.get_one(), (10, "Title 10")) def test_get_all(self): result = self.connection.execute("SELECT * FROM test ORDER BY id") self.assertEqual(result.get_all(), [(10, "Title 10"), (20, "Title 20")]) def test_iter(self): result = self.connection.execute("SELECT * FROM test ORDER BY id") self.assertEqual([item for item in result], [(10, "Title 10"), (20, "Title 20")]) def test_simultaneous_iter(self): result1 = self.connection.execute("SELECT * FROM test " "ORDER BY id ASC") result2 = self.connection.execute("SELECT * FROM test " "ORDER BY id DESC") iter1 = iter(result1) iter2 = iter(result2) self.assertEqual(next(iter1), (10, "Title 10")) self.assertEqual(next(iter2), (20, "Title 20")) self.assertEqual(next(iter1), (20, "Title 20")) self.assertEqual(next(iter2), (10, "Title 10")) self.assertRaises(StopIteration, next, iter1) self.assertRaises(StopIteration, next, iter2) def test_get_insert_identity(self): result = self.connection.execute("INSERT INTO test (title) " "VALUES ('Title 30')") primary_key = (Column("id", SQLToken("test")),) primary_variables = (Variable(),) expr = result.get_insert_identity(primary_key, primary_variables) select = Select(Column("title", SQLToken("test")), expr) result = self.connection.execute(select) self.assertEqual(result.get_one(), ("Title 30",)) def test_get_insert_identity_composed(self): result = self.connection.execute("INSERT INTO test (title) " "VALUES ('Title 30')") primary_key = (Column("id", SQLToken("test")), Column("title", SQLToken("test"))) primary_variables = (Variable(), Variable("Title 30")) expr = result.get_insert_identity(primary_key, primary_variables) select = Select(Column("title", SQLToken("test")), expr) result = self.connection.execute(select) self.assertEqual(result.get_one(), ("Title 30",)) def test_datetime(self): value = datetime(1977, 4, 5, 12, 34, 56, 78) self.connection.execute("INSERT INTO datetime_test (dt) VALUES (?)", (value,)) result = self.connection.execute("SELECT dt FROM datetime_test") variable = DateTimeVariable() result.set_variable(variable, result.get_one()[0]) if not self.supports_microseconds: value = value.replace(microsecond=0) self.assertEqual(variable.get(), value) def test_date(self): value = date(1977, 4, 5) self.connection.execute("INSERT INTO datetime_test (d) VALUES (?)", (value,)) result = self.connection.execute("SELECT d FROM datetime_test") variable = DateVariable() result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), value) def test_time(self): value = time(12, 34, 56, 78) self.connection.execute("INSERT INTO datetime_test (t) VALUES (?)", (value,)) result = self.connection.execute("SELECT t FROM datetime_test") variable = TimeVariable() result.set_variable(variable, result.get_one()[0]) if not self.supports_microseconds: value = value.replace(microsecond=0) self.assertEqual(variable.get(), value) def test_timedelta(self): value = timedelta(12, 34, 56) self.connection.execute("INSERT INTO datetime_test (td) VALUES (?)", (value,)) result = self.connection.execute("SELECT td FROM datetime_test") variable = TimeDeltaVariable() result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), value) def test_pickle(self): value = {"a": 1, "b": 2} value_dump = pickle.dumps(value, -1) self.connection.execute("INSERT INTO bin_test (b) VALUES (?)", (value_dump,)) result = self.connection.execute("SELECT b FROM bin_test") variable = PickleVariable() result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), value) def test_binary(self): """Ensure database works with high bits and embedded zeros.""" value = b"\xff\x00\xff\x00" self.connection.execute("INSERT INTO bin_test (b) VALUES (?)", (value,)) result = self.connection.execute("SELECT b FROM bin_test") variable = BytesVariable() result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), value) def test_binary_ascii(self): """Some databases like pysqlite2 may return unicode for strings.""" self.connection.execute("INSERT INTO bin_test VALUES (10, 'Value')") result = self.connection.execute("SELECT b FROM bin_test") variable = BytesVariable() # If the following doesn't raise a TypeError we're good. result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), b"Value") def test_order_by_group_by(self): self.connection.execute("INSERT INTO test VALUES (100, 'Title 10')") self.connection.execute("INSERT INTO test VALUES (101, 'Title 10')") id = Column("id", "test") title = Column("title", "test") expr = Select(Count(id), group_by=title, order_by=Count(id)) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,), (3,)]) def test_set_decimal_variable_from_str_column(self): self.connection.execute("INSERT INTO test VALUES (40, '40.5')") variable = DecimalVariable() result = self.connection.execute("SELECT title FROM test WHERE id=40") result.set_variable(variable, result.get_one()[0]) def test_get_decimal_variable_to_str_column(self): variable = DecimalVariable() variable.set("40.5", from_db=True) self.connection.execute("INSERT INTO test VALUES (40, ?)", (variable,)) result = self.connection.execute("SELECT title FROM test WHERE id=40") self.assertEqual(result.get_one()[0], "40.5") def test_quoting(self): # FIXME "with'quote" should be in the list below, but it doesn't # work because it breaks the parameter mark translation. for reserved_name in ["with space", 'with`"escape', "SELECT"]: reserved_name = SQLToken(reserved_name) expr = Select(reserved_name, tables=Alias(Select(Alias(1, reserved_name)))) result = self.connection.execute(expr) self.assertEqual(result.get_one(), (1,)) def test_concurrent_behavior(self): """The default behavior should be to handle transactions in isolation. Data committed in one transaction shouldn't be visible to another running transaction before the later is committed or aborted. If this isn't the case, the caching made by Storm (or by anything that works with data in memory, in fact) becomes a dangerous thing. For PostgreSQL, isolation level must be SERIALIZABLE. For MySQL, isolation level must be REPEATABLE READ (the default), and the InnoDB engine must be in use. For SQLite, the isolation level already is SERIALIZABLE when not in autocommit mode. OTOH, PySQLite is nuts regarding transactional behavior, and will easily offer READ COMMITTED behavior inside a "transaction" (it didn't tell SQLite to open a transaction, in fact). """ connection1 = self.connection connection2 = self.database.connect() try: result = connection1.execute("SELECT title FROM test WHERE id=10") self.assertEqual(result.get_one(), ("Title 10",)) try: connection2.execute("UPDATE test SET title='Title 100' " "WHERE id=10") connection2.commit() except OperationalError as e: self.assertEqual(str(e), "database is locked") # SQLite blocks result = connection1.execute("SELECT title FROM test WHERE id=10") self.assertEqual(result.get_one(), ("Title 10",)) finally: connection1.rollback() def test_wb_connect_sets_event_system(self): connection = self.database.connect(marker) self.assertEqual(connection._event, marker) def test_execute_sends_event(self): event = EventSystem(marker) calls = [] def register_transaction(owner): calls.append(owner) event.hook("register-transaction", register_transaction) connection = self.database.connect(event) connection.execute("SELECT 1") self.assertEqual(len(calls), 1) self.assertEqual(calls[0], marker) def from_database(self, row): return [int(item)+1 for item in row] def test_wb_result_get_one_goes_through_from_database(self): result = self.connection.execute("SELECT one, two FROM number") result.from_database = self.from_database self.assertEqual(result.get_one(), (2, 3)) def test_wb_result_get_all_goes_through_from_database(self): result = self.connection.execute("SELECT one, two FROM number") result.from_database = self.from_database self.assertEqual(result.get_all(), [(2, 3)]) def test_wb_result_iter_goes_through_from_database(self): result = self.connection.execute("SELECT one, two FROM number") result.from_database = self.from_database self.assertEqual(next(iter(result)), (2, 3)) def test_rowcount_insert(self): # All supported backends support rowcount, so far. result = self.connection.execute( "INSERT INTO test VALUES (999, '999')") self.assertEqual(result.rowcount, 1) def test_rowcount_delete(self): # All supported backends support rowcount, so far. result = self.connection.execute("DELETE FROM test") self.assertEqual(result.rowcount, 2) def test_rowcount_update(self): # All supported backends support rowcount, so far. result = self.connection.execute( "UPDATE test SET title='whatever'") self.assertEqual(result.rowcount, 2) def test_expr_startswith(self): self.connection.execute("INSERT INTO test VALUES (30, '!!_%blah')") self.connection.execute("INSERT INTO test VALUES (40, '!!blah')") id = Column("id", SQLToken("test")) title = Column("title", SQLToken("test")) expr = Select(id, title.startswith("!!_%")) result = list(self.connection.execute(expr)) self.assertEqual(result, [(30,)]) def test_expr_endswith(self): self.connection.execute("INSERT INTO test VALUES (30, 'blah_%!!')") self.connection.execute("INSERT INTO test VALUES (40, 'blah!!')") id = Column("id", SQLToken("test")) title = Column("title", SQLToken("test")) expr = Select(id, title.endswith("_%!!")) result = list(self.connection.execute(expr)) self.assertEqual(result, [(30,)]) def test_expr_contains_string(self): self.connection.execute("INSERT INTO test VALUES (30, 'blah_%!!x')") self.connection.execute("INSERT INTO test VALUES (40, 'blah!!x')") id = Column("id", SQLToken("test")) title = Column("title", SQLToken("test")) expr = Select(id, title.contains_string("_%!!")) result = list(self.connection.execute(expr)) self.assertEqual(result, [(30,)]) def test_block_access(self): """Access to the connection is blocked by block_access().""" self.connection.execute("SELECT 1") self.connection.block_access() self.assertRaises(ConnectionBlockedError, self.connection.execute, "SELECT 1") self.assertRaises(ConnectionBlockedError, self.connection.commit) # Allow rolling back a blocked connection. self.connection.rollback() # Unblock the connection, allowing access again. self.connection.unblock_access() self.connection.execute("SELECT 1") def test_wrap_exception_subclasses(self): """Subclasses of the generic DB-API exception types are wrapped.""" db_api_operational_error = getattr( self.database._exception_module, 'OperationalError') operational_error_types = [ type(name, (db_api_operational_error,), {}) for name in ('A', 'B')] for error_type in operational_error_types: error = error_type('error message') wrapped = self.database._wrap_exception(OperationalError, error) self.assertTrue(isinstance(wrapped, error_type)) self.assertTrue(isinstance(wrapped, OperationalError)) self.assertEqual(error_type.__name__, wrapped.__class__.__name__) self.assertEqual(('error message',), wrapped.args) class TwoPhaseCommitTest: def setUp(self): super().setUp() self.create_database() self.create_connection() self.drop_tables() self.create_tables() def tearDown(self): self.drop_tables() self.drop_connection() super().tearDown() def create_database(self): raise NotImplementedError def create_connection(self): self.connection = self.database.connect() def create_tables(self): raise NotImplementedError def drop_tables(self): try: self.connection.execute("DROP TABLE test") self.connection.commit() except: self.connection.rollback() def drop_connection(self): self.connection.close() def test_begin(self): """ begin() starts a transaction that can be ended with a two-phase commit. """ xid = Xid(0, "foo", "bar") self.connection.begin(xid) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.prepare() self.connection.commit() self.connection.rollback() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertTrue(result.get_one()) def test_begin_inside_a_two_phase_transaction(self): """ begin() can't be used if a two-phase transaction has already started. """ xid1 = Xid(0, "foo", "bar") self.connection.begin(xid1) xid2 = Xid(1, "egg", "baz") self.assertRaises(ProgrammingError, self.connection.begin, xid2) def test_begin_after_commit(self): """ After a two phase commit, it's possible to start a new transaction. """ xid = Xid(0, "foo", "bar") self.connection.begin(xid) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.commit() self.connection.begin(xid) result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertTrue(result.get_one()) def test_begin_after_rollback(self): """ After a tpc rollback, it's possible to start a new transaction. """ xid = Xid(0, "foo", "bar") self.connection.begin(xid) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.rollback() self.connection.begin(xid) result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertFalse(result.get_one()) def test_prepare_outside_a_two_phase_transaction(self): """ prepare() can't be used if a two-phase transaction has not begun yet. """ self.assertRaises(ProgrammingError, self.connection.prepare) def test_rollback_after_prepare(self): """ Calling rollback() after prepare() actually rolls back the changes. """ xid = Xid(0, "foo", "bar") self.connection.begin(xid) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.prepare() self.connection.rollback() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertFalse(result.get_one()) def test_mixing_standard_and_two_phase_commits(self): """ It's possible to mix standard and two phase commits across different transactions. """ self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.commit() xid = Xid(0, "foo", "bar") self.connection.begin(xid) self.connection.execute("INSERT INTO test VALUES (40, 'Title 40')") self.connection.prepare() self.connection.commit() result = self.connection.execute("SELECT id FROM test " "WHERE id IN (30, 40)") self.assertEqual([(30,), (40,)], result.get_all()) def test_recover_and_commit(self): """ It's possible to recover and commit pending transactions that were prepared but not committed or rolled back. """ # Prepare a transaction but leave it uncommitted self.connection.begin(Xid(0, "foo", "bar")) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.prepare() # Setup a new connection and recover the prepared transaction # committing it connection2 = self.database.connect() self.addCleanup(connection2.close) result = connection2.execute("SELECT id FROM test WHERE id=30") connection2.rollback() self.assertFalse(result.get_one()) [xid] = connection2.recover() self.assertEqual(0, xid.format_id) self.assertEqual("foo", xid.global_transaction_id) self.assertEqual("bar", xid.branch_qualifier) connection2.commit(xid) self.assertEqual([], connection2.recover()) # Reconnect, changes are be visible self.connection.close() self.connection = self.database.connect() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertTrue(result.get_one()) def test_recover_and_rollback(self): """ It's possible to recover and rollback pending transactions that were prepared but not committed or rolled back. """ # Prepare a transaction but leave it uncommitted self.connection.begin(Xid(0, "foo", "bar")) self.connection.execute("INSERT INTO test VALUES (30, 'Title 30')") self.connection.prepare() # Setup a new connection and recover the prepared transaction # rolling it back connection2 = self.database.connect() self.addCleanup(connection2.close) [xid] = connection2.recover() self.assertEqual(0, xid.format_id) self.assertEqual("foo", xid.global_transaction_id) self.assertEqual("bar", xid.branch_qualifier) connection2.rollback(xid) self.assertEqual([], connection2.recover()) # Reconnect, changes were rolled back self.connection.close() self.connection = self.database.connect() result = self.connection.execute("SELECT id FROM test WHERE id=30") self.assertFalse(result.get_one()) class UnsupportedDatabaseTest: helpers = [MakePath] dbapi_module_names = [] db_module_name = None def test_exception_when_unsupported(self): # Install a directory in front of the search path. module_dir = self.make_path() os.mkdir(module_dir) sys.path.insert(0, module_dir) # Copy the real module over to a new place, since the old one is # already using the real module, if it's available. db_module = __import__("storm.databases."+self.db_module_name, None, None, [""]) db_module_filename = db_module.__file__ if db_module_filename.endswith(".pyc"): db_module_filename = db_module_filename[:-1] shutil.copyfile(db_module_filename, os.path.join(module_dir, "_fake_.py")) dbapi_modules = {} for dbapi_module_name in self.dbapi_module_names: # If the real module is available, remove it from sys.modules. dbapi_module = sys.modules.pop(dbapi_module_name, None) if dbapi_module is not None: dbapi_modules[dbapi_module_name] = dbapi_module # Create a module which raises ImportError when imported, to fake # a missing module. dirname = self.make_path(path=os.path.join(module_dir, dbapi_module_name)) os.mkdir(dirname) self.make_path("raise ImportError", os.path.join(module_dir, dbapi_module_name, "__init__.py")) # Finally, test it. import _fake_ uri = URI("_fake_://db") try: self.assertRaises(DatabaseModuleError, _fake_.create_from_uri, uri) finally: # Unhack the environment. del sys.path[0] del sys.modules["_fake_"] sys.modules.update(dbapi_modules) class DatabaseDisconnectionMixin: environment_variable = "" host_environment_variable = "" default_port = None def setUp(self): super().setUp() self.create_database_and_proxy() self.create_connection() def tearDown(self): self.drop_connection() self.drop_database() self.proxy.close() super().tearDown() def is_supported(self): return bool(self.get_uri()) def get_uri(self): """Return URI instance with a defined host (and port, for TCP).""" if not self.environment_variable: raise RuntimeError( "Define at least %s.environment_variable" % type(self).__name__) uri_str = os.environ.get(self.host_environment_variable) if uri_str: uri = URI(uri_str) if not uri.host: raise RuntimeError("The URI in %s must include a host." % self.host_environment_variable) if not uri.host.startswith("/") and not uri.port: if not self.default_port: raise RuntimeError( "Define at least %s.default_port" % type(self).__name) uri.port = self.default_port return uri else: uri_str = os.environ.get(self.environment_variable) if uri_str: uri = URI(uri_str) if uri.host: if not uri.host.startswith("/") and not uri.port: if not self.default_port: raise RuntimeError( "Define at least %s.default_port" % type(self).__name) uri.port = self.default_port return uri return None def create_proxy(self, uri): """Create a TCP proxy forwarding requests to `uri`.""" return ProxyTCPServer((uri.host, uri.port)) def create_database_and_proxy(self): """Set up the TCP proxy and database object. The TCP proxy should forward requests on to the database. The database object should point at the TCP proxy. """ uri = self.get_uri() self.proxy = self.create_proxy(uri) uri.host, uri.port = self.proxy.server_address self.proxy_uri = uri self.database = create_database(uri) def create_connection(self): self.connection = self.database.connect() def drop_connection(self): self.connection.close() def drop_database(self): pass class DatabaseDisconnectionTest(DatabaseDisconnectionMixin): def test_proxy_works(self): """Ensure that we can talk to the database through the proxy.""" result = self.connection.execute("SELECT 1") self.assertEqual(result.get_one(), (1,)) def test_catch_disconnect_on_execute(self): """Test that database disconnections get caught on execute().""" result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") def test_catch_disconnect_on_commit(self): """Test that database disconnections get caught on commit().""" result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() self.assertRaises(DisconnectionError, self.connection.commit) def test_wb_catch_already_disconnected_on_rollback(self): """Connection.rollback() swallows disconnection errors. If the connection is being used outside of Storm's control, then it is possible that Storm won't see the disconnection. It should be able to recover from this situation though. """ result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() # Perform an action that should result in a disconnection. try: cursor = self.connection._raw_connection.cursor() cursor.execute("SELECT 1") cursor.fetchone() except Error as exc: self.assertTrue(self.connection.is_disconnection_error(exc)) else: self.fail("Disconnection was not caught.") # Make sure our raw connection's rollback() raises a disconnection # error when called. try: self.connection._raw_connection.rollback() except Error as exc: self.assertTrue(self.connection.is_disconnection_error(exc)) else: self.fail("Disconnection was not raised.") # Our rollback() will catch and swallow that disconnection error, # though. self.connection.rollback() def test_wb_catch_already_disconnected(self): """Storm detects connections that have already been disconnected. If the connection is being used outside of Storm's control, then it is possible that Storm won't see the disconnection. It should be able to recover from this situation though. """ result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() # Perform an action that should result in a disconnection. try: cursor = self.connection._raw_connection.cursor() cursor.execute("SELECT 1") cursor.fetchone() except DatabaseError as exc: self.assertTrue(self.connection.is_disconnection_error(exc)) else: self.fail("Disconnection was not caught.") self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") def test_connection_stays_disconnected_in_transaction(self): """Test that connection does not immediately reconnect.""" result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") def test_reconnect_after_rollback(self): """Test that we reconnect after rolling back the connection.""" result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.restart() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") self.connection.rollback() result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) def test_catch_disconnect_on_reconnect(self): """Test that reconnection failures result in DisconnectionError.""" result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.stop() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") # Rollback the connection, but because the proxy is still # down, we get a DisconnectionError again. self.connection.rollback() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") def test_close_connection_after_disconnect(self): result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) self.proxy.stop() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") self.connection.close() class TwoPhaseCommitDisconnectionTest: def test_begin_after_rollback_with_disconnection_error(self): """ If a rollback fails because of a disconnection error, the two-phase transaction should be properly reset. """ xid1 = Xid(0, "foo", "bar") self.connection.begin(xid1) self.connection.execute("SELECT 1") self.proxy.stop() self.connection.rollback() self.proxy.start() xid2 = Xid(0, "egg", "baz") self.connection.begin(xid2) result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) def test_begin_after_with_statement_disconnection_error_and_rollback(self): """ The two-phase transaction state is properly reset if a disconnection happens before the rollback. """ xid1 = Xid(0, "foo", "bar") self.connection.begin(xid1) self.proxy.close() self.assertRaises(DisconnectionError, self.connection.execute, "SELECT 1") self.connection.rollback() self.proxy.start() xid2 = Xid(0, "egg", "baz") self.connection.begin(xid2) result = self.connection.execute("SELECT 1") self.assertTrue(result.get_one()) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/databases/mysql.py0000644000175000017500000001557414645174376021031 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os from urllib.parse import urlunsplit from storm.databases.mysql import MySQL from storm.database import create_database from storm.expr import Column, Insert from storm.uri import URI from storm.variables import IntVariable, UnicodeVariable from storm.tests.databases.base import ( DatabaseTest, DatabaseDisconnectionTest, UnsupportedDatabaseTest) from storm.tests.databases.proxy import ProxyTCPServer from storm.tests.helper import TestHelper def create_proxy_and_uri(uri): """Create a TCP proxy to a Unix-domain database identified by `uri`.""" proxy = ProxyTCPServer(uri.options["unix_socket"]) proxy_host, proxy_port = proxy.server_address proxy_uri = URI(urlunsplit( ("mysql", "%s:%s" % (proxy_host, proxy_port), "/storm_test", "", ""))) return proxy, proxy_uri class MySQLTest(DatabaseTest, TestHelper): supports_microseconds = False def is_supported(self): return bool(os.environ.get("STORM_MYSQL_URI")) def create_database(self): self.database = create_database(os.environ["STORM_MYSQL_URI"]) def create_tables(self): self.connection.execute("CREATE TABLE number " "(one INTEGER, two INTEGER, three INTEGER)") self.connection.execute("CREATE TABLE test " "(id INT AUTO_INCREMENT PRIMARY KEY," " title VARCHAR(50)) ENGINE=InnoDB") self.connection.execute("CREATE TABLE datetime_test " "(id INT AUTO_INCREMENT PRIMARY KEY," " dt TIMESTAMP, d DATE, t TIME, td TEXT) " "ENGINE=InnoDB") self.connection.execute("CREATE TABLE bin_test " "(id INT AUTO_INCREMENT PRIMARY KEY," " b BLOB) ENGINE=InnoDB") def test_wb_create_database(self): database = create_database("mysql://un:pw@ht:12/db?unix_socket=us") self.assertTrue(isinstance(database, MySQL)) for key, value in [("db", "db"), ("host", "ht"), ("port", 12), ("user", "un"), ("passwd", "pw"), ("unix_socket", "us")]: self.assertEqual(database._connect_kwargs.get(key), value) def test_charset_defaults_to_utf8mb3(self): result = self.connection.execute("SELECT @@character_set_client") self.assertEqual(result.get_one(), ("utf8mb3",)) def test_charset_option(self): uri = URI(os.environ["STORM_MYSQL_URI"]) uri.options["charset"] = "ascii" database = create_database(uri) connection = database.connect() result = connection.execute("SELECT @@character_set_client") self.assertEqual(result.get_one(), ("ascii",)) def test_get_insert_identity(self): # Primary keys are filled in during execute() for MySQL pass def test_get_insert_identity_composed(self): # Primary keys are filled in during execute() for MySQL pass def test_execute_insert_auto_increment_primary_key(self): id_column = Column("id", "test") id_variable = IntVariable() title_column = Column("title", "test") title_variable = UnicodeVariable("testing") # This is not part of the table. It is just used to show that # only one primary key variable is set from the insert ID. dummy_column = Column("dummy", "test") dummy_variable = IntVariable() insert = Insert({title_column: title_variable}, primary_columns=(id_column, dummy_column), primary_variables=(id_variable, dummy_variable)) self.connection.execute(insert) self.assertTrue(id_variable.is_defined()) self.assertFalse(dummy_variable.is_defined()) # The newly inserted row should have the maximum id value for # the table. result = self.connection.execute("SELECT MAX(id) FROM test") self.assertEqual(result.get_one()[0], id_variable.get()) def test_mysql_specific_reserved_words(self): reserved_words = """ accessible analyze asensitive before bigint binary blob call change condition current_user database databases day_hour day_microsecond day_minute day_second delayed deterministic distinctrow div dual each elseif enclosed escaped exit explain float4 float8 force fulltext high_priority hour_microsecond hour_minute hour_second if ignore index infile inout int1 int2 int3 int4 int8 iterate keys kill leave limit linear lines load localtime localtimestamp lock long longblob longtext loop low_priority master_ssl_verify_server_cert mediumblob mediumint mediumtext middleint minute_microsecond minute_second mod modifies no_write_to_binlog optimize optionally out outfile purge range read_write reads regexp release rename repeat replace require return rlike schemas second_microsecond sensitive separator show spatial specific sql_big_result sql_calc_found_rows sql_small_result sqlexception sqlwarning ssl starting straight_join terminated tinyblob tinyint tinytext trigger undo unlock unsigned use utc_date utc_time utc_timestamp varbinary varcharacter while xor year_month zerofill """.split() for word in reserved_words: self.assertTrue(self.connection.compile.is_reserved_word(word), "Word missing: %s" % (word,)) class MySQLUnsupportedTest(UnsupportedDatabaseTest, TestHelper): dbapi_module_names = ["MySQLdb"] db_module_name = "mysql" class MySQLDisconnectionTest(DatabaseDisconnectionTest, TestHelper): environment_variable = "STORM_MYSQL_URI" host_environment_variable = "STORM_MYSQL_HOST_URI" default_port = 3306 def create_proxy(self, uri): """See `DatabaseDisconnectionMixin.create_proxy`.""" if "unix_socket" in uri.options: return create_proxy_and_uri(uri)[0] else: return super().create_proxy(uri) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/databases/postgres.py0000644000175000017500000010656314645174376021531 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import date, time, timedelta import os import json from urllib.parse import urlunsplit from storm.databases.postgres import ( Postgres, compile, currval, Returning, Case, PostgresTimeoutTracer, make_dsn, JSONElement, JSONTextElement, JSON) from storm.database import create_database from storm.store import Store from storm.exceptions import InterfaceError, ProgrammingError from storm.variables import DateTimeVariable, BytesVariable from storm.variables import ListVariable, IntVariable, Variable from storm.properties import Int from storm.exceptions import DisconnectionError, OperationalError from storm.expr import (Union, Select, Insert, Update, Alias, SQLRaw, State, Sequence, Like, Column, COLUMN, Cast, Func) from storm.tracer import install_tracer, TimeoutError from storm.uri import URI # We need the info to register the 'type' compiler. In normal # circumstances this is naturally imported. import storm.info storm # Silence lint. from storm.tests import has_fixtures, has_subunit from storm.tests.databases.base import ( DatabaseTest, DatabaseDisconnectionTest, UnsupportedDatabaseTest, TwoPhaseCommitTest, TwoPhaseCommitDisconnectionTest) from storm.tests.databases.proxy import ProxyTCPServer from storm.tests.expr import column1, column2, column3, elem1, table1, TrackContext from storm.tests.tracer import TimeoutTracerTestBase from storm.tests.helper import TestHelper try: import pgbouncer except ImportError: has_pgbouncer = False else: has_pgbouncer = True def create_proxy_and_uri(uri): """Create a TCP proxy to a Unix-domain database identified by `uri`.""" proxy = ProxyTCPServer(os.path.join(uri.host, ".s.PGSQL.5432")) proxy_host, proxy_port = proxy.server_address proxy_uri = URI(urlunsplit( ("postgres", "%s:%s" % (proxy_host, proxy_port), "/storm_test", "", ""))) return proxy, proxy_uri def terminate_other_backends(connection): """Terminate all connections to the database except the one given.""" pid_column = "procpid" if connection._database._version < 90200 else "pid" connection.execute( "SELECT pg_terminate_backend(%(pid_column)s)" " FROM pg_stat_activity" " WHERE datname = current_database()" " AND %(pid_column)s != pg_backend_pid()" % {"pid_column": pid_column}) def terminate_all_backends(database): """Terminate all connections to the given database.""" connection = database.connect() terminate_other_backends(connection) connection.close() class PostgresTest(DatabaseTest, TestHelper): def is_supported(self): return bool(os.environ.get("STORM_POSTGRES_URI")) def create_database(self): self.database = create_database(os.environ["STORM_POSTGRES_URI"]) def create_tables(self): self.connection.execute("CREATE TABLE number " "(one INTEGER, two INTEGER, three INTEGER)") self.connection.execute("CREATE TABLE test " "(id SERIAL PRIMARY KEY, title VARCHAR)") self.connection.execute("CREATE TABLE datetime_test " "(id SERIAL PRIMARY KEY," " dt TIMESTAMP, d DATE, t TIME, td INTERVAL)") self.connection.execute("CREATE TABLE bin_test " "(id SERIAL PRIMARY KEY, b BYTEA)") self.connection.execute("CREATE TABLE like_case_insensitive_test " "(id SERIAL PRIMARY KEY, description TEXT)") self.connection.execute("CREATE TABLE returning_test " "(id1 INTEGER DEFAULT 123, " " id2 INTEGER DEFAULT 456)") self.connection.execute("CREATE TABLE json_test " "(id SERIAL PRIMARY KEY, " " json JSON)") def drop_tables(self): super().drop_tables() tables = ("like_case_insensitive_test", "returning_test", "json_test") for table in tables: try: self.connection.execute("DROP TABLE %s" % table) self.connection.commit() except: self.connection.rollback() def create_sample_data(self): super().create_sample_data() self.connection.execute("INSERT INTO like_case_insensitive_test " "(description) VALUES ('hullah')") self.connection.execute("INSERT INTO like_case_insensitive_test " "(description) VALUES ('HULLAH')") self.connection.commit() def test_wb_create_database(self): database = create_database("postgres://un:pw@ht:12/db") self.assertTrue(isinstance(database, Postgres)) self.assertEqual(database._dsn, "dbname=db host=ht port=12 user=un password=pw") def test_wb_version(self): version = self.database._version self.assertEqual(type(version), int) try: result = self.connection.execute("SHOW server_version_num") except ProgrammingError: self.assertEqual(version, 0) else: server_version = int(result.get_one()[0]) self.assertEqual(version, server_version) def test_utf8_client_encoding(self): connection = self.database.connect() result = connection.execute("SHOW client_encoding") encoding = result.get_one()[0] self.assertEqual(encoding.upper(), "UTF8") def test_unicode(self): raw_str = b"\xc3\xa1\xc3\xa9\xc3\xad\xc3\xb3\xc3\xba" uni_str = raw_str.decode("UTF-8") connection = self.database.connect() connection.execute( (b"INSERT INTO test VALUES (1, '%s')" % raw_str).decode("UTF-8")) result = connection.execute("SELECT title FROM test WHERE id=1") title = result.get_one()[0] self.assertTrue(isinstance(title, str)) self.assertEqual(title, uni_str) def test_unicode_array(self): raw_str = b"\xc3\xa1\xc3\xa9\xc3\xad\xc3\xb3\xc3\xba" uni_str = raw_str.decode("UTF-8") connection = self.database.connect() result = connection.execute( (b"""SELECT '{"%s"}'::TEXT[]""" % raw_str).decode("UTF-8")) self.assertEqual(result.get_one()[0], [uni_str]) result = connection.execute("""SELECT ?::TEXT[]""", ([uni_str],)) self.assertEqual(result.get_one()[0], [uni_str]) def test_time(self): connection = self.database.connect() value = time(12, 34) result = connection.execute("SELECT ?::TIME", (value,)) self.assertEqual(result.get_one()[0], value) def test_date(self): connection = self.database.connect() value = date(2007, 6, 22) result = connection.execute("SELECT ?::DATE", (value,)) self.assertEqual(result.get_one()[0], value) def test_interval(self): connection = self.database.connect() value = timedelta(365) result = connection.execute("SELECT ?::INTERVAL", (value,)) self.assertEqual(result.get_one()[0], value) def test_datetime_with_none(self): self.connection.execute("INSERT INTO datetime_test (dt) VALUES (NULL)") result = self.connection.execute("SELECT dt FROM datetime_test") variable = DateTimeVariable() result.set_variable(variable, result.get_one()[0]) self.assertEqual(variable.get(), None) def test_array_support(self): try: self.connection.execute("DROP TABLE array_test") self.connection.commit() except: self.connection.rollback() self.connection.execute("CREATE TABLE array_test " "(id SERIAL PRIMARY KEY, a INT[])") variable = ListVariable(IntVariable) variable.set([1,2,3,4]) state = State() statement = compile(variable, state) self.connection.execute("INSERT INTO array_test VALUES (1, %s)" % statement, state.parameters) result = self.connection.execute("SELECT a FROM array_test WHERE id=1") array = result.get_one()[0] self.assertTrue(isinstance(array, list)) variable = ListVariable(IntVariable) result.set_variable(variable, array) self.assertEqual(variable.get(), [1,2,3,4]) def test_array_support_with_empty(self): try: self.connection.execute("DROP TABLE array_test") self.connection.commit() except: self.connection.rollback() self.connection.execute("CREATE TABLE array_test " "(id SERIAL PRIMARY KEY, a INT[])") variable = ListVariable(IntVariable) variable.set([]) state = State() statement = compile(variable, state) self.connection.execute("INSERT INTO array_test VALUES (1, %s)" % statement, state.parameters) result = self.connection.execute("SELECT a FROM array_test WHERE id=1") array = result.get_one()[0] self.assertTrue(isinstance(array, list)) variable = ListVariable(IntVariable) result.set_variable(variable, array) self.assertEqual(variable.get(), []) def test_expressions_in_union_order_by(self): # The following statement breaks in postgres: # SELECT 1 AS id UNION SELECT 1 ORDER BY id+1; # With the error: # ORDER BY on a UNION/INTERSECT/EXCEPT result must # be on one of the result columns column = SQLRaw("1") Alias.auto_counter = 0 alias = Alias(column, "id") expr = Union(Select(alias), Select(column), order_by=alias+1, limit=1, offset=1, all=True) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT * FROM ' '((SELECT 1 AS id) UNION ALL (SELECT 1)) AS "_1" ' 'ORDER BY id + ? LIMIT 1 OFFSET 1') self.assertVariablesEqual(state.parameters, [Variable(1)]) result = self.connection.execute(expr) self.assertEqual(result.get_one(), (1,)) def test_expressions_in_union_in_union_order_by(self): column = SQLRaw("1") alias = Alias(column, "id") expr = Union(Select(alias), Select(column), order_by=alias+1, limit=1, offset=1, all=True) expr = Union(expr, expr, order_by=alias+1, all=True) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,), (1,)]) def test_sequence(self): expr1 = Select(Sequence("test_id_seq")) expr2 = "SELECT currval('test_id_seq')" value1 = self.connection.execute(expr1).get_one()[0] value2 = self.connection.execute(expr2).get_one()[0] value3 = self.connection.execute(expr1).get_one()[0] self.assertEqual(value1, value2) self.assertEqual(value3-value1, 1) def test_like_case(self): expr = Like("name", "value") statement = compile(expr) self.assertEqual(statement, "? LIKE ?") expr = Like("name", "value", case_sensitive=True) statement = compile(expr) self.assertEqual(statement, "? LIKE ?") expr = Like("name", "value", case_sensitive=False) statement = compile(expr) self.assertEqual(statement, "? ILIKE ?") def test_case_default_like(self): like = Like(SQLRaw("description"), "%hullah%") expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,)]) like = Like(SQLRaw("description"), "%HULLAH%") expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(2,)]) def test_case_sensitive_like(self): like = Like(SQLRaw("description"), "%hullah%", case_sensitive=True) expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,)]) like = Like(SQLRaw("description"), "%HULLAH%", case_sensitive=True) expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(2,)]) def test_case_insensitive_like(self): like = Like(SQLRaw("description"), "%hullah%", case_sensitive=False) expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,), (2,)]) like = Like(SQLRaw("description"), "%HULLAH%", case_sensitive=False) expr = Select(SQLRaw("id"), like, tables=["like_case_insensitive_test"]) result = self.connection.execute(expr) self.assertEqual(result.get_all(), [(1,), (2,)]) def test_none_on_string_variable(self): """ Verify that the logic to enforce fix E''-styled strings isn't breaking on NULL values. """ variable = BytesVariable(value=None) result = self.connection.execute(Select(variable)) self.assertEqual(result.get_one(), (None,)) def test_compile_table_with_schema(self): class Foo: __storm_table__ = "my schema.my table" id = Int("my.column", primary=True) self.assertEqual(compile(Select(Foo.id)), 'SELECT "my schema"."my table"."my.column" ' 'FROM "my schema"."my table"') def test_compile_case(self): """The Case expr is compiled in a Postgres' CASE expression.""" cases = [ (Column("foo") > 3, "big"), (Column("bar") == None, 4)] state = State() statement = compile(Case(cases), state) self.assertEqual( "CASE WHEN (foo > ?) THEN ? WHEN (bar IS NULL) THEN ? END", statement) self.assertEqual( [3, "big", 4], [param.get() for param in state.parameters]) def test_compile_case_with_default(self): """ If a default is provided, the resulting CASE expression includes an ELSE clause. """ cases = [(Column("foo") > 3, "big")] state = State() statement = compile(Case(cases, default=9), state) self.assertEqual( "CASE WHEN (foo > ?) THEN ? ELSE ? END", statement) self.assertEqual( [3, "big", 9], [param.get() for param in state.parameters]) def test_compile_case_with_expression(self): """ If an expression is provided, the resulting CASE expression uses the simple syntax. """ cases = [(1, "one"), (2, "two")] state = State() statement = compile(Case(cases, expression=Column("foo")), state) self.assertEqual( "CASE foo WHEN ? THEN ? WHEN ? THEN ? END", statement) self.assertEqual( [1, "one", 2, "two"], [param.get() for param in state.parameters]) def test_currval_no_escaping(self): expr = currval(Column("thecolumn", "theschema.thetable")) statement = compile(expr) expected = """currval('theschema.thetable_thecolumn_seq')""" self.assertEqual(statement, expected) def test_currval_escaped_schema(self): expr = currval(Column("thecolumn", "the schema.thetable")) statement = compile(expr) expected = """currval('"the schema".thetable_thecolumn_seq')""" self.assertEqual(statement, expected) def test_currval_escaped_table(self): expr = currval(Column("thecolumn", "theschema.the table")) statement = compile(expr) expected = """currval('theschema."the table_thecolumn_seq"')""" self.assertEqual(statement, expected) def test_currval_escaped_column(self): expr = currval(Column("the column", "theschema.thetable")) statement = compile(expr) expected = """currval('theschema."thetable_the column_seq"')""" self.assertEqual(statement, expected) def test_currval_escaped_column_no_schema(self): expr = currval(Column("the column", "thetable")) statement = compile(expr) expected = """currval('"thetable_the column_seq"')""" self.assertEqual(statement, expected) def test_currval_escaped_schema_table_and_column(self): expr = currval(Column("the column", "the schema.the table")) statement = compile(expr) expected = """currval('"the schema"."the table_the column_seq"')""" self.assertEqual(statement, expected) def test_get_insert_identity(self): column = Column("thecolumn", "thetable") variable = IntVariable() result = self.connection.execute("SELECT 1") where = result.get_insert_identity((column,), (variable,)) self.assertEqual(compile(where), "thetable.thecolumn = " "(SELECT currval('thetable_thecolumn_seq'))") def test_returning_column_context(self): column2 = TrackContext() insert = Insert({column1: elem1}, table1, primary_columns=column2) compile(Returning(insert)) self.assertEqual(column2.context, COLUMN) def test_returning_update(self): update = Update({column1: elem1}, table=table1, primary_columns=(column2, column3)) self.assertEqual(compile(Returning(update)), 'UPDATE "table 1" SET column1=elem1 ' 'RETURNING column2, column3') def test_returning_update_with_columns(self): update = Update({column1: elem1}, table=table1, primary_columns=(column2, column3)) self.assertEqual(compile(Returning(update, columns=[column3])), 'UPDATE "table 1" SET column1=elem1 ' 'RETURNING column3') def test_execute_insert_returning(self): if self.database._version < 80200: return # Can't run this test with old PostgreSQL versions. column1 = Column("id1", "returning_test") column2 = Column("id2", "returning_test") variable1 = IntVariable() variable2 = IntVariable() insert = Insert({}, primary_columns=(column1, column2), primary_variables=(variable1, variable2)) self.connection.execute(insert) self.assertTrue(variable1.is_defined()) self.assertTrue(variable2.is_defined()) self.assertEqual(variable1.get(), 123) self.assertEqual(variable2.get(), 456) result = self.connection.execute("SELECT * FROM returning_test") self.assertEqual(result.get_one(), (123, 456)) def test_wb_execute_insert_returning_not_used_with_old_postgres(self): """Shouldn't try to use RETURNING with PostgreSQL < 8.2.""" column1 = Column("id1", "returning_test") column2 = Column("id2", "returning_test") variable1 = IntVariable() variable2 = IntVariable() insert = Insert({}, primary_columns=(column1, column2), primary_variables=(variable1, variable2)) self.database._version = 80109 self.connection.execute(insert) self.assertFalse(variable1.is_defined()) self.assertFalse(variable2.is_defined()) result = self.connection.execute("SELECT * FROM returning_test") self.assertEqual(result.get_one(), (123, 456)) def test_execute_insert_returning_without_columns(self): """Without primary_columns, the RETURNING system won't be used.""" column1 = Column("id1", "returning_test") variable1 = IntVariable() insert = Insert({column1: 123}, primary_variables=(variable1,)) self.connection.execute(insert) self.assertFalse(variable1.is_defined()) result = self.connection.execute("SELECT * FROM returning_test") self.assertEqual(result.get_one(), (123, 456)) def test_execute_insert_returning_without_variables(self): """Without primary_variables, the RETURNING system won't be used.""" column1 = Column("id1", "returning_test") insert = Insert({}, primary_columns=(column1,)) self.connection.execute(insert) result = self.connection.execute("SELECT * FROM returning_test") self.assertEqual(result.get_one(), (123, 456)) def test_execute_update_returning(self): if self.database._version < 80200: return # Can't run this test with old PostgreSQL versions. column1 = Column("id1", "returning_test") column2 = Column("id2", "returning_test") self.connection.execute( "INSERT INTO returning_test VALUES (1, 2)") update = Update({"id2": 3}, column1 == 1, primary_columns=(column1, column2)) result = self.connection.execute(Returning(update)) self.assertEqual(result.get_one(), (1, 3)) def test_isolation_autocommit(self): database = create_database( os.environ["STORM_POSTGRES_URI"] + "?isolation=autocommit") connection = database.connect() self.addCleanup(connection.close) result = connection.execute("SHOW TRANSACTION ISOLATION LEVEL") # It matches read committed in Postgres internel self.assertEqual(result.get_one()[0], "read committed") connection.execute("INSERT INTO bin_test VALUES (1, 'foo')") result = self.connection.execute("SELECT id FROM bin_test") # I didn't commit, but data should already be there self.assertEqual(result.get_all(), [(1,)]) connection.rollback() def test_isolation_read_committed(self): database = create_database( os.environ["STORM_POSTGRES_URI"] + "?isolation=read-committed") connection = database.connect() self.addCleanup(connection.close) result = connection.execute("SHOW TRANSACTION ISOLATION LEVEL") self.assertEqual(result.get_one()[0], "read committed") connection.execute("INSERT INTO bin_test VALUES (1, 'foo')") result = self.connection.execute("SELECT id FROM bin_test") # Data should not be there already self.assertEqual(result.get_all(), []) connection.rollback() # Start a transaction result = connection.execute("SELECT 1") self.assertEqual(result.get_one(), (1,)) self.connection.execute("INSERT INTO bin_test VALUES (1, 'foo')") self.connection.commit() result = connection.execute("SELECT id FROM bin_test") # Data is already here! self.assertEqual(result.get_one(), (1,)) connection.rollback() def test_isolation_serializable(self): database = create_database( os.environ["STORM_POSTGRES_URI"] + "?isolation=serializable") connection = database.connect() self.addCleanup(connection.close) result = connection.execute("SHOW TRANSACTION ISOLATION LEVEL") self.assertEqual(result.get_one()[0], "serializable") # Start a transaction result = connection.execute("SELECT 1") self.assertEqual(result.get_one(), (1,)) self.connection.execute("INSERT INTO bin_test VALUES (1, 'foo')") self.connection.commit() result = connection.execute("SELECT id FROM bin_test") # We can't see data yet, because transaction started before self.assertEqual(result.get_one(), None) connection.rollback() def test_default_isolation(self): """ The default isolation level is REPEATABLE READ, but it's only supported by psycopg2 2.4.2 and newer. Before, SERIALIZABLE is used instead. """ result = self.connection.execute("SHOW TRANSACTION ISOLATION LEVEL") import psycopg2 psycopg2_version = psycopg2.__version__.split(None, 1)[0] if psycopg2_version < "2.4.2": self.assertEqual(result.get_one()[0], "serializable") else: self.assertEqual(result.get_one()[0], "repeatable read") def test_unknown_serialization(self): self.assertRaises(ValueError, create_database, os.environ["STORM_POSTGRES_URI"] + "?isolation=stuff") def test_is_disconnection_error_with_ssl_syscall_error(self): """ If the underlying driver raises a ProgrammingError with 'SSL SYSCALL error', we consider the connection dead and mark it as needing reconnection. """ exc = ProgrammingError("SSL SYSCALL error: Connection timed out") self.assertTrue(self.connection.is_disconnection_error(exc)) def test_is_disconnection_error_with_could_not_send_data(self): """ If the underlying driver raises an OperationalError with 'could not send data to server', we consider the connection dead and mark it as needing reconnection. """ exc = OperationalError("could not send data to server") self.assertTrue(self.connection.is_disconnection_error(exc)) def test_is_disconnection_error_with_could_not_receive_data(self): """ If the underlying driver raises an OperationalError with 'could not receive data from server', we consider the connection dead and mark it as needing reconnection. """ exc = OperationalError("could not receive data from server") self.assertTrue(self.connection.is_disconnection_error(exc)) def test_json_element(self): "JSONElement returns an element from a json field." connection = self.database.connect() json_value = Cast('{"a": 1}', "json") expr = JSONElement(json_value, "a") # Need to cast as text since newer psycopg versions decode JSON # automatically. result = connection.execute(Select(Cast(expr, "text"))) self.assertEqual("1", result.get_one()[0]) result = connection.execute(Select(Func("pg_typeof", expr))) self.assertEqual("json", result.get_one()[0]) def test_json_text_element(self): "JSONTextElement returns an element from a json field as text." connection = self.database.connect() json_value = Cast('{"a": 1}', "json") expr = JSONTextElement(json_value, "a") result = connection.execute(Select(expr)) self.assertEqual("1", result.get_one()[0]) result = connection.execute(Select(Func("pg_typeof", expr))) self.assertEqual("text", result.get_one()[0]) def test_json_property(self): """The JSON property is encoded as JSON""" class TestModel: __storm_table__ = "json_test" id = Int(primary=True) json = JSON() connection = self.database.connect() value = {"a": 3, "b": "foo", "c": None} connection.execute( "INSERT INTO json_test (json) VALUES (?)", (json.dumps(value),)) connection.commit() store = Store(self.database) obj = store.find(TestModel).one() store.close() # The JSON object is decoded to python self.assertEqual(value, obj.json) _max_prepared_transactions = None class PostgresTwoPhaseCommitTest(TwoPhaseCommitTest, TestHelper): def is_supported(self): uri = os.environ.get("STORM_POSTGRES_URI") if not uri: return False global _max_prepared_transactions if _max_prepared_transactions is None: database = create_database(uri) connection = database.connect() result = connection.execute("SHOW MAX_PREPARED_TRANSACTIONS") _max_prepared_transactions = int(result.get_one()[0]) connection.close() return _max_prepared_transactions > 0 def create_database(self): self.database = create_database(os.environ["STORM_POSTGRES_URI"]) def create_tables(self): self.connection.execute("CREATE TABLE test " "(id SERIAL PRIMARY KEY, title VARCHAR)") self.connection.commit() class PostgresUnsupportedTest(UnsupportedDatabaseTest, TestHelper): dbapi_module_names = ["psycopg2"] db_module_name = "postgres" class PostgresDisconnectionTest(DatabaseDisconnectionTest, TwoPhaseCommitDisconnectionTest, TestHelper): environment_variable = "STORM_POSTGRES_URI" host_environment_variable = "STORM_POSTGRES_HOST_URI" default_port = 5432 def create_proxy(self, uri): """See `DatabaseDisconnectionMixin.create_proxy`.""" if uri.host.startswith("/"): return create_proxy_and_uri(uri)[0] else: return super().create_proxy(uri) def test_rollback_swallows_InterfaceError(self): """Test that InterfaceErrors get caught on rollback(). InterfaceErrors are a form of a disconnection error, so rollback() must swallow them and reconnect. """ class FakeConnection: def rollback(self): raise InterfaceError('connection already closed') self.connection._raw_connection = FakeConnection() try: self.connection.rollback() except Exception as exc: self.fail('Exception should have been swallowed: %s' % repr(exc)) class PostgresDisconnectionTestWithoutProxyBase: # DatabaseDisconnectionTest uses a socket proxy to simulate broken # connections. This class tests some other causes of disconnection. database_uri = None def is_supported(self): return bool(self.database_uri) and super().is_supported() def setUp(self): super().setUp() self.database = create_database(self.database_uri) def test_terminated_backend(self): # The error raised when trying to use a connection that has been # terminated at the server is considered a disconnection error. connection = self.database.connect() terminate_all_backends(self.database) self.assertRaises( DisconnectionError, connection.execute, "SELECT current_database()") if has_subunit: # Some of the following tests are prone to segfaults, presumably in # _psycopg.so. Run them in a subprocess if possible. from subunit import IsolatedTestCase class MisbehavingTestCase(TestHelper, IsolatedTestCase): pass else: # If we can't run them in a subprocess we still want to create tests, but # prevent them from running, so that the skip is reported class MisbehavingTestCase(TestHelper): def is_supported(self): return False class PostgresDisconnectionTestWithoutProxyUnixSockets( PostgresDisconnectionTestWithoutProxyBase, MisbehavingTestCase): """Disconnection tests using Unix sockets.""" database_uri = os.environ.get("STORM_POSTGRES_URI") class PostgresDisconnectionTestWithoutProxyTCPSockets( PostgresDisconnectionTestWithoutProxyBase, MisbehavingTestCase): """Disconnection tests using TCP sockets.""" database_uri = os.environ.get("STORM_POSTGRES_HOST_URI") def setUp(self): super().setUp() if self.database.get_uri().host.startswith("/"): proxy, proxy_uri = create_proxy_and_uri(self.database.get_uri()) self.addCleanup(proxy.close) self.database = create_database(proxy_uri) class PostgresDisconnectionTestWithPGBouncerBase: # Connecting via pgbouncer # introduces new possible causes of disconnections. def is_supported(self): return ( has_fixtures and has_pgbouncer and bool(os.environ.get("STORM_POSTGRES_HOST_URI"))) def setUp(self): super().setUp() database_uri = URI(os.environ["STORM_POSTGRES_HOST_URI"]) if database_uri.host.startswith("/"): proxy, database_uri = create_proxy_and_uri(database_uri) self.addCleanup(proxy.close) database_user = database_uri.username or os.environ['USER'] database_dsn = make_dsn(database_uri) # Create a pgbouncer fixture. self.pgbouncer = pgbouncer.fixture.PGBouncerFixture() self.pgbouncer.databases[database_uri.database] = database_dsn self.pgbouncer.users[database_user] = "trusted" self.pgbouncer.admin_users = [database_user] self.useFixture(self.pgbouncer) # Create a Database that uses pgbouncer. pgbouncer_uri = database_uri.copy() pgbouncer_uri.host = self.pgbouncer.host pgbouncer_uri.port = self.pgbouncer.port self.database = create_database(pgbouncer_uri) def test_terminated_backend(self): # The error raised when trying to use a connection through pgbouncer # that has been terminated at the server is considered a disconnection # error. connection = self.database.connect() terminate_all_backends(self.database) self.assertRaises( DisconnectionError, connection.execute, "SELECT current_database()") def test_pgbouncer_stopped(self): # The error raised from a connection that is no longer connected # because pgbouncer has been immediately shutdown (via SIGTERM; see # man 1 pgbouncer) is considered a disconnection error. connection = self.database.connect() self.pgbouncer.stop() self.assertRaises( DisconnectionError, connection.execute, "SELECT current_database()") if has_fixtures: # Upgrade to full test case class with fixtures. from fixtures import TestWithFixtures class PostgresDisconnectionTestWithPGBouncer( PostgresDisconnectionTestWithPGBouncerBase, TestWithFixtures, TestHelper): pass class PostgresTimeoutTracerTest(TimeoutTracerTestBase): tracer_class = PostgresTimeoutTracer def is_supported(self): return bool(os.environ.get("STORM_POSTGRES_URI")) def setUp(self): super().setUp() self.database = create_database(os.environ["STORM_POSTGRES_URI"]) self.connection = self.database.connect() install_tracer(self.tracer) self.tracer.get_remaining_time = lambda: self.remaining_time self.remaining_time = 10.5 def tearDown(self): self.connection.close() super().tearDown() def test_set_statement_timeout(self): result = self.connection.execute("SHOW statement_timeout") self.assertEqual(result.get_one(), ("10500ms",)) def test_connection_raw_execute_error(self): statement = "SELECT pg_sleep(0.5)" self.remaining_time = 0.001 try: self.connection.execute(statement) except TimeoutError as e: self.assertEqual("SQL server cancelled statement", e.message) self.assertEqual(statement, e.statement) self.assertEqual((), e.params) else: self.fail("TimeoutError not raised") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/databases/proxy.py0000644000175000017500000001111514645174376021030 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import errno import os import select import socket import socketserver import threading TIMEOUT = 0.1 class ProxyRequestHandler(socketserver.BaseRequestHandler): """A request handler that proxies traffic to another TCP port.""" def __init__(self, request, client_address, server): self._generation = server._generation socketserver.BaseRequestHandler.__init__( self, request, client_address, server) def handle(self): if isinstance(self.server.proxy_dest, (bytes, str)): dst = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) else: dst = socket.socket(socket.AF_INET, socket.SOCK_STREAM) dst.connect(self.server.proxy_dest) readers = [self.request, dst] while readers: rlist, wlist, xlist = select.select(readers, [], [], TIMEOUT) # If the server generation has been incremented, close the # connection. if self._generation != self.server._generation: return if self.request in rlist: chunk = os.read(self.request.fileno(), 1024) try: dst.send(chunk) except OSError as e: if e.errno == errno.EPIPE: return raise if chunk == "": readers.remove(self.request) dst.shutdown(socket.SHUT_WR) if dst in rlist: try: chunk = os.read(dst.fileno(), 1024) except OSError as e: if e.errno == errno.ECONNRESET: return raise self.request.send(chunk) if chunk == "": readers.remove(dst) self.request.shutdown(socket.SHUT_WR) class ProxyTCPServer(socketserver.ThreadingTCPServer): allow_reuse_address = True def __init__(self, proxy_dest): socketserver.ThreadingTCPServer.__init__( self, ("127.0.0.1", 0), ProxyRequestHandler) # Python 2.4 doesn't retrieve the socket details, so record # them here. We need to do this so we can recreate the socket # with the same address later. self.server_address = self.socket.getsockname() self.proxy_dest = proxy_dest self._start_lock = threading.Lock() self._thread = None self._generation = 0 self._running = False self.start() def __del__(self): self.close() def close(self): if self._running: self.stop() def start(self): assert not self._running, "Server should not be running" self._thread = threading.Thread(target=self._run) self._thread.setDaemon(True) self._running = True self._start_lock.acquire() self._thread.start() # Wait for server to start self._start_lock.acquire() self._start_lock.release() def _run(self): self.server_activate() self.socket.settimeout(TIMEOUT) self._start_lock.release() while self._running: try: self.handle_request() except socket.timeout: pass def stop(self): assert self._running, "Server should be running" # Increment server generation, and wait for thread to stop. self._generation += 1 self._running = False self._thread.join() # Recreate socket, to kill listen queue. As we've allowed # address reuse, this should work. self.socket.close() self.socket = socket.socket(self.address_family, self.socket_type) self.server_bind() def restart(self): self.stop() self.start() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/databases/sqlite.py0000644000175000017500000002361414645174376021157 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import timedelta import time import os from storm.exceptions import OperationalError from storm.databases.sqlite import SQLite from storm.database import create_database from storm.uri import URI from storm.tests.databases.base import DatabaseTest, UnsupportedDatabaseTest from storm.tests.helper import TestHelper, MakePath class SQLiteMemoryTest(DatabaseTest, TestHelper): helpers = [MakePath] def get_path(self): return "" def create_database(self): self.database = SQLite(URI("sqlite:%s?synchronous=OFF&timeout=0" % self.get_path())) def create_tables(self): self.connection.execute("CREATE TABLE number " "(one INTEGER, two INTEGER, three INTEGER)") self.connection.execute("CREATE TABLE test " "(id INTEGER PRIMARY KEY, title VARCHAR)") self.connection.execute("CREATE TABLE datetime_test " "(id INTEGER PRIMARY KEY," " dt TIMESTAMP, d DATE, t TIME, td INTERVAL)") self.connection.execute("CREATE TABLE bin_test " "(id INTEGER PRIMARY KEY, b BLOB)") def drop_tables(self): pass def test_wb_create_database(self): database = create_database("sqlite:") self.assertTrue(isinstance(database, SQLite)) self.assertEqual(database._filename, ":memory:") def test_concurrent_behavior(self): pass # We can't connect to the in-memory database twice, so we can't # exercise the concurrency behavior (nor it makes sense). def test_synchronous(self): synchronous_values = {"OFF": 0, "NORMAL": 1, "FULL": 2} for value in synchronous_values: database = SQLite(URI("sqlite:%s?synchronous=%s" % (self.get_path(), value))) connection = database.connect() result = connection.execute("PRAGMA synchronous") self.assertEqual(result.get_one()[0], synchronous_values[value]) def test_sqlite_specific_reserved_words(self): """Check sqlite-specific reserved words are recognized. This uses a list copied from http://www.sqlite.org/lang_keywords.html with the reserved words from SQL1992 removed. """ reserved_words = """ abort after analyze attach autoincrement before conflict database detach each exclusive explain fail glob if ignore index indexed instead isnull limit notnull offset plan pragma query raise regexp reindex release rename replace row savepoint temp trigger vacuum virtual """.split() for word in reserved_words: self.assertTrue(self.connection.compile.is_reserved_word(word), "Word missing: %s" % (word,)) class SQLiteFileTest(SQLiteMemoryTest): def get_path(self): return self.make_path() def test_wb_create_database(self): filename = self.make_path() database = create_database("sqlite:%s" % filename) self.assertTrue(isinstance(database, SQLite)) self.assertEqual(database._filename, filename) def test_timeout(self): database = create_database("sqlite:%s?timeout=0.3" % self.get_path()) connection1 = database.connect() connection2 = database.connect() connection1.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") connection1.commit() connection1.execute("INSERT INTO test VALUES (1)") started = time.time() try: connection2.execute("INSERT INTO test VALUES (2)") except OperationalError as exception: self.assertEqual(str(exception), "database is locked") self.assertTrue(time.time()-started >= 0.3) else: self.fail("OperationalError not raised") def test_commit_timeout(self): """Regression test for commit observing the timeout. In 0.10, the timeout wasn't observed for connection.commit(). """ # Create a database with a table. database = create_database("sqlite:%s?timeout=0.3" % self.get_path()) connection1 = database.connect() connection1.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") connection1.commit() # Put some data in, but also make a second connection to the database, # which will prevent a commit until it is closed. connection1.execute("INSERT INTO test VALUES (1)") connection2 = database.connect() connection2.execute("SELECT id FROM test") started = time.time() try: connection1.commit() except OperationalError as exception: self.assertEqual(str(exception), "database is locked") # In 0.10, the next assertion failed because the timeout wasn't # enforced for the "COMMIT" statement. self.assertTrue(time.time()-started >= 0.3) else: self.fail("OperationalError not raised") def test_recover_after_timeout(self): """Regression test for recovering from database locked exception. In 0.10, connection.commit() would forget that a transaction was in progress if an exception was raised, such as an OperationalError due to another connection being open. As a result, a subsequent modification to the database would cause BEGIN to be issued to the database, which would complain that a transaction was already in progress. """ # Create a database with a table. database = create_database("sqlite:%s?timeout=0.3" % self.get_path()) connection1 = database.connect() connection1.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") connection1.commit() # Put some data in, but also make a second connection to the database, # which will prevent a commit until it is closed. connection1.execute("INSERT INTO test VALUES (1)") connection2 = database.connect() connection2.execute("SELECT id FROM test") self.assertRaises(OperationalError, connection1.commit) # Close the second connection - it should now be possible to commit. connection2.close() # In 0.10, the next statement raised OperationalError: cannot start a # transaction within a transaction connection1.execute("INSERT INTO test VALUES (2)") connection1.commit() # Check that the correct data is present self.assertEqual(connection1.execute("SELECT id FROM test").get_all(), [(1,), (2,)]) def test_journal(self): journal_values = {"DELETE": 'delete', "TRUNCATE": 'truncate', "PERSIST": 'persist', "MEMORY": 'memory', "WAL": 'wal', "OFF": 'off'} for value in journal_values: database = SQLite(URI("sqlite:%s?journal_mode=%s" % (self.get_path(), value))) connection = database.connect() result = connection.execute("PRAGMA journal_mode").get_one()[0] self.assertEqual(result, journal_values[value]) def test_journal_persistency_to_rollback(self): journal_values = {"DELETE": 'delete', "TRUNCATE": 'truncate', "PERSIST": 'persist', "MEMORY": 'memory', "WAL": 'wal', "OFF": 'off'} for value in journal_values: database = SQLite(URI("sqlite:%s?journal_mode=%s" % (self.get_path(), value))) connection = database.connect() connection.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") connection.rollback() result = connection.execute("PRAGMA journal_mode").get_one()[0] self.assertEqual(result, journal_values[value]) def test_foreign_keys(self): foreign_keys_values = {"ON": 1, "OFF": 0} for value in foreign_keys_values: database = SQLite(URI("sqlite:%s?foreign_keys=%s" % (self.get_path(), value))) connection = database.connect() result = connection.execute("PRAGMA foreign_keys").get_one()[0] self.assertEqual(result, foreign_keys_values[value]) def test_foreign_keys_persistency_to_rollback(self): foreign_keys_values = {"ON": 1, "OFF": 0} for value in foreign_keys_values: database = SQLite(URI("sqlite:%s?foreign_keys=%s" % (self.get_path(), value))) connection = database.connect() connection.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)") connection.rollback() result = connection.execute("PRAGMA foreign_keys").get_one()[0] self.assertEqual(result, foreign_keys_values[value]) class SQLiteUnsupportedTest(UnsupportedDatabaseTest, TestHelper): dbapi_module_names = ["pysqlite2", "sqlite3"] db_module_name = "sqlite" ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/tests/django/0000755000175000017500000000000014645532536016605 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/tests/django/__init__.py0000644000175000017500000000000011752263216020674 0ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/event.py0000644000175000017500000000645514645174376017054 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.event import EventSystem from storm.tests.helper import TestHelper class Marker: def __eq__(self, other): return self is other def __lt__(self, other): return False if self is other else NotImplemented __gt__ = __lt__ def __le__(self, other): return True if self is other else NotImplemented __ge__ = __le__ marker = Marker() class EventTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.event = EventSystem(marker) def test_hook_unhook_emit(self): called1 = [] called2 = [] def callback1(owner, arg1, arg2): called1.append((owner, arg1, arg2)) def callback2(owner, arg1, arg2, data1, data2): called2.append((owner, arg1, arg2, data1, data2)) self.event.hook("one", callback1) self.event.hook("one", callback1) self.event.hook("one", callback2, 10, 20) self.event.hook("two", callback2, 10, 20) self.event.hook("two", callback2, 10, 20) self.event.hook("two", callback2, 30, 40) self.event.hook("three", callback1) self.event.emit("one", 1, 2) self.event.emit("two", 3, 4) self.event.unhook("two", callback2, 10, 20) self.event.emit("two", 3, 4) self.event.emit("three", 5, 6) self.assertEqual(sorted(called1), [ (marker, 1, 2), (marker, 5, 6), ]) self.assertEqual(sorted(called2), [ (marker, 1, 2, 10, 20), (marker, 3, 4, 10, 20), (marker, 3, 4, 30, 40), (marker, 3, 4, 30, 40), ]) def test_unhook_by_returning_false(self): called = [] def callback(owner): called.append(owner) return len(called) < 2 self.event.hook("event", callback) self.event.emit("event") self.event.emit("event") self.event.emit("event") self.event.emit("event") self.assertEqual(called, [marker, marker]) def test_weak_reference(self): marker = Marker() called = [] def callback(owner): called.append(owner) self.event = EventSystem(marker) self.event.hook("event", callback) self.event.emit("event") self.assertEqual(called, [marker]) del called[:] del marker self.event.emit("event") self.assertEqual(called, []) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/expr.py0000644000175000017500000027276714645174376016724 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from decimal import Decimal from storm.variables import * from storm.expr import * from storm.tests.helper import TestHelper class Func1(NamedFunc): name = "func1" class Func2(NamedFunc): name = "func2" # Create columnN, tableN, and elemN variables. for i in range(10): for name in ["column", "elem"]: exec("%s%d = SQLToken('%s%d')" % (name, i, name, i)) for name in ["table"]: exec("%s%d = '%s %d'" % (name, i, name, i)) class TrackContext(FromExpr): context = None @compile.when(TrackContext) def compile_track_context(compile, expr, state): expr.context = state.context return "" def track_contexts(n): return [TrackContext() for i in range(n)] class ExprTest(TestHelper): def test_select_default(self): expr = Select(()) self.assertEqual(expr.columns, ()) self.assertEqual(expr.where, Undef) self.assertEqual(expr.tables, Undef) self.assertEqual(expr.default_tables, Undef) self.assertEqual(expr.order_by, Undef) self.assertEqual(expr.group_by, Undef) self.assertEqual(expr.limit, Undef) self.assertEqual(expr.offset, Undef) self.assertEqual(expr.distinct, False) def test_select_constructor(self): objects = [object() for i in range(9)] expr = Select(*objects) self.assertEqual(expr.columns, objects[0]) self.assertEqual(expr.where, objects[1]) self.assertEqual(expr.tables, objects[2]) self.assertEqual(expr.default_tables, objects[3]) self.assertEqual(expr.order_by, objects[4]) self.assertEqual(expr.group_by, objects[5]) self.assertEqual(expr.limit, objects[6]) self.assertEqual(expr.offset, objects[7]) self.assertEqual(expr.distinct, objects[8]) def test_insert_default(self): expr = Insert(None) self.assertEqual(expr.map, None) self.assertEqual(expr.table, Undef) self.assertEqual(expr.default_table, Undef) self.assertEqual(expr.primary_columns, Undef) self.assertEqual(expr.primary_variables, Undef) def test_insert_constructor(self): objects = [object() for i in range(5)] expr = Insert(*objects) self.assertEqual(expr.map, objects[0]) self.assertEqual(expr.table, objects[1]) self.assertEqual(expr.default_table, objects[2]) self.assertEqual(expr.primary_columns, objects[3]) self.assertEqual(expr.primary_variables, objects[4]) def test_update_default(self): expr = Update(None) self.assertEqual(expr.map, None) self.assertEqual(expr.where, Undef) self.assertEqual(expr.table, Undef) self.assertEqual(expr.default_table, Undef) def test_update_constructor(self): objects = [object() for i in range(4)] expr = Update(*objects) self.assertEqual(expr.map, objects[0]) self.assertEqual(expr.where, objects[1]) self.assertEqual(expr.table, objects[2]) self.assertEqual(expr.default_table, objects[3]) def test_delete_default(self): expr = Delete() self.assertEqual(expr.where, Undef) self.assertEqual(expr.table, Undef) def test_delete_constructor(self): objects = [object() for i in range(3)] expr = Delete(*objects) self.assertEqual(expr.where, objects[0]) self.assertEqual(expr.table, objects[1]) self.assertEqual(expr.default_table, objects[2]) def test_and(self): expr = And(elem1, elem2, elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) def test_or(self): expr = Or(elem1, elem2, elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) def test_column_default(self): expr = Column() self.assertEqual(expr.name, Undef) self.assertEqual(expr.table, Undef) self.assertIdentical(expr.compile_cache, None) # Test for identity. We don't want False there. self.assertIs(expr.primary, 0) self.assertEqual(expr.variable_factory, Variable) def test_column_constructor(self): objects = [object() for i in range(3)] objects.insert(2, True) expr = Column(*objects) self.assertEqual(expr.name, objects[0]) self.assertEqual(expr.table, objects[1]) # Test for identity. We don't want True there either. self.assertIs(expr.primary, 1) self.assertEqual(expr.variable_factory, objects[3]) def test_func(self): expr = Func("myfunc", elem1, elem2) self.assertEqual(expr.name, "myfunc") self.assertEqual(expr.args, (elem1, elem2)) def test_named_func(self): class MyFunc(NamedFunc): name = "myfunc" expr = MyFunc(elem1, elem2) self.assertEqual(expr.name, "myfunc") self.assertEqual(expr.args, (elem1, elem2)) def test_like(self): expr = Like(elem1, elem2) self.assertEqual(expr.expr1, elem1) self.assertEqual(expr.expr2, elem2) def test_like_escape(self): expr = Like(elem1, elem2, elem3) self.assertEqual(expr.expr1, elem1) self.assertEqual(expr.expr2, elem2) self.assertEqual(expr.escape, elem3) def test_like_case(self): expr = Like(elem1, elem2, elem3) self.assertEqual(expr.case_sensitive, None) expr = Like(elem1, elem2, elem3, True) self.assertEqual(expr.case_sensitive, True) expr = Like(elem1, elem2, elem3, False) self.assertEqual(expr.case_sensitive, False) def test_startswith(self): expr = Func1() self.assertRaises(ExprError, expr.startswith, b"not a unicode string") like_expr = expr.startswith("abc!!_%") self.assertTrue(isinstance(like_expr, Like)) self.assertIs(like_expr.expr1, expr) self.assertEqual(like_expr.expr2, "abc!!!!!_!%%") self.assertEqual(like_expr.escape, "!") def test_startswith_case(self): expr = Func1() like_expr = expr.startswith("abc!!_%") self.assertIsNone(like_expr.case_sensitive) like_expr = expr.startswith("abc!!_%", case_sensitive=True) self.assertIs(True, like_expr.case_sensitive) like_expr = expr.startswith("abc!!_%", case_sensitive=False) self.assertIs(False, like_expr.case_sensitive) def test_endswith(self): expr = Func1() self.assertRaises(ExprError, expr.startswith, b"not a unicode string") like_expr = expr.endswith("abc!!_%") self.assertTrue(isinstance(like_expr, Like)) self.assertIs(like_expr.expr1, expr) self.assertEqual(like_expr.expr2, "%abc!!!!!_!%") self.assertEqual(like_expr.escape, "!") def test_endswith_case(self): expr = Func1() like_expr = expr.endswith("abc!!_%") self.assertIsNone(like_expr.case_sensitive) like_expr = expr.endswith("abc!!_%", case_sensitive=True) self.assertIs(True, like_expr.case_sensitive) like_expr = expr.endswith("abc!!_%", case_sensitive=False) self.assertIs(False, like_expr.case_sensitive) def test_contains_string(self): expr = Func1() self.assertRaises( ExprError, expr.contains_string, b"not a unicode string") like_expr = expr.contains_string("abc!!_%") self.assertTrue(isinstance(like_expr, Like)) self.assertIs(like_expr.expr1, expr) self.assertEqual(like_expr.expr2, "%abc!!!!!_!%%") self.assertEqual(like_expr.escape, "!") def test_contains_string_case(self): expr = Func1() like_expr = expr.contains_string("abc!!_%") self.assertIsNone(like_expr.case_sensitive) like_expr = expr.contains_string("abc!!_%", case_sensitive=True) self.assertIs(True, like_expr.case_sensitive) like_expr = expr.contains_string("abc!!_%", case_sensitive=False) self.assertIs(False, like_expr.case_sensitive) def test_is(self): expr = Is(elem1, elem2) self.assertEqual(expr.expr1, elem1) self.assertEqual(expr.expr2, elem2) def test_is_not(self): expr = IsNot(elem1, elem2) self.assertEqual(expr.expr1, elem1) self.assertEqual(expr.expr2, elem2) def test_eq(self): expr = Eq(elem1, elem2) self.assertEqual(expr.expr1, elem1) self.assertEqual(expr.expr2, elem2) def test_sql_default(self): expr = SQL(None) self.assertEqual(expr.expr, None) self.assertEqual(expr.params, Undef) self.assertEqual(expr.tables, Undef) def test_sql_constructor(self): objects = [object() for i in range(3)] expr = SQL(*objects) self.assertEqual(expr.expr, objects[0]) self.assertEqual(expr.params, objects[1]) self.assertEqual(expr.tables, objects[2]) def test_join_expr_right(self): expr = JoinExpr(None) self.assertEqual(expr.right, None) self.assertEqual(expr.left, Undef) self.assertEqual(expr.on, Undef) def test_join_expr_on(self): on = Expr() expr = JoinExpr(None, on) self.assertEqual(expr.right, None) self.assertEqual(expr.left, Undef) self.assertEqual(expr.on, on) def test_join_expr_on_keyword(self): on = Expr() expr = JoinExpr(None, on=on) self.assertEqual(expr.right, None) self.assertEqual(expr.left, Undef) self.assertEqual(expr.on, on) def test_join_expr_on_invalid(self): on = Expr() self.assertRaises(ExprError, JoinExpr, None, on, None) def test_join_expr_right_left(self): objects = [object() for i in range(2)] expr = JoinExpr(*objects) self.assertEqual(expr.left, objects[0]) self.assertEqual(expr.right, objects[1]) self.assertEqual(expr.on, Undef) def test_join_expr_right_left_on(self): objects = [object() for i in range(3)] expr = JoinExpr(*objects) self.assertEqual(expr.left, objects[0]) self.assertEqual(expr.right, objects[1]) self.assertEqual(expr.on, objects[2]) def test_join_expr_right_join(self): join = JoinExpr(None) expr = JoinExpr(None, join) self.assertEqual(expr.right, join) self.assertEqual(expr.left, None) self.assertEqual(expr.on, Undef) def test_table(self): objects = [object() for i in range(1)] expr = Table(*objects) self.assertEqual(expr.name, objects[0]) def test_alias_default(self): expr = Alias(None) self.assertEqual(expr.expr, None) self.assertTrue(isinstance(expr.name, str)) def test_alias_constructor(self): objects = [object() for i in range(2)] expr = Alias(*objects) self.assertEqual(expr.expr, objects[0]) self.assertEqual(expr.name, objects[1]) def test_union(self): expr = Union(elem1, elem2, elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) def test_union_with_kwargs(self): expr = Union(elem1, elem2, all=True, order_by=(), limit=1, offset=2) self.assertEqual(expr.exprs, (elem1, elem2)) self.assertEqual(expr.all, True) self.assertEqual(expr.order_by, ()) self.assertEqual(expr.limit, 1) self.assertEqual(expr.offset, 2) def test_union_collapse(self): expr = Union(Union(elem1, elem2), elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Only first expression is collapsed. expr = Union(elem1, Union(elem2, elem3)) self.assertEqual(expr.exprs[0], elem1) self.assertTrue(isinstance(expr.exprs[1], Union)) # Don't collapse if all is different. expr = Union(Union(elem1, elem2, all=True), elem3) self.assertTrue(isinstance(expr.exprs[0], Union)) expr = Union(Union(elem1, elem2), elem3, all=True) self.assertTrue(isinstance(expr.exprs[0], Union)) expr = Union(Union(elem1, elem2, all=True), elem3, all=True) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Don't collapse if limit or offset are set. expr = Union(Union(elem1, elem2, limit=1), elem3) self.assertTrue(isinstance(expr.exprs[0], Union)) expr = Union(Union(elem1, elem2, offset=3), elem3) self.assertTrue(isinstance(expr.exprs[0], Union)) # Don't collapse other set expressions. expr = Union(Except(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Except)) expr = Union(Intersect(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Intersect)) def test_except(self): expr = Except(elem1, elem2, elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) def test_except_with_kwargs(self): expr = Except(elem1, elem2, all=True, order_by=(), limit=1, offset=2) self.assertEqual(expr.exprs, (elem1, elem2)) self.assertEqual(expr.all, True) self.assertEqual(expr.order_by, ()) self.assertEqual(expr.limit, 1) self.assertEqual(expr.offset, 2) def test_except_collapse(self): expr = Except(Except(elem1, elem2), elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Only first expression is collapsed. expr = Except(elem1, Except(elem2, elem3)) self.assertEqual(expr.exprs[0], elem1) self.assertTrue(isinstance(expr.exprs[1], Except)) # Don't collapse if all is different. expr = Except(Except(elem1, elem2, all=True), elem3) self.assertTrue(isinstance(expr.exprs[0], Except)) expr = Except(Except(elem1, elem2), elem3, all=True) self.assertTrue(isinstance(expr.exprs[0], Except)) expr = Except(Except(elem1, elem2, all=True), elem3, all=True) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Don't collapse if limit or offset are set. expr = Except(Except(elem1, elem2, limit=1), elem3) self.assertTrue(isinstance(expr.exprs[0], Except)) expr = Except(Except(elem1, elem2, offset=3), elem3) self.assertTrue(isinstance(expr.exprs[0], Except)) # Don't collapse other set expressions. expr = Except(Union(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Union)) expr = Except(Intersect(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Intersect)) def test_intersect(self): expr = Intersect(elem1, elem2, elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) def test_intersect_with_kwargs(self): expr = Intersect( elem1, elem2, all=True, order_by=(), limit=1, offset=2) self.assertEqual(expr.exprs, (elem1, elem2)) self.assertEqual(expr.all, True) self.assertEqual(expr.order_by, ()) self.assertEqual(expr.limit, 1) self.assertEqual(expr.offset, 2) def test_intersect_collapse(self): expr = Intersect(Intersect(elem1, elem2), elem3) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Only first expression is collapsed. expr = Intersect(elem1, Intersect(elem2, elem3)) self.assertEqual(expr.exprs[0], elem1) self.assertTrue(isinstance(expr.exprs[1], Intersect)) # Don't collapse if all is different. expr = Intersect(Intersect(elem1, elem2, all=True), elem3) self.assertTrue(isinstance(expr.exprs[0], Intersect)) expr = Intersect(Intersect(elem1, elem2), elem3, all=True) self.assertTrue(isinstance(expr.exprs[0], Intersect)) expr = Intersect(Intersect(elem1, elem2, all=True), elem3, all=True) self.assertEqual(expr.exprs, (elem1, elem2, elem3)) # Don't collapse if limit or offset are set. expr = Intersect(Intersect(elem1, elem2, limit=1), elem3) self.assertTrue(isinstance(expr.exprs[0], Intersect)) expr = Intersect(Intersect(elem1, elem2, offset=3), elem3) self.assertTrue(isinstance(expr.exprs[0], Intersect)) # Don't collapse other set expressions. expr = Intersect(Union(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Union)) expr = Intersect(Except(elem1, elem2), elem3) self.assertTrue(isinstance(expr.exprs[0], Except)) def test_auto_tables(self): expr = AutoTables(elem1, [elem2]) self.assertEqual(expr.expr, elem1) self.assertEqual(expr.tables, [elem2]) def test_sequence(self): expr = Sequence(elem1) self.assertEqual(expr.name, elem1) class StateTest(TestHelper): def setUp(self): TestHelper.setUp(self) self.state = State() def test_attrs(self): self.assertEqual(self.state.parameters, []) self.assertEqual(self.state.auto_tables, []) self.assertEqual(self.state.context, None) def test_push_pop(self): self.state.parameters.extend([1, 2]) self.state.push("parameters", []) self.assertEqual(self.state.parameters, []) self.state.pop() self.assertEqual(self.state.parameters, [1, 2]) self.state.push("parameters") self.assertEqual(self.state.parameters, [1, 2]) self.state.parameters.append(3) self.assertEqual(self.state.parameters, [1, 2, 3]) self.state.pop() self.assertEqual(self.state.parameters, [1, 2]) def test_push_pop_unexistent(self): self.state.push("nonexistent") self.assertEqual(self.state.nonexistent, None) self.state.nonexistent = "something" self.state.pop() self.assertEqual(self.state.nonexistent, None) class CompileTest(TestHelper): def test_simple_inheritance(self): custom_compile = compile.create_child() statement = custom_compile(Func1()) self.assertEqual(statement, "func1()") def test_customize(self): custom_compile = compile.create_child() @custom_compile.when(type(None)) def compile_none(compile, state, expr): return "None" statement = custom_compile(Func1(None)) self.assertEqual(statement, "func1(None)") def test_customize_inheritance(self): class C: pass compile_parent = Compile() compile_child = compile_parent.create_child() @compile_parent.when(C) def compile_in_parent(compile, state, expr): return "parent" statement = compile_child(C()) self.assertEqual(statement, "parent") @compile_child.when(C) def compile_in_child(compile, state, expr): return "child" statement = compile_child(C()) self.assertEqual(statement, "child") def test_precedence(self): e = [SQLRaw('%d' % i) for i in range(10)] expr = And(e[1], Or(e[2], e[3]), Add(e[4], Mul(e[5], Sub(e[6], Div(e[7], Div(e[8], e[9])))))) statement = compile(expr) self.assertEqual(statement, "1 AND (2 OR 3) AND 4 + 5 * (6 - 7 / (8 / 9))") expr = Func1(Select(Count()), [Select(Count())]) statement = compile(expr) self.assertEqual(statement, "func1((SELECT COUNT(*)), (SELECT COUNT(*)))") def test_get_precedence(self): self.assertTrue(compile.get_precedence(Or) < compile.get_precedence(And)) self.assertTrue(compile.get_precedence(Add) < compile.get_precedence(Mul)) self.assertTrue(compile.get_precedence(Sub) < compile.get_precedence(Div)) def test_customize_precedence(self): expr = And(elem1, Or(elem2, elem3)) custom_compile = compile.create_child() custom_compile.set_precedence(10, And) custom_compile.set_precedence(11, Or) statement = custom_compile(expr) self.assertEqual(statement, "elem1 AND elem2 OR elem3") custom_compile.set_precedence(10, Or) statement = custom_compile(expr) self.assertEqual(statement, "elem1 AND elem2 OR elem3") custom_compile.set_precedence(9, Or) statement = custom_compile(expr) self.assertEqual(statement, "elem1 AND (elem2 OR elem3)") def test_customize_precedence_inheritance(self): compile_parent = compile.create_child() compile_child = compile_parent.create_child() expr = And(elem1, Or(elem2, elem3)) compile_parent.set_precedence(10, And) compile_parent.set_precedence(11, Or) self.assertEqual(compile_child.get_precedence(Or), 11) self.assertEqual(compile_parent.get_precedence(Or), 11) statement = compile_child(expr) self.assertEqual(statement, "elem1 AND elem2 OR elem3") compile_parent.set_precedence(10, Or) self.assertEqual(compile_child.get_precedence(Or), 10) self.assertEqual(compile_parent.get_precedence(Or), 10) statement = compile_child(expr) self.assertEqual(statement, "elem1 AND elem2 OR elem3") compile_child.set_precedence(9, Or) self.assertEqual(compile_child.get_precedence(Or), 9) self.assertEqual(compile_parent.get_precedence(Or), 10) statement = compile_child(expr) self.assertEqual(statement, "elem1 AND (elem2 OR elem3)") def test_compile_sequence(self): expr = [elem1, Func1(), (Func2(), None)] statement = compile(expr) self.assertEqual(statement, "elem1, func1(), func2(), NULL") def test_compile_invalid(self): self.assertRaises(CompileError, compile, object()) self.assertRaises(CompileError, compile, [object()]) def test_bytes(self): state = State() statement = compile(b"str", state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [BytesVariable(b"str")]) def test_unicode(self): state = State() statement = compile("str", state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [UnicodeVariable("str")]) def test_int(self): state = State() statement = compile(1, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [IntVariable(1)]) def test_bool(self): state = State() statement = compile(True, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [BoolVariable(1)]) def test_float(self): state = State() statement = compile(1.1, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [FloatVariable(1.1)]) def test_decimal(self): state = State() statement = compile(Decimal("1.1"), state) self.assertEqual(statement, "?") self.assertVariablesEqual( state.parameters, [DecimalVariable(Decimal("1.1"))]) def test_datetime(self): dt = datetime(1977, 5, 4, 12, 34) state = State() statement = compile(dt, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [DateTimeVariable(dt)]) def test_date(self): d = date(1977, 5, 4) state = State() statement = compile(d, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [DateVariable(d)]) def test_time(self): t = time(12, 34) state = State() statement = compile(t, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [TimeVariable(t)]) def test_timedelta(self): td = timedelta(days=1, seconds=2, microseconds=3) state = State() statement = compile(td, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [TimeDeltaVariable(td)]) def test_none(self): state = State() statement = compile(None, state) self.assertEqual(statement, "NULL") self.assertEqual(state.parameters, []) def test_select(self): expr = Select([column1, column2]) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT column1, column2") self.assertEqual(state.parameters, []) def test_select_distinct(self): expr = Select([column1, column2], Undef, [table1], distinct=True) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT DISTINCT column1, column2 FROM "table 1"') self.assertEqual(state.parameters, []) def test_select_distinct_on(self): expr = Select([column1, column2], Undef, [table1], distinct=[column2, column1]) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT DISTINCT ON (column2, column1) ' 'column1, column2 FROM "table 1"') self.assertEqual(state.parameters, []) def test_select_where(self): expr = Select([column1, Func1()], Func1(), [table1, Func1()], order_by=[column2, Func1()], group_by=[column3, Func1()], limit=3, offset=4) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1, func1() ' 'FROM "table 1", func1() ' 'WHERE func1() ' 'GROUP BY column3, func1() ' 'ORDER BY column2, func1() ' 'LIMIT 3 OFFSET 4') self.assertEqual(state.parameters, []) def test_select_join_where(self): expr = Select(column1, Func1() == "value1", Join(table1, Func2() == "value2")) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM ' 'JOIN "table 1" ON func2() = ? ' 'WHERE func1() = ?') self.assertEqual([variable.get() for variable in state.parameters], ["value2", "value1"]) def test_select_auto_table(self): expr = Select(Column(column1, table1), Column(column2, table2) == 1), state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".column1 ' 'FROM "table 1", "table 2" ' 'WHERE "table 2".column2 = ?') self.assertVariablesEqual(state.parameters, [Variable(1)]) def test_select_auto_table_duplicated(self): expr = Select(Column(column1, table1), Column(column2, table1) == 1), state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".column1 ' 'FROM "table 1" WHERE ' '"table 1".column2 = ?') self.assertVariablesEqual(state.parameters, [Variable(1)]) def test_select_auto_table_default(self): expr = Select(Column(column1), Column(column2) == 1, default_tables=table1), state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM "table 1" ' 'WHERE column2 = ?') self.assertVariablesEqual(state.parameters, [Variable(1)]) def test_select_auto_table_default_with_joins(self): expr = Select(Column(column1), default_tables=[table1, Join(table2)]), state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 ' 'FROM "table 1" JOIN "table 2"') self.assertEqual(state.parameters, []) def test_select_auto_table_unknown(self): statement = compile(Select(elem1)) self.assertEqual(statement, "SELECT elem1") def test_select_auto_table_sub(self): col1 = Column(column1, table1) col2 = Column(column2, table2) expr = Select(col1, In(elem1, Select(col2, col1 == col2, col2.table))) statement = compile(expr) self.assertEqual(statement, 'SELECT "table 1".column1 FROM "table 1" WHERE ' 'elem1 IN (SELECT "table 2".column2 FROM "table 2" ' 'WHERE "table 1".column1 = "table 2".column2)') def test_select_join(self): expr = Select([column1, Func1()], Func1(), [table1, Join(table2), Join(table3)]) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1, func1() ' 'FROM "table 1" JOIN "table 2"' ' JOIN "table 3" ' 'WHERE func1()') self.assertEqual(state.parameters, []) def test_select_join_right_left(self): expr = Select([column1, Func1()], Func1(), [table1, Join(table2, table3)]) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1, func1() ' 'FROM "table 1", "table 2" ' 'JOIN "table 3" WHERE func1()') self.assertEqual(state.parameters, []) def test_select_with_strings(self): expr = Select(column1, "1 = 2", table1, order_by="column1", group_by="column2") state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM "table 1" ' 'WHERE 1 = 2 GROUP BY column2 ' 'ORDER BY column1') self.assertEqual(state.parameters, []) def test_select_with_unicode(self): expr = Select(column1, "1 = 2", table1, order_by="column1", group_by=["column2"]) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM "table 1" ' 'WHERE 1 = 2 GROUP BY column2 ' 'ORDER BY column1') self.assertEqual(state.parameters, []) def test_select_having(self): expr = Select(column1, tables=table1, order_by="column1", group_by=["column2"], having="1 = 2") state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM "table 1" ' 'GROUP BY column2 HAVING 1 = 2 ' 'ORDER BY column1') self.assertEqual(state.parameters, []) def test_select_contexts(self): column, where, table, order_by, group_by = track_contexts(5) expr = Select(column, where, table, order_by=order_by, group_by=group_by) compile(expr) self.assertEqual(column.context, COLUMN) self.assertEqual(where.context, EXPR) self.assertEqual(table.context, TABLE) self.assertEqual(order_by.context, EXPR) self.assertEqual(group_by.context, EXPR) def test_insert(self): expr = Insert({column1: elem1, Func1(): Func2()}, Func2()) state = State() statement = compile(expr, state) self.assertTrue(statement in ( "INSERT INTO func2() (column1, func1()) " "VALUES (elem1, func2())", "INSERT INTO func2() (func1(), column1) " "VALUES (func2(), elem1)"), statement) self.assertEqual(state.parameters, []) def test_insert_with_columns(self): expr = Insert({Column(column1, table1): elem1, Column(column2, table1): elem2}, table2) state = State() statement = compile(expr, state) self.assertTrue(statement in ( 'INSERT INTO "table 2" (column1, column2) ' 'VALUES (elem1, elem2)', 'INSERT INTO "table 2" (column2, column1) ' 'VALUES (elem2, elem1)'), statement) self.assertEqual(state.parameters, []) def test_insert_with_columns_to_escape(self): expr = Insert({Column("column 1", table1): elem1}, table2) state = State() statement = compile(expr, state) self.assertEqual(statement, 'INSERT INTO "table 2" ("column 1") VALUES (elem1)') self.assertEqual(state.parameters, []) def test_insert_with_columns_as_raw_strings(self): expr = Insert({"column 1": elem1}, table2) state = State() statement = compile(expr, state) self.assertEqual(statement, 'INSERT INTO "table 2" ("column 1") VALUES (elem1)') self.assertEqual(state.parameters, []) def test_insert_auto_table(self): expr = Insert({Column(column1, table1): elem1}) state = State() statement = compile(expr, state) self.assertEqual(statement, 'INSERT INTO "table 1" (column1) ' 'VALUES (elem1)') self.assertEqual(state.parameters, []) def test_insert_auto_table_default(self): expr = Insert({Column(column1): elem1}, default_table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'INSERT INTO "table 1" (column1) ' 'VALUES (elem1)') self.assertEqual(state.parameters, []) def test_insert_auto_table_unknown(self): expr = Insert({Column(column1): elem1}) self.assertRaises(NoTableError, compile, expr) def test_insert_contexts(self): column, value, table = track_contexts(3) expr = Insert({column: value}, table) compile(expr) self.assertEqual(column.context, COLUMN_NAME) self.assertEqual(value.context, EXPR) self.assertEqual(table.context, TABLE) def test_insert_bulk(self): expr = Insert((Column(column1, table1), Column(column2, table1)), values=[(elem1, elem2), (elem3, elem4)]) state = State() statement = compile(expr, state) self.assertEqual( statement, 'INSERT INTO "table 1" (column1, column2) ' 'VALUES (elem1, elem2), (elem3, elem4)') self.assertEqual(state.parameters, []) def test_insert_select(self): expr = Insert((Column(column1, table1), Column(column2, table1)), values=Select( (Column(column3, table3), Column(column4, table4)))) state = State() statement = compile(expr, state) self.assertEqual( statement, 'INSERT INTO "table 1" (column1, column2) ' 'SELECT "table 3".column3, "table 4".column4 ' 'FROM "table 3", "table 4"') self.assertEqual(state.parameters, []) def test_update(self): expr = Update({column1: elem1, Func1(): Func2()}, table=Func1()) state = State() statement = compile(expr, state) self.assertTrue(statement in ( "UPDATE func1() SET column1=elem1, func1()=func2()", "UPDATE func1() SET func1()=func2(), column1=elem1" ), statement) self.assertEqual(state.parameters, []) def test_update_with_columns(self): expr = Update({Column(column1, table1): elem1}, table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 1" SET column1=elem1') self.assertEqual(state.parameters, []) def test_update_with_columns_to_escape(self): expr = Update({Column("column x", table1): elem1}, table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 1" SET "column x"=elem1') self.assertEqual(state.parameters, []) def test_update_with_columns_as_raw_strings(self): expr = Update({"column 1": elem1}, table=table2) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 2" SET "column 1"=elem1') self.assertEqual(state.parameters, []) def test_update_where(self): expr = Update({column1: elem1}, Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "UPDATE func2() SET column1=elem1 WHERE func1()") self.assertEqual(state.parameters, []) def test_update_auto_table(self): expr = Update({Column(column1, table1): elem1}) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 1" SET column1=elem1') self.assertEqual(state.parameters, []) def test_update_auto_table_default(self): expr = Update({Column(column1): elem1}, default_table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 1" SET column1=elem1') self.assertEqual(state.parameters, []) def test_update_auto_table_unknown(self): expr = Update({Column(column1): elem1}) self.assertRaises(CompileError, compile, expr) def test_update_with_strings(self): expr = Update({column1: elem1}, "1 = 2", table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'UPDATE "table 1" SET column1=elem1 WHERE 1 = 2') self.assertEqual(state.parameters, []) def test_update_contexts(self): set_left, set_right, where, table = track_contexts(4) expr = Update({set_left: set_right}, where, table) compile(expr) self.assertEqual(set_left.context, COLUMN_NAME) self.assertEqual(set_right.context, COLUMN_NAME) self.assertEqual(where.context, EXPR) self.assertEqual(table.context, TABLE) def test_delete(self): expr = Delete(table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'DELETE FROM "table 1"') self.assertEqual(state.parameters, []) def test_delete_where(self): expr = Delete(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "DELETE FROM func2() WHERE func1()") self.assertEqual(state.parameters, []) def test_delete_with_strings(self): expr = Delete("1 = 2", table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'DELETE FROM "table 1" WHERE 1 = 2') self.assertEqual(state.parameters, []) def test_delete_auto_table(self): expr = Delete(Column(column1, table1) == 1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'DELETE FROM "table 1" WHERE "table 1".column1 = ?') self.assertVariablesEqual(state.parameters, [Variable(1)]) def test_delete_auto_table_default(self): expr = Delete(Column(column1) == 1, default_table=table1) state = State() statement = compile(expr, state) self.assertEqual(statement, 'DELETE FROM "table 1" WHERE column1 = ?') self.assertVariablesEqual(state.parameters, [Variable(1)]) def test_delete_auto_table_unknown(self): expr = Delete(Column(column1) == 1) self.assertRaises(NoTableError, compile, expr) def test_delete_contexts(self): where, table = track_contexts(2) expr = Delete(where, table) compile(expr) self.assertEqual(where.context, EXPR) self.assertEqual(table.context, TABLE) def test_column(self): expr = Column(column1) state = State() statement = compile(expr, state) self.assertEqual(statement, "column1") self.assertEqual(state.parameters, []) self.assertEqual(expr.compile_cache, "column1") def test_column_table(self): column = Column(column1, Func1()) expr = Select(column) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT func1().column1 FROM func1()") self.assertEqual(state.parameters, []) self.assertEqual(column.compile_cache, "column1") def test_column_contexts(self): table, = track_contexts(1) expr = Column(column1, table) compile(expr) self.assertEqual(table.context, COLUMN_PREFIX) def test_column_with_reserved_words(self): expr = Select(Column("name 1", "table 1")) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1"."name 1" FROM "table 1"') def test_row(self): expr = Row(column1, column2) statement = compile(expr) self.assertEqual(statement, "ROW(column1, column2)") def test_variable(self): expr = Variable("value") state = State() statement = compile(expr, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_is_null(self): expr = Is(Func1(), None) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NULL") self.assertEqual(state.parameters, []) def test_is_true(self): expr = Is(Func1(), True) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS TRUE") self.assertEqual(state.parameters, []) def test_is_false(self): expr = Is(Func1(), False) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS FALSE") self.assertEqual(state.parameters, []) def test_is_invalid(self): expr = Is(Func1(), "x") self.assertRaises(CompileError, compile, expr) def test_is_not_null(self): expr = IsNot(Func1(), None) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NOT NULL") self.assertEqual(state.parameters, []) def test_is_not_true(self): expr = IsNot(Func1(), True) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NOT TRUE") self.assertEqual(state.parameters, []) def test_is_not_false(self): expr = IsNot(Func1(), False) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NOT FALSE") self.assertEqual(state.parameters, []) def test_is_not_invalid(self): expr = IsNot(Func1(), "x") self.assertRaises(CompileError, compile, expr) def test_eq(self): expr = Eq(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() = func2()") self.assertEqual(state.parameters, []) expr = Func1() == "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_is_in(self): expr = Func1().is_in(["Hello", "World"]) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IN (?, ?)") self.assertVariablesEqual( state.parameters, [Variable("Hello"), Variable("World")]) def test_is_in_empty(self): expr = Func1().is_in([]) state = State() statement = compile(expr, state) self.assertEqual(statement, "?") self.assertVariablesEqual(state.parameters, [BoolVariable(False)]) def test_is_in_expr(self): expr = Func1().is_in(Select(column1)) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IN (SELECT column1)") self.assertEqual(state.parameters, []) def test_eq_none(self): expr = Func1() == None self.assertIsNone(expr.expr2) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NULL") self.assertEqual(state.parameters, []) def test_ne(self): expr = Ne(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() != func2()") self.assertEqual(state.parameters, []) expr = Func1() != "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() != ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_ne_none(self): expr = Func1() != None self.assertIsNone(expr.expr2) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IS NOT NULL") self.assertEqual(state.parameters, []) def test_gt(self): expr = Gt(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() > func2()") self.assertEqual(state.parameters, []) expr = Func1() > "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() > ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_ge(self): expr = Ge(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() >= func2()") self.assertEqual(state.parameters, []) expr = Func1() >= "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() >= ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_lt(self): expr = Lt(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() < func2()") self.assertEqual(state.parameters, []) expr = Func1() < "value" statement = compile(expr, state) self.assertEqual(statement, "func1() < ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_le(self): expr = Le(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() <= func2()") self.assertEqual(state.parameters, []) expr = Func1() <= "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() <= ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_lshift(self): expr = LShift(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() << func2()") self.assertEqual(state.parameters, []) expr = Func1() << "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() << ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_rshift(self): expr = RShift(Func1(), Func2()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() >> func2()") self.assertEqual(state.parameters, []) expr = Func1() >> "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() >> ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_like(self): expr = Like(Func1(), b"value") state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() LIKE ?") self.assertVariablesEqual(state.parameters, [BytesVariable(b"value")]) expr = Func1().like("Hello") state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() LIKE ?") self.assertVariablesEqual(state.parameters, [Variable("Hello")]) def test_like_escape(self): expr = Like(Func1(), b"value", b"!") state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() LIKE ? ESCAPE ?") self.assertVariablesEqual(state.parameters, [BytesVariable(b"value"), BytesVariable(b"!")]) expr = Func1().like("Hello", b"!") state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() LIKE ? ESCAPE ?") self.assertVariablesEqual(state.parameters, [Variable("Hello"), BytesVariable(b"!")]) def test_like_compareable_case(self): expr = Func1().like("Hello") self.assertEqual(expr.case_sensitive, None) expr = Func1().like("Hello", case_sensitive=True) self.assertEqual(expr.case_sensitive, True) expr = Func1().like("Hello", case_sensitive=False) self.assertEqual(expr.case_sensitive, False) def test_in(self): expr = In(Func1(), b"value") state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IN (?)") self.assertVariablesEqual(state.parameters, [BytesVariable(b"value")]) expr = In(Func1(), elem1) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() IN (elem1)") self.assertEqual(state.parameters, []) def test_and(self): expr = And(elem1, elem2, And(elem3, elem4)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 AND elem2 AND elem3 AND elem4") self.assertEqual(state.parameters, []) expr = Func1() & "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() AND ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_or(self): expr = Or(elem1, elem2, Or(elem3, elem4)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 OR elem2 OR elem3 OR elem4") self.assertEqual(state.parameters, []) expr = Func1() | "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() OR ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_and_with_strings(self): expr = And("elem1", "elem2") state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 AND elem2") self.assertEqual(state.parameters, []) def test_or_with_strings(self): expr = Or("elem1", "elem2") state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 OR elem2") self.assertEqual(state.parameters, []) def test_add(self): expr = Add(elem1, elem2, Add(elem3, elem4)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 + elem2 + elem3 + elem4") self.assertEqual(state.parameters, []) expr = Func1() + "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() + ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_sub(self): expr = Sub(elem1, Sub(elem2, elem3)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 - (elem2 - elem3)") self.assertEqual(state.parameters, []) expr = Sub(Sub(elem1, elem2), elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 - elem2 - elem3") self.assertVariablesEqual(state.parameters, []) expr = Func1() - "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() - ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_mul(self): expr = Mul(elem1, elem2, Mul(elem3, elem4)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 * elem2 * elem3 * elem4") self.assertEqual(state.parameters, []) expr = Func1() * "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() * ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_div(self): expr = Div(elem1, Div(elem2, elem3)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 / (elem2 / elem3)") self.assertEqual(state.parameters, []) expr = Div(Div(elem1, elem2), elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 / elem2 / elem3") self.assertEqual(state.parameters, []) expr = Func1() / "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() / ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_mod(self): expr = Mod(elem1, Mod(elem2, elem3)) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 % (elem2 % elem3)") self.assertEqual(state.parameters, []) expr = Mod(Mod(elem1, elem2), elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 % elem2 % elem3") self.assertEqual(state.parameters, []) expr = Func1() % "value" state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() % ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_func(self): expr = Func("myfunc", elem1, Func1(elem2)) state = State() statement = compile(expr, state) self.assertEqual(statement, "myfunc(elem1, func1(elem2))") self.assertEqual(state.parameters, []) def test_named_func(self): expr = Func1(elem1, Func2(elem2)) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1(elem1, func2(elem2))") self.assertEqual(state.parameters, []) def test_count(self): expr = Count(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "COUNT(func1())") self.assertEqual(state.parameters, []) def test_count_all(self): expr = Count() state = State() statement = compile(expr, state) self.assertEqual(statement, "COUNT(*)") self.assertEqual(state.parameters, []) def test_count_distinct(self): expr = Count(Func1(), distinct=True) state = State() statement = compile(expr, state) self.assertEqual(statement, "COUNT(DISTINCT func1())") self.assertEqual(state.parameters, []) def test_count_distinct_all(self): self.assertRaises(ValueError, Count, distinct=True) def test_cast(self): """ The L{Cast} expression renders a C{CAST} function call with a user-defined input value and the type to cast it to. """ expr = Cast(Func1(), "TEXT") state = State() statement = compile(expr, state) self.assertEqual(statement, "CAST(func1() AS TEXT)") self.assertEqual(state.parameters, []) def test_max(self): expr = Max(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "MAX(func1())") self.assertEqual(state.parameters, []) def test_min(self): expr = Min(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "MIN(func1())") self.assertEqual(state.parameters, []) def test_avg(self): expr = Avg(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "AVG(func1())") self.assertEqual(state.parameters, []) def test_sum(self): expr = Sum(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "SUM(func1())") self.assertEqual(state.parameters, []) def test_lower(self): expr = Lower(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "LOWER(func1())") self.assertEqual(state.parameters, []) expr = Func1().lower() state = State() statement = compile(expr, state) self.assertEqual(statement, "LOWER(func1())") self.assertEqual(state.parameters, []) def test_upper(self): expr = Upper(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "UPPER(func1())") self.assertEqual(state.parameters, []) expr = Func1().upper() state = State() statement = compile(expr, state) self.assertEqual(statement, "UPPER(func1())") self.assertEqual(state.parameters, []) def test_coalesce(self): expr = Coalesce(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "COALESCE(func1())") self.assertEqual(state.parameters, []) def test_coalesce_with_many_arguments(self): expr = Coalesce(Func1(), Func2(), None) state = State() statement = compile(expr, state) self.assertEqual(statement, "COALESCE(func1(), func2(), NULL)") self.assertEqual(state.parameters, []) def test_not(self): expr = Not(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "NOT func1()") self.assertEqual(state.parameters, []) def test_exists(self): expr = Exists(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "EXISTS func1()") self.assertEqual(state.parameters, []) def test_neg(self): expr = Neg(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "- func1()") self.assertEqual(state.parameters, []) expr = -Func1() state = State() statement = compile(expr, state) self.assertEqual(statement, "- func1()") self.assertEqual(state.parameters, []) def test_asc(self): expr = Asc(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() ASC") self.assertEqual(state.parameters, []) def test_desc(self): expr = Desc(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() DESC") self.assertEqual(state.parameters, []) def test_asc_with_string(self): expr = Asc("column") state = State() statement = compile(expr, state) self.assertEqual(statement, "column ASC") self.assertEqual(state.parameters, []) def test_desc_with_string(self): expr = Desc("column") state = State() statement = compile(expr, state) self.assertEqual(statement, "column DESC") self.assertEqual(state.parameters, []) def test_sql(self): expr = SQL("expression") state = State() statement = compile(expr, state) self.assertEqual(statement, "expression") self.assertEqual(state.parameters, []) def test_sql_params(self): expr = SQL("expression", ["params"]) state = State() statement = compile(expr, state) self.assertEqual(statement, "expression") self.assertEqual(state.parameters, ["params"]) def test_sql_invalid_params(self): expr = SQL("expression", "not a list or tuple") self.assertRaises(CompileError, compile, expr) def test_sql_tables(self): expr = Select([column1, Func1()], SQL("expression", [], Func2())) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT column1, func1() FROM func2() " "WHERE expression") self.assertEqual(state.parameters, []) def test_sql_tables_with_list_or_tuple(self): sql = SQL("expression", [], [Func1(), Func2()]) expr = Select(column1, sql) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT column1 FROM func1(), func2() " "WHERE expression") self.assertEqual(state.parameters, []) sql = SQL("expression", [], (Func1(), Func2())) expr = Select(column1, sql) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT column1 FROM func1(), func2() " "WHERE expression") self.assertEqual(state.parameters, []) def test_sql_comparison(self): expr = SQL("expression1") & SQL("expression2") state = State() statement = compile(expr, state) self.assertEqual(statement, "(expression1) AND (expression2)") self.assertEqual(state.parameters, []) def test_table(self): expr = Table(table1) self.assertIdentical(expr.compile_cache, None) state = State() statement = compile(expr, state) self.assertEqual(statement, '"table 1"') self.assertEqual(state.parameters, []) self.assertEqual(expr.compile_cache, '"table 1"') def test_alias(self): expr = Alias(Table(table1), "name") state = State() statement = compile(expr, state) self.assertEqual(statement, "name") self.assertEqual(state.parameters, []) def test_alias_in_tables(self): expr = Select(column1, tables=Alias(Table(table1), "alias 1")) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT column1 FROM "table 1" AS "alias 1"') self.assertEqual(state.parameters, []) def test_alias_in_tables_auto_name(self): expr = Select(column1, tables=Alias(Table(table1))) state = State() statement = compile(expr, state) self.assertEqual(statement[:statement.rfind("_")+1], 'SELECT column1 FROM "table 1" AS "_') self.assertEqual(state.parameters, []) def test_alias_in_column_prefix(self): expr = Select(Column(column1, Alias(Table(table1), "alias 1"))) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "alias 1".column1 ' 'FROM "table 1" AS "alias 1"') self.assertEqual(state.parameters, []) def test_alias_for_column(self): expr = Select(Alias(Column(column1, table1), "alias 1")) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".column1 AS "alias 1" ' 'FROM "table 1"') self.assertEqual(state.parameters, []) def test_alias_union(self): union = Union(Select(elem1), Select(elem2)) expr = Select(elem3, tables=Alias(union, "alias")) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT elem3 FROM " "((SELECT elem1) UNION (SELECT elem2)) AS alias") self.assertEqual(state.parameters, []) def test_distinct(self): """L{Distinct} adds a DISTINCT prefix to the given expression.""" distinct = Distinct(Column(elem1)) state = State() statement = compile(distinct, state) self.assertEqual(statement, "DISTINCT elem1") self.assertEqual(state.parameters, []) def test_join(self): expr = Join(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "JOIN func1()") self.assertEqual(state.parameters, []) def test_join_on(self): expr = Join(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "JOIN func1() ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_join_on_with_string(self): expr = Join(Func1(), on="a = b") state = State() statement = compile(expr, state) self.assertEqual(statement, "JOIN func1() ON a = b") self.assertEqual(state.parameters, []) def test_join_left_right(self): expr = Join(table1, table2) state = State() statement = compile(expr, state) self.assertEqual(statement, '"table 1" JOIN "table 2"') self.assertEqual(state.parameters, []) def test_join_nested(self): expr = Join(table1, Join(table2, table3)) state = State() statement = compile(expr, state) self.assertEqual(statement, '"table 1" JOIN ' '("table 2" JOIN "table 3")') self.assertEqual(state.parameters, []) def test_join_double_nested(self): expr = Join(Join(table1, table2), Join(table3, table4)) state = State() statement = compile(expr, state) self.assertEqual(statement, '"table 1" JOIN "table 2" JOIN ' '("table 3" JOIN "table 4")') self.assertEqual(state.parameters, []) def test_join_table(self): expr = Join(Table(table1), Table(table2)) state = State() statement = compile(expr, state) self.assertEqual(statement, '"table 1" JOIN "table 2"') self.assertEqual(state.parameters, []) def test_join_contexts(self): table1, table2, on = track_contexts(3) expr = Join(table1, table2, on) compile(expr) self.assertEqual(table1.context, None) self.assertEqual(table2.context, None) self.assertEqual(on.context, EXPR) def test_left_join(self): expr = LeftJoin(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "LEFT JOIN func1()") self.assertEqual(state.parameters, []) def test_left_join_on(self): expr = LeftJoin(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "LEFT JOIN func1() ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_right_join(self): expr = RightJoin(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "RIGHT JOIN func1()") self.assertEqual(state.parameters, []) def test_right_join_on(self): expr = RightJoin(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "RIGHT JOIN func1() ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_natural_join(self): expr = NaturalJoin(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL JOIN func1()") self.assertEqual(state.parameters, []) def test_natural_join_on(self): expr = NaturalJoin(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL JOIN func1() ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_natural_left_join(self): expr = NaturalLeftJoin(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL LEFT JOIN func1()") self.assertEqual(state.parameters, []) def test_natural_left_join_on(self): expr = NaturalLeftJoin(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL LEFT JOIN func1() " "ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_natural_right_join(self): expr = NaturalRightJoin(Func1()) state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL RIGHT JOIN func1()") self.assertEqual(state.parameters, []) def test_natural_right_join_on(self): expr = NaturalRightJoin(Func1(), Func2() == "value") state = State() statement = compile(expr, state) self.assertEqual(statement, "NATURAL RIGHT JOIN func1() " "ON func2() = ?") self.assertVariablesEqual(state.parameters, [Variable("value")]) def test_union(self): expr = Union(Func1(), elem2, elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() UNION elem2 UNION elem3") self.assertEqual(state.parameters, []) def test_union_all(self): expr = Union(Func1(), elem2, elem3, all=True) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() UNION ALL elem2 UNION ALL elem3") self.assertEqual(state.parameters, []) def test_union_order_by_limit_offset(self): expr = Union(elem1, elem2, order_by=Func1(), limit=1, offset=2) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 UNION elem2 ORDER BY func1() " "LIMIT 1 OFFSET 2") self.assertEqual(state.parameters, []) def test_union_select(self): expr = Union(Select(elem1), Select(elem2)) state = State() statement = compile(expr, state) self.assertEqual(statement, "(SELECT elem1) UNION (SELECT elem2)") self.assertEqual(state.parameters, []) def test_union_select_nested(self): expr = Union(Select(elem1), Union(Select(elem2), Select(elem3))) state = State() statement = compile(expr, state) self.assertEqual(statement, "(SELECT elem1) UNION" " ((SELECT elem2) UNION (SELECT elem3))") self.assertEqual(state.parameters, []) def test_union_order_by_and_select(self): """ When ORDER BY is present, databases usually have trouble using fully qualified column names. Because of that, we transform pure column names into aliases, and use them in the ORDER BY. """ Alias.auto_counter = 0 column1 = Column(elem1) column2 = Column(elem2) expr = Union(Select(column1), Select(column2), order_by=(column1, column2)) state = State() statement = compile(expr, state) self.assertEqual( statement, '(SELECT elem1 AS "_1") UNION (SELECT elem2 AS "_2") ' 'ORDER BY "_1", "_2"') self.assertEqual(state.parameters, []) def test_union_contexts(self): select1, select2, order_by = track_contexts(3) expr = Union(select1, select2, order_by=order_by) compile(expr) self.assertEqual(select1.context, SELECT) self.assertEqual(select2.context, SELECT) self.assertEqual(order_by.context, COLUMN_NAME) def test_except(self): expr = Except(Func1(), elem2, elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() EXCEPT elem2 EXCEPT elem3") self.assertEqual(state.parameters, []) def test_except_all(self): expr = Except(Func1(), elem2, elem3, all=True) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() EXCEPT ALL elem2 " "EXCEPT ALL elem3") self.assertEqual(state.parameters, []) def test_except_order_by_limit_offset(self): expr = Except(elem1, elem2, order_by=Func1(), limit=1, offset=2) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 EXCEPT elem2 ORDER BY func1() " "LIMIT 1 OFFSET 2") self.assertEqual(state.parameters, []) def test_except_select(self): expr = Except(Select(elem1), Select(elem2)) state = State() statement = compile(expr, state) self.assertEqual(statement, "(SELECT elem1) EXCEPT (SELECT elem2)") self.assertEqual(state.parameters, []) def test_except_select_nested(self): expr = Except(Select(elem1), Except(Select(elem2), Select(elem3))) state = State() statement = compile(expr, state) self.assertEqual(statement, "(SELECT elem1) EXCEPT" " ((SELECT elem2) EXCEPT (SELECT elem3))") self.assertEqual(state.parameters, []) def test_except_contexts(self): select1, select2, order_by = track_contexts(3) expr = Except(select1, select2, order_by=order_by) compile(expr) self.assertEqual(select1.context, SELECT) self.assertEqual(select2.context, SELECT) self.assertEqual(order_by.context, COLUMN_NAME) def test_intersect(self): expr = Intersect(Func1(), elem2, elem3) state = State() statement = compile(expr, state) self.assertEqual(statement, "func1() INTERSECT elem2 INTERSECT elem3") self.assertEqual(state.parameters, []) def test_intersect_all(self): expr = Intersect(Func1(), elem2, elem3, all=True) state = State() statement = compile(expr, state) self.assertEqual( statement, "func1() INTERSECT ALL elem2 INTERSECT ALL elem3") self.assertEqual(state.parameters, []) def test_intersect_order_by_limit_offset(self): expr = Intersect(elem1, elem2, order_by=Func1(), limit=1, offset=2) state = State() statement = compile(expr, state) self.assertEqual(statement, "elem1 INTERSECT elem2 ORDER BY func1() " "LIMIT 1 OFFSET 2") self.assertEqual(state.parameters, []) def test_intersect_select(self): expr = Intersect(Select(elem1), Select(elem2)) state = State() statement = compile(expr, state) self.assertEqual(statement, "(SELECT elem1) INTERSECT (SELECT elem2)") self.assertEqual(state.parameters, []) def test_intersect_select_nested(self): expr = Intersect( Select(elem1), Intersect(Select(elem2), Select(elem3))) state = State() statement = compile(expr, state) self.assertEqual( statement, "(SELECT elem1) INTERSECT" " ((SELECT elem2) INTERSECT (SELECT elem3))") self.assertEqual(state.parameters, []) def test_intersect_contexts(self): select1, select2, order_by = track_contexts(3) expr = Intersect(select1, select2, order_by=order_by) compile(expr) self.assertEqual(select1.context, SELECT) self.assertEqual(select2.context, SELECT) self.assertEqual(order_by.context, COLUMN_NAME) def test_auto_table(self): expr = Select(AutoTables(1, [table1])) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT ? FROM "table 1"') self.assertVariablesEqual(state.parameters, [IntVariable(1)]) def test_auto_tables_with_column(self): expr = Select(AutoTables(Column(elem1, table1), [table2])) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1 ' 'FROM "table 1", "table 2"') self.assertEqual(state.parameters, []) def test_auto_tables_with_column_and_replace(self): expr = Select(AutoTables(Column(elem1, table1), [table2], replace=True)) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1 FROM "table 2"') self.assertEqual(state.parameters, []) def test_auto_tables_with_join(self): expr = Select(AutoTables(Column(elem1, table1), [LeftJoin(table2)])) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1 FROM "table 1" ' 'LEFT JOIN "table 2"') self.assertEqual(state.parameters, []) def test_auto_tables_with_join_with_left_table(self): expr = Select(AutoTables(Column(elem1, table1), [LeftJoin(table1, table2)])) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1 FROM "table 1" ' 'LEFT JOIN "table 2"') self.assertEqual(state.parameters, []) def test_auto_tables_duplicated(self): expr = Select([AutoTables(Column(elem1, table1), [Join(table2)]), AutoTables(Column(elem2, table2), [Join(table1)]), AutoTables(Column(elem3, table1), [Join(table1)]), AutoTables(Column(elem4, table3), [table1]), AutoTables(Column(elem5, table1), [Join(table4, table5)])]) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1, "table 2".elem2, ' '"table 1".elem3, "table 3".elem4, "table 1".elem5 ' 'FROM "table 3", "table 4" JOIN "table 5" JOIN ' '"table 1" JOIN "table 2"') self.assertEqual(state.parameters, []) def test_auto_tables_duplicated_nested(self): expr = Select(AutoTables(Column(elem1, table1), [Join(table2)]), In(1, Select(AutoTables(Column(elem1, table1), [Join(table2)])))) state = State() statement = compile(expr, state) self.assertEqual(statement, 'SELECT "table 1".elem1 FROM "table 1" JOIN ' '"table 2" WHERE ? IN (SELECT "table 1".elem1 ' 'FROM "table 1" JOIN "table 2")') self.assertVariablesEqual(state.parameters, [IntVariable(1)]) def test_sql_token(self): expr = SQLToken("something") state = State() statement = compile(expr, state) self.assertEqual(statement, "something") self.assertEqual(state.parameters, []) def test_sql_token_spaces(self): expr = SQLToken("some thing") statement = compile(expr) self.assertEqual(statement, '"some thing"') def test_sql_token_quotes(self): expr = SQLToken("some'thing") statement = compile(expr) self.assertEqual(statement, '"some\'thing"') def test_sql_token_double_quotes(self): expr = SQLToken('some"thing') statement = compile(expr) self.assertEqual(statement, '"some""thing"') def test_sql_token_reserved(self): custom_compile = compile.create_child() custom_compile.add_reserved_words(["something"]) expr = SQLToken("something") state = State() statement = custom_compile(expr, state) self.assertEqual(statement, '"something"') self.assertEqual(state.parameters, []) def test_sql_token_reserved_from_parent(self): expr = SQLToken("something") parent_compile = compile.create_child() child_compile = parent_compile.create_child() statement = child_compile(expr) self.assertEqual(statement, "something") parent_compile.add_reserved_words(["something"]) statement = child_compile(expr) self.assertEqual(statement, '"something"') def test_sql_token_remove_reserved_word_on_child(self): expr = SQLToken("something") parent_compile = compile.create_child() parent_compile.add_reserved_words(["something"]) child_compile = parent_compile.create_child() statement = child_compile(expr) self.assertEqual(statement, '"something"') child_compile.remove_reserved_words(["something"]) statement = child_compile(expr) self.assertEqual(statement, "something") def test_is_reserved_word(self): parent_compile = compile.create_child() child_compile = parent_compile.create_child() self.assertEqual(child_compile.is_reserved_word("someTHING"), False) parent_compile.add_reserved_words(["SOMEthing"]) self.assertEqual(child_compile.is_reserved_word("somETHing"), True) child_compile.remove_reserved_words(["soMETHing"]) self.assertEqual(child_compile.is_reserved_word("somethING"), False) def test_sql1992_reserved_words(self): reserved_words = """ absolute action add all allocate alter and any are as asc assertion at authorization avg begin between bit bit_length both by cascade cascaded case cast catalog char character char_ length character_length check close coalesce collate collation column commit connect connection constraint constraints continue convert corresponding count create cross current current_date current_time current_timestamp current_ user cursor date day deallocate dec decimal declare default deferrable deferred delete desc describe descriptor diagnostics disconnect distinct domain double drop else end end-exec escape except exception exec execute exists external extract false fetch first float for foreign found from full get global go goto grant group having hour identity immediate in indicator initially inner input insensitive insert int integer intersect interval into is isolation join key language last leading left level like local lower match max min minute module month names national natural nchar next no not null nullif numeric octet_length of on only open option or order outer output overlaps pad partial position precision prepare preserve primary prior privileges procedure public read real references relative restrict revoke right rollback rows schema scroll second section select session session_ user set size smallint some space sql sqlcode sqlerror sqlstate substring sum system_user table temporary then time timestamp timezone_ hour timezone_minute to trailing transaction translate translation trim true union unique unknown update upper usage user using value values varchar varying view when whenever where with work write year zone """.split() for word in reserved_words: self.assertEqual(compile.is_reserved_word(word), True) class CompilePythonTest(TestHelper): def test_precedence(self): e = [SQLRaw('%d' % i) for i in range(10)] expr = And(e[1], Or(e[2], e[3]), Add(e[4], Mul(e[5], Sub(e[6], Div(e[7], Div(e[8], e[9])))))) py_expr = compile_python(expr) self.assertEqual(py_expr, "1 and (2 or 3) and 4 + 5 * (6 - 7 / (8 / 9))") def test_get_precedence(self): self.assertTrue(compile_python.get_precedence(Or) < compile_python.get_precedence(And)) self.assertTrue(compile_python.get_precedence(Add) < compile_python.get_precedence(Mul)) self.assertTrue(compile_python.get_precedence(Sub) < compile_python.get_precedence(Div)) def test_compile_sequence(self): expr = [elem1, Variable(1), (Variable(2), None)] state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "elem1, _0, _1, None") self.assertEqual(state.parameters, [1, 2]) def test_compile_invalid(self): self.assertRaises(CompileError, compile_python, object()) self.assertRaises(CompileError, compile_python, [object()]) def test_compile_unsupported(self): self.assertRaises(CompileError, compile_python, Expr()) self.assertRaises(CompileError, compile_python, Func1()) def test_bytes(self): py_expr = compile_python(b"str") self.assertEqual(py_expr, "b'str'") def test_unicode(self): py_expr = compile_python("str") self.assertEqual(py_expr, "'str'") def test_int(self): py_expr = compile_python(1) self.assertEqual(py_expr, "1") def test_bool(self): state = State() py_expr = compile_python(True, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, [True]) def test_float(self): py_expr = compile_python(1.1) self.assertEqual(py_expr, repr(1.1)) def test_datetime(self): dt = datetime(1977, 5, 4, 12, 34) state = State() py_expr = compile_python(dt, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, [dt]) def test_date(self): d = date(1977, 5, 4) state = State() py_expr = compile_python(d, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, [d]) def test_time(self): t = time(12, 34) state = State() py_expr = compile_python(t, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, [t]) def test_timedelta(self): td = timedelta(days=1, seconds=2, microseconds=3) state = State() py_expr = compile_python(td, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, [td]) def test_none(self): py_expr = compile_python(None) self.assertEqual(py_expr, "None") def test_column(self): expr = Column(column1) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "get_column(_0)") self.assertEqual(state.parameters, [expr]) def test_column_table(self): expr = Column(column1, table1) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "get_column(_0)") self.assertEqual(state.parameters, [expr]) def test_variable(self): expr = Variable("value") state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0") self.assertEqual(state.parameters, ["value"]) def test_is_null(self): expr = Is(Variable(True), None) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is None") self.assertEqual(state.parameters, [True]) def test_is_true(self): expr = Is(Variable(True), True) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is _1") self.assertEqual(state.parameters, [True, True]) def test_is_false(self): expr = Is(Variable(True), False) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is _1") self.assertEqual(state.parameters, [True, False]) def test_is_not_null(self): expr = IsNot(Variable(True), None) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is not None") self.assertEqual(state.parameters, [True]) def test_is_not_true(self): expr = IsNot(Variable(True), True) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is not _1") self.assertEqual(state.parameters, [True, True]) def test_is_not_false(self): expr = IsNot(Variable(True), False) state = State() statement = compile_python(expr, state) self.assertEqual(statement, "_0 is not _1") self.assertEqual(state.parameters, [True, False]) def test_eq(self): expr = Eq(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 == _1") self.assertEqual(state.parameters, [1, 2]) def test_ne(self): expr = Ne(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 != _1") self.assertEqual(state.parameters, [1, 2]) def test_gt(self): expr = Gt(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 > _1") self.assertEqual(state.parameters, [1, 2]) def test_ge(self): expr = Ge(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 >= _1") self.assertEqual(state.parameters, [1, 2]) def test_lt(self): expr = Lt(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 < _1") self.assertEqual(state.parameters, [1, 2]) def test_le(self): expr = Le(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 <= _1") self.assertEqual(state.parameters, [1, 2]) def test_lshift(self): expr = LShift(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 << _1") self.assertEqual(state.parameters, [1, 2]) def test_rshift(self): expr = RShift(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 >> _1") self.assertEqual(state.parameters, [1, 2]) def test_in(self): expr = In(Variable(1), Variable(2)) state = State() py_expr = compile_python(expr, state) self.assertEqual(py_expr, "_0 in (_1,)") self.assertEqual(state.parameters, [1, 2]) def test_and(self): expr = And(elem1, elem2, And(elem3, elem4)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 and elem2 and elem3 and elem4") def test_or(self): expr = Or(elem1, elem2, Or(elem3, elem4)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 or elem2 or elem3 or elem4") def test_add(self): expr = Add(elem1, elem2, Add(elem3, elem4)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 + elem2 + elem3 + elem4") def test_neg(self): expr = Neg(elem1) py_expr = compile_python(expr) self.assertEqual(py_expr, "-elem1") def test_sub(self): expr = Sub(elem1, Sub(elem2, elem3)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 - (elem2 - elem3)") expr = Sub(Sub(elem1, elem2), elem3) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 - elem2 - elem3") def test_mul(self): expr = Mul(elem1, elem2, Mul(elem3, elem4)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 * elem2 * elem3 * elem4") def test_div(self): expr = Div(elem1, Div(elem2, elem3)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 / (elem2 / elem3)") expr = Div(Div(elem1, elem2), elem3) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 / elem2 / elem3") def test_mod(self): expr = Mod(elem1, Mod(elem2, elem3)) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 % (elem2 % elem3)") expr = Mod(Mod(elem1, elem2), elem3) py_expr = compile_python(expr) self.assertEqual(py_expr, "elem1 % elem2 % elem3") def test_match(self): col1 = Column(column1) col2 = Column(column2) match = compile_python.get_matcher((col1 > 10) & (col2 < 10)) self.assertTrue(match({col1: 15, col2: 5}.get)) self.assertFalse(match({col1: 5, col2: 15}.get)) def test_match_bad_repr(self): """The get_matcher() works for expressions containing values whose repr is not valid Python syntax.""" class BadRepr: def __repr__(self): return "$Not a valid Python expression$" value = BadRepr() col1 = Column(column1) match = compile_python.get_matcher(col1 == Variable(value)) self.assertTrue(match({col1: value}.get)) class LazyValueExprTest(TestHelper): def test_expr_is_lazy_value(self): marker = object() expr = SQL("Hullah!") variable = Variable() variable.set(expr) self.assertIs(variable.get(marker), marker) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/helper.py0000644000175000017500000000716514645174376017211 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from io import StringIO import logging import shutil import sys import tempfile from storm.tests import mocker __all__ = ["TestHelper", "MakePath", "LogKeeper"] class TestHelper(mocker.MockerTestCase): helpers = [] def is_supported(self): return True def setUp(self): super().setUp() self._helper_instances = [] for helper_factory in self.helpers: helper = helper_factory() helper.set_up(self) self._helper_instances.append(helper) def tearDown(self): for helper in reversed(self._helper_instances): helper.tear_down(self) super().tearDown() @property def _testMethod(self): return getattr(self, self._testMethodName) def run(self, result=None): # Skip if is_supported() does not return True. if not self.is_supported(): if hasattr(result, "addSkip"): result.startTest(self) result.addSkip(self, "Test not supported") return super().run(result) def assertVariablesEqual(self, checked, expected): self.assertEqual(len(checked), len(expected)) for check, expect in zip(checked, expected): self.assertEqual(check.__class__, expect.__class__) self.assertEqual(check.get(), expect.get()) class MakePath: def set_up(self, test_case): self.dirname = tempfile.mkdtemp() self.dirs = [] self.counter = 0 test_case.make_dir = self.make_dir test_case.make_path = self.make_path def tear_down(self, test_case): shutil.rmtree(self.dirname) [shutil.rmtree(dir) for dir in self.dirs] def make_dir(self): path = tempfile.mkdtemp() self.dirs.append(path) return path def make_path(self, content=None, path=None): if path is None: self.counter += 1 path = "%s/%03d" % (self.dirname, self.counter) if content is not None: file = open(path, "w") try: file.write(content) finally: file.close() return path class LogKeeper: """Record logging information. Puts a 'logfile' attribute on your test case, which is a StringIO containing all log output. """ def set_up(self, test_case): logger = logging.getLogger() test_case.logfile = StringIO() handler = logging.StreamHandler(test_case.logfile) self.old_handlers = logger.handlers # Sanity check; this might not be 100% what we want if self.old_handlers: test_case.assertEqual(len(self.old_handlers), 1) test_case.assertEqual(self.old_handlers[0].stream, sys.stderr) logger.handlers = [handler] def tear_down(self, test_case): logging.getLogger().handlers = self.old_handlers ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/info.py0000644000175000017500000005573314645174376016671 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from weakref import ref import gc from storm.exceptions import ClassInfoError from storm.properties import Property from storm.variables import Variable from storm.expr import Undef, Select, compile from storm.info import * from storm.tests.helper import TestHelper class Wrapper: def __init__(self, obj): self.obj = obj __storm_object_info__ = property(lambda self: self.obj.__storm_object_info__) class GetTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "table" prop1 = Property("column1", primary=True) self.Class = Class self.obj = Class() def test_get_cls_info(self): cls_info = get_cls_info(self.Class) self.assertTrue(isinstance(cls_info, ClassInfo)) self.assertTrue(cls_info is get_cls_info(self.Class)) def test_get_obj_info(self): obj_info = get_obj_info(self.obj) self.assertTrue(isinstance(obj_info, ObjectInfo)) self.assertTrue(obj_info is get_obj_info(self.obj)) def test_get_obj_info_on_obj_info(self): obj_info = get_obj_info(self.obj) self.assertTrue(get_obj_info(obj_info) is obj_info) def test_set_obj_info(self): obj_info1 = get_obj_info(self.obj) obj_info2 = ObjectInfo(self.obj) self.assertEqual(get_obj_info(self.obj), obj_info1) set_obj_info(self.obj, obj_info2) self.assertEqual(get_obj_info(self.obj), obj_info2) class ClassInfoTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "table" prop1 = Property("column1", primary=True) prop2 = Property("column2") self.Class = Class self.cls_info = get_cls_info(Class) def test_invalid_class(self): class Class: pass self.assertRaises(ClassInfoError, ClassInfo, Class) def test_cls(self): self.assertEqual(self.cls_info.cls, self.Class) def test_columns(self): self.assertEqual(self.cls_info.columns, (self.Class.prop1, self.Class.prop2)) def test_table(self): self.assertEqual(self.cls_info.table.name, "table") def test_primary_key(self): # Can't use == for props. self.assertTrue(self.cls_info.primary_key[0] is self.Class.prop1) self.assertEqual(len(self.cls_info.primary_key), 1) def test_primary_key_with_attribute(self): class SubClass(self.Class): __storm_primary__ = "prop2" cls_info = get_cls_info(SubClass) self.assertTrue(cls_info.primary_key[0] is SubClass.prop2) self.assertEqual(len(self.cls_info.primary_key), 1) def test_primary_key_composed(self): class Class: __storm_table__ = "table" prop1 = Property("column1", primary=2) prop2 = Property("column2", primary=1) cls_info = ClassInfo(Class) # Can't use == for props, since they're columns. self.assertTrue(cls_info.primary_key[0] is Class.prop2) self.assertTrue(cls_info.primary_key[1] is Class.prop1) self.assertEqual(len(cls_info.primary_key), 2) def test_primary_key_composed_with_attribute(self): class Class: __storm_table__ = "table" __storm_primary__ = "prop2", "prop1" # Define primary=True to ensure that the attribute # has precedence. prop1 = Property("column1", primary=True) prop2 = Property("column2") cls_info = ClassInfo(Class) # Can't use == for props, since they're columns. self.assertTrue(cls_info.primary_key[0] is Class.prop2) self.assertTrue(cls_info.primary_key[1] is Class.prop1) self.assertEqual(len(cls_info.primary_key), 2) def test_primary_key_composed_duplicated(self): class Class: __storm_table__ = "table" prop1 = Property("column1", primary=True) prop2 = Property("column2", primary=True) self.assertRaises(ClassInfoError, ClassInfo, Class) def test_primary_key_missing(self): class Class: __storm_table__ = "table" prop1 = Property("column1") prop2 = Property("column2") self.assertRaises(ClassInfoError, ClassInfo, Class) def test_primary_key_attribute_missing(self): class Class: __storm_table__ = "table" __storm_primary__ = () prop1 = Property("column1", primary=True) prop2 = Property("column2") self.assertRaises(ClassInfoError, ClassInfo, Class) def test_primary_key_pos(self): class Class: __storm_table__ = "table" prop1 = Property("column1", primary=2) prop2 = Property("column2") prop3 = Property("column3", primary=1) cls_info = ClassInfo(Class) self.assertEqual(cls_info.primary_key_pos, (2, 0)) class ObjectInfoTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "table" prop1 = Property("column1", primary=True) prop2 = Property("column2") self.Class = Class self.obj = Class() self.obj_info = get_obj_info(self.obj) self.cls_info = get_cls_info(Class) self.variable1 = self.obj_info.variables[Class.prop1] self.variable2 = self.obj_info.variables[Class.prop2] def test_hashing(self): self.assertEqual(hash(self.obj_info), hash(self.obj_info)) def test_equals(self): obj_info1 = self.obj_info obj_info2 = get_obj_info(self.Class()) self.assertFalse(obj_info1 == obj_info2) def test_not_equals(self): obj_info1 = self.obj_info obj_info2 = get_obj_info(self.Class()) self.assertTrue(obj_info1 != obj_info2) def test_dict_subclass(self): self.assertTrue(isinstance(self.obj_info, dict)) def test_variables(self): self.assertTrue(isinstance(self.obj_info.variables, dict)) for column in self.cls_info.columns: variable = self.obj_info.variables.get(column) self.assertTrue(isinstance(variable, Variable)) self.assertTrue(variable.column is column) self.assertEqual(len(self.obj_info.variables), len(self.cls_info.columns)) def test_variable_has_validator_object_factory(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) class Class: __storm_table__ = "table" prop = Property(primary=True, variable_kwargs={"validator": validator}) obj = Class() get_obj_info(obj).variables[Class.prop].set(123) self.assertEqual(args, [(obj, "prop", 123)]) def test_primary_vars(self): self.assertTrue(isinstance(self.obj_info.primary_vars, tuple)) for column, variable in zip(self.cls_info.primary_key, self.obj_info.primary_vars): self.assertEqual(self.obj_info.variables.get(column), variable) self.assertEqual(len(self.obj_info.primary_vars), len(self.cls_info.primary_key)) def test_checkpoint(self): self.obj.prop1 = 10 self.obj_info.checkpoint() self.assertEqual(self.obj.prop1, 10) self.assertEqual(self.variable1.has_changed(), False) self.obj.prop1 = 20 self.assertEqual(self.obj.prop1, 20) self.assertEqual(self.variable1.has_changed(), True) self.obj_info.checkpoint() self.assertEqual(self.obj.prop1, 20) self.assertEqual(self.variable1.has_changed(), False) self.obj.prop1 = 20 self.assertEqual(self.obj.prop1, 20) self.assertEqual(self.variable1.has_changed(), False) def test_add_change_notification(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb): changes1.append((1, obj_info, variable, old_value, new_value, fromdb)) def object_changed2(obj_info, variable, old_value, new_value, fromdb): changes2.append((2, obj_info, variable, old_value, new_value, fromdb)) self.obj_info.checkpoint() self.obj_info.event.hook("changed", object_changed1) self.obj_info.event.hook("changed", object_changed2) self.obj.prop2 = 10 self.obj.prop1 = 20 self.assertEqual(changes1, [(1, self.obj_info, self.variable2, Undef, 10, False), (1, self.obj_info, self.variable1, Undef, 20, False)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 10, False), (2, self.obj_info, self.variable1, Undef, 20, False)]) del changes1[:] del changes2[:] self.obj.prop1 = None self.obj.prop2 = None self.assertEqual(changes1, [(1, self.obj_info, self.variable1, 20, None, False), (1, self.obj_info, self.variable2, 10, None, False)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable1, 20, None, False), (2, self.obj_info, self.variable2, 10, None, False)]) del changes1[:] del changes2[:] del self.obj.prop1 del self.obj.prop2 self.assertEqual(changes1, [(1, self.obj_info, self.variable1, None, Undef, False), (1, self.obj_info, self.variable2, None, Undef, False)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable1, None, Undef, False), (2, self.obj_info, self.variable2, None, Undef, False)]) def test_add_change_notification_with_arg(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb, arg): changes1.append((1, obj_info, variable, old_value, new_value, fromdb, arg)) def object_changed2(obj_info, variable, old_value, new_value, fromdb, arg): changes2.append((2, obj_info, variable, old_value, new_value, fromdb, arg)) self.obj_info.checkpoint() obj = object() self.obj_info.event.hook("changed", object_changed1, obj) self.obj_info.event.hook("changed", object_changed2, obj) self.obj.prop2 = 10 self.obj.prop1 = 20 self.assertEqual(changes1, [(1, self.obj_info, self.variable2, Undef, 10, False, obj), (1, self.obj_info, self.variable1, Undef, 20, False, obj)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 10, False, obj), (2, self.obj_info, self.variable1, Undef, 20, False, obj)]) del changes1[:] del changes2[:] self.obj.prop1 = None self.obj.prop2 = None self.assertEqual(changes1, [(1, self.obj_info, self.variable1, 20, None, False, obj), (1, self.obj_info, self.variable2, 10, None, False, obj)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable1, 20, None, False, obj), (2, self.obj_info, self.variable2, 10, None, False, obj)]) del changes1[:] del changes2[:] del self.obj.prop1 del self.obj.prop2 self.assertEqual(changes1, [(1, self.obj_info, self.variable1, None, Undef, False, obj), (1, self.obj_info, self.variable2, None, Undef, False, obj)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable1, None, Undef, False, obj), (2, self.obj_info, self.variable2, None, Undef, False, obj)]) def test_remove_change_notification(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb): changes1.append((1, obj_info, variable, old_value, new_value, fromdb)) def object_changed2(obj_info, variable, old_value, new_value, fromdb): changes2.append((2, obj_info, variable, old_value, new_value, fromdb)) self.obj_info.checkpoint() self.obj_info.event.hook("changed", object_changed1) self.obj_info.event.hook("changed", object_changed2) self.obj_info.event.unhook("changed", object_changed1) self.obj.prop2 = 20 self.obj.prop1 = 10 self.assertEqual(changes1, []) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 20, False), (2, self.obj_info, self.variable1, Undef, 10, False)]) def test_remove_change_notification_with_arg(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb, arg): changes1.append((1, obj_info, variable, old_value, new_value, fromdb, arg)) def object_changed2(obj_info, variable, old_value, new_value, fromdb, arg): changes2.append((2, obj_info, variable, old_value, new_value, fromdb, arg)) self.obj_info.checkpoint() obj = object() self.obj_info.event.hook("changed", object_changed1, obj) self.obj_info.event.hook("changed", object_changed2, obj) self.obj_info.event.unhook("changed", object_changed1, obj) self.obj.prop2 = 20 self.obj.prop1 = 10 self.assertEqual(changes1, []) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 20, False, obj), (2, self.obj_info, self.variable1, Undef, 10, False, obj)]) def test_auto_remove_change_notification(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb): changes1.append((1, obj_info, variable, old_value, new_value, fromdb)) return False def object_changed2(obj_info, variable, old_value, new_value, fromdb): changes2.append((2, obj_info, variable, old_value, new_value, fromdb)) return False self.obj_info.checkpoint() self.obj_info.event.hook("changed", object_changed1) self.obj_info.event.hook("changed", object_changed2) self.obj.prop2 = 20 self.obj.prop1 = 10 self.assertEqual(changes1, [(1, self.obj_info, self.variable2, Undef, 20, False)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 20, False)]) def test_auto_remove_change_notification_with_arg(self): changes1 = [] changes2 = [] def object_changed1(obj_info, variable, old_value, new_value, fromdb, arg): changes1.append((1, obj_info, variable, old_value, new_value, fromdb, arg)) return False def object_changed2(obj_info, variable, old_value, new_value, fromdb, arg): changes2.append((2, obj_info, variable, old_value, new_value, fromdb, arg)) return False self.obj_info.checkpoint() obj = object() self.obj_info.event.hook("changed", object_changed1, obj) self.obj_info.event.hook("changed", object_changed2, obj) self.obj.prop2 = 20 self.obj.prop1 = 10 self.assertEqual(changes1, [(1, self.obj_info, self.variable2, Undef, 20, False, obj)]) self.assertEqual(changes2, [(2, self.obj_info, self.variable2, Undef, 20, False, obj)]) def test_get_obj(self): self.assertTrue(self.obj_info.get_obj() is self.obj) def test_get_obj_reference(self): """ We used to assign the get_obj() manually. This breaks stored references to the method (IOW, what we do in the test below). It was a bit faster, but in exchange for the danger of introducing subtle bugs which are super hard to debug. """ get_obj = self.obj_info.get_obj self.assertTrue(get_obj() is self.obj) another_obj = self.Class() self.obj_info.set_obj(another_obj) self.assertTrue(get_obj() is another_obj) def test_set_obj(self): obj = self.Class() self.obj_info.set_obj(obj) self.assertTrue(self.obj_info.get_obj() is obj) def test_weak_reference(self): obj = self.Class() obj_info = get_obj_info(obj) del obj self.assertEqual(obj_info.get_obj(), None) def test_object_deleted_notification(self): obj = self.Class() obj_info = get_obj_info(obj) obj_info["tainted"] = True deleted = [] def object_deleted(obj_info): deleted.append(obj_info) obj_info.event.hook("object-deleted", object_deleted) del obj_info del obj self.assertEqual(len(deleted), 1) self.assertTrue("tainted" in deleted[0]) def test_object_deleted_notification_after_set_obj(self): obj = self.Class() obj_info = get_obj_info(obj) obj_info["tainted"] = True obj = self.Class() obj_info.set_obj(obj) deleted = [] def object_deleted(obj_info): deleted.append(obj_info) obj_info.event.hook("object-deleted", object_deleted) del obj_info del obj self.assertEqual(len(deleted), 1) self.assertTrue("tainted" in deleted[0]) class ClassAliasTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "table" prop1 = Property("column1", primary=True) self.Class = Class self.ClassAlias = ClassAlias(self.Class, "alias") def test_cls_info_cls(self): cls_info = get_cls_info(self.ClassAlias) self.assertEqual(cls_info.cls, self.Class) self.assertEqual(cls_info.table.name, "alias") self.assertEqual(self.ClassAlias.prop1.name, "column1") self.assertEqual(self.ClassAlias.prop1.table, self.ClassAlias) def test_compile(self): statement = compile(self.ClassAlias) self.assertEqual(statement, "alias") def test_compile_with_reserved_keyword(self): Alias = ClassAlias(self.Class, "select") statement = compile(Alias) self.assertEqual(statement, '"select"') def test_compile_in_select(self): expr = Select(self.ClassAlias.prop1, self.ClassAlias.prop1 == 1, self.ClassAlias) statement = compile(expr) self.assertEqual(statement, 'SELECT alias.column1 FROM "table" AS alias ' 'WHERE alias.column1 = ?') def test_compile_in_select_with_reserved_keyword(self): Alias = ClassAlias(self.Class, "select") expr = Select(Alias.prop1, Alias.prop1 == 1, Alias) statement = compile(expr) self.assertEqual(statement, 'SELECT "select".column1 FROM "table" AS "select" ' 'WHERE "select".column1 = ?') def test_crazy_metaclass(self): """We don't want metaclasses playing around when we build an alias.""" TestHelper.setUp(self) class MetaClass(type): def __new__(meta_cls, name, bases, dict): cls = type.__new__(meta_cls, name, bases, dict) cls.__storm_table__ = "HAH! GOTCH YA!" return cls class Class(metaclass=MetaClass): __storm_table__ = "table" prop1 = Property("column1", primary=True) Alias = ClassAlias(Class, "USE_THIS") self.assertEqual(Alias.__storm_table__, "USE_THIS") def test_cached_aliases(self): """ Class aliases are cached such that multiple invocations of C{ClassAlias} return the same object. """ alias1 = ClassAlias(self.Class, "something_unlikely") alias2 = ClassAlias(self.Class, "something_unlikely") self.assertIdentical(alias1, alias2) alias3 = ClassAlias(self.Class, "something_unlikely2") self.assertNotIdentical(alias1, alias3) alias4 = ClassAlias(self.Class, "something_unlikely2") self.assertIdentical(alias3, alias4) def test_unnamed_aliases_not_cached(self): alias1 = ClassAlias(self.Class) alias2 = ClassAlias(self.Class) self.assertNotIdentical(alias1, alias2) def test_alias_cache_is_per_class(self): """ The cache of class aliases is not as bad as it once was. """ class LocalClass(self.Class): pass alias = ClassAlias(self.Class, "something_unlikely") alias2 = ClassAlias(LocalClass, "something_unlikely") self.assertNotIdentical(alias, alias2) def test_aliases_only_last_as_long_as_class(self): """ The cached ClassAliases only last for as long as the class is alive. """ class LocalClass(self.Class): pass alias = ClassAlias(LocalClass, "something_unlikely3") alias_ref = ref(alias) class_ref = ref(LocalClass) del alias del LocalClass gc.collect(); gc.collect(); gc.collect() self.assertIdentical(class_ref(), None) self.assertIdentical(alias_ref(), None) class TypeCompilerTest(TestHelper): def test_nested_classes(self): """Convoluted case checking that the model is right.""" class Class1: __storm_table__ = "class1" id = Property(primary=True) class Class2: __storm_table__ = Class1 id = Property(primary=True) statement = compile(Class2) self.assertEqual(statement, "class1") alias = ClassAlias(Class2, "alias") statement = compile(Select(alias.id)) self.assertEqual(statement, "SELECT alias.id FROM class1 AS alias") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039171.0 storm-1.0/storm/tests/mocker.py0000644000175000017500000022120014645174503017166 0ustar00cjwatsoncjwatson""" Copyright (c) 2007 Gustavo Niemeyer Graceful platform for test doubles in Python (mocks, stubs, fakes, and dummies). """ import builtins import tempfile import unittest import inspect import shutil import sys import os import gc __all__ = ["Mocker", "expect", "IS", "CONTAINS", "IN", "MATCH", "ANY", "ARGS", "KWARGS"] __author__ = "Gustavo Niemeyer " __license__ = "PSF License" __version__ = "0.10.1" ERROR_PREFIX = "[Mocker] " # -------------------------------------------------------------------- # Exceptions class MatchError(AssertionError): """Raised when an unknown expression is seen in playback mode.""" # -------------------------------------------------------------------- # Helper for chained-style calling. class expect: """This is a simple helper that allows a different call-style. With this class one can comfortably do chaining of calls to the mocker object responsible by the object being handled. For instance:: expect(obj.attr).result(3).count(1, 2) Is the same as:: obj.attr mocker.result(3) mocker.count(1, 2) """ def __init__(self, mock, attr=None): self._mock = mock self._attr = attr def __getattr__(self, attr): return self.__class__(self._mock, attr) def __call__(self, *args, **kwargs): getattr(self._mock.__mocker__, self._attr)(*args, **kwargs) return self # -------------------------------------------------------------------- # Extensions to Python's unittest. class MockerTestCase(unittest.TestCase): """unittest.TestCase subclass with Mocker support. @ivar mocker: The mocker instance. This is a convenience only. Mocker may easily be used with the standard C{unittest.TestCase} class if wanted. Test methods have a Mocker instance available on C{self.mocker}. At the end of each test method, expectations of the mocker will be verified, and any requested changes made to the environment will be restored. In addition to the integration with Mocker, this class provides a few additional helper methods. """ expect = expect def __init__(self, methodName="runTest"): # So here is the trick: we take the real test method, wrap it on # a function that do the job we have to do, and insert it in the # *instance* dictionary, so that getattr() will return our # replacement rather than the class method. test_method = getattr(self, methodName, None) if test_method is not None: def test_method_wrapper(): try: result = test_method() except: raise else: if (self.mocker.is_recording() and self.mocker.get_events()): raise RuntimeError("Mocker must be put in replay " "mode with self.mocker.replay()") if (hasattr(result, "addCallback") and hasattr(result, "addErrback")): def verify(result): self.mocker.verify() return result result.addCallback(verify) else: self.mocker.verify() return result # Copy all attributes from the original method.. for attr in dir(test_method): # .. unless they're present in our wrapper already. if not hasattr(test_method_wrapper, attr) or attr == "__doc__": setattr(test_method_wrapper, attr, getattr(test_method, attr)) setattr(self, methodName, test_method_wrapper) # We could overload run() normally, but other well-known testing # frameworks do it as well, and some of them won't call the super, # which might mean that cleanup wouldn't happen. With that in mind, # we make integration easier by using the following trick. run_method = self.run def run_wrapper(*args, **kwargs): try: return run_method(*args, **kwargs) finally: self.__cleanup() self.run = run_wrapper self.mocker = Mocker() self.__cleanup_funcs = [] self.__cleanup_paths = [] super().__init__(methodName) def __cleanup(self): for path in self.__cleanup_paths: if os.path.isfile(path): os.unlink(path) elif os.path.isdir(path): shutil.rmtree(path) self.mocker.restore() for func, args, kwargs in reversed(self.__cleanup_funcs): func(*args, **kwargs) def addCleanup(self, func, *args, **kwargs): self.__cleanup_funcs.append((func, args, kwargs)) def makeFile(self, content=None, suffix="", prefix="tmp", basename=None, dirname=None, path=None): """Create a temporary file and return the path to it. @param content: Initial content for the file. @param suffix: Suffix to be given to the file's basename. @param prefix: Prefix to be given to the file's basename. @param basename: Full basename for the file. @param dirname: Put file inside this directory. The file is removed after the test runs. """ if path is not None: self.__cleanup_paths.append(path) elif basename is not None: if dirname is None: dirname = tempfile.mkdtemp() self.__cleanup_paths.append(dirname) path = os.path.join(dirname, basename) else: fd, path = tempfile.mkstemp(suffix, prefix, dirname) self.__cleanup_paths.append(path) os.close(fd) if content is None: os.unlink(path) if content is not None: file = open(path, "w") file.write(content) file.close() return path def makeDir(self, suffix="", prefix="tmp", dirname=None, path=None): """Create a temporary directory and return the path to it. @param suffix: Suffix to be given to the file's basename. @param prefix: Prefix to be given to the file's basename. @param dirname: Put directory inside this parent directory. The directory is removed after the test runs. """ if path is not None: os.makedirs(path) else: path = tempfile.mkdtemp(suffix, prefix, dirname) self.__cleanup_paths.append(path) return path def failUnlessIs(self, first, second, msg=None): """Assert that C{first} is the same object as C{second}.""" if first is not second: raise self.failureException(msg or "%r is not %r" % (first, second)) def failIfIs(self, first, second, msg=None): """Assert that C{first} is not the same object as C{second}.""" if first is second: raise self.failureException(msg or "%r is %r" % (first, second)) def failUnlessIn(self, first, second, msg=None): """Assert that C{first} is contained in C{second}.""" if first not in second: raise self.failureException(msg or "%r not in %r" % (first, second)) def failUnlessStartsWith(self, first, second, msg=None): """Assert that C{first} starts with C{second}.""" if first[:len(second)] != second: raise self.failureException(msg or "%r doesn't start with %r" % (first, second)) def failIfStartsWith(self, first, second, msg=None): """Assert that C{first} doesn't start with C{second}.""" if first[:len(second)] == second: raise self.failureException(msg or "%r starts with %r" % (first, second)) def failUnlessEndsWith(self, first, second, msg=None): """Assert that C{first} starts with C{second}.""" if first[len(first)-len(second):] != second: raise self.failureException(msg or "%r doesn't end with %r" % (first, second)) def failIfEndsWith(self, first, second, msg=None): """Assert that C{first} doesn't start with C{second}.""" if first[len(first)-len(second):] == second: raise self.failureException(msg or "%r ends with %r" % (first, second)) def failIfIn(self, first, second, msg=None): """Assert that C{first} is not contained in C{second}.""" if first in second: raise self.failureException(msg or "%r in %r" % (first, second)) def failUnlessApproximates(self, first, second, tolerance, msg=None): """Assert that C{first} is near C{second} by at most C{tolerance}.""" if abs(first - second) > tolerance: raise self.failureException(msg or "abs(%r - %r) > %r" % (first, second, tolerance)) def failIfApproximates(self, first, second, tolerance, msg=None): """Assert that C{first} is far from C{second} by at least C{tolerance}. """ if abs(first - second) <= tolerance: raise self.failureException(msg or "abs(%r - %r) <= %r" % (first, second, tolerance)) def failUnlessMethodsMatch(self, first, second): """Assert that public methods in C{first} are present in C{second}. This method asserts that all public methods found in C{first} are also present in C{second} and accept the same arguments. C{first} may have its own private methods, though, and may not have all methods found in C{second}. Note that if a private method in C{first} matches the name of one in C{second}, their specification is still compared. This is useful to verify if a fake or stub class have the same API as the real class being simulated. """ first_methods = dict(inspect.getmembers(first, inspect.ismethod)) second_methods = dict(inspect.getmembers(second, inspect.ismethod)) for name, first_method in first_methods.items(): first_argspec = inspect.getargspec(first_method) first_formatted = inspect.formatargspec(*first_argspec) second_method = second_methods.get(name) if second_method is None: if name[:1] == "_": continue # First may have its own private methods. raise self.failureException("%s.%s%s not present in %s" % (first.__name__, name, first_formatted, second.__name__)) second_argspec = inspect.getargspec(second_method) if first_argspec != second_argspec: second_formatted = inspect.formatargspec(*second_argspec) raise self.failureException("%s.%s%s != %s.%s%s" % (first.__name__, name, first_formatted, second.__name__, name, second_formatted)) assertIs = failUnlessIs assertIsNot = failIfIs assertIn = failUnlessIn assertNotIn = failIfIn assertStartsWith = failUnlessStartsWith assertNotStartsWith = failIfStartsWith assertEndsWith = failUnlessEndsWith assertNotEndsWith = failIfEndsWith assertApproximates = failUnlessApproximates assertNotApproximates = failIfApproximates assertMethodsMatch = failUnlessMethodsMatch # The following is provided for compatibility with Twisted's trial. assertIdentical = assertIs assertNotIdentical = assertIsNot failUnlessIdentical = failUnlessIs failIfIdentical = failIfIs # -------------------------------------------------------------------- # Mocker. class classinstancemethod: def __init__(self, method): self.method = method def __get__(self, obj, cls=None): def bound_method(*args, **kwargs): return self.method(cls, obj, *args, **kwargs) return bound_method class MockerMeta(type): def __init__(self, name, bases, dict): # Make independent lists on each subclass, inheriting from parent. self._recorders = list(getattr(self, "_recorders", ())) class MockerBase(metaclass=MockerMeta): """Controller of mock objects. A mocker instance is used to command recording and replay of expectations on any number of mock objects. Expectations should be expressed for the mock object while in record mode (the initial one) by using the mock object itself, and using the mocker (and/or C{expect()} as a helper) to define additional behavior for each event. For instance:: mock = mocker.mock() mock.hello() mocker.result("Hi!") mocker.replay() assert mock.hello() == "Hi!" mock.restore() mock.verify() In this short excerpt a mock object is being created, then an expectation of a call to the C{hello()} method was recorded, and when called the method should return the value C{10}. Then, the mocker is put in replay mode, and the expectation is satisfied by calling the C{hello()} method, which indeed returns 10. Finally, a call to the L{restore()} method is performed to undo any needed changes made in the environment, and the L{verify()} method is called to ensure that all defined expectations were met. The same logic can be expressed more elegantly using the C{with mocker:} statement, as follows:: mock = mocker.mock() mock.hello() mocker.result("Hi!") with mocker: assert mock.hello() == "Hi!" Also, the MockerTestCase class, which integrates the mocker on a unittest.TestCase subclass, may be used to reduce the overhead of controlling the mocker. A test could be written as follows:: class SampleTest(MockerTestCase): def test_hello(self): mock = self.mocker.mock() mock.hello() self.mocker.result("Hi!") self.mocker.replay() self.assertEqual(mock.hello(), "Hi!") """ _recorders = [] # For convenience only. on = expect def __init__(self): self._recorders = self._recorders[:] self._events = [] self._recording = True self._ordering = False self._last_orderer = None def is_recording(self): """Return True if in recording mode, False if in replay mode. Recording is the initial state. """ return self._recording def replay(self): """Change to replay mode, where recorded events are reproduced. If already in replay mode, the mocker will be restored, with all expectations reset, and then put again in replay mode. An alternative and more comfortable way to replay changes is using the 'with' statement, as follows:: mocker = Mocker() with mocker: The 'with' statement will automatically put mocker in replay mode, and will also verify if all events were correctly reproduced at the end (using L{verify()}), and also restore any changes done in the environment (with L{restore()}). Also check the MockerTestCase class, which integrates the unittest.TestCase class with mocker. """ if not self._recording: for event in self._events: event.restore() else: self._recording = False for event in self._events: event.replay() def restore(self): """Restore changes in the environment, and return to recording mode. This should always be called after the test is complete (succeeding or not). There are ways to call this method automatically on completion (e.g. using a C{with mocker:} statement, or using the L{MockerTestCase} class. """ if not self._recording: self._recording = True for event in self._events: event.restore() def reset(self): """Reset the mocker state. This will restore environment changes, if currently in replay mode, and then remove all events previously recorded. """ if not self._recording: self.restore() self.unorder() del self._events[:] def get_events(self): """Return all recorded events.""" return self._events[:] def add_event(self, event): """Add an event. This method is used internally by the implementation, and shouldn't be needed on normal mocker usage. """ self._events.append(event) if self._ordering: orderer = event.add_task(Orderer(event.path)) if self._last_orderer: orderer.add_dependency(self._last_orderer) self._last_orderer = orderer return event def verify(self): """Check if all expectations were met, and raise AssertionError if not. The exception message will include a nice description of which expectations were not met, and why. """ errors = [] for event in self._events: try: event.verify() except AssertionError as e: error = str(e) if not error: raise RuntimeError("Empty error message from %r" % event) errors.append(error) if errors: message = [ERROR_PREFIX + "Unmet expectations:", ""] for error in errors: lines = error.splitlines() message.append("=> " + lines.pop(0)) message.extend([" " + line for line in lines]) message.append("") raise AssertionError(os.linesep.join(message)) def mock(self, spec_and_type=None, spec=None, type=None, name=None, count=True): """Return a new mock object. @param spec_and_type: Handy positional argument which sets both spec and type. @param spec: Method calls will be checked for correctness against the given class. @param type: If set, the Mock's __class__ attribute will return the given type. This will make C{isinstance()} calls on the object work. @param name: Name for the mock object, used in the representation of expressions. The name is rarely needed, as it's usually guessed correctly from the variable name used. @param count: If set to false, expressions may be executed any number of times, unless an expectation is explicitly set using the L{count()} method. By default, expressions are expected once. """ if spec_and_type is not None: spec = type = spec_and_type return Mock(self, spec=spec, type=type, name=name, count=count) def proxy(self, object, spec=True, type=True, name=None, count=True, passthrough=True): """Return a new mock object which proxies to the given object. Proxies are useful when only part of the behavior of an object is to be mocked. Unknown expressions may be passed through to the real implementation implicitly (if the C{passthrough} argument is True), or explicitly (using the L{passthrough()} method on the event). @param object: Real object to be proxied, and replaced by the mock on replay mode. It may also be an "import path", such as C{"time.time"}, in which case the object will be the C{time} function from the C{time} module. @param spec: Method calls will be checked for correctness against the given object, which may be a class or an instance where attributes will be looked up. Defaults to the the C{object} parameter. May be set to None explicitly, in which case spec checking is disabled. Checks may also be disabled explicitly on a per-event basis with the L{nospec()} method. @param type: If set, the Mock's __class__ attribute will return the given type. This will make C{isinstance()} calls on the object work. Defaults to the type of the C{object} parameter. May be set to None explicitly. @param name: Name for the mock object, used in the representation of expressions. The name is rarely needed, as it's usually guessed correctly from the variable name used. @param count: If set to false, expressions may be executed any number of times, unless an expectation is explicitly set using the L{count()} method. By default, expressions are expected once. @param passthrough: If set to False, passthrough of actions on the proxy to the real object will only happen when explicitly requested via the L{passthrough()} method. """ if isinstance(object, str): if name is None: name = object import_stack = object.split(".") attr_stack = [] while import_stack: module_path = ".".join(import_stack) try: object = __import__(module_path, {}, {}, [""]) except ImportError: attr_stack.insert(0, import_stack.pop()) if not import_stack: raise continue else: for attr in attr_stack: object = getattr(object, attr) break if spec is True: spec = object if type is True: type = builtins.type(object) return Mock(self, spec=spec, type=type, object=object, name=name, count=count, passthrough=passthrough) def replace(self, object, spec=True, type=True, name=None, count=True, passthrough=True): """Create a proxy, and replace the original object with the mock. On replay, the original object will be replaced by the returned proxy in all dictionaries found in the running interpreter via the garbage collecting system. This should cover module namespaces, class namespaces, instance namespaces, and so on. @param object: Real object to be proxied, and replaced by the mock on replay mode. It may also be an "import path", such as C{"time.time"}, in which case the object will be the C{time} function from the C{time} module. @param spec: Method calls will be checked for correctness against the given object, which may be a class or an instance where attributes will be looked up. Defaults to the the C{object} parameter. May be set to None explicitly, in which case spec checking is disabled. Checks may also be disabled explicitly on a per-event basis with the L{nospec()} method. @param type: If set, the Mock's __class__ attribute will return the given type. This will make C{isinstance()} calls on the object work. Defaults to the type of the C{object} parameter. May be set to None explicitly. @param name: Name for the mock object, used in the representation of expressions. The name is rarely needed, as it's usually guessed correctly from the variable name used. @param passthrough: If set to False, passthrough of actions on the proxy to the real object will only happen when explicitly requested via the L{passthrough()} method. """ mock = self.proxy(object, spec, type, name, count, passthrough) event = self._get_replay_restore_event() event.add_task(ProxyReplacer(mock)) return mock def patch(self, object, spec=True): """Patch an existing object to reproduce recorded events. The result of this method is still a mock object, which can be used like any other mock object to record events. The difference is that when the mocker is put on replay mode, the *real* object will be modified to behave according to recorded expectations. Patching works in individual instances, and also in classes. When an instance is patched, recorded events will only be considered on this specific instance, and other instances should behave normally. When a class is patched, the reproduction of events will be considered on any instance of this class once created (collectively). Observe that, unlike with proxies which catch only events done through the mock object, *all* accesses to recorded expectations will be considered; even these coming from the object itself (e.g. C{self.hello()} is considered if this method was patched). While this is a very powerful feature, and many times the reason to use patches in the first place, it's important to keep this behavior in mind. Patching of the original object only takes place when the mocker is put on replay mode, and the patched object will be restored to its original state once the L{restore()} method is called (explicitly, or implicitly with alternative conventions, such as a C{with mocker:} block, or a MockerTestCase class). @param object: Class or instance to be patched. @param spec: Method calls will be checked for correctness against the given object, which may be a class or an instance where attributes will be looked up. Defaults to the the C{object} parameter. May be set to None explicitly, in which case spec checking is disabled. Checks may also be disabled explicitly on a per-event basis with the L{nospec()} method. """ if spec is True: spec = object patcher = Patcher() event = self._get_replay_restore_event() event.add_task(patcher) mock = Mock(self, object=object, patcher=patcher, passthrough=True, spec=spec) object.__mocker_mock__ = mock return mock def act(self, path): """This is called by mock objects whenever something happens to them. This method is part of the implementation between the mocker and mock objects. """ if self._recording: event = self.add_event(Event(path)) for recorder in self._recorders: recorder(self, event) return Mock(self, path) else: # First run events that may run, then run unsatisfied events, then # ones not previously run. We put the index in the ordering tuple # instead of the actual event because we want a stable sort # (ordering between 2 events is undefined). events = self._events order = [(events[i].satisfied()*2 + events[i].has_run(), i) for i in range(len(events))] order.sort() postponed = None for weight, i in order: event = events[i] if event.matches(path): if event.may_run(path): return event.run(path) elif postponed is None: postponed = event if postponed is not None: return postponed.run(path) raise MatchError(ERROR_PREFIX + "Unexpected expression: %s" % path) def get_recorders(cls, self): """Return recorders associated with this mocker class or instance. This method may be called on mocker instances and also on mocker classes. See the L{add_recorder()} method for more information. """ return (self or cls)._recorders[:] get_recorders = classinstancemethod(get_recorders) def add_recorder(cls, self, recorder): """Add a recorder to this mocker class or instance. @param recorder: Callable accepting C{(mocker, event)} as parameters. This is part of the implementation of mocker. All registered recorders are called for translating events that happen during recording into expectations to be met once the state is switched to replay mode. This method may be called on mocker instances and also on mocker classes. When called on a class, the recorder will be used by all instances, and also inherited on subclassing. When called on instances, the recorder is added only to the given instance. """ (self or cls)._recorders.append(recorder) return recorder add_recorder = classinstancemethod(add_recorder) def remove_recorder(cls, self, recorder): """Remove the given recorder from this mocker class or instance. This method may be called on mocker classes and also on mocker instances. See the L{add_recorder()} method for more information. """ (self or cls)._recorders.remove(recorder) remove_recorder = classinstancemethod(remove_recorder) def result(self, value): """Make the last recorded event return the given value on replay. @param value: Object to be returned when the event is replayed. """ self.call(lambda *args, **kwargs: value) def generate(self, sequence): """Last recorded event will return a generator with the given sequence. @param sequence: Sequence of values to be generated. """ def generate(*args, **kwargs): yield from sequence self.call(generate) def throw(self, exception): """Make the last recorded event raise the given exception on replay. @param exception: Class or instance of exception to be raised. """ def raise_exception(*args, **kwargs): raise exception self.call(raise_exception) def call(self, func): """Make the last recorded event cause the given function to be called. @param func: Function to be called. The result of the function will be used as the event result. """ self._events[-1].add_task(FunctionRunner(func)) def count(self, min, max=False): """Last recorded event must be replayed between min and max times. @param min: Minimum number of times that the event must happen. @param max: Maximum number of times that the event must happen. If not given, it defaults to the same value of the C{min} parameter. If set to None, there is no upper limit, and the expectation is met as long as it happens at least C{min} times. """ event = self._events[-1] for task in event.get_tasks(): if isinstance(task, RunCounter): event.remove_task(task) event.add_task(RunCounter(min, max)) def is_ordering(self): """Return true if all events are being ordered. See the L{order()} method. """ return self._ordering def unorder(self): """Disable the ordered mode. See the L{order()} method for more information. """ self._ordering = False self._last_orderer = None def order(self, *path_holders): """Create an expectation of order between two or more events. @param path_holders: Objects returned as the result of recorded events. By default, mocker won't force events to happen precisely in the order they were recorded. Calling this method will change this behavior so that events will only match if reproduced in the correct order. There are two ways in which this method may be used. Which one is used in a given occasion depends only on convenience. If no arguments are passed, the mocker will be put in a mode where all the recorded events following the method call will only be met if they happen in order. When that's used, the mocker may be put back in unordered mode by calling the L{unorder()} method, or by using a 'with' block, like so:: with mocker.ordered(): In this case, only expressions in will be ordered, and the mocker will be back in unordered mode after the 'with' block. The second way to use it is by specifying precisely which events should be ordered. As an example:: mock = mocker.mock() expr1 = mock.hello() expr2 = mock.world expr3 = mock.x.y.z mocker.order(expr1, expr2, expr3) This method of ordering only works when the expression returns another object. Also check the L{after()} and L{before()} methods, which are alternative ways to perform this. """ if not path_holders: self._ordering = True return OrderedContext(self) last_orderer = None for path_holder in path_holders: if type(path_holder) is Path: path = path_holder else: path = path_holder.__mocker_path__ for event in self._events: if event.path is path: for task in event.get_tasks(): if isinstance(task, Orderer): orderer = task break else: orderer = Orderer(path) event.add_task(orderer) if last_orderer: orderer.add_dependency(last_orderer) last_orderer = orderer break def after(self, *path_holders): """Last recorded event must happen after events referred to. As an example, the idiom:: expect(mock.x).after(mock.y, mock.z) is an alternative way to say:: expr_x = mock.x expr_y = mock.y expr_z = mock.z mocker.order(expr_y, expr_x) mocker.order(expr_z, expr_x) See L{order()} for more information. @param path_holders: Objects returned as the result of recorded events which should happen before the last recorded event """ last_path = self._events[-1].path for path_holder in path_holders: self.order(path_holder, last_path) def before(self, *path_holders): """Last recorded event must happen before events referred to. As an example, the idiom:: expect(mock.x).before(mock.y, mock.z) is an alternative way to say:: expr_x = mock.x expr_y = mock.y expr_z = mock.z mocker.order(expr_x, expr_y) mocker.order(expr_x, expr_z) See L{order()} for more information. @param path_holders: Objects returned as the result of recorded events which should happen after the last recorded event """ last_path = self._events[-1].path for path_holder in path_holders: self.order(last_path, path_holder) def nospec(self): """Don't check method specification of real object on last event. By default, when using a mock created as the result of a call to L{proxy()}, L{replace()}, and C{patch()}, or when passing the spec attribute to the L{mock()} method, method calls on the given object are checked for correctness against the specification of the real object (or the explicitly provided spec). This method will disable that check specifically for the last recorded event. """ event = self._events[-1] for task in event.get_tasks(): if isinstance(task, SpecChecker): event.remove_task(task) def passthrough(self, result_callback=None): """Make the last recorded event run on the real object once seen. This can only be used on proxies, as returned by the L{proxy()} and L{replace()} methods, or on mocks representing patched objects, as returned by the L{patch()} method. @param result_callback: If given, this function will be called with the result of the *real* method call as the only argument. """ event = self._events[-1] if event.path.root_object is None: raise TypeError("Mock object isn't a proxy") event.add_task(PathExecuter(result_callback)) def __enter__(self): """Enter in a 'with' context. This will run replay().""" self.replay() return self def __exit__(self, type, value, traceback): """Exit from a 'with' context. This will run restore() at all times, but will only run verify() if the 'with' block itself hasn't raised an exception. Exceptions in that block are never swallowed. """ self.restore() if type is None: self.verify() return False def _get_replay_restore_event(self): """Return unique L{ReplayRestoreEvent}, creating if needed. Some tasks only want to replay/restore. When that's the case, they shouldn't act on other events during replay. Also, they can all be put in a single event when that's the case. Thus, we add a single L{ReplayRestoreEvent} as the first element of the list. """ if not self._events or type(self._events[0]) != ReplayRestoreEvent: self._events.insert(0, ReplayRestoreEvent()) return self._events[0] class OrderedContext: def __init__(self, mocker): self._mocker = mocker def __enter__(self): return None def __exit__(self, type, value, traceback): self._mocker.unorder() class Mocker(MockerBase): __doc__ = MockerBase.__doc__ # Decorator to add recorders on the standard Mocker class. recorder = Mocker.add_recorder # -------------------------------------------------------------------- # Mock object. class Mock: def __init__(self, mocker, path=None, name=None, spec=None, type=None, object=None, passthrough=False, patcher=None, count=True): self.__mocker__ = mocker self.__mocker_path__ = path or Path(self, object) self.__mocker_name__ = name self.__mocker_spec__ = spec self.__mocker_object__ = object self.__mocker_passthrough__ = passthrough self.__mocker_patcher__ = patcher self.__mocker_replace__ = False self.__mocker_type__ = type self.__mocker_count__ = count def __mocker_act__(self, kind, args=(), kwargs={}, object=None): if self.__mocker_name__ is None: self.__mocker_name__ = find_object_name(self, 2) action = Action(kind, args, kwargs, self.__mocker_path__) path = self.__mocker_path__ + action if object is not None: path.root_object = object try: return self.__mocker__.act(path) except MatchError as exception: root_mock = path.root_mock if (path.root_object is not None and root_mock.__mocker_passthrough__): return path.execute(path.root_object) # Reinstantiate to show raise statement on traceback, and # also to make the traceback shown shorter. raise MatchError(str(exception)) except AssertionError as e: lines = str(e).splitlines() message = [ERROR_PREFIX + "Unmet expectation:", ""] message.append("=> " + lines.pop(0)) message.extend([" " + line for line in lines]) message.append("") raise AssertionError(os.linesep.join(message)) def __getattribute__(self, name): if name.startswith("__mocker_"): return super().__getattribute__(name) if name == "__class__": if self.__mocker__.is_recording() or self.__mocker_type__ is None: return type(self) return self.__mocker_type__ return self.__mocker_act__("getattr", (name,)) def __setattr__(self, name, value): if name.startswith("__mocker_"): return super().__setattr__(name, value) return self.__mocker_act__("setattr", (name, value)) def __delattr__(self, name): return self.__mocker_act__("delattr", (name,)) def __call__(self, *args, **kwargs): return self.__mocker_act__("call", args, kwargs) def __contains__(self, value): return self.__mocker_act__("contains", (value,)) def __getitem__(self, key): return self.__mocker_act__("getitem", (key,)) def __setitem__(self, key, value): return self.__mocker_act__("setitem", (key, value)) def __delitem__(self, key): return self.__mocker_act__("delitem", (key,)) def __len__(self): # MatchError is turned on an AttributeError so that list() and # friends act properly when trying to get length hints on # something that doesn't offer them. try: result = self.__mocker_act__("len") except MatchError as e: raise AttributeError(str(e)) if type(result) is Mock: return 0 return result def __bool__(self): try: return self.__mocker_act__("bool") except MatchError as e: return True def __iter__(self): # XXX On py3k, when next() becomes __next__(), we'll be able # to return the mock itself because it will be considered # an iterator (we'll be mocking __next__ as well, which we # can't now). result = self.__mocker_act__("iter") if type(result) is Mock: return iter([]) return result # When adding a new action kind here, also add support for it on # Action.execute() and Path.__str__(). def find_object_name(obj, depth=0): """Try to detect how the object is named on a previous scope.""" try: frame = sys._getframe(depth+1) except: return None for name, frame_obj in frame.f_locals.items(): if frame_obj is obj: return name self = frame.f_locals.get("self") if self is not None: try: items = list(self.__dict__.items()) except: pass else: for name, self_obj in items: if self_obj is obj: return name return None # -------------------------------------------------------------------- # Action and path. class Action: def __init__(self, kind, args, kwargs, path=None): self.kind = kind self.args = args self.kwargs = kwargs self.path = path self._execute_cache = {} def __repr__(self): if self.path is None: return "Action(%r, %r, %r)" % (self.kind, self.args, self.kwargs) return "Action(%r, %r, %r, %r)" % \ (self.kind, self.args, self.kwargs, self.path) def __eq__(self, other): return (self.kind == other.kind and self.args == other.args and self.kwargs == other.kwargs) def __ne__(self, other): return not self.__eq__(other) def matches(self, other): return (self.kind == other.kind and match_params(self.args, self.kwargs, other.args, other.kwargs)) def execute(self, object): # This caching scheme may fail if the object gets deallocated before # the action, as the id might get reused. It's somewhat easy to fix # that with a weakref callback. For our uses, though, the object # should never get deallocated before the action itself, so we'll # just keep it simple. if id(object) in self._execute_cache: return self._execute_cache[id(object)] execute = getattr(object, "__mocker_execute__", None) if execute is not None: result = execute(self, object) else: kind = self.kind if kind == "getattr": result = getattr(object, self.args[0]) elif kind == "setattr": result = setattr(object, self.args[0], self.args[1]) elif kind == "delattr": result = delattr(object, self.args[0]) elif kind == "call": result = object(*self.args, **self.kwargs) elif kind == "contains": result = self.args[0] in object elif kind == "getitem": result = object[self.args[0]] elif kind == "setitem": result = object[self.args[0]] = self.args[1] elif kind == "delitem": del object[self.args[0]] result = None elif kind == "len": result = len(object) elif kind == "bool": result = bool(object) elif kind == "iter": result = iter(object) else: raise RuntimeError("Don't know how to execute %r kind." % kind) self._execute_cache[id(object)] = result return result class Path: def __init__(self, root_mock, root_object=None, actions=()): self.root_mock = root_mock self.root_object = root_object self.actions = tuple(actions) self.__mocker_replace__ = False def parent_path(self): if not self.actions: return None return self.actions[-1].path parent_path = property(parent_path) def __add__(self, action): """Return a new path which includes the given action at the end.""" return self.__class__(self.root_mock, self.root_object, self.actions + (action,)) def __eq__(self, other): """Verify if the two paths are equal. Two paths are equal if they refer to the same mock object, and have the actions with equal kind, args and kwargs. """ if (self.root_mock is not other.root_mock or self.root_object is not other.root_object or len(self.actions) != len(other.actions)): return False for action, other_action in zip(self.actions, other.actions): if action != other_action: return False return True def matches(self, other): """Verify if the two paths are equivalent. Two paths are equal if they refer to the same mock object, and have the same actions performed on them. """ if (self.root_mock is not other.root_mock or len(self.actions) != len(other.actions)): return False for action, other_action in zip(self.actions, other.actions): if not action.matches(other_action): return False return True def execute(self, object): """Execute all actions sequentially on object, and return result. """ for action in self.actions: object = action.execute(object) return object def __str__(self): """Transform the path into a nice string such as obj.x.y('z').""" result = self.root_mock.__mocker_name__ or "" for action in self.actions: if action.kind == "getattr": result = "%s.%s" % (result, action.args[0]) elif action.kind == "setattr": result = "%s.%s = %r" % (result, action.args[0], action.args[1]) elif action.kind == "delattr": result = "del %s.%s" % (result, action.args[0]) elif action.kind == "call": args = [repr(x) for x in action.args] items = list(action.kwargs.items()) items.sort() for pair in items: args.append("%s=%r" % pair) result = "%s(%s)" % (result, ", ".join(args)) elif action.kind == "contains": result = "%r in %s" % (action.args[0], result) elif action.kind == "getitem": result = "%s[%r]" % (result, action.args[0]) elif action.kind == "setitem": result = "%s[%r] = %r" % (result, action.args[0], action.args[1]) elif action.kind == "delitem": result = "del %s[%r]" % (result, action.args[0]) elif action.kind == "len": result = "len(%s)" % result elif action.kind == "bool": result = "bool(%s)" % result elif action.kind == "iter": result = "iter(%s)" % result else: raise RuntimeError("Don't know how to format kind %r" % action.kind) return result class SpecialArgument: """Base for special arguments for matching parameters.""" def __init__(self, object=None): self.object = object def __repr__(self): if self.object is None: return self.__class__.__name__ else: return "%s(%r)" % (self.__class__.__name__, self.object) def matches(self, other): return True def __eq__(self, other): return type(other) == type(self) and self.object == other.object class ANY(SpecialArgument): """Matches any single argument.""" ANY = ANY() class ARGS(SpecialArgument): """Matches zero or more positional arguments.""" ARGS = ARGS() class KWARGS(SpecialArgument): """Matches zero or more keyword arguments.""" KWARGS = KWARGS() class IS(SpecialArgument): def matches(self, other): return self.object is other def __eq__(self, other): return type(other) == type(self) and self.object is other.object class CONTAINS(SpecialArgument): def matches(self, other): try: other.__contains__ except AttributeError: try: iter(other) except TypeError: # If an object can't be iterated, and has no __contains__ # hook, it'd blow up on the test below. We test this in # advance to prevent catching more errors than we really # want. return False return self.object in other class IN(SpecialArgument): def matches(self, other): return other in self.object class MATCH(SpecialArgument): def matches(self, other): return bool(self.object(other)) def __eq__(self, other): return type(other) == type(self) and self.object is other.object def match_params(args1, kwargs1, args2, kwargs2): """Match the two sets of parameters, considering special parameters.""" has_args = ARGS in args1 has_kwargs = KWARGS in args1 if has_kwargs: args1 = [arg1 for arg1 in args1 if arg1 is not KWARGS] elif len(kwargs1) != len(kwargs2): return False if not has_args and len(args1) != len(args2): return False # Either we have the same number of kwargs, or unknown keywords are # accepted (KWARGS was used), so check just the ones in kwargs1. for key, arg1 in kwargs1.items(): if key not in kwargs2: return False arg2 = kwargs2[key] if isinstance(arg1, SpecialArgument): if not arg1.matches(arg2): return False elif arg1 != arg2: return False # Keywords match. Now either we have the same number of # arguments, or ARGS was used. If ARGS wasn't used, arguments # must match one-on-one necessarily. if not has_args: for arg1, arg2 in zip(args1, args2): if isinstance(arg1, SpecialArgument): if not arg1.matches(arg2): return False elif arg1 != arg2: return False return True # Easy choice. Keywords are matching, and anything on args is accepted. if (ARGS,) == args1: return True # We have something different there. If we don't have positional # arguments on the original call, it can't match. if not args2: # Unless we have just several ARGS (which is bizarre, but..). for arg1 in args1: if arg1 is not ARGS: return False return True # Ok, all bets are lost. We have to actually do the more expensive # matching. This is an algorithm based on the idea of the Levenshtein # Distance between two strings, but heavily hacked for this purpose. args2l = len(args2) if args1[0] is ARGS: args1 = args1[1:] array = [0]*args2l else: array = [1]*args2l for i in range(len(args1)): last = array[0] if args1[i] is ARGS: for j in range(1, args2l): last, array[j] = array[j], min(array[j-1], array[j], last) else: array[0] = i or int(args1[i] != args2[0]) for j in range(1, args2l): last, array[j] = array[j], last or int(args1[i] != args2[j]) if 0 not in array: return False if array[-1] != 0: return False return True # -------------------------------------------------------------------- # Event and task base. class Event: """Aggregation of tasks that keep track of a recorded action. An event represents something that may or may not happen while the mocked environment is running, such as an attribute access, or a method call. The event is composed of several tasks that are orchestrated together to create a composed meaning for the event, including for which actions it should be run, what happens when it runs, and what's the expectations about the actions run. """ def __init__(self, path=None): self.path = path self._tasks = [] self._has_run = False def add_task(self, task): """Add a new task to this taks.""" self._tasks.append(task) return task def remove_task(self, task): self._tasks.remove(task) def get_tasks(self): return self._tasks[:] def matches(self, path): """Return true if *all* tasks match the given path.""" for task in self._tasks: if not task.matches(path): return False return bool(self._tasks) def has_run(self): return self._has_run def may_run(self, path): """Verify if any task would certainly raise an error if run. This will call the C{may_run()} method on each task and return false if any of them returns false. """ for task in self._tasks: if not task.may_run(path): return False return True def run(self, path): """Run all tasks with the given action. @param path: The path of the expression run. Running an event means running all of its tasks individually and in order. An event should only ever be run if all of its tasks claim to match the given action. The result of this method will be the last result of a task which isn't None, or None if they're all None. """ self._has_run = True result = None errors = [] for task in self._tasks: try: task_result = task.run(path) except AssertionError as e: error = str(e) if not error: raise RuntimeError("Empty error message from %r" % task) errors.append(error) else: if task_result is not None: result = task_result if errors: message = [str(self.path)] if str(path) != message[0]: message.append("- Run: %s" % path) for error in errors: lines = error.splitlines() message.append("- " + lines.pop(0)) message.extend([" " + line for line in lines]) raise AssertionError(os.linesep.join(message)) return result def satisfied(self): """Return true if all tasks are satisfied. Being satisfied means that there are no unmet expectations. """ for task in self._tasks: try: task.verify() except AssertionError: return False return True def verify(self): """Run verify on all tasks. The verify method is supposed to raise an AssertionError if the task has unmet expectations, with a one-line explanation about why this item is unmet. This method should be safe to be called multiple times without side effects. """ errors = [] for task in self._tasks: try: task.verify() except AssertionError as e: error = str(e) if not error: raise RuntimeError("Empty error message from %r" % task) errors.append(error) if errors: message = [str(self.path)] for error in errors: lines = error.splitlines() message.append("- " + lines.pop(0)) message.extend([" " + line for line in lines]) raise AssertionError(os.linesep.join(message)) def replay(self): """Put all tasks in replay mode.""" self._has_run = False for task in self._tasks: task.replay() def restore(self): """Restore the state of all tasks.""" for task in self._tasks: task.restore() class ReplayRestoreEvent(Event): """Helper event for tasks which need replay/restore but shouldn't match.""" def matches(self, path): return False class Task: """Element used to track one specific aspect on an event. A task is responsible for adding any kind of logic to an event. Examples of that are counting the number of times the event was made, verifying parameters if any, and so on. """ def matches(self, path): """Return true if the task is supposed to be run for the given path. """ return True def may_run(self, path): """Return false if running this task would certainly raise an error.""" return True def run(self, path): """Perform the task item, considering that the given action happened. """ def verify(self): """Raise AssertionError if expectations for this item are unmet. The verify method is supposed to raise an AssertionError if the task has unmet expectations, with a one-line explanation about why this item is unmet. This method should be safe to be called multiple times without side effects. """ def replay(self): """Put the task in replay mode. Any expectations of the task should be reset. """ def restore(self): """Restore any environmental changes made by the task. Verify should continue to work after this is called. """ # -------------------------------------------------------------------- # Task implementations. class OnRestoreCaller(Task): """Call a given callback when restoring.""" def __init__(self, callback): self._callback = callback def restore(self): self._callback() class PathMatcher(Task): """Match the action path against a given path.""" def __init__(self, path): self.path = path def matches(self, path): return self.path.matches(path) def path_matcher_recorder(mocker, event): event.add_task(PathMatcher(event.path)) Mocker.add_recorder(path_matcher_recorder) class RunCounter(Task): """Task which verifies if the number of runs are within given boundaries. """ def __init__(self, min, max=False): self.min = min if max is None: self.max = sys.maxsize elif max is False: self.max = min else: self.max = max self._runs = 0 def replay(self): self._runs = 0 def may_run(self, path): return self._runs < self.max def run(self, path): self._runs += 1 if self._runs > self.max: self.verify() def verify(self): if not self.min <= self._runs <= self.max: if self._runs < self.min: raise AssertionError("Performed fewer times than expected.") raise AssertionError("Performed more times than expected.") class ImplicitRunCounter(RunCounter): """RunCounter inserted by default on any event. This is a way to differentiate explicitly added counters and implicit ones. """ def run_counter_recorder(mocker, event): """Any event may be repeated once, unless disabled by default.""" if event.path.root_mock.__mocker_count__: event.add_task(ImplicitRunCounter(1)) Mocker.add_recorder(run_counter_recorder) def run_counter_removal_recorder(mocker, event): """ Events created by getattr actions which lead to other events may be repeated any number of times. For that, we remove implicit run counters of any getattr actions leading to the current one. """ parent_path = event.path.parent_path for event in mocker.get_events()[::-1]: if (event.path is parent_path and event.path.actions[-1].kind == "getattr"): for task in event.get_tasks(): if type(task) is ImplicitRunCounter: event.remove_task(task) Mocker.add_recorder(run_counter_removal_recorder) class MockReturner(Task): """Return a mock based on the action path.""" def __init__(self, mocker): self.mocker = mocker def run(self, path): return Mock(self.mocker, path) def mock_returner_recorder(mocker, event): """Events that lead to other events must return mock objects.""" parent_path = event.path.parent_path for event in mocker.get_events(): if event.path is parent_path: for task in event.get_tasks(): if isinstance(task, MockReturner): break else: event.add_task(MockReturner(mocker)) break Mocker.add_recorder(mock_returner_recorder) class FunctionRunner(Task): """Task that runs a function everything it's run. Arguments of the last action in the path are passed to the function, and the function result is also returned. """ def __init__(self, func): self._func = func def run(self, path): action = path.actions[-1] return self._func(*action.args, **action.kwargs) class PathExecuter(Task): """Task that executes a path in the real object, and returns the result.""" def __init__(self, result_callback=None): self._result_callback = result_callback def get_result_callback(self): return self._result_callback def run(self, path): result = path.execute(path.root_object) if self._result_callback is not None: self._result_callback(result) return result class Orderer(Task): """Task to establish an order relation between two events. An orderer task will only match once all its dependencies have been run. """ def __init__(self, path): self.path = path self._run = False self._dependencies = [] def replay(self): self._run = False def has_run(self): return self._run def may_run(self, path): for dependency in self._dependencies: if not dependency.has_run(): return False return True def run(self, path): for dependency in self._dependencies: if not dependency.has_run(): raise AssertionError("Should be after: %s" % dependency.path) self._run = True def add_dependency(self, orderer): self._dependencies.append(orderer) def get_dependencies(self): return self._dependencies class SpecChecker(Task): """Task to check if arguments of the last action conform to a real method. """ def __init__(self, method): self._method = method self._unsupported = False if method: try: # On Python 3, inspect.getargspec includes the bound first # argument (self or similar) for bound methods, which # confuses matters. The modern signature API doesn't have # this problem. self._signature = inspect.signature(method) # Method descriptors don't have the first argument already # bound, but we want to skip it anyway. if getattr(method, "__objclass__", None) is not None: parameters = list(self._signature.parameters.values()) # This is positional-only for unbound methods that are # implemented in C. if (parameters[0].kind == inspect.Parameter.POSITIONAL_ONLY): self._signature = self._signature.replace( parameters=parameters[1:]) except TypeError: self._unsupported = True def get_method(self): return self._method def _raise(self, message): raise AssertionError("Specification is %s%s: %s" % (self._method.__name__, self._signature, message)) def verify(self): if not self._method: raise AssertionError("Method not found in real specification") def may_run(self, path): try: self.run(path) except AssertionError: return False return True def run(self, path): if not self._method: raise AssertionError("Method not found in real specification") if self._unsupported: return # Can't check it. Happens with builtin functions. :-( action = path.actions[-1] try: self._signature.bind(*action.args, **action.kwargs) except TypeError as e: self._raise(str(e)) def spec_checker_recorder(mocker, event): spec = event.path.root_mock.__mocker_spec__ if spec: actions = event.path.actions if len(actions) == 1: if actions[0].kind == "call": method = getattr(spec, "__call__", None) event.add_task(SpecChecker(method)) elif len(actions) == 2: if actions[0].kind == "getattr" and actions[1].kind == "call": method = getattr(spec, actions[0].args[0], None) event.add_task(SpecChecker(method)) Mocker.add_recorder(spec_checker_recorder) class ProxyReplacer(Task): """Task which installs and deinstalls proxy mocks. This task will replace a real object by a mock in all dictionaries found in the running interpreter via the garbage collecting system. """ def __init__(self, mock): self.mock = mock self.__mocker_replace__ = False def replay(self): global_replace(self.mock.__mocker_object__, self.mock) def restore(self): global_replace(self.mock, self.mock.__mocker_object__) def global_replace(remove, install): """Replace object 'remove' with object 'install' on all dictionaries.""" for referrer in gc.get_referrers(remove): if (type(referrer) is dict and referrer.get("__mocker_replace__", True)): for key, value in list(referrer.items()): if value is remove: referrer[key] = install class Undefined: def __repr__(self): return "Undefined" Undefined = Undefined() class Patcher(Task): def __init__(self): super().__init__() self._monitored = {} # {kind: {id(object): object}} self._patched = {} def is_monitoring(self, obj, kind): monitored = self._monitored.get(kind) if monitored: if id(obj) in monitored: return True cls = type(obj) if issubclass(cls, type): cls = obj bases = {id(base) for base in cls.__mro__} bases.intersection_update(monitored) return bool(bases) return False def monitor(self, obj, kind): if kind not in self._monitored: self._monitored[kind] = {} self._monitored[kind][id(obj)] = obj def patch_attr(self, obj, attr, value): original = obj.__dict__.get(attr, Undefined) self._patched[id(obj), attr] = obj, attr, original setattr(obj, attr, value) def get_unpatched_attr(self, obj, attr): cls = type(obj) if issubclass(cls, type): cls = obj result = Undefined for mro_cls in cls.__mro__: key = (id(mro_cls), attr) if key in self._patched: result = self._patched[key][2] if result is not Undefined: break elif attr in mro_cls.__dict__: result = mro_cls.__dict__.get(attr, Undefined) break if isinstance(result, object) and hasattr(type(result), "__get__"): if cls is obj: obj = None return result.__get__(obj, cls) return result def _get_kind_attr(self, kind): if kind == "getattr": return "__getattribute__" return "__%s__" % kind def replay(self): for kind in self._monitored: attr = self._get_kind_attr(kind) seen = set() for obj in self._monitored[kind].values(): cls = type(obj) if issubclass(cls, type): cls = obj if cls not in seen: seen.add(cls) unpatched = getattr(cls, attr, Undefined) self.patch_attr(cls, attr, PatchedMethod(kind, unpatched, self.is_monitoring)) self.patch_attr(cls, "__mocker_execute__", self.execute) def restore(self): for obj, attr, original in self._patched.values(): if original is Undefined: delattr(obj, attr) else: setattr(obj, attr, original) self._patched.clear() def execute(self, action, object): attr = self._get_kind_attr(action.kind) unpatched = self.get_unpatched_attr(object, attr) try: return unpatched(*action.args, **action.kwargs) except AttributeError: if action.kind == "getattr": # The normal behavior of Python is to try __getattribute__, # and if it raises AttributeError, try __getattr__. We've # tried the unpatched __getattribute__ above, and we'll now # try __getattr__. try: __getattr__ = unpatched("__getattr__") except AttributeError: pass else: return __getattr__(*action.args, **action.kwargs) raise class PatchedMethod: def __init__(self, kind, unpatched, is_monitoring): self._kind = kind self._unpatched = unpatched self._is_monitoring = is_monitoring def __get__(self, obj, cls=None): object = obj or cls if not self._is_monitoring(object, self._kind): return self._unpatched.__get__(obj, cls) def method(*args, **kwargs): if self._kind == "getattr" and args[0].startswith("__mocker_"): return self._unpatched.__get__(obj, cls)(args[0]) mock = object.__mocker_mock__ return mock.__mocker_act__(self._kind, args, kwargs, object) return method def __call__(self, obj, *args, **kwargs): # At least with __getattribute__, Python seems to use *both* the # descriptor API and also call the class attribute directly. It # looks like an interpreter bug, or at least an undocumented # inconsistency. return self.__get__(obj)(*args, **kwargs) def patcher_recorder(mocker, event): mock = event.path.root_mock if mock.__mocker_patcher__ and len(event.path.actions) == 1: patcher = mock.__mocker_patcher__ patcher.monitor(mock.__mocker_object__, event.path.actions[0].kind) Mocker.add_recorder(patcher_recorder) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/properties.py0000644000175000017500000011736314645174376020130 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta from decimal import Decimal as decimal import gc import json import uuid from storm.exceptions import NoneError, PropertyPathError from storm.properties import PropertyPublisherMeta from storm.properties import * from storm.variables import * from storm.info import get_obj_info from storm.expr import Column, Select, compile, State, SQLRaw from storm.tests.info import Wrapper from storm.tests.helper import TestHelper class CustomVariable(Variable): pass class Custom(SimpleProperty): variable_class = CustomVariable class PropertyTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "mytable" prop1 = Custom("column1", primary=True) prop2 = Custom() prop3 = Custom("column3", default=50, allow_none=False) class SubClass(Class): __storm_table__ = "mysubtable" self.Class = Class self.SubClass = SubClass def test_column(self): self.assertTrue(isinstance(self.Class.prop1, Column)) def test_cls(self): self.assertEqual(self.Class.prop1.cls, self.Class) self.assertEqual(self.Class.prop2.cls, self.Class) self.assertEqual(self.SubClass.prop1.cls, self.SubClass) self.assertEqual(self.SubClass.prop2.cls, self.SubClass) self.assertEqual(self.Class.prop1.cls, self.Class) self.assertEqual(self.Class.prop2.cls, self.Class) def test_cls_reverse(self): self.assertEqual(self.SubClass.prop1.cls, self.SubClass) self.assertEqual(self.SubClass.prop2.cls, self.SubClass) self.assertEqual(self.Class.prop1.cls, self.Class) self.assertEqual(self.Class.prop2.cls, self.Class) self.assertEqual(self.SubClass.prop1.cls, self.SubClass) self.assertEqual(self.SubClass.prop2.cls, self.SubClass) def test_name(self): self.assertEqual(self.Class.prop1.name, "column1") def test_auto_name(self): self.assertEqual(self.Class.prop2.name, "prop2") def test_auto_table(self): self.assertEqual(self.Class.prop1.table, self.Class) self.assertEqual(self.Class.prop2.table, self.Class) def test_auto_table_subclass(self): self.assertEqual(self.Class.prop1.table, self.Class) self.assertEqual(self.Class.prop2.table, self.Class) self.assertEqual(self.SubClass.prop1.table, self.SubClass) self.assertEqual(self.SubClass.prop2.table, self.SubClass) def test_auto_table_subclass_reverse_initialization(self): self.assertEqual(self.SubClass.prop1.table, self.SubClass) self.assertEqual(self.SubClass.prop2.table, self.SubClass) self.assertEqual(self.Class.prop1.table, self.Class) self.assertEqual(self.Class.prop2.table, self.Class) def test_variable_factory(self): variable = self.Class.prop1.variable_factory() self.assertTrue(isinstance(variable, CustomVariable)) self.assertFalse(variable.is_defined()) variable = self.Class.prop3.variable_factory() self.assertTrue(isinstance(variable, CustomVariable)) self.assertTrue(variable.is_defined()) def test_variable_factory_validator_attribute(self): # Should work even if we make things harder by reusing properties. prop = Custom() class Class1: __storm_table__ = "table1" prop1 = prop class Class2: __storm_table__ = "table2" prop2 = prop args = [] def validator(obj, attr, value): args.append((obj, attr, value)) variable1 = Class1.prop1.variable_factory(validator=validator) variable2 = Class2.prop2.variable_factory(validator=validator) variable1.set(1) variable2.set(2) self.assertEqual(args, [(None, "prop1", 1), (None, "prop2", 2)]) def test_default(self): obj = self.SubClass() self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertEqual(obj.prop3, 50) def test_set_get(self): obj = self.Class() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 self.assertEqual(obj.prop1, 10) self.assertEqual(obj.prop2, 20) self.assertEqual(obj.prop3, 30) def test_set_get_none(self): obj = self.Class() obj.prop1 = None obj.prop2 = None self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertRaises(NoneError, setattr, obj, "prop3", None) def test_set_with_validator(self): args = [] def validator(obj, attr, value): args[:] = obj, attr, value return 42 class Class: __storm_table__ = "mytable" prop = Custom("column", primary=True, validator=validator) obj = Class() obj.prop = 21 self.assertEqual(args, [obj, "prop", 21]) self.assertEqual(obj.prop, 42) def test_set_get_subclass(self): obj = self.SubClass() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 self.assertEqual(obj.prop1, 10) self.assertEqual(obj.prop2, 20) self.assertEqual(obj.prop3, 30) def test_set_get_explicitly(self): obj = self.Class() prop1 = self.Class.prop1 prop2 = self.Class.prop2 prop3 = self.Class.prop3 prop1.__set__(obj, 10) prop2.__set__(obj, 20) prop3.__set__(obj, 30) self.assertEqual(prop1.__get__(obj), 10) self.assertEqual(prop2.__get__(obj), 20) self.assertEqual(prop3.__get__(obj), 30) def test_set_get_subclass_explicitly(self): obj = self.SubClass() prop1 = self.Class.prop1 prop2 = self.Class.prop2 prop3 = self.Class.prop3 prop1.__set__(obj, 10) prop2.__set__(obj, 20) prop3.__set__(obj, 30) self.assertEqual(prop1.__get__(obj), 10) self.assertEqual(prop2.__get__(obj), 20) self.assertEqual(prop3.__get__(obj), 30) def test_delete(self): obj = self.Class() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 del obj.prop1 del obj.prop2 del obj.prop3 self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertEqual(obj.prop3, None) def test_delete_subclass(self): obj = self.SubClass() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 del obj.prop1 del obj.prop2 del obj.prop3 self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertEqual(obj.prop3, None) def test_delete_explicitly(self): obj = self.Class() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 self.Class.prop1.__delete__(obj) self.Class.prop2.__delete__(obj) self.Class.prop3.__delete__(obj) self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertEqual(obj.prop3, None) def test_delete_subclass_explicitly(self): obj = self.SubClass() obj.prop1 = 10 obj.prop2 = 20 obj.prop3 = 30 self.Class.prop1.__delete__(obj) self.Class.prop2.__delete__(obj) self.Class.prop3.__delete__(obj) self.assertEqual(obj.prop1, None) self.assertEqual(obj.prop2, None) self.assertEqual(obj.prop3, None) def test_comparable_expr(self): prop1 = self.Class.prop1 prop2 = self.Class.prop2 prop3 = self.Class.prop3 expr = Select(SQLRaw("*"), (prop1 == "value1") & (prop2 == "value2") & (prop3 == "value3")) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT * FROM mytable WHERE " "mytable.column1 = ? AND " "mytable.prop2 = ? AND " "mytable.column3 = ?") self.assertVariablesEqual( state.parameters, [CustomVariable("value1"), CustomVariable("value2"), CustomVariable("value3")]) def test_comparable_expr_subclass(self): prop1 = self.SubClass.prop1 prop2 = self.SubClass.prop2 prop3 = self.SubClass.prop3 expr = Select(SQLRaw("*"), (prop1 == "value1") & (prop2 == "value2") & (prop3 == "value3")) state = State() statement = compile(expr, state) self.assertEqual(statement, "SELECT * FROM mysubtable WHERE " "mysubtable.column1 = ? AND " "mysubtable.prop2 = ? AND " "mysubtable.column3 = ?") self.assertVariablesEqual( state.parameters, [CustomVariable("value1"), CustomVariable("value2"), CustomVariable("value3")]) def test_set_get_delete_with_wrapper(self): obj = self.Class() get_obj_info(obj) # Ensure the obj_info exists for obj. self.Class.prop1.__set__(Wrapper(obj), 10) self.assertEqual(self.Class.prop1.__get__(Wrapper(obj)), 10) self.Class.prop1.__delete__(Wrapper(obj)) self.assertEqual(self.Class.prop1.__get__(Wrapper(obj)), None) def test_reuse_of_instance(self): """Properties are dynamically bound to the class where they're used. It basically means that the property may be instantiated independently from the class itself, and reused in any number of classes. It's not something we should announce as granted, but right now it works, and we should try not to break it. """ prop = Custom() class Class1: __storm_table__ = "table1" prop1 = prop class Class2: __storm_table__ = "table2" prop2 = prop self.assertEqual(Class1.prop1.name, "prop1") self.assertEqual(Class1.prop1.table, Class1) self.assertEqual(Class2.prop2.name, "prop2") self.assertEqual(Class2.prop2.table, Class2) class PropertyKindsTest(TestHelper): def setup(self, property, *args, **kwargs): prop2_kwargs = kwargs.pop("prop2_kwargs", {}) kwargs["primary"] = True class Class: __storm_table__ = "mytable" prop1 = property("column1", *args, **kwargs) prop2 = property(**prop2_kwargs) class SubClass(Class): pass self.Class = Class self.SubClass = SubClass self.obj = SubClass() self.obj_info = get_obj_info(self.obj) self.column1 = self.SubClass.prop1 self.column2 = self.SubClass.prop2 self.variable1 = self.obj_info.variables[self.column1] self.variable2 = self.obj_info.variables[self.column2] def test_bool(self): self.setup(Bool, default=True, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, BoolVariable)) self.assertTrue(isinstance(self.variable2, BoolVariable)) self.assertEqual(self.obj.prop1, True) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = 1 self.assertTrue(self.obj.prop1 is True) self.obj.prop1 = 0 self.assertTrue(self.obj.prop1 is False) def test_int(self): self.setup(Int, default=50, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, IntVariable)) self.assertTrue(isinstance(self.variable2, IntVariable)) self.assertEqual(self.obj.prop1, 50) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = False self.assertEqual(self.obj.prop1, 0) self.obj.prop1 = True self.assertEqual(self.obj.prop1, 1) def test_float(self): self.setup(Float, default=50.5, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, FloatVariable)) self.assertTrue(isinstance(self.variable2, FloatVariable)) self.assertEqual(self.obj.prop1, 50.5) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = 1 self.assertTrue(isinstance(self.obj.prop1, float)) def test_decimal(self): self.setup(Decimal, default=decimal("50.5"), allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, DecimalVariable)) self.assertTrue(isinstance(self.variable2, DecimalVariable)) self.assertEqual(self.obj.prop1, decimal("50.5")) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = 1 self.assertTrue(isinstance(self.obj.prop1, decimal)) def test_bytes(self): self.setup(Bytes, default=b"def", allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, BytesVariable)) self.assertTrue(isinstance(self.variable2, BytesVariable)) self.assertEqual(self.obj.prop1, b"def") self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.assertRaises(TypeError, setattr, self.obj, "prop1", "unicode") def test_unicode(self): self.setup(Unicode, default="def", allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, UnicodeVariable)) self.assertTrue(isinstance(self.variable2, UnicodeVariable)) self.assertEqual(self.obj.prop1, "def") self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.assertRaises(TypeError, setattr, self.obj, "prop1", b"str") def test_datetime(self): self.setup(DateTime, default=0, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, DateTimeVariable)) self.assertTrue(isinstance(self.variable2, DateTimeVariable)) self.assertEqual(self.obj.prop1, datetime.utcfromtimestamp(0)) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = 0.0 self.assertEqual(self.obj.prop1, datetime.utcfromtimestamp(0)) self.obj.prop1 = datetime(2006, 1, 1, 12, 34) self.assertEqual(self.obj.prop1, datetime(2006, 1, 1, 12, 34)) self.assertRaises(TypeError, setattr, self.obj, "prop1", object()) def test_date(self): self.setup(Date, default=date(2006, 1, 1), allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, DateVariable)) self.assertTrue(isinstance(self.variable2, DateVariable)) self.assertEqual(self.obj.prop1, date(2006, 1, 1)) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = datetime(2006, 1, 1, 12, 34, 56) self.assertEqual(self.obj.prop1, date(2006, 1, 1)) self.obj.prop1 = date(2006, 1, 1) self.assertEqual(self.obj.prop1, date(2006, 1, 1)) self.assertRaises(TypeError, setattr, self.obj, "prop1", object()) def test_time(self): self.setup(Time, default=time(12, 34), allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, TimeVariable)) self.assertTrue(isinstance(self.variable2, TimeVariable)) self.assertEqual(self.obj.prop1, time(12, 34)) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = datetime(2006, 1, 1, 12, 34, 56) self.assertEqual(self.obj.prop1, time(12, 34, 56)) self.obj.prop1 = time(12, 34, 56) self.assertEqual(self.obj.prop1, time(12, 34, 56)) self.assertRaises(TypeError, setattr, self.obj, "prop1", object()) def test_timedelta(self): self.setup(TimeDelta, default=timedelta(days=1, seconds=2, microseconds=3), allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, TimeDeltaVariable)) self.assertTrue(isinstance(self.variable2, TimeDeltaVariable)) self.assertEqual(self.obj.prop1, timedelta(days=1, seconds=2, microseconds=3)) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = timedelta(days=42, seconds=42, microseconds=42) self.assertEqual(self.obj.prop1, timedelta(days=42, seconds=42, microseconds=42)) self.assertRaises(TypeError, setattr, self.obj, "prop1", object()) def test_uuid(self): value1 = uuid.UUID("{0609f76b-878f-4546-baf5-c1b135e8de72}") value2 = uuid.UUID("{c9703f9d-0abb-47d7-a793-8f90f1b98d5e}") self.setup(UUID, default=value1, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, UUIDVariable)) self.assertTrue(isinstance(self.variable2, UUIDVariable)) self.assertEqual(self.obj.prop1, value1) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = value1 self.assertEqual(self.obj.prop1, value1) self.obj.prop1 = value2 self.assertEqual(self.obj.prop1, value2) self.assertRaises(TypeError, setattr, self.obj, "prop1", "{0609f76b-878f-4546-baf5-c1b135e8de72}") def test_enum(self): self.setup(Enum, map={"foo": 1, "bar": 2}, default="foo", allow_none=False, prop2_kwargs=dict(map={"foo": 1, "bar": 2})) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, EnumVariable)) self.assertTrue(isinstance(self.variable2, EnumVariable)) self.assertEqual(self.obj.prop1, "foo") self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = "foo" self.assertEqual(self.obj.prop1, "foo") self.obj.prop1 = "bar" self.assertEqual(self.obj.prop1, "bar") self.assertRaises(ValueError, setattr, self.obj, "prop1", "baz") self.assertRaises(ValueError, setattr, self.obj, "prop1", 1) def test_enum_with_set_map(self): self.setup(Enum, map={"foo": 1, "bar": 2}, set_map={"fooics": 1, "barics": 2}, default="fooics", allow_none=False, prop2_kwargs=dict(map={"foo": 1, "bar": 2})) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, EnumVariable)) self.assertTrue(isinstance(self.variable2, EnumVariable)) self.assertEqual(self.obj.prop1, "foo") self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = "fooics" self.assertEqual(self.obj.prop1, "foo") self.obj.prop1 = "barics" self.assertEqual(self.obj.prop1, "bar") self.assertRaises(ValueError, setattr, self.obj, "prop1", "foo") self.assertRaises(ValueError, setattr, self.obj, "prop1", 1) def test_pickle(self): self.setup(Pickle, default_factory=dict, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, PickleVariable)) self.assertTrue(isinstance(self.variable2, PickleVariable)) self.assertEqual(self.obj.prop1, {}) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = [] self.assertEqual(self.obj.prop1, []) self.obj.prop1.append("a") self.assertEqual(self.obj.prop1, ["a"]) def test_pickle_events(self): self.setup(Pickle, default_factory=list, allow_none=False) changes = [] def changed(owner, variable, old_value, new_value, fromdb): changes.append((variable, old_value, new_value, fromdb)) # Can't checkpoint Undef. self.obj.prop2 = [] self.obj_info.checkpoint() self.obj_info.event.emit("start-tracking-changes", self.obj_info.event) self.obj_info.event.hook("changed", changed) self.assertEqual(self.obj.prop1, []) self.assertEqual(changes, []) self.obj.prop1.append("a") self.assertEqual(changes, []) # Check "flush" event. Notice that the other variable wasn't # listed, since it wasn't changed. self.obj_info.event.emit("flush") self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) del changes[:] # Check "object-deleted" event. Notice that the other variable # wasn't listed again, since it wasn't changed. del self.obj self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) def test_json(self): # Skip test if json support is not available. if json is None: return self.setup(JSON, default_factory=dict, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, JSONVariable)) self.assertTrue(isinstance(self.variable2, JSONVariable)) self.assertEqual(self.obj.prop1, {}) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = [] self.assertEqual(self.obj.prop1, []) self.obj.prop1.append("a") self.assertEqual(self.obj.prop1, ["a"]) def test_json_events(self): # Skip test if json support is not available. if json is None: return self.setup(JSON, default_factory=list, allow_none=False) changes = [] def changed(owner, variable, old_value, new_value, fromdb): changes.append((variable, old_value, new_value, fromdb)) # Can't checkpoint Undef. self.obj.prop2 = [] self.obj_info.checkpoint() self.obj_info.event.emit("start-tracking-changes", self.obj_info.event) self.obj_info.event.hook("changed", changed) self.assertEqual(self.obj.prop1, []) self.assertEqual(changes, []) self.obj.prop1.append("a") self.assertEqual(changes, []) # Check "flush" event. Notice that the other variable wasn't # listed, since it wasn't changed. self.obj_info.event.emit("flush") self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) del changes[:] # Check "object-deleted" event. Notice that the other variable # wasn't listed again, since it wasn't changed. del self.obj self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) def test_list(self): self.setup(List, default_factory=list, allow_none=False) self.assertTrue(isinstance(self.column1, Column)) self.assertTrue(isinstance(self.column2, Column)) self.assertEqual(self.column1.name, "column1") self.assertEqual(self.column1.table, self.SubClass) self.assertEqual(self.column2.name, "prop2") self.assertEqual(self.column2.table, self.SubClass) self.assertTrue(isinstance(self.variable1, ListVariable)) self.assertTrue(isinstance(self.variable2, ListVariable)) self.assertEqual(self.obj.prop1, []) self.assertRaises(NoneError, setattr, self.obj, "prop1", None) self.obj.prop2 = None self.assertEqual(self.obj.prop2, None) self.obj.prop1 = ["a"] self.assertEqual(self.obj.prop1, ["a"]) self.obj.prop1.append("b") self.assertEqual(self.obj.prop1, ["a", "b"]) def test_list_events(self): self.setup(List, default_factory=list, allow_none=False) changes = [] def changed(owner, variable, old_value, new_value, fromdb): changes.append((variable, old_value, new_value, fromdb)) self.obj_info.checkpoint() self.obj_info.event.emit("start-tracking-changes", self.obj_info.event) self.obj_info.event.hook("changed", changed) self.assertEqual(self.obj.prop1, []) self.assertEqual(changes, []) self.obj.prop1.append("a") self.assertEqual(changes, []) # Check "flush" event. Notice that the other variable wasn't # listed, since it wasn't changed. self.obj_info.event.emit("flush") self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) del changes[:] # Check "object-deleted" event. Notice that the other variable # wasn't listed again, since it wasn't changed. del self.obj self.assertEqual(changes, [(self.variable1, None, ["a"], False)]) def test_variable_factory_arguments(self): class Class: __storm_table__ = "test" id = Int(primary=True) validator_args = [] def validator(obj, attr, value): validator_args[:] = obj, attr, value return value for func, cls, value in [ (Bool, BoolVariable, True), (Int, IntVariable, 1), (Float, FloatVariable, 1.1), (Bytes, BytesVariable, b"str"), (Unicode, UnicodeVariable, "unicode"), (DateTime, DateTimeVariable, datetime.now()), (Date, DateVariable, date.today()), (Time, TimeVariable, datetime.now().time()), (Pickle, PickleVariable, {}), ]: # Test no default and allow_none=True. Class.prop = func(name="name") column = Class.prop.__get__(None, Class) self.assertEqual(column.name, "name") self.assertEqual(column.table, Class) variable = column.variable_factory() self.assertTrue(isinstance(variable, cls)) self.assertEqual(variable.get(), None) variable.set(None) self.assertEqual(variable.get(), None) # Test default and allow_none=False. Class.prop = func(name="name", default=value, allow_none=False) column = Class.prop.__get__(None, Class) self.assertEqual(column.name, "name") self.assertEqual(column.table, Class) variable = column.variable_factory() self.assertTrue(isinstance(variable, cls)) self.assertRaises(NoneError, variable.set, None) self.assertEqual(variable.get(), value) # Test default=None and allow_none=False (incoherent). Class.prop = func(name="name", default=None, allow_none=False) column = Class.prop.__get__(None, Class) self.assertEqual(column.name, "name") self.assertEqual(column.table, Class) self.assertRaises(NoneError, column.variable_factory) # Test default_factory. Class.prop = func(name="name", default_factory=lambda:value) column = Class.prop.__get__(None, Class) self.assertEqual(column.name, "name") self.assertEqual(column.table, Class) variable = column.variable_factory() self.assertTrue(isinstance(variable, cls)) self.assertEqual(variable.get(), value) # Test validator. Class.prop = func(name="name", validator=validator, default=value) column = Class.prop.__get__(None, Class) self.assertEqual(column.name, "name") self.assertEqual(column.table, Class) del validator_args[:] variable = column.variable_factory() self.assertTrue(isinstance(variable, cls)) # Validator is not called on instantiation. self.assertEqual(validator_args, []) # But is when setting the variable. variable.set(value) self.assertEqual(validator_args, [None, "prop", value]) class PropertyRegistryTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Class: __storm_table__ = "mytable" prop1 = Property("column1", primary=True) prop2 = Property() class SubClass(Class): __storm_table__ = "mysubtable" self.Class = Class self.SubClass = SubClass self.AnotherClass = type("Class", (Class,), {}) self.registry = PropertyRegistry() def test_get_empty(self): self.assertRaises(PropertyPathError, self.registry.get, "unexistent") def test_get(self): self.registry.add_class(self.Class) prop1 = self.registry.get("prop1") prop2 = self.registry.get("prop2") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) def test_get_with_class_name(self): self.registry.add_class(self.Class) prop1 = self.registry.get("Class.prop1") prop2 = self.registry.get("Class.prop2") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) def test_get_with_two_classes(self): self.registry.add_class(self.Class) self.registry.add_class(self.SubClass) prop1 = self.registry.get("Class.prop1") prop2 = self.registry.get("Class.prop2") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) prop1 = self.registry.get("SubClass.prop1") prop2 = self.registry.get("SubClass.prop2") self.assertTrue(prop1 is self.SubClass.prop1) self.assertTrue(prop2 is self.SubClass.prop2) def test_get_ambiguous(self): self.AnotherClass.__module__ += ".foo" self.registry.add_class(self.Class) self.registry.add_class(self.SubClass) self.registry.add_class(self.AnotherClass) self.assertRaises(PropertyPathError, self.registry.get, "Class.prop1") self.assertRaises(PropertyPathError, self.registry.get, "Class.prop2") prop1 = self.registry.get("SubClass.prop1") prop2 = self.registry.get("SubClass.prop2") self.assertTrue(prop1 is self.SubClass.prop1) self.assertTrue(prop2 is self.SubClass.prop2) def test_get_ambiguous_but_different_path(self): self.AnotherClass.__module__ += ".foo" self.registry.add_class(self.Class) self.registry.add_class(self.SubClass) self.registry.add_class(self.AnotherClass) prop1 = self.registry.get("properties.Class.prop1") prop2 = self.registry.get("properties.Class.prop2") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) prop1 = self.registry.get("SubClass.prop1") prop2 = self.registry.get("SubClass.prop2") self.assertTrue(prop1 is self.SubClass.prop1) self.assertTrue(prop2 is self.SubClass.prop2) prop1 = self.registry.get("foo.Class.prop1") prop2 = self.registry.get("foo.Class.prop2") self.assertTrue(prop1 is self.AnotherClass.prop1) self.assertTrue(prop2 is self.AnotherClass.prop2) def test_get_ambiguous_but_different_path_with_namespace(self): self.AnotherClass.__module__ += ".foo" self.registry.add_class(self.Class) self.registry.add_class(self.SubClass) self.registry.add_class(self.AnotherClass) prop1 = self.registry.get("Class.prop1", "storm.tests.properties") prop2 = self.registry.get("Class.prop2", "storm.tests.properties.bar") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) prop1 = self.registry.get("Class.prop1", "storm.tests.properties.foo") prop2 = self.registry.get( "Class.prop2", "storm.tests.properties.foo.bar") self.assertTrue(prop1 is self.AnotherClass.prop1) self.assertTrue(prop2 is self.AnotherClass.prop2) def test_class_is_collectable(self): self.AnotherClass.__module__ += ".foo" self.registry.add_class(self.Class) self.registry.add_class(self.AnotherClass) del self.AnotherClass gc.collect() prop1 = self.registry.get("prop1") prop2 = self.registry.get("prop2") self.assertTrue(prop1 is self.Class.prop1) self.assertTrue(prop2 is self.Class.prop2) def test_add_property(self): self.registry.add_property(self.Class, self.Class.prop1, "custom_name") prop1 = self.registry.get("Class.custom_name") self.assertEqual(prop1, self.Class.prop1) self.assertRaises(PropertyPathError, self.registry.get, "Class.prop1") class PropertyPublisherMetaTest(TestHelper): def setUp(self): TestHelper.setUp(self) class Base(metaclass=PropertyPublisherMeta): pass class Class(Base): __storm_table__ = "mytable" prop1 = Property("column1", primary=True) prop2 = Property() class SubClass(Class): __storm_table__ = "mysubtable" self.Class = Class self.SubClass = SubClass class Class(Class): __module__ += ".foo" prop3 = Property("column3") self.AnotherClass = Class self.registry = Base._storm_property_registry def test_get_empty(self): self.assertRaises(PropertyPathError, self.registry.get, "unexistent") def test_get_subclass(self): prop1 = self.registry.get("SubClass.prop1") prop2 = self.registry.get("SubClass.prop2") self.assertTrue(prop1 is self.SubClass.prop1) self.assertTrue(prop2 is self.SubClass.prop2) def test_get_ambiguous(self): self.assertRaises(PropertyPathError, self.registry.get, "Class.prop1") self.assertRaises(PropertyPathError, self.registry.get, "Class.prop2") ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/tests/schema/0000755000175000017500000000000014645532536016603 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/tests/schema/__init__.py0000644000175000017500000000143611752263216020710 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039171.0 storm-1.0/storm/tests/schema/patch.py0000644000175000017500000003712714645174503020262 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os import shutil import sys import traceback from storm.locals import StormError, Store, create_database from storm.schema.patch import ( Patch, PatchApplier, UnknownPatchError, BadPatchError, PatchSet) from storm.tests.mocker import MockerTestCase patch_test_0 = """ x = None def apply(store): global x x = 42 from mypackage import shared_data shared_data.append(42) """ patch_test_1 = """ y = None def apply(store): global y y = 380 from mypackage import shared_data shared_data.append(380) """ patch_explosion = """ from storm.locals import StormError def apply(store): raise StormError('KABOOM!') """ patch_after_explosion = """ def apply(store): pass """ patch_no_args_apply = """ def apply(): pass """ patch_missing_apply = """ def misnamed_apply(store): pass """ patch_name_error = """ def apply(store): blah # Comment """ class MockPatchStore: def __init__(self, database, patches=[]): self.database = database self.rolled_back = 0 self.committed = 0 self.patches = patches self.objs = [] def execute(self, statement, params=None, noresult=False): pass def find(self, cls_spec, *args, **kwargs): return self.patches def rollback(self): self.rolled_back += 1 def add(self, obj): self.objs.append(obj) def commit(self): self.committed += 1 class PatchApplierTest(MockerTestCase): def setUp(self): super().setUp() self.patchdir = self.makeDir() self.pkgdir = os.path.join(self.patchdir, "mypackage") os.makedirs(self.pkgdir) f = open(os.path.join(self.pkgdir, "__init__.py"), "w") f.write("shared_data = []") f.close() # Order of creation here is important to try to screw up the # patch ordering, as os.listdir returns in order of mtime (or # something). for pname, data in [("patch_380.py", patch_test_1), ("patch_42.py", patch_test_0)]: self.add_module(pname, data) sys.path.append(self.patchdir) self.filename = self.makeFile() self.uri = "sqlite:///%s" % self.filename self.store = Store(create_database(self.uri)) self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.assertFalse(self.store.get(Patch, (42))) self.assertFalse(self.store.get(Patch, (380))) import mypackage self.mypackage = mypackage self.patch_set = PatchSet(mypackage) # Create another connection just to keep track of the state of the # whole transaction manager. See the assertion functions below. self.another_store = Store(create_database("sqlite:")) self.another_store.execute("CREATE TABLE test (id INT)") self.another_store.commit() self.prepare_for_transaction_check() class Committer: def commit(committer): self.store.commit() self.another_store.commit() def rollback(committer): self.store.rollback() self.another_store.rollback() self.committer = Committer() self.patch_applier = PatchApplier(self.store, self.patch_set, self.committer) def tearDown(self): super().tearDown() self.committer.rollback() sys.path.remove(self.patchdir) for name in list(sys.modules): if name == "mypackage" or name.startswith("mypackage."): del sys.modules[name] def add_module(self, module_filename, contents): filename = os.path.join(self.pkgdir, module_filename) file = open(filename, "w") file.write(contents) file.close() def remove_all_modules(self): for filename in os.listdir(self.pkgdir): path = os.path.join(self.pkgdir, filename) if os.path.isdir(path): shutil.rmtree(path) else: os.unlink(path) def prepare_for_transaction_check(self): self.another_store.execute("DELETE FROM test") self.another_store.execute("INSERT INTO test VALUES (1)") def assert_transaction_committed(self): self.another_store.rollback() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEqual(result, (1,), "Transaction manager wasn't committed.") def assert_transaction_aborted(self): self.another_store.commit() result = self.another_store.execute("SELECT * FROM test").get_one() self.assertEqual(result, None, "Transaction manager wasn't aborted.") def test_apply(self): """ L{PatchApplier.apply} executes the patch with the given version. """ self.patch_applier.apply(42) x = getattr(self.mypackage, "patch_42").x self.assertEqual(x, 42) self.assertTrue(self.store.get(Patch, (42))) self.assertTrue("mypackage.patch_42" in sys.modules) self.assert_transaction_committed() def test_apply_with_patch_directory(self): """ If the given L{PatchSet} uses sub-level patches, then the L{PatchApplier.apply} method will look at the per-patch directory and apply the relevant sub-level patch. """ path = os.path.join(self.pkgdir, "patch_99") self.makeDir(path=path) self.makeFile(content="", path=os.path.join(path, "__init__.py")) self.makeFile(content=patch_test_0, path=os.path.join(path, "foo.py")) self.patch_set._sub_level = "foo" self.add_module("patch_99/foo.py", patch_test_0) self.patch_applier.apply(99) self.assertTrue(self.store.get(Patch, (99))) def test_apply_all(self): """ L{PatchApplier.apply_all} executes all unapplied patches. """ self.patch_applier.apply_all() self.assertTrue("mypackage.patch_42" in sys.modules) self.assertTrue("mypackage.patch_380" in sys.modules) x = getattr(self.mypackage, "patch_42").x y = getattr(self.mypackage, "patch_380").y self.assertEqual(x, 42) self.assertEqual(y, 380) self.assert_transaction_committed() def test_apply_exploding_patch(self): """ L{PatchApplier.apply} aborts the transaction if the patch fails. """ self.remove_all_modules() self.add_module("patch_666.py", patch_explosion) self.assertRaises(StormError, self.patch_applier.apply, 666) self.assert_transaction_aborted() def test_wb_apply_all_exploding_patch(self): """ When a patch explodes the store is rolled back to make sure that any changes the patch made to the database are removed. Any other patches that have been applied successfully before it should not be rolled back. Any patches pending after the exploding patch should remain unapplied. """ self.add_module("patch_666.py", patch_explosion) self.add_module("patch_667.py", patch_after_explosion) self.assertEqual(list(self.patch_applier.get_unapplied_versions()), [42, 380, 666, 667]) self.assertRaises(StormError, self.patch_applier.apply_all) self.assertEqual(list(self.patch_applier.get_unapplied_versions()), [666, 667]) def test_mark_applied(self): """ L{PatchApplier.mark} marks a patch has applied by inserting a new row in the patch table. """ self.patch_applier.mark_applied(42) self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertFalse(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_mark_applied_all(self): """ L{PatchApplier.mark_applied_all} marks all pending patches as applied. """ self.patch_applier.mark_applied_all() self.assertFalse("mypackage.patch_42" in sys.modules) self.assertFalse("mypackage.patch_380" in sys.modules) self.assertTrue(self.store.get(Patch, 42)) self.assertTrue(self.store.get(Patch, 380)) self.assert_transaction_committed() def test_application_order(self): """ L{PatchApplier.apply_all} applies the patches in increasing version order. """ self.patch_applier.apply_all() self.assertEqual(self.mypackage.shared_data, [42, 380]) def test_has_pending_patches(self): """ L{PatchApplier.has_pending_patches} returns C{True} if there are patches to be applied, C{False} otherwise. """ self.assertTrue(self.patch_applier.has_pending_patches()) self.patch_applier.apply_all() self.assertFalse(self.patch_applier.has_pending_patches()) def test_abort_if_unknown_patches(self): """ L{PatchApplier.mark_applied} raises and error if the patch table contains patches without a matching file in the patch module. """ self.patch_applier.mark_applied(381) self.assertRaises(UnknownPatchError, self.patch_applier.apply_all) def test_get_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns the versions of all unapplied patches. """ patches = [Patch(42), Patch(380), Patch(381)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual({381}, patch_applier.get_unknown_patch_versions()) def test_no_unknown_patch_versions(self): """ L{PatchApplier.get_unknown_patch_versions} returns an empty set if no patches are unapplied. """ patches = [Patch(42), Patch(380)] my_store = MockPatchStore("database", patches=patches) patch_applier = PatchApplier(my_store, self.mypackage) self.assertEqual(set(), patch_applier.get_unknown_patch_versions()) def test_patch_with_incorrect_apply(self): """ L{PatchApplier.apply_all} raises an error as soon as one of the patches to be applied fails. """ self.add_module("patch_999.py", patch_no_args_apply) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("takes 0 positional arguments" in str(e)) self.assertTrue("TypeError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_with_missing_apply(self): """ L{PatchApplier.apply_all} raises an error if one of the patches to to be applied has no 'apply' function defined. """ self.add_module("patch_999.py", patch_missing_apply) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("no attribute" in str(e)) self.assertTrue("AttributeError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_with_syntax_error(self): """ L{PatchApplier.apply_all} raises an error if one of the patches to to be applied contains a syntax error. """ self.add_module("patch_999.py", "that's not python") try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue(" 999 " in str(e)) self.assertTrue("SyntaxError" in str(e)) else: self.fail("BadPatchError not raised") def test_patch_error_includes_traceback(self): """ The exception raised by L{PatchApplier.apply_all} when a patch fails include the relevant traceback from the patch. """ self.add_module("patch_999.py", patch_name_error) try: self.patch_applier.apply_all() except BadPatchError as e: self.assertTrue("mypackage/patch_999.py" in str(e)) self.assertTrue("NameError" in str(e)) self.assertTrue("blah" in str(e)) formatted = traceback.format_exc() self.assertTrue("# Comment" in formatted) else: self.fail("BadPatchError not raised") class PatchSetTest(MockerTestCase): def setUp(self): super().setUp() self.sys_dir = self.makeDir() self.package_dir = os.path.join(self.sys_dir, "mypackage") os.makedirs(self.package_dir) self.makeFile( content="", dirname=self.package_dir, basename="__init__.py") sys.path.append(self.sys_dir) import mypackage self.patch_package = PatchSet(mypackage, sub_level="foo") def tearDown(self): super().tearDown() for name in list(sys.modules): if name == "mypackage" or name.startswith("mypackage."): del sys.modules[name] def test_get_patch_versions(self): """ The C{get_patch_versions} method returns the available patch versions, by looking at directories named like "patch_N". """ patch_1_dir = os.path.join(self.package_dir, "patch_1") os.makedirs(patch_1_dir) self.assertEqual([1], self.patch_package.get_patch_versions()) def test_get_patch_versions_ignores_non_patch_directories(self): """ The C{get_patch_versions} method ignores files or directories not matching the required name pattern. """ random_dir = os.path.join(self.package_dir, "random") os.makedirs(random_dir) self.assertEqual([], self.patch_package.get_patch_versions()) def test_get_patch_module(self): """ The C{get_patch_module} method returns the Python module for the patch with the given version. """ patch_1_dir = os.path.join(self.package_dir, "patch_1") os.makedirs(patch_1_dir) self.makeFile(content="", dirname=patch_1_dir, basename="__init__.py") self.makeFile(content="", dirname=patch_1_dir, basename="foo.py") patch_module = self.patch_package.get_patch_module(1) self.assertEqual("mypackage.patch_1.foo", patch_module.__name__) def test_get_patch_module_no_sub_level(self): """ The C{get_patch_module} method returns a dummy patch module if no sub-level file exists in the patch directory for the given version. """ patch_1_dir = os.path.join(self.package_dir, "patch_1") os.makedirs(patch_1_dir) patch_module = self.patch_package.get_patch_module(1) store = object() self.assertIsNone(patch_module.apply(store)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/schema/schema.py0000644000175000017500000002106314645174376020423 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os import sys from storm.locals import StormError, Store, create_database from storm.schema.schema import ( Schema, SchemaMissingError, UnappliedPatchesError) from storm.tests.mocker import MockerTestCase class Package: def __init__(self, package_dir, name): self.name = name self._package_dir = package_dir def create_module(self, filename, contents): filename = os.path.join(self._package_dir, filename) file = open(filename, "w") file.write(contents) file.close() class SchemaTest(MockerTestCase): def setUp(self): super().setUp() self.database = create_database("sqlite:///%s" % self.makeFile()) self.store = Store(self.database) self._package_dirs = set() self._package_names = set() self.package = self.create_package(self.makeDir(), "patch_package") import patch_package creates = ["CREATE TABLE person (id INTEGER, name TEXT)"] drops = ["DROP TABLE person"] deletes = ["DELETE FROM person"] self.schema = Schema(creates, drops, deletes, patch_package) def tearDown(self): for package_dir in self._package_dirs: sys.path.remove(package_dir) for name in list(sys.modules): if name in self._package_names: del sys.modules[name] elif any(name.startswith("%s." % x) for x in self._package_names): del sys.modules[name] super().tearDown() def create_package(self, base_dir, name, init_module=None): """Create a Python package. Packages created using this method will be removed from L{sys.path} and L{sys.modules} during L{tearDown}. @param package_dir: The directory in which to create the new package. @param name: The name of the package. @param init_module: Optionally, the text to include in the __init__.py file. @return: A L{Package} instance that can be used to create modules. """ package_dir = os.path.join(base_dir, name) self._package_names.add(name) os.makedirs(package_dir) file = open(os.path.join(package_dir, "__init__.py"), "w") if init_module: file.write(init_module) file.close() sys.path.append(base_dir) self._package_dirs.add(base_dir) return Package(package_dir, name) def test_check_with_missing_schema(self): """ L{Schema.check} raises an exception if the given store is completely pristine and no schema has been applied yet. The transaction doesn't get rolled back so it's still usable. """ self.store.execute("CREATE TABLE foo (bar INT)") self.assertRaises(SchemaMissingError, self.schema.check, self.store) self.assertIsNone(self.store.execute("SELECT 1 FROM foo").get_one()) def test_check_with_unapplied_patches(self): """ L{Schema.check} raises an exception if the given store has unapplied schema patches. """ self.schema.create(self.store) contents = """ def apply(store): pass """ self.package.create_module("patch_1.py", contents) self.assertRaises(UnappliedPatchesError, self.schema.check, self.store) def test_create(self): """ L{Schema.create} can be used to create the tables of a L{Store}. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.create(self.store) self.assertEqual(list(self.store.execute("SELECT * FROM person")), []) # By default changes are committed store2 = Store(self.database) self.assertEqual(list(store2.execute("SELECT * FROM person")), []) def test_create_with_autocommit_off(self): """ L{Schema.autocommit} can be used to turn automatic commits off. """ self.schema.autocommit(False) self.schema.create(self.store) self.store.rollback() self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_drop(self): """ L{Schema.drop} can be used to drop the tables of a L{Store}. """ self.schema.create(self.store) self.assertEqual(list(self.store.execute("SELECT * FROM person")), []) self.schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") def test_drop_with_missing_patch_table(self): """ L{Schema.drop} works fine even if the user's supplied statements end up dropping the patch table that we created. """ import patch_package schema = Schema([], ["DROP TABLE patch"], [], patch_package) schema.create(self.store) schema.drop(self.store) self.assertRaises(StormError, self.store.execute, "SELECT * FROM patch") def test_delete(self): """ L{Schema.delete} can be used to clear the tables of a L{Store}. """ self.schema.create(self.store) self.store.execute("INSERT INTO person (id, name) VALUES (1, 'Jane')") self.assertEqual(list(self.store.execute("SELECT * FROM person")), [(1, "Jane")]) self.schema.delete(self.store) self.assertEqual(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_creates_schema(self): """ L{Schema.upgrade} creates a schema from scratch if no exist, and is effectively equivalent to L{Schema.create} in such case. """ self.assertRaises(StormError, self.store.execute, "SELECT * FROM person") self.schema.upgrade(self.store) self.assertEqual(list(self.store.execute("SELECT * FROM person")), []) def test_upgrade_marks_patches_applied(self): """ L{Schema.upgrade} updates the patch table after applying the needed patches. """ contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) statement = "SELECT * FROM patch" self.assertRaises(StormError, self.store.execute, statement) self.schema.upgrade(self.store) self.assertEqual(list(self.store.execute("SELECT * FROM patch")), [(1,)]) def test_upgrade_applies_patches(self): """ L{Schema.upgrade} executes the needed patches, that typically modify the existing schema. """ self.schema.create(self.store) contents = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ self.package.create_module("patch_1.py", contents) self.schema.upgrade(self.store) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEqual(list(self.store.execute("SELECT * FROM person")), [(1, "Jane", "123")]) def test_advance(self): """ L{Schema.advance} executes the given patch version. """ self.schema.create(self.store) contents1 = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN phone TEXT') """ contents2 = """ def apply(store): store.execute('ALTER TABLE person ADD COLUMN address TEXT') """ self.package.create_module("patch_1.py", contents1) self.package.create_module("patch_2.py", contents2) self.schema.advance(self.store, 1) self.store.execute( "INSERT INTO person (id, name, phone) VALUES (1, 'Jane', '123')") self.assertEqual(list(self.store.execute("SELECT * FROM person")), [(1, "Jane", "123")]) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/schema/sharding.py0000644000175000017500000000712014645174376020760 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2014 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.schema.schema import SchemaMissingError, UnappliedPatchesError from storm.schema.sharding import Sharding, PatchLevelMismatchError from storm.tests.mocker import MockerTestCase class FakeSchema: patches = 2 def __init__(self): self.applied = [] def check(self, store): if store.pristine: raise SchemaMissingError() if store.patch < self.patches: unapplied_versions = range(store.patch + 1, self.patches + 1) raise UnappliedPatchesError(unapplied_versions) def create(self, store): store.pristine = False def upgrade(self, store): for i in range(2): store.patch += 1 self.applied.append((store, store.patch)) def advance(self, store, version): store.patch = version self.applied.append((store, store.patch)) class FakeStore: pristine = True # If no schema was ever applied patch = 0 # Current patch level of the store class ShardingTest(MockerTestCase): def setUp(self): super().setUp() self.store = FakeStore() self.schema = FakeSchema() self.sharding = Sharding() def test_upgrade_pristine_store(self): """ Pristine L{Store}s get their L{Schema} created from scratch. """ self.sharding.add(self.store, self.schema) self.sharding.upgrade() self.assertFalse(self.store.pristine) def test_upgrade_apply_patches(self): """ If a L{Store}s is not at the latest patch level, all pending patches get applied. """ self.store.pristine = False self.sharding.add(self.store, self.schema) self.sharding.upgrade() self.assertEqual(2, self.store.patch) def test_upgrade_multi_store(self): """ If a L{Store}s is not at the latest patch level, all pending patches get applied, one level at a time. """ self.store.pristine = False self.sharding.add(self.store, self.schema) store2 = FakeStore() store2.pristine = False self.sharding.add(store2, self.schema) self.sharding.upgrade() self.assertEqual(2, self.store.patch) self.assertEqual(2, store2.patch) self.assertEqual( [(self.store, 1), (store2, 1), (self.store, 2), (store2, 2)], self.schema.applied) def test_upgrade_patch_level_mismatch(self): """ If not all L{Store}s are at the same patch level, an exception is raised. """ self.store.pristine = False self.sharding.add(self.store, self.schema) store2 = FakeStore() store2.pristine = False store2.patch = 1 self.sharding.add(store2, self.schema) self.assertRaises(PatchLevelMismatchError, self.sharding.upgrade) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/sqlobject.py0000644000175000017500000013105614645174376017715 0ustar00cjwatsoncjwatson# # Copyright (c) 2006-2010 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import datetime import operator from storm.database import create_database from storm.exceptions import NoneError from storm.sqlobject import * from storm.store import Store from storm.tz import tzutc from storm.tests.helper import TestHelper class SQLObjectTest(TestHelper): def setUp(self): TestHelper.setUp(self) # Allow classes with the same name in different tests to resolve # property path strings properly. SQLObjectBase._storm_property_registry.clear() self.store = Store(create_database("sqlite:")) class SQLObject(SQLObjectBase): @staticmethod def _get_store(): return self.store self.SQLObject = SQLObject self.store.execute("CREATE TABLE person " "(id INTEGER PRIMARY KEY, name TEXT, age INTEGER," " ts TIMESTAMP, delta INTERVAL," " address_id INTEGER)") self.store.execute("INSERT INTO person VALUES " "(1, 'John Joe', 20, '2007-02-05 19:53:15'," " '1 day, 12:34:56', 1)") self.store.execute("INSERT INTO person VALUES " "(2, 'John Doe', 20, '2007-02-05 20:53:15'," " '42 days 12:34:56.78', 2)") self.store.execute("CREATE TABLE address " "(id INTEGER PRIMARY KEY, city TEXT)") self.store.execute("INSERT INTO address VALUES (1, 'Curitiba')") self.store.execute("INSERT INTO address VALUES (2, 'Sao Carlos')") self.store.execute("CREATE TABLE phone " "(id INTEGER PRIMARY KEY, person_id INTEGER," "number TEXT)") self.store.execute("INSERT INTO phone VALUES (1, 2, '1234-5678')") self.store.execute("INSERT INTO phone VALUES (2, 1, '8765-4321')") self.store.execute("INSERT INTO phone VALUES (3, 2, '8765-5678')") self.store.execute("CREATE TABLE person_phone " "(id INTEGER PRIMARY KEY, person_id INTEGER, " "phone_id INTEGER)") self.store.execute("INSERT INTO person_phone VALUES (1, 2, 1)") self.store.execute("INSERT INTO person_phone VALUES (2, 2, 2)") self.store.execute("INSERT INTO person_phone VALUES (3, 1, 1)") class Person(self.SQLObject): _defaultOrder = "-Person.name" name = StringCol() age = IntCol() ts = UtcDateTimeCol() self.Person = Person def test_get(self): person = self.Person.get(2) self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_get_not_found(self): self.assertRaises(SQLObjectNotFound, self.Person.get, 1000) def test_get_typecast(self): person = self.Person.get("2") self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_destroySelf(self): person = self.Person.get(2) person.destroySelf() self.assertRaises(SQLObjectNotFound, self.Person.get, 2) def test_delete(self): self.Person.delete(2) self.assertRaises(SQLObjectNotFound, self.Person.get, 2) def test_custom_table_name(self): class MyPerson(self.Person): _table = "person" person = MyPerson.get(2) self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_custom_id_name(self): class MyPerson(self.SQLObject): _defaultOrder = "-Person.name" _table = "person" _idName = "name" _idType = str age = IntCol() ts = UtcDateTimeCol() person = MyPerson.get("John Doe") self.assertTrue(person) self.assertEqual(person.id, "John Doe") def test_create(self): person = self.Person(name="John Joe") self.assertTrue(Store.of(person) is self.store) self.assertEqual(type(person.id), int) self.assertEqual(person.name, "John Joe") def test_SO_creating(self): test = self class Person(self.Person): def set(self, **args): test.assertEqual(self._SO_creating, True) test.assertEqual(args, {"name": "John Joe"}) person = Person(name="John Joe") self.assertEqual(person._SO_creating, False) def test_object_not_added_if__create_fails(self): objects = [] class Person(self.Person): def _create(self, id, **kwargs): objects.append(self) raise RuntimeError self.assertRaises(RuntimeError, Person, name="John Joe") self.assertEqual(len(objects), 1) person = objects[0] self.assertEqual(Store.of(person), None) def test_init_hook(self): called = [] class Person(self.Person): def _init(self, *args, **kwargs): called.append(True) person = Person(name="John Joe") self.assertEqual(called, [True]) Person.get(2) self.assertEqual(called, [True, True]) def test_alternateID(self): class Person(self.SQLObject): name = StringCol(alternateID=True) person = Person.byName("John Doe") self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_alternateMethodName(self): class Person(self.SQLObject): name = StringCol(alternateMethodName="byFoo") person = Person.byFoo("John Doe") self.assertTrue(person) self.assertEqual(person.name, "John Doe") self.assertRaises(SQLObjectNotFound, Person.byFoo, "John None") def test_select(self): result = self.Person.select("name = 'John Joe'") self.assertEqual(result[0].name, "John Joe") def test_select_sqlbuilder(self): result = self.Person.select(self.Person.q.name == "John Joe") self.assertEqual(result[0].name, "John Joe") def test_select_orderBy(self): result = self.Person.select("name LIKE 'John%'", orderBy=("name","id")) self.assertEqual(result[0].name, "John Doe") def test_select_orderBy_expr(self): result = self.Person.select("name LIKE 'John%'", orderBy=self.Person.name) self.assertEqual(result[0].name, "John Doe") def test_select_all(self): result = self.Person.select() self.assertEqual(result[0].name, "John Joe") def test_select_empty_string(self): result = self.Person.select('') self.assertEqual(result[0].name, "John Joe") def test_select_limit(self): result = self.Person.select(limit=1) self.assertEqual(len(list(result)), 1) def test_select_negative_offset(self): result = self.Person.select(orderBy="name") self.assertEqual(result[-1].name, "John Joe") def test_select_slice_negative_offset(self): result = self.Person.select(orderBy="name")[-1:] self.assertEqual(result[0].name, "John Joe") def test_select_distinct(self): result = self.Person.select("person.name = 'John Joe'", clauseTables=["phone"], distinct=True) self.assertEqual(len(list(result)), 1) def test_select_selectAlso(self): # Since John Doe has two phone numbers, this would return him # twice without the distinct=True bit. result = self.Person.select( "person.id = phone.person_id", clauseTables=["phone"], selectAlso="LOWER(name) AS lower_name", orderBy="lower_name", distinct=True) people = list(result) self.assertEqual(len(people), 2) self.assertEqual(people[0].name, "John Doe") self.assertEqual(people[1].name, "John Joe") def test_select_selectAlso_with_prejoin(self): class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id", notNull=True) class Address(self.SQLObject): city = StringCol() result = Person.select( prejoins=["address"], selectAlso="LOWER(person.name) AS lower_name", orderBy="lower_name") people = list(result) self.assertEqual(len(people), 2) self.assertEqual([(person.name, person.address.city) for person in people], [("John Doe", "Sao Carlos"), ("John Joe", "Curitiba")]) def test_select_clauseTables_simple(self): result = self.Person.select("name = 'John Joe'", ["person"]) self.assertEqual(result[0].name, "John Joe") def test_select_clauseTables_implicit_join(self): result = self.Person.select("person.name = 'John Joe' and " "phone.person_id = person.id", ["person", "phone"]) self.assertEqual(result[0].name, "John Joe") def test_select_clauseTables_no_cls_table(self): result = self.Person.select("person.name = 'John Joe' and " "phone.person_id = person.id", ["phone"]) self.assertEqual(result[0].name, "John Joe") def test_selectBy(self): result = self.Person.selectBy(name="John Joe") self.assertEqual(result[0].name, "John Joe") def test_selectBy_orderBy(self): result = self.Person.selectBy(age=20, orderBy="name") self.assertEqual(result[0].name, "John Doe") result = self.Person.selectBy(age=20, orderBy="-name") self.assertEqual(result[0].name, "John Joe") def test_selectOne(self): person = self.Person.selectOne("name = 'John Joe'") self.assertTrue(person) self.assertEqual(person.name, "John Joe") nobody = self.Person.selectOne("name = 'John None'") self.assertEqual(nobody, None) # SQLBuilder style expression: person = self.Person.selectOne(self.Person.q.name == "John Joe") self.assertNotEqual(person, None) self.assertEqual(person.name, "John Joe") def test_selectOne_multiple_results(self): self.assertRaises(SQLObjectMoreThanOneResultError, self.Person.selectOne) def test_selectOne_clauseTables(self): person = self.Person.selectOne("person.name = 'John Joe' and " "phone.person_id = person.id", ["phone"]) self.assertEqual(person.name, "John Joe") def test_selectOneBy(self): person = self.Person.selectOneBy(name="John Joe") self.assertTrue(person) self.assertEqual(person.name, "John Joe") nobody = self.Person.selectOneBy(name="John None") self.assertEqual(nobody, None) def test_selectOneBy_multiple_results(self): self.assertRaises(SQLObjectMoreThanOneResultError, self.Person.selectOneBy) def test_selectFirst(self): person = self.Person.selectFirst("name LIKE 'John%'", orderBy="name") self.assertTrue(person) self.assertEqual(person.name, "John Doe") person = self.Person.selectFirst("name LIKE 'John%'", orderBy="-name") self.assertTrue(person) self.assertEqual(person.name, "John Joe") nobody = self.Person.selectFirst("name = 'John None'", orderBy="name") self.assertEqual(nobody, None) # SQLBuilder style expression: person = self.Person.selectFirst(LIKE(self.Person.q.name, "John%"), orderBy="name") self.assertNotEqual(person, None) self.assertEqual(person.name, "John Doe") def test_selectFirst_default_order(self): person = self.Person.selectFirst("name LIKE 'John%'") self.assertTrue(person) self.assertEqual(person.name, "John Joe") def test_selectFirst_default_order_list(self): class Person(self.Person): _defaultOrder = ["name"] person = Person.selectFirst("name LIKE 'John%'") self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_selectFirst_default_order_expr(self): class Person(self.Person): _defaultOrder = [SQLConstant("name")] person = Person.selectFirst("name LIKE 'John%'") self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_selectFirst_default_order_fully_qualified(self): class Person(self.Person): _defaultOrder = ["person.name"] person = Person.selectFirst("name LIKE 'John%'") self.assertTrue(person) self.assertEqual(person.name, "John Doe") def test_selectFirstBy(self): person = self.Person.selectFirstBy(age=20, orderBy="name") self.assertTrue(person) self.assertEqual(person.name, "John Doe") person = self.Person.selectFirstBy(age=20, orderBy="-name") self.assertTrue(person) self.assertEqual(person.name, "John Joe") nobody = self.Person.selectFirstBy(age=1000, orderBy="name") self.assertEqual(nobody, None) def test_selectFirstBy_default_order(self): person = self.Person.selectFirstBy(age=20) self.assertTrue(person) self.assertEqual(person.name, "John Joe") def test_syncUpdate(self): """syncUpdate() flushes pending changes to the database.""" person = self.Person.get(id=1) person.name = "John Smith" person.syncUpdate() name = self.store.execute( "SELECT name FROM person WHERE id = 1").get_one()[0] self.assertEqual(name, "John Smith") def test_sync(self): """sync() flushes pending changes and invalidates the cache.""" person = self.Person.get(id=1) person.name = "John Smith" person.sync() name = self.store.execute( "SELECT name FROM person WHERE id = 1").get_one()[0] self.assertEqual(name, "John Smith") # Now make a change behind Storm's back and show that sync() # makes the new value from the database visible. self.store.execute("UPDATE person SET name = 'Jane Smith' " "WHERE id = 1", noresult=True) person.sync() self.assertEqual(person.name, "Jane Smith") def test_col_name(self): class Person(self.SQLObject): foo = StringCol(dbName="name") person = Person.get(2) self.assertEqual(person.foo, "John Doe") class Person(self.SQLObject): foo = StringCol("name") person = Person.get(2) self.assertEqual(person.foo, "John Doe") def test_col_default(self): class Person(self.SQLObject): name = StringCol(default="Johny") person = Person() self.assertEqual(person.name, "Johny") def test_col_default_factory(self): class Person(self.SQLObject): name = StringCol(default=lambda: "Johny") person = Person() self.assertEqual(person.name, "Johny") def test_col_not_null(self): class Person(self.SQLObject): name = StringCol(notNull=True) person = Person.get(2) self.assertRaises(NoneError, setattr, person, "name", None) def test_col_storm_validator(self): calls = [] def validator(obj, attr, value): calls.append((obj, attr, value)) return value class Person(self.SQLObject): name = StringCol(storm_validator=validator) person = Person.get(2) person.name = 'foo' self.assertEqual(calls, [(person, 'name', 'foo')]) def test_string_col(self): class Person(self.SQLObject): name = StringCol() person = Person.get(2) self.assertEqual(person.name, "John Doe") def test_int_col(self): class Person(self.SQLObject): age = IntCol() person = Person.get(2) self.assertEqual(person.age, 20) def test_bool_col(self): class Person(self.SQLObject): age = BoolCol() person = Person.get(2) self.assertEqual(person.age, True) def test_float_col(self): class Person(self.SQLObject): age = FloatCol() person = Person.get(2) self.assertTrue(abs(person.age - 20.0) < 1e-6) def test_utcdatetime_col(self): class Person(self.SQLObject): ts = UtcDateTimeCol() person = Person.get(2) self.assertEqual(person.ts, datetime.datetime(2007, 2, 5, 20, 53, 15, tzinfo=tzutc())) def test_date_col(self): class Person(self.SQLObject): ts = DateCol() person = Person.get(2) self.assertEqual(person.ts, datetime.date(2007, 2, 5)) def test_interval_col(self): class Person(self.SQLObject): delta = IntervalCol() person = Person.get(2) self.assertEqual(person.delta, datetime.timedelta(42, 45296, 780000)) def test_foreign_key(self): class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id", notNull=True) class Address(self.SQLObject): city = StringCol() person = Person.get(2) self.assertEqual(person.addressID, 2) self.assertEqual(person.address.city, "Sao Carlos") def test_foreign_key_no_dbname(self): self.store.execute("CREATE TABLE another_person " "(id INTEGER PRIMARY KEY, name TEXT, age INTEGER," " ts TIMESTAMP, address INTEGER)") self.store.execute("INSERT INTO another_person VALUES " "(2, 'John Doe', 20, '2007-02-05 20:53:15', 2)") class AnotherPerson(self.Person): address = ForeignKey(foreignKey="Address", notNull=True) class Address(self.SQLObject): city = StringCol() person = AnotherPerson.get(2) self.assertEqual(person.addressID, 2) self.assertEqual(person.address.city, "Sao Carlos") def test_foreign_key_orderBy(self): class Person(self.Person): _defaultOrder = "address" address = ForeignKey(foreignKey="Address", dbName="address_id", notNull=True) class Address(self.SQLObject): city = StringCol() person = Person.selectFirst() self.assertEqual(person.addressID, 1) def test_foreign_key_storm_validator(self): calls = [] def validator(obj, attr, value): calls.append((obj, attr, value)) return value class Person(self.SQLObject): address = ForeignKey(foreignKey="Address", dbName="address_id", storm_validator=validator) class Address(self.SQLObject): city = StringCol() person = Person.get(2) address = Address.get(1) person.address = address self.assertEqual(calls, [(person, 'addressID', 1)]) def test_multiple_join(self): class AnotherPerson(self.Person): _table = "person" phones = SQLMultipleJoin("Phone", joinColumn="person") class Phone(self.SQLObject): person = ForeignKey("AnotherPerson", dbName="person_id") number = StringCol() person = AnotherPerson.get(2) # Make sure that the result is wrapped. result = person.phones.orderBy("-number") self.assertEqual([phone.number for phone in result], ["8765-5678", "1234-5678"]) # Test add/remove methods. number = Phone.selectOneBy(number="1234-5678") person.removePhone(number) self.assertEqual(sorted(phone.number for phone in person.phones), ["8765-5678"]) person.addPhone(number) self.assertEqual(sorted(phone.number for phone in person.phones), ["1234-5678", "8765-5678"]) def test_multiple_join_prejoins(self): self.store.execute("ALTER TABLE phone ADD COLUMN address_id INT") self.store.execute("UPDATE phone SET address_id = 1") self.store.execute("UPDATE phone SET address_id = 2 WHERE id = 3") class AnotherPerson(self.Person): _table = "person" phones = SQLMultipleJoin("Phone", joinColumn="person", orderBy="number", prejoins=["address"]) class Phone(self.SQLObject): person = ForeignKey("AnotherPerson", dbName="person_id") address = ForeignKey("Address", dbName="address_id") number = StringCol() class Address(self.SQLObject): city = StringCol() person = AnotherPerson.get(2) [phone1, phone2] = person.phones # Delete addresses behind Storm's back to show that the # addresses have been loaded. self.store.execute("DELETE FROM address") self.assertEqual(phone1.number, "1234-5678") self.assertEqual(phone1.address.city, "Curitiba") self.assertEqual(phone2.number, "8765-5678") self.assertEqual(phone2.address.city, "Sao Carlos") def test_related_join(self): class AnotherPerson(self.Person): _table = "person" phones = SQLRelatedJoin("Phone", otherColumn="phone_id", intermediateTable="PersonPhone", joinColumn="person_id", orderBy="id") class PersonPhone(self.Person): person_id = IntCol() phone_id = IntCol() class Phone(self.SQLObject): number = StringCol() person = AnotherPerson.get(2) self.assertEqual([phone.number for phone in person.phones], ["1234-5678", "8765-4321"]) # Make sure that the result is wrapped. result = person.phones.orderBy("-number") self.assertEqual([phone.number for phone in result], ["8765-4321", "1234-5678"]) # Test add/remove methods. number = Phone.selectOneBy(number="1234-5678") person.removePhone(number) self.assertEqual(sorted(phone.number for phone in person.phones), ["8765-4321"]) person.addPhone(number) self.assertEqual(sorted(phone.number for phone in person.phones), ["1234-5678", "8765-4321"]) def test_related_join_prejoins(self): self.store.execute("ALTER TABLE phone ADD COLUMN address_id INT") self.store.execute("UPDATE phone SET address_id = 1") self.store.execute("UPDATE phone SET address_id = 2 WHERE id = 2") class AnotherPerson(self.Person): _table = "person" phones = SQLRelatedJoin("Phone", otherColumn="phone_id", intermediateTable="PersonPhone", joinColumn="person_id", orderBy="id", prejoins=["address"]) class PersonPhone(self.Person): person_id = IntCol() phone_id = IntCol() class Phone(self.SQLObject): number = StringCol() address = ForeignKey("Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() person = AnotherPerson.get(2) [phone1, phone2] = person.phones # Delete addresses behind Storm's back to show that the # addresses have been loaded. self.store.execute("DELETE FROM address") self.assertEqual(phone1.number, "1234-5678") self.assertEqual(phone1.address.city, "Curitiba") self.assertEqual(phone2.number, "8765-4321") self.assertEqual(phone2.address.city, "Sao Carlos") def test_single_join(self): self.store.execute("CREATE TABLE office " "(id INTEGER PRIMARY KEY, phone_id INTEGER," "name TEXT)") self.store.execute("INSERT INTO office VALUES (1, 1, 'An office')") class Phone(self.SQLObject): office = SingleJoin("Office", joinColumn="phoneID") class Office(self.SQLObject): phone = ForeignKey(foreignKey="Phone", dbName="phone_id", notNull=True) name = StringCol() office = Office.get(1) self.assertEqual(office.name, "An office") phone = Phone.get(1) self.assertEqual(phone.office, office) # The single join returns None for a phone with no office phone = Phone.get(2) self.assertEqual(phone.office, None) def test_result_set_orderBy(self): result = self.Person.select() result = result.orderBy("-name") self.assertEqual([person.name for person in result], ["John Joe", "John Doe"]) result = result.orderBy("name") self.assertEqual([person.name for person in result], ["John Doe", "John Joe"]) def test_result_set_orderBy_fully_qualified(self): result = self.Person.select() result = result.orderBy("-person.name") self.assertEqual([person.name for person in result], ["John Joe", "John Doe"]) result = result.orderBy("person.name") self.assertEqual([person.name for person in result], ["John Doe", "John Joe"]) def test_result_set_count(self): result = self.Person.select() self.assertEqual(result.count(), 2) def test_result_set_count_limit(self): result = self.Person.select(limit=1) self.assertEqual(len(list(result)), 1) self.assertEqual(result.count(), 1) def test_result_set_count_sliced(self): result = self.Person.select() sliced_result = result[1:] self.assertEqual(len(list(sliced_result)), 1) self.assertEqual(sliced_result.count(), 1) def test_result_set_count_sliced_empty(self): result = self.Person.select() sliced_result = result[1:1] self.assertEqual(len(list(sliced_result)), 0) self.assertEqual(sliced_result.count(), 0) def test_result_set_count_sliced_empty_zero(self): result = self.Person.select() sliced_result = result[0:0] self.assertEqual(len(list(sliced_result)), 0) self.assertEqual(sliced_result.count(), 0) def test_result_set_count_sliced_none(self): result = self.Person.select() sliced_result = result[None:None] self.assertEqual(len(list(sliced_result)), 2) self.assertEqual(sliced_result.count(), 2) def test_result_set_count_sliced_start_none(self): result = self.Person.select() sliced_result = result[None:1] self.assertEqual(len(list(sliced_result)), 1) self.assertEqual(sliced_result.count(), 1) def test_result_set_count_sliced_end_none(self): result = self.Person.select() sliced_result = result[1:None] self.assertEqual(len(list(sliced_result)), 1) self.assertEqual(sliced_result.count(), 1) def test_result_set_count_distinct(self): result = self.Person.select( "person.id = phone.person_id", clauseTables=["phone"], distinct=True) self.assertEqual(result.count(), 2) def test_result_set_count_union_distinct(self): result1 = self.Person.select("person.id = 1", distinct=True) result2 = self.Person.select("person.id = 2", distinct=True) self.assertEqual(result1.union(result2).count(), 2) def test_result_set_count_with_joins(self): result = self.Person.select( "person.address_id = address.id", clauseTables=["address"]) self.assertEqual(result.count(), 2) def test_result_set__getitem__(self): result = self.Person.select() self.assertEqual(result[0].name, "John Joe") def test_result_set__iter__(self): result = self.Person.select() self.assertEqual(list(result.__iter__())[0].name, "John Joe") def test_result_set__bool__(self): """ L{SQLObjectResultSet.__bool__} returns C{True} if the result set contains results. If it contains no results, C{False} is returned. """ result = self.Person.select() self.assertEqual(result.__bool__(), True) result = self.Person.select(self.Person.q.name == "No Person") self.assertEqual(result.__bool__(), False) def test_result_set_is_empty(self): """ L{SQLObjectResultSet.is_empty} returns C{True} if the result set doesn't contain any results. If it does contain results, C{False} is returned. """ result = self.Person.select() self.assertEqual(result.is_empty(), False) result = self.Person.select(self.Person.q.name == "No Person") self.assertEqual(result.is_empty(), True) def test_result_set_distinct(self): result = self.Person.select("person.name = 'John Joe'", clauseTables=["phone"]) self.assertEqual(len(list(result.distinct())), 1) def test_result_set_limit(self): result = self.Person.select() self.assertEqual(len(list(result.limit(1))), 1) def test_result_set_union(self): result1 = self.Person.selectBy(id=1) result2 = self.Person.selectBy(id=2) result3 = result1.union(result2, orderBy="name") self.assertEqual([person.name for person in result3], ["John Doe", "John Joe"]) def test_result_set_union_all(self): result1 = self.Person.selectBy(id=1) result2 = result1.union(result1, unionAll=True) self.assertEqual([person.name for person in result2], ["John Joe", "John Joe"]) def test_result_set_except_(self): person = self.Person(id=3, name="John Moe") result1 = self.Person.select() result2 = self.Person.selectBy(id=2) result3 = result1.except_(result2, orderBy="name") self.assertEqual([person.name for person in result3], ["John Joe", "John Moe"]) def test_result_set_intersect(self): person = self.Person(id=3, name="John Moe") result1 = self.Person.select() result2 = self.Person.select(self.Person.id.is_in((2, 3))) result3 = result1.intersect(result2, orderBy="name") self.assertEqual([person.name for person in result3], ["John Doe", "John Moe"]) def test_result_set_prejoin(self): self.store.execute("ALTER TABLE person ADD COLUMN phone_id INTEGER") self.store.execute("UPDATE person SET phone_id=1 WHERE name='John Doe'") class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") phone = ForeignKey(foreignKey="Phone", dbName="phone_id") class Address(self.SQLObject): city = StringCol() class Phone(self.SQLObject): number = StringCol() result = Person.select("person.name = 'John Doe'") result = result.prejoin(["address", "phone"]) people = list(result) # Remove rows behind its back. self.store.execute("DELETE FROM address") self.store.execute("DELETE FROM phone") # They were prefetched, so it should work even then. self.assertEqual([person.address.city for person in people], ["Sao Carlos"]) self.assertEqual([person.phone.number for person in people], ["1234-5678"]) def test_result_set_prejoin_getitem(self): """Ensure that detuplelizing is used on getitem.""" class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() result = Person.select("person.name = 'John Doe'", prejoins=["address"]) person = result[0] # Remove the row behind its back. self.store.execute("DELETE FROM address") # They were prefetched, so it should work even then. self.assertEqual(person.address.city, "Sao Carlos") def test_result_set_prejoin_one(self): """Ensure that detuplelizing is used on selectOne().""" class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() person = Person.selectOne("person.name = 'John Doe'", prejoins=["address"]) # Remove the row behind its back. self.store.execute("DELETE FROM address") # They were prefetched, so it should work even then. self.assertEqual(person.address.city, "Sao Carlos") def test_result_set_prejoin_first(self): """Ensure that detuplelizing is used on selectFirst().""" class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() person = Person.selectFirst("person.name = 'John Doe'", prejoins=["address"], orderBy="name") # Remove the row behind Storm's back. self.store.execute("DELETE FROM address") # They were prefetched, so it should work even then. self.assertEqual(person.address.city, "Sao Carlos") def test_result_set_prejoin_by(self): """Ensure that prejoins work with selectBy() queries.""" class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() result = Person.selectBy(name="John Doe").prejoin(["address"]) person = result[0] # Remove the row behind Storm's back. self.store.execute("DELETE FROM address") # They were prefetched, so it should work even then. self.assertEqual(person.address.city, "Sao Carlos") def test_result_set_prejoin_related(self): """Dotted prejoins are used to prejoin through another table.""" class Phone(self.SQLObject): person = ForeignKey(foreignKey="AnotherPerson", dbName="person_id") number = StringCol() class AnotherPerson(self.Person): _table = "person" address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() phone = Phone.selectOne("phone.number = '1234-5678'", prejoins=["person.address"]) # Remove the rows behind Storm's back. self.store.execute("DELETE FROM address") self.store.execute("DELETE FROM person") # They were prefetched, so it should work even then. self.assertEqual(phone.person.name, "John Doe") self.assertEqual(phone.person.address.city, "Sao Carlos") def test_result_set_prejoin_table_twice(self): """A single table can be prejoined multiple times.""" self.store.execute("CREATE TABLE lease " "(id INTEGER PRIMARY KEY," " landlord_id INTEGER, tenant_id INTEGER)") self.store.execute("INSERT INTO lease VALUES (1, 1, 2)") class Address(self.SQLObject): city = StringCol() class AnotherPerson(self.Person): _table = "person" address = ForeignKey(foreignKey="Address", dbName="address_id") class Lease(self.SQLObject): landlord = ForeignKey(foreignKey="AnotherPerson", dbName="landlord_id") tenant = ForeignKey(foreignKey="AnotherPerson", dbName="tenant_id") lease = Lease.select(prejoins=["landlord", "landlord.address", "tenant", "tenant.address"])[0] # Remove the person rows behind Storm's back. self.store.execute("DELETE FROM address") self.store.execute("DELETE FROM person") self.assertEqual(lease.landlord.name, "John Joe") self.assertEqual(lease.landlord.address.city, "Curitiba") self.assertEqual(lease.tenant.name, "John Doe") self.assertEqual(lease.tenant.address.city, "Sao Carlos") def test_result_set_prejoin_count(self): """Prejoins do not affect the result of aggregates like COUNT().""" class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() result = Person.select("name = 'John Doe'", prejoins=["address"]) self.assertEqual(result.count(), 1) def test_result_set_prejoin_mismatch_union(self): """Prejoins do not cause UNION incompatibilities. """ class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() # The prejoin should not prevent the union from working. At # the moment this is done by unconditionally stripping the # prejoins (which is what our SQLObject patch did), but could # be smarter. result1 = Person.select("name = 'John Doe'", prejoins=["address"]) result2 = Person.select("name = 'John Joe'") result = result1.union(result2) names = sorted(person.name for person in result) self.assertEqual(names, ["John Doe", "John Joe"]) def test_result_set_prejoin_mismatch_except(self): """Prejoins do not cause EXCEPT incompatibilities. """ class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() # The prejoin should not prevent the union from working. At # the moment this is done by unconditionally stripping the # prejoins (which is what our SQLObject patch did), but could # be smarter. result1 = Person.select("name = 'John Doe'", prejoins=["address"]) result2 = Person.select("name = 'John Joe'") result = result1.except_(result2) names = sorted(person.name for person in result) self.assertEqual(names, ["John Doe"]) def test_result_set_prejoin_mismatch_intersect(self): """Prejoins do not cause INTERSECT incompatibilities. """ class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() # The prejoin should not prevent the union from working. At # the moment this is done by unconditionally stripping the # prejoins (which is what our SQLObject patch did), but could # be smarter. result1 = Person.select("name = 'John Doe'", prejoins=["address"]) result2 = Person.select("name = 'John Doe'") result = result1.intersect(result2) names = sorted(person.name for person in result) self.assertEqual(names, ["John Doe"]) def test_result_set_prejoinClauseTables(self): self.store.execute("ALTER TABLE person ADD COLUMN phone_id INTEGER") self.store.execute("UPDATE person SET phone_id=1 WHERE name='John Doe'") class Person(self.Person): address = ForeignKey(foreignKey="AddressClass", dbName="address_id") phone = ForeignKey(foreignKey="PhoneClass", dbName="phone_id") # Name the class so that it doesn't match the table name, to ensure # that the prejoin is actually using table names, rather than class # names. class AddressClass(self.SQLObject): _table = "address" city = StringCol() class PhoneClass(self.SQLObject): _table = "phone" number = StringCol() result = Person.select("person.name = 'John Doe' and " "person.phone_id = phone.id and " "person.address_id = address.id", clauseTables=["address", "phone"]) result = result.prejoinClauseTables(["address", "phone"]) people = list(result) # Remove rows behind its back. self.store.execute("DELETE FROM address") self.store.execute("DELETE FROM phone") # They were prefetched, so it should work even then. self.assertEqual([person.address.city for person in people], ["Sao Carlos"]) self.assertEqual([person.phone.number for person in people], ["1234-5678"]) def test_result_set_sum_string(self): result = self.Person.select() self.assertEqual(result.sum('age'), 40) def test_result_set_sum_expr(self): result = self.Person.select() self.assertEqual(result.sum(self.Person.q.age), 40) def test_result_set_contains(self): john = self.Person.selectOneBy(name="John Doe") self.assertTrue(john in self.Person.select()) self.assertFalse(john in self.Person.selectBy(name="John Joe")) self.assertFalse(john in self.Person.select( "Person.name = 'John Joe'")) def test_result_set_contains_does_not_use_iter(self): """Calling 'item in result_set' does not iterate over the set. """ def no_iter(self): raise RuntimeError real_iter = SQLObjectResultSet.__iter__ SQLObjectResultSet.__iter__ = no_iter try: john = self.Person.selectOneBy(name="John Doe") self.assertTrue(john in self.Person.select()) finally: SQLObjectResultSet.__iter__ = real_iter def test_result_set_contains_wrong_type(self): class Address(self.SQLObject): city = StringCol() address = Address.get(1) result_set = self.Person.select() self.assertRaises(TypeError, operator.contains, result_set, address) def test_result_set_contains_with_prejoins(self): class Person(self.Person): address = ForeignKey(foreignKey="Address", dbName="address_id") class Address(self.SQLObject): city = StringCol() john = Person.selectOneBy(name="John Doe") result_set = Person.select("name = 'John Doe'", prejoins=["address"]) self.assertTrue(john in result_set) def test_table_dot_q(self): # Table.q.fieldname is a syntax used in SQLObject for # sqlbuilder expressions. Storm can use the main properties # for this, so the Table.q syntax just returns those # properties: class Person(self.SQLObject): _idName = "name" _idType = str address = ForeignKey(foreignKey="Phone", dbName="address_id", notNull=True) self.assertEqual(id(Person.q.id), id(Person.id)) self.assertEqual(id(Person.q.address), id(Person.address)) self.assertEqual(id(Person.q.addressID), id(Person.addressID)) person = Person.get("John Joe") self.assertEqual(id(person.q.id), id(Person.id)) self.assertEqual(id(person.q.address), id(Person.address)) self.assertEqual(id(person.q.addressID), id(Person.addressID)) def test_set(self): class Person(self.Person): def set(self, **kw): kw["id"] += 1 super().set(**kw) person = Person(id=3, name="John Moe") self.assertEqual(person.id, 4) self.assertEqual(person.name, "John Moe") def test_CONTAINSSTRING(self): expr = CONTAINSSTRING(self.Person.q.name, "Do") result = self.Person.select(expr) self.assertEqual([person.name for person in result], ["John Doe"]) result[0].name = "Funny !%_ Name" expr = NOT(CONTAINSSTRING(self.Person.q.name, "!%_")) result = self.Person.select(expr) self.assertEqual([person.name for person in result], ["John Joe"]) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/tests/store/0000755000175000017500000000000014645532536016477 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/tests/store/__init__.py0000644000175000017500000000000011752263216020566 0ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/store/base.py0000644000175000017500000064364014645174376020004 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import decimal import gc from io import StringIO import operator import pickle from uuid import uuid4 import weakref from storm.references import Reference, ReferenceSet, Proxy from storm.database import Result, STATE_DISCONNECTED from storm.properties import ( Int, Float, Bytes, Unicode, Property, Pickle, UUID) from storm.properties import PropertyPublisherMeta, Decimal from storm.variables import PickleVariable from storm.expr import ( Asc, Desc, Select, LeftJoin, SQL, Count, Sum, Avg, And, Or, Eq, Lower) from storm.variables import Variable, UnicodeVariable, IntVariable from storm.info import get_obj_info, ClassAlias from storm.exceptions import ( ClosedError, ConnectionBlockedError, FeatureError, LostObjectError, NoStoreError, NotFlushedError, NotOneError, OrderLoopError, UnorderedError, WrongStoreError, DisconnectionError) from storm.cache import Cache from storm.store import AutoReload, EmptyResultSet, Store, ResultSet from storm.tracer import debug from storm.tests.info import Wrapper from storm.tests.helper import TestHelper class Foo: __storm_table__ = "foo" id = Int(primary=True) title = Unicode() class Bar: __storm_table__ = "bar" id = Int(primary=True) title = Unicode() foo_id = Int() foo = Reference(foo_id, Foo.id) class UniqueID: __storm_table__ = "unique_id" id = UUID(primary=True) def __init__(self, id): self.id = id class Blob: __storm_table__ = "bin" id = Int(primary=True) bin = Bytes() class Link: __storm_table__ = "link" __storm_primary__ = "foo_id", "bar_id" foo_id = Int() bar_id = Int() class SelfRef: __storm_table__ = "selfref" id = Int(primary=True) title = Unicode() selfref_id = Int() selfref = Reference(selfref_id, id) selfref_on_remote = Reference(id, selfref_id, on_remote=True) class FooRef(Foo): bar = Reference(Foo.id, Bar.foo_id) class FooRefSet(Foo): bars = ReferenceSet(Foo.id, Bar.foo_id) class FooRefSetOrderID(Foo): bars = ReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.id) class FooRefSetOrderTitle(Foo): bars = ReferenceSet(Foo.id, Bar.foo_id, order_by=Bar.title) class FooIndRefSet(Foo): bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id) class FooIndRefSetOrderID(Foo): bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id, order_by=Bar.id) class FooIndRefSetOrderTitle(Foo): bars = ReferenceSet(Foo.id, Link.foo_id, Link.bar_id, Bar.id, order_by=Bar.title) class FooValue: __storm_table__ = "foovalue" id = Int(primary=True) foo_id = Int() value1 = Int() value2 = Int() class BarProxy: __storm_table__ = "bar" id = Int(primary=True) title = Unicode() foo_id = Int() foo = Reference(foo_id, Foo.id) foo_title = Proxy(foo, Foo.title) class Money: __storm_table__ = "money" id = Int(primary=True) value = Decimal() class DecorateVariable(Variable): def parse_get(self, value, to_db): return "to_%s(%s)" % (to_db and "db" or "py", value) def parse_set(self, value, from_db): return "from_%s(%s)" % (from_db and "db" or "py", value) class FooVariable(Foo): title = Property(variable_class=DecorateVariable) class DummyDatabase: def connect(self, event=None): return None class StoreCacheTest(TestHelper): def test_wb_custom_cache(self): cache = Cache(25) store = Store(DummyDatabase(), cache=cache) self.assertEqual(store._cache, cache) def test_wb_default_cache_size(self): store = Store(DummyDatabase()) self.assertEqual(store._cache._size, 1000) class StoreDatabaseTest(TestHelper): def test_store_has_reference_to_its_database(self): database = DummyDatabase() store = Store(database) self.assertIdentical(store.get_database(), database) class StoreTest: def setUp(self): self.store = None self.stores = [] self.create_database() self.connection = self.database.connect() self.drop_tables() self.create_tables() self.create_sample_data() self.create_store() def tearDown(self): self.drop_store() self.drop_sample_data() self.drop_tables() self.drop_database() self.connection.close() def create_database(self): raise NotImplementedError def create_tables(self): raise NotImplementedError def create_sample_data(self): connection = self.connection connection.execute("INSERT INTO foo (id, title)" " VALUES (10, 'Title 30')") connection.execute("INSERT INTO foo (id, title)" " VALUES (20, 'Title 20')") connection.execute("INSERT INTO foo (id, title)" " VALUES (30, 'Title 10')") connection.execute("INSERT INTO bar (id, foo_id, title)" " VALUES (100, 10, 'Title 300')") connection.execute("INSERT INTO bar (id, foo_id, title)" " VALUES (200, 20, 'Title 200')") connection.execute("INSERT INTO bar (id, foo_id, title)" " VALUES (300, 30, 'Title 100')") connection.execute("INSERT INTO bin (id, bin) VALUES (10, 'Blob 30')") connection.execute("INSERT INTO bin (id, bin) VALUES (20, 'Blob 20')") connection.execute("INSERT INTO bin (id, bin) VALUES (30, 'Blob 10')") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 100)") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 200)") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (10, 300)") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (20, 100)") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (20, 200)") connection.execute("INSERT INTO link (foo_id, bar_id) VALUES (30, 300)") connection.execute("INSERT INTO money (id, value)" " VALUES (10, '12.3455')") connection.execute("INSERT INTO selfref (id, title, selfref_id)" " VALUES (15, 'SelfRef 15', NULL)") connection.execute("INSERT INTO selfref (id, title, selfref_id)" " VALUES (25, 'SelfRef 25', NULL)") connection.execute("INSERT INTO selfref (id, title, selfref_id)" " VALUES (35, 'SelfRef 35', 15)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (1, 10, 2, 1)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (2, 10, 2, 1)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (3, 10, 2, 1)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (4, 10, 2, 2)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (5, 20, 1, 3)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (6, 20, 1, 3)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (7, 20, 1, 4)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (8, 20, 1, 4)") connection.execute("INSERT INTO foovalue (id, foo_id, value1, value2)" " VALUES (9, 20, 1, 2)") connection.commit() def create_store(self): store = Store(self.database) self.stores.append(store) if self.store is None: self.store = store return store def drop_store(self): for store in self.stores: store.rollback() # Closing the store is needed because testcase objects are all # instantiated at once, and thus connections are kept open. store.close() def drop_sample_data(self): pass def drop_tables(self): for table in ["foo", "bar", "bin", "link", "money", "selfref", "foovalue", "unique_id"]: try: self.connection.execute("DROP TABLE %s" % table) self.connection.commit() except: self.connection.rollback() def drop_database(self): pass def get_items(self): # Bypass the store to avoid flushing. connection = self.store._connection result = connection.execute("SELECT * FROM foo ORDER BY id") return list(result) def get_committed_items(self): connection = self.database.connect() result = connection.execute("SELECT * FROM foo ORDER BY id") return list(result) def get_cache(self, store): # We don't offer a public API for this just yet. return store._cache def test_execute(self): result = self.store.execute("SELECT 1") self.assertTrue(isinstance(result, Result)) self.assertEqual(result.get_one(), (1,)) result = self.store.execute("SELECT 1", noresult=True) self.assertEqual(result, None) def test_execute_params(self): result = self.store.execute("SELECT ?", [1]) self.assertTrue(isinstance(result, Result)) self.assertEqual(result.get_one(), (1,)) def test_execute_flushes(self): foo = self.store.get(Foo, 10) foo.title = "New Title" result = self.store.execute("SELECT title FROM foo WHERE id=10") self.assertEqual(result.get_one(), ("New Title",)) def test_close(self): store = Store(self.database) store.close() self.assertRaises(ClosedError, store.execute, "SELECT 1") def test_get(self): foo = self.store.get(Foo, 10) self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.get(Foo, 20) self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") foo = self.store.get(Foo, 40) self.assertEqual(foo, None) def test_get_cached(self): foo = self.store.get(Foo, 10) self.assertTrue(self.store.get(Foo, 10) is foo) def test_wb_get_cached_doesnt_need_connection(self): foo = self.store.get(Foo, 10) connection = self.store._connection self.store._connection = None self.store.get(Foo, 10) self.store._connection = connection def test_cache_cleanup(self): # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) foo = self.store.get(Foo, 10) foo.taint = True del foo gc.collect() foo = self.store.get(Foo, 10) self.assertFalse(getattr(foo, "taint", False)) def test_add_returns_object(self): """ Store.add() returns the object passed to it. This allows this kind of code: thing = Thing() store.add(thing) return thing to be simplified as: return store.add(Thing()) """ foo = Foo() self.assertEqual(self.store.add(foo), foo) def test_add_and_stop_referencing(self): # After adding an object, no references should be needed in # python for it still to be added to the database. foo = Foo() foo.title = "live" self.store.add(foo) del foo gc.collect() self.assertTrue(self.store.find(Foo, title="live").one()) def test_obj_info_with_deleted_object(self): # Let's try to put Storm in trouble by killing the object # while still holding a reference to the obj_info. # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) class MyFoo(Foo): loaded = False def __storm_loaded__(self): self.loaded = True foo = self.store.get(MyFoo, 20) foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() self.assertEqual(obj_info.get_obj(), None) foo = self.store.find(MyFoo, id=20).one() self.assertTrue(foo) self.assertFalse(getattr(foo, "tainted", False)) # The object was rebuilt, so the loaded hook must have run. self.assertTrue(foo.loaded) def test_obj_info_with_deleted_object_and_changed_event(self): """ When an object is collected, the variables disable change notification to not create a leak. If we're holding a reference to the obj_info and rebuild the object, it should re-enable change notication. """ class PickleBlob(Blob): bin = Pickle() # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() del blob gc.collect() pickle_blob = self.store.get(PickleBlob, 20) obj_info = get_obj_info(pickle_blob) del pickle_blob gc.collect() self.assertEqual(obj_info.get_obj(), None) pickle_blob = self.store.get(PickleBlob, 20) pickle_blob.bin = "foobin" events = [] obj_info.event.hook("changed", lambda *args: events.append(args)) self.store.flush() self.assertEqual(len(events), 1) def test_wb_flush_event_with_deleted_object_before_flush(self): """ When an object is deleted before flush and it contains mutable variables, those variables unhook from the global event system to prevent a leak. """ class PickleBlob(Blob): bin = Pickle() # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() del blob gc.collect() pickle_blob = self.store.get(PickleBlob, 20) pickle_blob.bin = "foobin" del pickle_blob self.store.flush() self.assertEqual(self.store._event._hooks["flush"], set()) def test_mutable_variable_detect_change_from_alive(self): """ Changes in a mutable variable like a L{PickleVariable} are correctly detected, even if the object comes from the alive cache. """ class PickleBlob(Blob): bin = Pickle() blob = PickleBlob() blob.bin = {"k": "v"} blob.id = 4000 self.store.add(blob) self.store.commit() blob = self.store.find(PickleBlob, PickleBlob.id == 4000).one() blob.bin["k1"] = "v1" self.store.commit() blob = self.store.find(PickleBlob, PickleBlob.id == 4000).one() self.assertEqual(blob.bin, {"k1": "v1", "k": "v"}) def test_mutable_variable_no_reference_cycle(self): """ Mutable variables only hold weak refs to EventSystem, to prevent leaks. """ class PickleBlob(Blob): bin = Pickle() blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() del blob # Get an existing object and make an unflushed change to it so that # a flush hook for the variable is registered with the event system. pickle_blob = self.store.get(PickleBlob, 20) pickle_blob.bin = "foobin" pickle_blob_ref = weakref.ref(pickle_blob) del pickle_blob for store in self.stores: store.close() del store self.store = None self.stores = [] gc.collect() self.assertIsNone(pickle_blob_ref()) def test_wb_checkpoint_doesnt_override_changed(self): """ This test ensures that we don't uselessly checkpoint when getting back objects from the alive cache, which would hide changed values from the store. """ foo = self.store.get(Foo, 20) foo.title = "changed" self.store.block_implicit_flushes() foo2 = self.store.find(Foo, Foo.id == 20).one() self.store.unblock_implicit_flushes() self.store.commit() foo3 = self.store.find(Foo, Foo.id == 20).one() self.assertEqual(foo3.title, "changed") def test_obj_info_with_deleted_object_with_get(self): # Same thing, but using get rather than find. # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) foo = self.store.get(Foo, 20) foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() self.assertEqual(obj_info.get_obj(), None) foo = self.store.get(Foo, 20) self.assertTrue(foo) self.assertFalse(getattr(foo, "tainted", False)) def test_delete_object_when_obj_info_is_dirty(self): """Object should stay in memory if dirty.""" # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) foo = self.store.get(Foo, 20) foo.title = "Changed" foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() self.assertTrue(obj_info.get_obj()) def test_get_tuple(self): class MyFoo(Foo): __storm_primary__ = "title", "id" foo = self.store.get(MyFoo, ("Title 30", 10)) self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.get(MyFoo, ("Title 20", 10)) self.assertEqual(foo, None) def test_of(self): foo = self.store.get(Foo, 10) self.assertEqual(Store.of(foo), self.store) self.assertEqual(Store.of(Foo()), None) self.assertEqual(Store.of(object()), None) def test_is_empty(self): result = self.store.find(Foo, id=300) self.assertEqual(result.is_empty(), True) result = self.store.find(Foo, id=30) self.assertEqual(result.is_empty(), False) def test_is_empty_strips_order_by(self): """ L{ResultSet.is_empty} strips the C{ORDER BY} clause, if one is present, since it isn't required to actually determine if a result set has any matching rows. This should provide a performance improvement when the ordered result set would be large. """ stream = StringIO() self.addCleanup(debug, False) debug(True, stream) result = self.store.find(Foo, Foo.id == 300) result.order_by(Foo.id) self.assertEqual(True, result.is_empty()) self.assertNotIn("ORDER BY", stream.getvalue()) def test_is_empty_with_composed_key(self): result = self.store.find(Link, foo_id=300, bar_id=3000) self.assertEqual(result.is_empty(), True) result = self.store.find(Link, foo_id=30, bar_id=300) self.assertEqual(result.is_empty(), False) def test_is_empty_with_expression_find(self): result = self.store.find(Foo.title, Foo.id == 300) self.assertEqual(result.is_empty(), True) result = self.store.find(Foo.title, Foo.id == 30) self.assertEqual(result.is_empty(), False) def test_find_iter(self): result = self.store.find(Foo) lst = [(foo.id, foo.title) for foo in result] lst.sort() self.assertEqual(lst, [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_find_from_cache(self): foo = self.store.get(Foo, 10) self.assertTrue(self.store.find(Foo, id=10).one() is foo) def test_find_expr(self): result = self.store.find(Foo, Foo.id == 20, Foo.title == "Title 20") self.assertEqual([(foo.id, foo.title) for foo in result], [ (20, "Title 20"), ]) result = self.store.find(Foo, Foo.id == 10, Foo.title == "Title 20") self.assertEqual([(foo.id, foo.title) for foo in result], [ ]) def test_find_sql(self): foo = self.store.find(Foo, SQL("foo.id = 20")).one() self.assertEqual(foo.title, "Title 20") def test_find_str(self): foo = self.store.find(Foo, "foo.id = 20").one() self.assertEqual(foo.title, "Title 20") def test_find_keywords(self): result = self.store.find(Foo, id=20, title="Title 20") self.assertEqual([(foo.id, foo.title) for foo in result], [ (20, "Title 20") ]) result = self.store.find(Foo, id=10, title="Title 20") self.assertEqual([(foo.id, foo.title) for foo in result], [ ]) def test_find_order_by(self, *args): result = self.store.find(Foo).order_by(Foo.title) lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [ (30, "Title 10"), (20, "Title 20"), (10, "Title 30"), ]) def test_find_order_asc(self, *args): result = self.store.find(Foo).order_by(Asc(Foo.title)) lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [ (30, "Title 10"), (20, "Title 20"), (10, "Title 30"), ]) def test_find_order_desc(self, *args): result = self.store.find(Foo).order_by(Desc(Foo.title)) lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_find_default_order_asc(self): class MyFoo(Foo): __storm_order__ = "title" result = self.store.find(MyFoo) lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [ (30, "Title 10"), (20, "Title 20"), (10, "Title 30"), ]) def test_find_default_order_desc(self): class MyFoo(Foo): __storm_order__ = "-title" result = self.store.find(MyFoo) lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_find_default_order_with_tuple(self): class MyLink(Link): __storm_order__ = ("foo_id", "-bar_id") result = self.store.find(MyLink) lst = [(link.foo_id, link.bar_id) for link in result] self.assertEqual(lst, [ (10, 300), (10, 200), (10, 100), (20, 200), (20, 100), (30, 300), ]) def test_find_default_order_with_tuple_and_expr(self): class MyLink(Link): __storm_order__ = ("foo_id", Desc(Link.bar_id)) result = self.store.find(MyLink) lst = [(link.foo_id, link.bar_id) for link in result] self.assertEqual(lst, [ (10, 300), (10, 200), (10, 100), (20, 200), (20, 100), (30, 300), ]) def test_find_index(self): """ L{ResultSet.__getitem__} returns the object at the specified index. if a slice is used, a new L{ResultSet} is returned configured with the appropriate offset and limit. """ foo = self.store.find(Foo).order_by(Foo.title)[0] self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") foo = self.store.find(Foo).order_by(Foo.title)[1] self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") foo = self.store.find(Foo).order_by(Foo.title)[2] self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.find(Foo).order_by(Foo.title)[1:][1] self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") result = self.store.find(Foo).order_by(Foo.title) self.assertRaises(IndexError, result.__getitem__, 3) def test_find_slice(self): result = self.store.find(Foo).order_by(Foo.title)[1:2] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20")]) def test_find_slice_offset(self): result = self.store.find(Foo).order_by(Foo.title)[1:] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20"), (10, "Title 30")]) def test_find_slice_offset_any(self): foo = self.store.find(Foo).order_by(Foo.title)[1:].any() self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") def test_find_slice_offset_one(self): foo = self.store.find(Foo).order_by(Foo.title)[1:2].one() self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") def test_find_slice_offset_first(self): foo = self.store.find(Foo).order_by(Foo.title)[1:].first() self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") def test_find_slice_offset_last(self): foo = self.store.find(Foo).order_by(Foo.title)[1:].last() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") def test_find_slice_limit(self): result = self.store.find(Foo).order_by(Foo.title)[:2] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(30, "Title 10"), (20, "Title 20")]) def test_find_slice_limit_last(self): result = self.store.find(Foo).order_by(Foo.title)[:2] self.assertRaises(FeatureError, result.last) def test_find_slice_slice(self): result = self.store.find(Foo).order_by(Foo.title)[0:2][1:3] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20")]) result = self.store.find(Foo).order_by(Foo.title)[:2][1:3] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20")]) result = self.store.find(Foo).order_by(Foo.title)[1:3][0:1] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20")]) result = self.store.find(Foo).order_by(Foo.title)[1:3][:1] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, [(20, "Title 20")]) result = self.store.find(Foo).order_by(Foo.title)[5:5][1:1] lst = [(foo.id, foo.title) for foo in result] self.assertEqual(lst, []) def test_find_slice_order_by(self): result = self.store.find(Foo)[2:] self.assertRaises(FeatureError, result.order_by, None) result = self.store.find(Foo)[:2] self.assertRaises(FeatureError, result.order_by, None) def test_find_slice_remove(self): result = self.store.find(Foo)[2:] self.assertRaises(FeatureError, result.remove) result = self.store.find(Foo)[:2] self.assertRaises(FeatureError, result.remove) def test_find_contains(self): foo = self.store.get(Foo, 10) result = self.store.find(Foo) self.assertEqual(foo in result, True) result = self.store.find(Foo, Foo.id == 20) self.assertEqual(foo in result, False) result = self.store.find(Foo, "foo.id = 20") self.assertEqual(foo in result, False) def test_find_contains_wrong_type(self): foo = self.store.get(Foo, 10) bar = self.store.get(Bar, 200) self.assertRaises(TypeError, operator.contains, self.store.find(Foo), bar) self.assertRaises(TypeError, operator.contains, self.store.find((Foo,)), foo) self.assertRaises(TypeError, operator.contains, self.store.find(Foo), (foo,)) self.assertRaises(TypeError, operator.contains, self.store.find((Foo, Bar)), (bar, foo)) def test_find_contains_does_not_use_iter(self): def no_iter(self): raise RuntimeError() orig_iter = ResultSet.__iter__ ResultSet.__iter__ = no_iter try: foo = self.store.get(Foo, 10) result = self.store.find(Foo) self.assertEqual(foo in result, True) finally: ResultSet.__iter__ = orig_iter def test_find_contains_with_composed_key(self): link = self.store.get(Link, (10, 100)) result = self.store.find(Link, Link.foo_id == 10) self.assertEqual(link in result, True) result = self.store.find(Link, Link.bar_id == 200) self.assertEqual(link in result, False) def test_find_contains_with_set_expression(self): foo = self.store.get(Foo, 10) result1 = self.store.find(Foo, Foo.id == 10) result2 = self.store.find(Foo, Foo.id != 10) self.assertEqual(foo in result1.union(result2), True) if self.__class__.__name__.startswith("MySQL"): return self.assertEqual(foo in result1.intersection(result2), False) self.assertEqual(foo in result1.intersection(result1), True) self.assertEqual(foo in result1.difference(result2), True) self.assertEqual(foo in result1.difference(result1), False) def test_find_any(self, *args): """ L{ResultSet.any} returns an arbitrary objects from the result set. """ self.assertNotEqual(None, self.store.find(Foo).any()) self.assertEqual(None, self.store.find(Foo, id=40).any()) def test_find_any_strips_order_by(self): """ L{ResultSet.any} strips the C{ORDER BY} clause, if one is present, since it isn't required. This should provide a performance improvement when the ordered result set would be large. """ stream = StringIO() self.addCleanup(debug, False) debug(True, stream) result = self.store.find(Foo, Foo.id == 300) result.order_by(Foo.id) result.any() self.assertNotIn("ORDER BY", stream.getvalue()) def test_find_first(self, *args): self.assertRaises(UnorderedError, self.store.find(Foo).first) foo = self.store.find(Foo).order_by(Foo.title).first() self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") foo = self.store.find(Foo).order_by(Foo.id).first() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.find(Foo, id=40).order_by(Foo.id).first() self.assertEqual(foo, None) def test_find_last(self, *args): self.assertRaises(UnorderedError, self.store.find(Foo).last) foo = self.store.find(Foo).order_by(Foo.title).last() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.find(Foo).order_by(Foo.id).last() self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") foo = self.store.find(Foo, id=40).order_by(Foo.id).last() self.assertEqual(foo, None) def test_find_last_desc(self, *args): foo = self.store.find(Foo).order_by(Desc(Foo.title)).last() self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") foo = self.store.find(Foo).order_by(Asc(Foo.id)).last() self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") def test_find_one(self, *args): self.assertRaises(NotOneError, self.store.find(Foo).one) foo = self.store.find(Foo, id=10).one() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") foo = self.store.find(Foo, id=40).one() self.assertEqual(foo, None) def test_find_count(self): self.assertEqual(self.store.find(Foo).count(), 3) def test_find_count_after_slice(self): """ When we slice a ResultSet obtained after a set operation (like union), we get a fresh select that doesn't modify the limit and offset attribute of the original ResultSet. """ result1 = self.store.find(Foo, Foo.id == 10) result2 = self.store.find(Foo, Foo.id == 20) result3 = result1.union(result2) result3.order_by(Foo.id) self.assertEqual(result3.count(), 2) result_slice = list(result3[:2]) self.assertEqual(result3.count(), 2) def test_find_count_column(self): self.assertEqual(self.store.find(Link).count(Link.foo_id), 6) def test_find_count_column_distinct(self): count = self.store.find(Link).count(Link.foo_id, distinct=True) self.assertEqual(count, 3) def test_find_limit_count(self): result = self.store.find(Link.foo_id) result.config(limit=2) count = result.count() self.assertEqual(count, 2) def test_find_offset_count(self): result = self.store.find(Link.foo_id) result.config(offset=3) count = result.count() self.assertEqual(count, 3) def test_find_sliced_count(self): result = self.store.find(Link.foo_id) count = result[2:4].count() self.assertEqual(count, 2) def test_find_distinct_count(self): result = self.store.find(Link.foo_id) result.config(distinct=True) count = result.count() self.assertEqual(count, 3) def test_find_distinct_order_by_limit_count(self): result = self.store.find(Foo) result.order_by(Foo.title) result.config(distinct=True, limit=3) count = result.count() self.assertEqual(count, 3) def test_find_distinct_count_multiple_columns(self): result = self.store.find((Link.foo_id, Link.bar_id)) result.config(distinct=True) count = result.count() self.assertEqual(count, 6) def test_find_count_column_with_implicit_distinct(self): result = self.store.find(Link) result.config(distinct=True) count = result.count(Link.foo_id) self.assertEqual(count, 6) def test_find_max(self): self.assertEqual(self.store.find(Foo).max(Foo.id), 30) def test_find_max_expr(self): self.assertEqual(self.store.find(Foo).max(Foo.id + 1), 31) def test_find_max_unicode(self): title = self.store.find(Foo).max(Foo.title) self.assertEqual(title, "Title 30") self.assertTrue(isinstance(title, str)) def test_find_max_with_empty_result_and_disallow_none(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int(allow_none=False) result = self.store.find(Bar, Bar.id > 1000) self.assertTrue(result.is_empty()) self.assertEqual(result.max(Bar.foo_id), None) def test_find_min(self): self.assertEqual(self.store.find(Foo).min(Foo.id), 10) def test_find_min_expr(self): self.assertEqual(self.store.find(Foo).min(Foo.id - 1), 9) def test_find_min_unicode(self): title = self.store.find(Foo).min(Foo.title) self.assertEqual(title, "Title 10") self.assertTrue(isinstance(title, str)) def test_find_min_with_empty_result_and_disallow_none(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int(allow_none=False) result = self.store.find(Bar, Bar.id > 1000) self.assertTrue(result.is_empty()) self.assertEqual(result.min(Bar.foo_id), None) def test_find_avg(self): self.assertEqual(self.store.find(Foo).avg(Foo.id), 20) def test_find_avg_expr(self): self.assertEqual(self.store.find(Foo).avg(Foo.id + 10), 30) def test_find_avg_float(self): foo = Foo() foo.id = 15 foo.title = "Title 15" self.store.add(foo) self.assertEqual(self.store.find(Foo).avg(Foo.id), 18.75) def test_find_sum(self): self.assertEqual(self.store.find(Foo).sum(Foo.id), 60) def test_find_sum_expr(self): self.assertEqual(self.store.find(Foo).sum(Foo.id * 2), 120) def test_find_sum_with_empty_result_and_disallow_none(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int(allow_none=False) result = self.store.find(Bar, Bar.id > 1000) self.assertTrue(result.is_empty()) self.assertEqual(result.sum(Bar.foo_id), None) def test_find_max_order_by(self): """Interaction between order by and aggregation shouldn't break.""" result = self.store.find(Foo) self.assertEqual(result.order_by(Foo.id).max(Foo.id), 30) def test_find_get_select_expr_without_columns(self): """ A L{FeatureError} is raised if L{ResultSet.get_select_expr} is called without a list of L{Column}s. """ result = self.store.find(Foo) self.assertRaises(FeatureError, result.get_select_expr) def test_find_get_select_expr(self): """ Only the specified L{Column}s are included in the L{Select} expression provided by L{ResultSet.get_select_expr}. """ foo = self.store.get(Foo, 10) result1 = self.store.find(Foo, Foo.id <= 10) subselect = result1.get_select_expr(Foo.id) self.assertEqual((Foo.id,), subselect.columns) result2 = self.store.find(Foo, Foo.id.is_in(subselect)) self.assertEqual([foo], list(result2)) def test_find_get_select_expr_with_set_expression(self): """ A L{FeatureError} is raised if L{ResultSet.get_select_expr} is used with a L{ResultSet} that represents a set expression, such as a union. """ result1 = self.store.find(Foo, Foo.id == 10) result2 = self.store.find(Foo, Foo.id == 20) result3 = result1.union(result2) self.assertRaises(FeatureError, result3.get_select_expr, Foo.id) def test_find_values(self): values = self.store.find(Foo).order_by(Foo.id).values(Foo.id) self.assertEqual(list(values), [10, 20, 30]) values = self.store.find(Foo).order_by(Foo.id).values(Foo.title) values = list(values) self.assertEqual(values, ["Title 30", "Title 20", "Title 10"]) self.assertEqual([type(value) for value in values], [str, str, str]) def test_find_multiple_values(self): result = self.store.find(Foo).order_by(Foo.id) values = result.values(Foo.id, Foo.title) self.assertEqual(list(values), [(10, "Title 30"), (20, "Title 20"), (30, "Title 10")]) def test_find_values_with_no_arguments(self): result = self.store.find(Foo).order_by(Foo.id) self.assertRaises(FeatureError, next, result.values()) def test_find_slice_values(self): values = self.store.find(Foo).order_by(Foo.id)[1:2].values(Foo.id) self.assertEqual(list(values), [20]) def test_find_values_with_set_expression(self): """ A L{FeatureError} is raised if L{ResultSet.values} is used with a L{ResultSet} that represents a set expression, such as a union. """ result1 = self.store.find(Foo, Foo.id == 10) result2 = self.store.find(Foo, Foo.id == 20) result3 = result1.union(result2) self.assertRaises(FeatureError, list, result3.values(Foo.id)) def test_find_remove(self): self.store.find(Foo, Foo.id == 20).remove() self.assertEqual(self.get_items(), [ (10, "Title 30"), (30, "Title 10"), ]) def test_find_cached(self): foo = self.store.get(Foo, 20) bar = self.store.get(Bar, 200) self.assertTrue(foo) self.assertTrue(bar) self.assertEqual(self.store.find(Foo).cached(), [foo]) def test_find_cached_where(self): foo1 = self.store.get(Foo, 10) foo2 = self.store.get(Foo, 20) bar = self.store.get(Bar, 200) self.assertTrue(foo1) self.assertTrue(foo2) self.assertTrue(bar) self.assertEqual(self.store.find(Foo, title="Title 20").cached(), [foo2]) def test_find_cached_invalidated(self): foo = self.store.get(Foo, 20) self.store.invalidate(foo) self.assertEqual(self.store.find(Foo).cached(), [foo]) def test_find_cached_invalidated_and_deleted(self): foo = self.store.get(Foo, 20) self.store.execute("DELETE FROM foo WHERE id=20") self.store.invalidate(foo) # Do not look for the primary key (id), since it's able to get # it without touching the database. Use the title instead. self.assertEqual(self.store.find(Foo, title="Title 20").cached(), []) def test_find_cached_with_info_alive_and_object_dead(self): # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) foo = self.store.get(Foo, 20) foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() cached = self.store.find(Foo).cached() self.assertEqual(len(cached), 1) foo = self.store.get(Foo, 20) self.assertFalse(hasattr(foo, "tainted")) def test_using_find_join(self): bar = self.store.get(Bar, 100) bar.foo_id = None tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id)) result = tables.find(Bar).order_by(Foo.id, Bar.id) lst = [bar and (bar.id, bar.title) for bar in result] self.assertEqual(lst, [ None, (200, "Title 200"), (300, "Title 100"), ]) def test_using_find_with_strings(self): foo = self.store.using("foo").find(Foo, id=10).one() self.assertEqual(foo.title, "Title 30") foo = self.store.using("foo", "bar").find(Foo, id=10).any() self.assertEqual(foo.title, "Title 30") def test_using_find_join_with_strings(self): bar = self.store.get(Bar, 100) bar.foo_id = None tables = self.store.using(LeftJoin("foo", "bar", "bar.foo_id = foo.id")) result = tables.find(Bar).order_by(Foo.id, Bar.id) lst = [bar and (bar.id, bar.title) for bar in result] self.assertEqual(lst, [ None, (200, "Title 200"), (300, "Title 100"), ]) def test_find_tuple(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) result = result.order_by(Foo.id) lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title)) for (foo, bar) in result] self.assertEqual(lst, [ ((10, "Title 30"), (100, "Title 300")), ((30, "Title 10"), (300, "Title 100")), ]) def test_find_tuple_using(self): bar = self.store.get(Bar, 200) bar.foo_id = None tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id)) result = tables.find((Foo, Bar)).order_by(Foo.id) lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title)) for (foo, bar) in result] self.assertEqual(lst, [ ((10, "Title 30"), (100, "Title 300")), ((20, "Title 20"), None), ((30, "Title 10"), (300, "Title 100")), ]) def test_find_tuple_using_with_disallow_none(self): class Bar: __storm_table__ = "bar" id = Int(primary=True, allow_none=False) title = Unicode() foo_id = Int() foo = Reference(foo_id, Foo.id) bar = self.store.get(Bar, 200) self.store.remove(bar) tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id)) result = tables.find((Foo, Bar)).order_by(Foo.id) lst = [(foo and (foo.id, foo.title), bar and (bar.id, bar.title)) for (foo, bar) in result] self.assertEqual(lst, [ ((10, "Title 30"), (100, "Title 300")), ((20, "Title 20"), None), ((30, "Title 10"), (300, "Title 100")), ]) def test_find_tuple_using_skip_when_none(self): bar = self.store.get(Bar, 200) bar.foo_id = None tables = self.store.using(Foo, LeftJoin(Bar, Bar.foo_id == Foo.id), LeftJoin(Link, Link.bar_id == Bar.id)) result = tables.find((Bar, Link)).order_by(Foo.id, Bar.id, Link.foo_id) lst = [(bar and (bar.id, bar.title), link and (link.bar_id, link.foo_id)) for (bar, link) in result] self.assertEqual(lst, [ ((100, "Title 300"), (100, 10)), ((100, "Title 300"), (100, 20)), (None, None), ((300, "Title 100"), (300, 10)), ((300, "Title 100"), (300, 30)), ]) def test_find_tuple_contains(self): foo = self.store.get(Foo, 10) bar = self.store.get(Bar, 100) bar200 = self.store.get(Bar, 200) result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) self.assertEqual((foo, bar) in result, True) self.assertEqual((foo, bar200) in result, False) def test_find_tuple_contains_with_set_expression(self): foo = self.store.get(Foo, 10) bar = self.store.get(Bar, 100) bar200 = self.store.get(Bar, 200) result1 = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) result2 = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) self.assertEqual((foo, bar) in result1.union(result2), True) if self.__class__.__name__.startswith("MySQL"): return self.assertEqual((foo, bar) in result1.intersection(result2), True) self.assertEqual((foo, bar) in result1.difference(result2), False) def test_find_tuple_any(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) foo, bar = result.order_by(Foo.id).any() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") self.assertEqual(bar.id, 100) self.assertEqual(bar.title, "Title 300") def test_find_tuple_first(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) foo, bar = result.order_by(Foo.id).first() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") self.assertEqual(bar.id, 100) self.assertEqual(bar.title, "Title 300") def test_find_tuple_last(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) foo, bar = result.order_by(Foo.id).last() self.assertEqual(foo.id, 30) self.assertEqual(foo.title, "Title 10") self.assertEqual(bar.id, 300) self.assertEqual(bar.title, "Title 100") def test_find_tuple_one(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id, Foo.id == 10) foo, bar = result.order_by(Foo.id).one() self.assertEqual(foo.id, 10) self.assertEqual(foo.title, "Title 30") self.assertEqual(bar.id, 100) self.assertEqual(bar.title, "Title 300") def test_find_tuple_count(self): bar = self.store.get(Bar, 200) bar.foo_id = None result = self.store.find((Foo, Bar), Bar.foo_id == Foo.id) self.assertEqual(result.count(), 2) def test_find_tuple_remove(self): result = self.store.find((Foo, Bar)) self.assertRaises(FeatureError, result.remove) def test_find_tuple_set(self): result = self.store.find((Foo, Bar)) self.assertRaises(FeatureError, result.set, title="Title 40") def test_find_tuple_kwargs(self): self.assertRaises(FeatureError, self.store.find, (Foo, Bar), title="Title 10") def test_find_tuple_cached(self): result = self.store.find((Foo, Bar)) self.assertRaises(FeatureError, result.cached) def test_find_using_cached(self): result = self.store.using(Foo, Bar).find(Foo) self.assertRaises(FeatureError, result.cached) def test_find_with_expr(self): result = self.store.find(Foo.title) self.assertEqual(sorted(result), ["Title 10", "Title 20", "Title 30"]) def test_find_with_expr_uses_variable_set(self): result = self.store.find(FooVariable.title, FooVariable.id == 10) self.assertEqual(list(result), ["to_py(from_db(Title 30))"]) def test_find_tuple_with_expr(self): result = self.store.find((Foo, Bar.id, Bar.title), Bar.foo_id == Foo.id) result.order_by(Foo.id) self.assertEqual([(foo.id, foo.title, bar_id, bar_title) for foo, bar_id, bar_title in result], [(10, "Title 30", 100, "Title 300"), (20, "Title 20", 200, "Title 200"), (30, "Title 10", 300, "Title 100")]) def test_find_using_with_expr(self): result = self.store.using(Foo).find(Foo.title) self.assertEqual(sorted(result), ["Title 10", "Title 20", "Title 30"]) def test_find_with_expr_contains(self): result = self.store.find(Foo.title) self.assertEqual("Title 10" in result, True) self.assertEqual("Title 42" in result, False) def test_find_tuple_with_expr_contains(self): foo = self.store.get(Foo, 10) result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertEqual((foo, "Title 300") in result, True) self.assertEqual((foo, "Title 100") in result, False) def test_find_with_expr_contains_with_set_expression(self): result1 = self.store.find(Foo.title) result2 = self.store.find(Foo.title) self.assertEqual("Title 10" in result1.union(result2), True) if self.__class__.__name__.startswith("MySQL"): return self.assertEqual("Title 10" in result1.intersection(result2), True) self.assertEqual("Title 10" in result1.difference(result2), False) def test_find_with_expr_remove_unsupported(self): result = self.store.find(Foo.title) self.assertRaises(FeatureError, result.remove) def test_find_tuple_with_expr_remove_unsupported(self): result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertRaises(FeatureError, result.remove) def test_find_with_expr_count(self): result = self.store.find(Foo.title) self.assertEqual(result.count(), 3) def test_find_tuple_with_expr_count(self): result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertEqual(result.count(), 3) def test_find_with_expr_values(self): result = self.store.find(Foo.title) self.assertEqual(sorted(result.values(Foo.title)), ["Title 10", "Title 20", "Title 30"]) def test_find_tuple_with_expr_values(self): result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertEqual(sorted(result.values(Foo.title)), ["Title 10", "Title 20", "Title 30"]) def test_find_with_expr_set_unsupported(self): result = self.store.find(Foo.title) self.assertRaises(FeatureError, result.set) def test_find_tuple_with_expr_set_unsupported(self): result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertRaises(FeatureError, result.set) def test_find_with_expr_cached_unsupported(self): result = self.store.find(Foo.title) self.assertRaises(FeatureError, result.cached) def test_find_tuple_with_expr_cached_unsupported(self): result = self.store.find((Foo, Bar.title), Bar.foo_id == Foo.id) self.assertRaises(FeatureError, result.cached) def test_find_with_expr_union(self): result1 = self.store.find(Foo.title, Foo.id == 10) result2 = self.store.find(Foo.title, Foo.id != 10) result = result1.union(result2) self.assertEqual(sorted(result), ["Title 10", "Title 20", "Title 30",]) def test_find_with_expr_union_mismatch(self): result1 = self.store.find(Foo.title) result2 = self.store.find(Bar.foo_id) self.assertRaises(FeatureError, result1.union, result2) def test_find_tuple_with_expr_union(self): result1 = self.store.find( (Foo, Bar.title), Bar.foo_id == Foo.id, Bar.title == "Title 100") result2 = self.store.find( (Foo, Bar.title), Bar.foo_id == Foo.id, Bar.title == "Title 200") result = result1.union(result2) self.assertEqual(sorted((foo.id, title) for (foo, title) in result), [(20, "Title 200"), (30, "Title 100")]) def test_get_does_not_validate(self): def validator(object, attr, value): self.fail("validator called with arguments (%r, %r, %r)" % (object, attr, value)) class Foo: __storm_table__ = "foo" id = Int(primary=True) title = Unicode(validator=validator) foo = self.store.get(Foo, 10) self.assertEqual(foo.title, "Title 30") def test_get_does_not_validate_default_value(self): def validator(object, attr, value): self.fail("validator called with arguments (%r, %r, %r)" % (object, attr, value)) class Foo: __storm_table__ = "foo" id = Int(primary=True) title = Unicode(validator=validator, default="default value") foo = self.store.get(Foo, 10) self.assertEqual(foo.title, "Title 30") def test_find_does_not_validate(self): def validator(object, attr, value): self.fail("validator called with arguments (%r, %r, %r)" % (object, attr, value)) class Foo: __storm_table__ = "foo" id = Int(primary=True) title = Unicode(validator=validator) foo = self.store.find(Foo, Foo.id == 10).one() self.assertEqual(foo.title, "Title 30") def test_find_group_by(self): result = self.store.find((Count(FooValue.id), Sum(FooValue.value1))) result.group_by(FooValue.value2) result.order_by(Count(FooValue.id), Sum(FooValue.value1)) result = list(result) self.assertEqual(result, [(2, 2), (2, 2), (2, 3), (3, 6)]) def test_find_group_by_table(self): result = self.store.find( (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id) result.group_by(Foo) foo1 = self.store.get(Foo, 10) foo2 = self.store.get(Foo, 20) self.assertEqual(list(result), [(5, foo1), (16, foo2)]) def test_find_group_by_table_contains(self): result = self.store.find( (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id) result.group_by(Foo) foo1 = self.store.get(Foo, 10) self.assertEqual((5, foo1) in result, True) def test_find_group_by_multiple_tables(self): result = self.store.find( Sum(FooValue.value2), Foo.id == FooValue.foo_id) result.group_by(Foo.id) result.order_by(Sum(FooValue.value2)) result = list(result) self.assertEqual(result, [5, 16]) result = self.store.find( (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id) result.group_by(Foo) result.order_by(Sum(FooValue.value2)) result = list(result) foo1 = self.store.get(Foo, 10) foo2 = self.store.get(Foo, 20) self.assertEqual(result, [(5, foo1), (16, foo2)]) result = self.store.find( (Foo.id, Sum(FooValue.value2), Avg(FooValue.value1)), Foo.id == FooValue.foo_id) result.group_by(Foo.id) result.order_by(Foo.id) result = list(result) self.assertEqual(result, [(10, 5, 2), (20, 16, 1)]) def test_find_group_by_having(self): result = self.store.find( Sum(FooValue.value2), Foo.id == FooValue.foo_id) result.group_by(Foo.id) result.having(Sum(FooValue.value2) == 5) self.assertEqual(list(result), [5]) result = self.store.find( Sum(FooValue.value2), Foo.id == FooValue.foo_id) result.group_by(Foo.id) result.having(Count() == 5) self.assertEqual(list(result), [16]) def test_find_having_without_group_by(self): result = self.store.find(FooValue) self.assertRaises(FeatureError, result.having, FooValue.value1 == 1) def test_find_group_by_multiple_having(self): result = self.store.find((Count(), FooValue.value2)) result.group_by(FooValue.value2) result.having(Count() == 2, FooValue.value2 >= 3) result.order_by(Count(), FooValue.value2) list_result = list(result) self.assertEqual(list_result, [(2, 3), (2, 4)]) def test_find_successive_group_by(self): result = self.store.find(Count()) result.group_by(FooValue.value2) result.order_by(Count()) list_result = list(result) self.assertEqual(list_result, [2, 2, 2, 3]) result.group_by(FooValue.value1) list_result = list(result) self.assertEqual(list_result, [4, 5]) def test_find_multiple_group_by(self): result = self.store.find(Count()) result.group_by(FooValue.value2, FooValue.value1) result.order_by(Count()) list_result = list(result) self.assertEqual(list_result, [1, 1, 2, 2, 3]) def test_find_multiple_group_by_with_having(self): result = self.store.find((Count(), FooValue.value2)) result.group_by(FooValue.value2, FooValue.value1).having(Count() == 2) result.order_by(Count(), FooValue.value2) list_result = list(result) self.assertEqual(list_result, [(2, 3), (2, 4)]) def test_find_group_by_avg(self): result = self.store.find((Count(FooValue.id), Sum(FooValue.value1))) result.group_by(FooValue.value2) self.assertRaises(FeatureError, result.avg, FooValue.value2) def test_find_group_by_values(self): result = self.store.find( (Sum(FooValue.value2), Foo), Foo.id == FooValue.foo_id) result.group_by(Foo) result.order_by(Foo.title) result = list(result.values(Foo.title)) self.assertEqual(result, ['Title 20', 'Title 30']) def test_find_group_by_union(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=10) result3 = result1.union(result2) self.assertRaises(FeatureError, result3.group_by, Foo.title) def test_find_group_by_remove(self): result = self.store.find((Count(FooValue.id), Sum(FooValue.value1))) result.group_by(FooValue.value2) self.assertRaises(FeatureError, result.remove) def test_find_group_by_set(self): result = self.store.find((Count(FooValue.id), Sum(FooValue.value1))) result.group_by(FooValue.value2) self.assertRaises(FeatureError, result.set, FooValue.value1 == 1) def test_add_commit(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), (40, "Title 40"), ]) def test_add_rollback_commit(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.store.rollback() self.assertEqual(self.store.get(Foo, 3), None) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_add_get(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) old_foo = foo foo = self.store.get(Foo, 40) self.assertEqual(foo.id, 40) self.assertEqual(foo.title, "Title 40") self.assertTrue(foo is old_foo) def test_add_find(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) old_foo = foo foo = self.store.find(Foo, Foo.id == 40).one() self.assertEqual(foo.id, 40) self.assertEqual(foo.title, "Title 40") self.assertTrue(foo is old_foo) def test_add_twice(self): foo = Foo() self.store.add(foo) self.store.add(foo) self.assertEqual(Store.of(foo), self.store) def test_add_loaded(self): foo = self.store.get(Foo, 10) self.store.add(foo) self.assertEqual(Store.of(foo), self.store) def test_add_twice_to_wrong_store(self): foo = Foo() self.store.add(foo) self.assertRaises(WrongStoreError, Store(self.database).add, foo) def test_add_checkpoints(self): bar = Bar() self.store.add(bar) bar.id = 400 bar.title = "Title 400" bar.foo_id = 40 self.store.flush() self.store.execute("UPDATE bar SET title='Title 500' " "WHERE id=400") bar.foo_id = 400 # When not checkpointing, this flush will set title again. self.store.flush() self.store.reload(bar) self.assertEqual(bar.title, "Title 500") def test_add_completely_undefined(self): foo = Foo() self.store.add(foo) self.store.flush() self.assertEqual(type(foo.id), int) self.assertEqual(foo.title, "Default Title") def test_add_uuid(self): unique_id = self.store.add(UniqueID(uuid4())) self.assertEqual(unique_id, self.store.find(UniqueID).one()) def test_remove_commit(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.assertEqual(Store.of(foo), self.store) self.store.flush() self.assertEqual(Store.of(foo), None) self.assertEqual(self.get_items(), [ (10, "Title 30"), (30, "Title 10"), ]) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.commit() self.assertEqual(Store.of(foo), None) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (30, "Title 10"), ]) def test_remove_rollback_update(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.rollback() foo.title = "Title 200" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_remove_flush_rollback_update(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.flush() self.store.rollback() foo.title = "Title 200" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_remove_add_update(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.add(foo) foo.title = "Title 200" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_remove_flush_add_update(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.flush() self.store.add(foo) foo.title = "Title 200" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_remove_twice(self): foo = self.store.get(Foo, 10) self.store.remove(foo) self.store.remove(foo) def test_remove_unknown(self): foo = Foo() self.assertRaises(WrongStoreError, self.store.remove, foo) def test_remove_from_wrong_store(self): foo = self.store.get(Foo, 20) self.assertRaises(WrongStoreError, Store(self.database).remove, foo) def test_wb_remove_flush_update_isnt_dirty(self): foo = self.store.get(Foo, 20) obj_info = get_obj_info(foo) self.store.remove(foo) self.store.flush() foo.title = "Title 200" self.assertTrue(obj_info not in self.store._dirty) def test_wb_remove_rollback_isnt_dirty(self): foo = self.store.get(Foo, 20) obj_info = get_obj_info(foo) self.store.remove(foo) self.store.rollback() self.assertTrue(obj_info not in self.store._dirty) def test_wb_remove_flush_rollback_isnt_dirty(self): foo = self.store.get(Foo, 20) obj_info = get_obj_info(foo) self.store.remove(foo) self.store.flush() self.store.rollback() self.assertTrue(obj_info not in self.store._dirty) def test_add_rollback_not_in_store(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.store.rollback() self.assertEqual(Store.of(foo), None) def test_update_flush_commit(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_update_flush_reload_rollback(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" self.store.flush() self.store.reload(foo) self.store.rollback() self.assertEqual(foo.title, "Title 20") def test_update_commit(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_update_commit_twice(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" self.store.commit() foo.title = "Title 2000" self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (20, "Title 2000"), (30, "Title 10"), ]) def test_update_checkpoints(self): bar = self.store.get(Bar, 200) bar.title = "Title 400" self.store.flush() self.store.execute("UPDATE bar SET title='Title 500' " "WHERE id=200") bar.foo_id = 40 # When not checkpointing, this flush will set title again. self.store.flush() self.store.reload(bar) self.assertEqual(bar.title, "Title 500") def test_update_primary_key(self): foo = self.store.get(Foo, 20) foo.id = 25 self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (25, "Title 20"), (30, "Title 10"), ]) # Update twice to see if the notion of primary key for the # existent object was updated as well. foo.id = 27 self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 30"), (27, "Title 20"), (30, "Title 10"), ]) # Ensure only the right ones are there. self.assertTrue(self.store.get(Foo, 27) is foo) self.assertTrue(self.store.get(Foo, 25) is None) self.assertTrue(self.store.get(Foo, 20) is None) def test_update_primary_key_exchange(self): foo1 = self.store.get(Foo, 10) foo2 = self.store.get(Foo, 30) foo1.id = 40 self.store.flush() foo2.id = 10 self.store.flush() foo1.id = 30 self.assertTrue(self.store.get(Foo, 30) is foo1) self.assertTrue(self.store.get(Foo, 10) is foo2) self.store.commit() self.assertEqual(self.get_committed_items(), [ (10, "Title 10"), (20, "Title 20"), (30, "Title 30"), ]) def test_wb_update_not_dirty_after_flush(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" self.store.flush() # If changes get committed even with the notification disabled, # it means the dirty flag isn't being cleared. self.store._disable_change_notification(get_obj_info(foo)) foo.title = "Title 2000" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) def test_update_find(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" result = self.store.find(Foo, Foo.title == "Title 200") self.assertTrue(result.one() is foo) def test_update_get(self): foo = self.store.get(Foo, 20) foo.id = 200 self.assertTrue(self.store.get(Foo, 200) is foo) def test_add_update(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) foo.title = "Title 400" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), (40, "Title 400"), ]) def test_add_remove_add(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.store.remove(foo) self.assertEqual(Store.of(foo), None) foo.title = "Title 400" self.store.add(foo) foo.id = 400 self.store.commit() self.assertEqual(Store.of(foo), self.store) self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), (400, "Title 400"), ]) self.assertTrue(self.store.get(Foo, 400) is foo) def test_wb_add_remove_add(self): foo = Foo() obj_info = get_obj_info(foo) self.store.add(foo) self.assertTrue(obj_info in self.store._dirty) self.store.remove(foo) self.assertTrue(obj_info not in self.store._dirty) self.store.add(foo) self.assertTrue(obj_info in self.store._dirty) self.assertTrue(Store.of(foo) is self.store) def test_wb_update_remove_add(self): foo = self.store.get(Foo, 20) foo.title = "Title 200" obj_info = get_obj_info(foo) self.store.remove(foo) self.store.add(foo) self.assertTrue(obj_info in self.store._dirty) def test_commit_autoreloads(self): foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "Title 20") self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "Title 20") self.store.commit() self.assertEqual(foo.title, "New Title") def test_commit_invalidates(self): foo = self.store.get(Foo, 20) self.assertTrue(foo) self.store.execute("DELETE FROM foo WHERE id=20") self.assertEqual(self.store.get(Foo, 20), foo) self.store.commit() self.assertEqual(self.store.get(Foo, 20), None) def test_rollback_autoreloads(self): foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "Title 20") self.store.rollback() self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "New Title") def test_rollback_invalidates(self): foo = self.store.get(Foo, 20) self.assertTrue(foo) self.assertEqual(self.store.get(Foo, 20), foo) self.store.rollback() self.store.execute("DELETE FROM foo WHERE id=20") self.assertEqual(self.store.get(Foo, 20), None) def test_sub_class(self): class SubFoo(Foo): id = Float(primary=True) foo1 = self.store.get(Foo, 20) foo2 = self.store.get(SubFoo, 20) self.assertEqual(foo1.id, 20) self.assertEqual(foo2.id, 20) self.assertEqual(type(foo1.id), int) self.assertEqual(type(foo2.id), float) def test_join(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) title = Unicode() bar = Bar() bar.id = 40 bar.title = "Title 20" self.store.add(bar) # Add anbar object with the same title to ensure DISTINCT # is in place. bar = Bar() bar.id = 400 bar.title = "Title 20" self.store.add(bar) result = self.store.find(Foo, Foo.title == Bar.title) self.assertEqual([(foo.id, foo.title) for foo in result], [ (20, "Title 20"), (20, "Title 20"), ]) def test_join_distinct(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) title = Unicode() bar = Bar() bar.id = 40 bar.title = "Title 20" self.store.add(bar) # Add a bar object with the same title to ensure DISTINCT # is in place. bar = Bar() bar.id = 400 bar.title = "Title 20" self.store.add(bar) result = self.store.find(Foo, Foo.title == Bar.title) result.config(distinct=True) # Make sure that it won't unset it, and that it's returning itself. config = result.config() self.assertEqual([(foo.id, foo.title) for foo in result], [ (20, "Title 20"), ]) def test_sub_select(self): foo = self.store.find(Foo, Foo.id == Select(SQL("20"))).one() self.assertTrue(foo) self.assertEqual(foo.id, 20) self.assertEqual(foo.title, "Title 20") def test_cache_has_improper_object(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.commit() self.store.execute("INSERT INTO foo VALUES (20, 'Title 20')") self.assertTrue(self.store.get(Foo, 20) is not foo) def test_cache_has_improper_object_readded(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.flush() old_foo = foo # Keep a reference. foo = Foo() foo.id = 20 foo.title = "Readded" self.store.add(foo) self.store.commit() self.assertTrue(self.store.get(Foo, 20) is foo) def test_loaded_hook(self): loaded = [] class MyFoo(Foo): def __init__(self): loaded.append("NO!") def __storm_loaded__(self): loaded.append((self.id, self.title)) self.title = "Title 200" self.some_attribute = 1 foo = self.store.get(MyFoo, 20) self.assertEqual(loaded, [(20, "Title 20")]) self.assertEqual(foo.title, "Title 200") self.assertEqual(foo.some_attribute, 1) foo.some_attribute = 2 self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 200"), (30, "Title 10"), ]) self.store.rollback() self.assertEqual(foo.title, "Title 20") self.assertEqual(foo.some_attribute, 2) def test_flush_hook(self): class MyFoo(Foo): counter = 0 def __storm_pre_flush__(self): if self.counter == 0: self.title = "Flushing: %s" % self.title self.counter += 1 foo = self.store.get(MyFoo, 20) self.assertEqual(foo.title, "Title 20") self.store.flush() self.assertEqual(foo.title, "Title 20") # It wasn't dirty. foo.title = "Something" self.store.flush() self.assertEqual(foo.title, "Flushing: Something") # It got in the database, because it was flushed *twice* (the # title was changed after flushed, and thus the object got dirty # again). self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Flushing: Something"), (30, "Title 10"), ]) # This shouldn't do anything, because the object is clean again. foo.counter = 0 self.store.flush() self.assertEqual(foo.title, "Flushing: Something") def test_flush_hook_all(self): class MyFoo(Foo): def __storm_pre_flush__(self): other = [foo1, foo2][foo1 is self] other.title = "Changed in hook: " + other.title foo1 = self.store.get(MyFoo, 10) foo2 = self.store.get(MyFoo, 20) foo1.title = "Changed" self.store.flush() self.assertEqual(foo1.title, "Changed in hook: Changed") self.assertEqual(foo2.title, "Changed in hook: Title 20") def test_flushed_hook(self): class MyFoo(Foo): done = False def __storm_flushed__(self): if not self.done: self.done = True self.title = "Flushed: %s" % self.title foo = self.store.get(MyFoo, 20) self.assertEqual(foo.title, "Title 20") self.store.flush() self.assertEqual(foo.title, "Title 20") # It wasn't dirty. foo.title = "Something" self.store.flush() self.assertEqual(foo.title, "Flushed: Something") # It got in the database, because it was flushed *twice* (the # title was changed after flushed, and thus the object got dirty # again). self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Flushed: Something"), (30, "Title 10"), ]) # This shouldn't do anything, because the object is clean again. foo.done = False self.store.flush() self.assertEqual(foo.title, "Flushed: Something") def test_retrieve_default_primary_key(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) self.store.flush() self.assertNotEqual(foo.id, None) self.assertTrue(self.store.get(Foo, foo.id) is foo) def test_retrieve_default_value(self): foo = Foo() foo.id = 40 self.store.add(foo) self.store.flush() self.assertEqual(foo.title, "Default Title") def test_retrieve_null_when_no_default(self): bar = Bar() bar.id = 400 self.store.add(bar) self.store.flush() self.assertEqual(bar.title, None) def test_wb_remove_prop_not_dirty(self): foo = self.store.get(Foo, 20) obj_info = get_obj_info(foo) del foo.title self.assertTrue(obj_info not in self.store._dirty) def test_flush_with_removed_prop(self): foo = self.store.get(Foo, 20) del foo.title self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_flush_with_removed_prop_forced_dirty(self): foo = self.store.get(Foo, 20) del foo.title foo.id = 40 foo.id = 20 self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_flush_with_removed_prop_really_dirty(self): foo = self.store.get(Foo, 20) del foo.title foo.id = 25 self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (25, "Title 20"), (30, "Title 10"), ]) def test_wb_block_implicit_flushes(self): # Make sure calling store.flush() will fail. def flush(): raise RuntimeError("Flush called") self.store.flush = flush # The following operations do not call flush. self.store.block_implicit_flushes() foo = self.store.get(Foo, 20) foo = self.store.find(Foo, Foo.id == 20).one() self.store.execute("SELECT title FROM foo WHERE id = 20") self.store.unblock_implicit_flushes() self.assertRaises(RuntimeError, self.store.get, Foo, 20) def test_wb_block_implicit_flushes_is_recursive(self): # Make sure calling store.flush() will fail. def flush(): raise RuntimeError("Flush called") self.store.flush = flush self.store.block_implicit_flushes() self.store.block_implicit_flushes() self.store.unblock_implicit_flushes() # implicit flushes are still blocked, until unblock() is called again. foo = self.store.get(Foo, 20) self.store.unblock_implicit_flushes() self.assertRaises(RuntimeError, self.store.get, Foo, 20) def test_block_access(self): """Access to the store is blocked by block_access().""" # The set_blocked() method blocks access to the connection. self.store.block_access() self.assertRaises(ConnectionBlockedError, self.store.execute, "SELECT 1") self.assertRaises(ConnectionBlockedError, self.store.commit) # The rollback method is not blocked. self.store.rollback() self.store.unblock_access() self.store.execute("SELECT 1") def test_reload(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='Title 40' WHERE id=20") self.assertEqual(foo.title, "Title 20") self.store.reload(foo) self.assertEqual(foo.title, "Title 40") def test_reload_not_changed(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='Title 40' WHERE id=20") self.store.reload(foo) for variable in get_obj_info(foo).variables.values(): self.assertFalse(variable.has_changed()) def test_reload_new(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.assertRaises(WrongStoreError, self.store.reload, foo) def test_reload_new_unflushed(self): foo = Foo() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.assertRaises(NotFlushedError, self.store.reload, foo) def test_reload_removed(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.flush() self.assertRaises(WrongStoreError, self.store.reload, foo) def test_reload_unknown(self): foo = self.store.get(Foo, 20) store = self.create_store() self.assertRaises(WrongStoreError, store.reload, foo) def test_wb_reload_not_dirty(self): foo = self.store.get(Foo, 20) obj_info = get_obj_info(foo) foo.title = "Title 40" self.store.reload(foo) self.assertTrue(obj_info not in self.store._dirty) def test_find_set_empty(self): self.store.find(Foo, title="Title 20").set() foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "Title 20") def test_find_set(self): self.store.find(Foo, title="Title 20").set(title="Title 40") foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "Title 40") def test_find_set_with_func_expr(self): self.store.find(Foo, title="Title 20").set(title=Lower("Title 40")) foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "title 40") def test_find_set_equality_with_func_expr(self): self.store.find(Foo, title="Title 20").set( Foo.title == Lower("Title 40")) foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "title 40") def test_find_set_column(self): self.store.find(Bar, title="Title 200").set(foo_id=Bar.id) bar = self.store.get(Bar, 200) self.assertEqual(bar.foo_id, 200) def test_find_set_expr(self): self.store.find(Foo, title="Title 20").set(Foo.title == "Title 40") foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "Title 40") def test_find_set_none(self): self.store.find(Foo, title="Title 20").set(title=None) foo = self.store.get(Foo, 20) self.assertEqual(foo.title, None) def test_find_set_expr_column(self): self.store.find(Bar, id=200).set(Bar.foo_id == Bar.id) bar = self.store.get(Bar, 200) self.assertEqual(bar.id, 200) self.assertEqual(bar.foo_id, 200) def test_find_set_on_cached(self): foo1 = self.store.get(Foo, 20) foo2 = self.store.get(Foo, 30) self.store.find(Foo, id=20).set(id=40) self.assertEqual(foo1.id, 40) self.assertEqual(foo2.id, 30) def test_find_set_expr_on_cached(self): bar = self.store.get(Bar, 200) self.store.find(Bar, id=200).set(Bar.foo_id == Bar.id) self.assertEqual(bar.id, 200) self.assertEqual(bar.foo_id, 200) def test_find_set_none_on_cached(self): foo = self.store.get(Foo, 20) self.store.find(Foo, title="Title 20").set(title=None) self.assertEqual(foo.title, None) def test_find_set_on_cached_but_removed(self): foo1 = self.store.get(Foo, 20) foo2 = self.store.get(Foo, 30) self.store.remove(foo1) self.store.find(Foo, id=20).set(id=40) self.assertEqual(foo1.id, 20) self.assertEqual(foo2.id, 30) def test_find_set_on_cached_unsupported_python_expr(self): foo1 = self.store.get(Foo, 20) foo2 = self.store.get(Foo, 30) self.store.find( Foo, Foo.id == Select(SQL("20"))).set(title="Title 40") self.assertEqual(foo1.title, "Title 40") self.assertEqual(foo2.title, "Title 10") def test_find_set_expr_unsupported(self): result = self.store.find(Foo, title="Title 20") self.assertRaises(FeatureError, result.set, Foo.title > "Title 40") def test_find_set_expr_unsupported_without_column(self): result = self.store.find(Foo, title="Title 20") self.assertRaises(FeatureError, result.set, Eq(object(), IntVariable(1))) def test_find_set_expr_unsupported_without_expr_or_variable(self): result = self.store.find(Foo, title="Title 20") self.assertRaises(FeatureError, result.set, Eq(Foo.id, object())) def test_find_set_expr_unsupported_autoreloads(self): bar1 = self.store.get(Bar, 200) bar2 = self.store.get(Bar, 300) self.store.find(Bar, id=Select(SQL("200"))).set(title="Title 400") bar1_vars = get_obj_info(bar1).variables bar2_vars = get_obj_info(bar2).variables self.assertEqual(bar1_vars[Bar.title].get_lazy(), AutoReload) self.assertEqual(bar2_vars[Bar.title].get_lazy(), AutoReload) self.assertEqual(bar1_vars[Bar.foo_id].get_lazy(), None) self.assertEqual(bar2_vars[Bar.foo_id].get_lazy(), None) self.assertEqual(bar1.title, "Title 400") self.assertEqual(bar2.title, "Title 100") def test_find_set_expr_unsupported_mixed_autoreloads(self): # For an expression that does not compile (eg: # ResultSet.cached() raises a CompileError), while setting # cached entries' columns to AutoReload, if objects of # different types could be found in the cache then a KeyError # would happen if some object did not have a matching # column. See Bug #328603 for more info. foo1 = self.store.get(Foo, 20) bar1 = self.store.get(Bar, 200) self.store.find(Bar, id=Select(SQL("200"))).set(title="Title 400") foo1_vars = get_obj_info(foo1).variables bar1_vars = get_obj_info(bar1).variables self.assertNotEqual(foo1_vars[Foo.title].get_lazy(), AutoReload) self.assertEqual(bar1_vars[Bar.title].get_lazy(), AutoReload) self.assertEqual(bar1_vars[Bar.foo_id].get_lazy(), None) self.assertEqual(foo1.title, "Title 20") self.assertEqual(bar1.title, "Title 400") def test_find_set_autoreloads_with_func_expr(self): # In the process of fixing this bug, we've temporarily # introduced another bug: the expression would be called # twice. We've used an expression that increments the value by # one here to see if that case is triggered. In the buggy # bugfix, the value would end up being incremented by two due # to misfiring two updates. foo1 = self.store.get(FooValue, 1) self.assertEqual(foo1.value1, 2) self.store.find(FooValue, id=1).set(value1=SQL("value1 + 1")) foo1_vars = get_obj_info(foo1).variables self.assertEqual(foo1_vars[FooValue.value1].get_lazy(), AutoReload) self.assertEqual(foo1.value1, 3) def test_find_set_equality_autoreloads_with_func_expr(self): foo1 = self.store.get(FooValue, 1) self.assertEqual(foo1.value1, 2) self.store.find(FooValue, id=1).set( FooValue.value1 == SQL("value1 + 1")) foo1_vars = get_obj_info(foo1).variables self.assertEqual(foo1_vars[FooValue.value1].get_lazy(), AutoReload) self.assertEqual(foo1.value1, 3) def test_wb_find_set_checkpoints(self): bar = self.store.get(Bar, 200) self.store.find(Bar, id=200).set(title="Title 400") self.store._connection.execute("UPDATE bar SET " "title='Title 500' " "WHERE id=200") # When not checkpointing, this flush will set title again. self.store.flush() self.store.reload(bar) self.assertEqual(bar.title, "Title 500") def test_find_set_with_info_alive_and_object_dead(self): # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) foo = self.store.get(Foo, 20) foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() self.store.find(Foo, title="Title 20").set(title="Title 40") foo = self.store.get(Foo, 20) self.assertFalse(hasattr(foo, "tainted")) self.assertEqual(foo.title, "Title 40") def test_reference(self): bar = self.store.get(Bar, 100) self.assertTrue(bar.foo) self.assertEqual(bar.foo.title, "Title 30") def test_reference_explicitly_with_wrapper(self): bar = self.store.get(Bar, 100) foo = Bar.foo.__get__(Wrapper(bar)) self.assertTrue(foo) self.assertEqual(foo.title, "Title 30") def test_reference_break_on_local_diverged(self): bar = self.store.get(Bar, 100) self.assertTrue(bar.foo) bar.foo_id = 40 self.assertEqual(bar.foo, None) def test_reference_break_on_remote_diverged(self): bar = self.store.get(Bar, 100) bar.foo.id = 40 self.assertEqual(bar.foo, None) def test_reference_break_on_local_diverged_by_lazy(self): bar = self.store.get(Bar, 100) self.assertEqual(bar.foo.id, 10) bar.foo_id = SQL("20") self.assertEqual(bar.foo.id, 20) def test_reference_remote_leak_on_flush_with_changed(self): """ "changed" events only hold weak references to remote infos object, thus not creating a leak when unhooked. """ self.get_cache(self.store).set_size(0) bar = self.store.get(Bar, 100) bar.foo.title = "Changed title" bar_ref = weakref.ref(get_obj_info(bar)) foo = bar.foo del bar self.store.flush() gc.collect() self.assertEqual(bar_ref(), None) def test_reference_remote_leak_on_flush_with_removed(self): """ "removed" events only hold weak references to remote infos objects, thus not creating a leak when unhooked. """ self.get_cache(self.store).set_size(0) class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) foo = self.store.get(MyFoo, 10) foo.bar.title = "Changed title" foo_ref = weakref.ref(get_obj_info(foo)) bar = foo.bar del foo self.store.flush() gc.collect() self.assertEqual(foo_ref(), None) def test_reference_break_on_remote_diverged_by_lazy(self): class MyBar(Bar): pass MyBar.foo = Reference(MyBar.title, Foo.title) bar = self.store.get(MyBar, 100) bar.title = "Title 30" self.store.flush() self.assertEqual(bar.foo.id, 10) bar.foo.title = SQL("'Title 40'") self.assertEqual(bar.foo, None) self.assertEqual(self.store.find(Foo, title="Title 30").one(), None) self.assertEqual(self.store.get(Foo, 10).title, "Title 40") def test_reference_on_non_primary_key(self): self.store.execute("INSERT INTO bar VALUES (400, 40, 'Title 30')") class MyBar(Bar): foo = Reference(Bar.title, Foo.title) bar = self.store.get(Bar, 400) self.assertEqual(bar.title, "Title 30") self.assertEqual(bar.foo, None) mybar = self.store.get(MyBar, 400) self.assertEqual(mybar.title, "Title 30") self.assertNotEqual(mybar.foo, None) self.assertEqual(mybar.foo.id, 10) self.assertEqual(mybar.foo.title, "Title 30") def test_new_reference(self): bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo_id = 10 self.assertEqual(bar.foo, None) self.store.add(bar) self.assertTrue(bar.foo) self.assertEqual(bar.foo.title, "Title 30") def test_set_reference(self): bar = self.store.get(Bar, 100) self.assertEqual(bar.foo.id, 10) foo = self.store.get(Foo, 30) bar.foo = foo self.assertEqual(bar.foo.id, 30) result = self.store.execute("SELECT foo_id FROM bar WHERE id=100") self.assertEqual(result.get_one(), (30,)) def test_set_reference_explicitly_with_wrapper(self): bar = self.store.get(Bar, 100) self.assertEqual(bar.foo.id, 10) foo = self.store.get(Foo, 30) Bar.foo.__set__(Wrapper(bar), Wrapper(foo)) self.assertEqual(bar.foo.id, 30) result = self.store.execute("SELECT foo_id FROM bar WHERE id=100") self.assertEqual(result.get_one(), (30,)) def test_reference_assign_remote_key(self): bar = self.store.get(Bar, 100) self.assertEqual(bar.foo.id, 10) bar.foo = 30 self.assertEqual(bar.foo_id, 30) self.assertEqual(bar.foo.id, 30) result = self.store.execute("SELECT foo_id FROM bar WHERE id=100") self.assertEqual(result.get_one(), (30,)) def test_reference_on_added(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.store.add(bar) self.assertEqual(bar.foo.id, None) self.assertEqual(bar.foo.title, "Title 40") self.store.flush() self.assertTrue(bar.foo.id) self.assertEqual(bar.foo.title, "Title 40") result = self.store.execute("SELECT foo.title FROM foo, bar " "WHERE bar.id=400 AND " "foo.id = bar.foo_id") self.assertEqual(result.get_one(), ("Title 40",)) def test_reference_on_added_with_autoreload_key(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.store.add(bar) self.assertEqual(bar.foo.id, None) self.assertEqual(bar.foo.title, "Title 40") foo.id = AutoReload # Variable shouldn't be autoreloaded yet. obj_info = get_obj_info(foo) self.assertEqual(obj_info.variables[Foo.id].get_lazy(), AutoReload) self.assertEqual(type(foo.id), int) self.store.flush() self.assertTrue(bar.foo.id) self.assertEqual(bar.foo.title, "Title 40") result = self.store.execute("SELECT foo.title FROM foo, bar " "WHERE bar.id=400 AND " "foo.id = bar.foo_id") self.assertEqual(result.get_one(), ("Title 40",)) def test_reference_assign_none(self): foo = Foo() foo.title = "Title 40" bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo bar.foo = None bar.foo = None # Twice to make sure it doesn't blow up. self.store.add(bar) self.store.flush() self.assertEqual(type(bar.id), int) self.assertEqual(foo.id, None) def test_reference_assign_none_with_unseen(self): bar = self.store.get(Bar, 200) bar.foo = None self.assertEqual(bar.foo, None) def test_reference_on_added_composed_key(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int() title = Unicode() foo = Reference((foo_id, title), (Foo.id, Foo.title)) foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.foo = foo self.store.add(bar) self.assertEqual(bar.foo.id, None) self.assertEqual(bar.foo.title, "Title 40") self.assertEqual(bar.title, "Title 40") self.store.flush() self.assertTrue(bar.foo.id) self.assertEqual(bar.foo.title, "Title 40") result = self.store.execute("SELECT foo.title FROM foo, bar " "WHERE bar.id=400 AND " "foo.id = bar.foo_id") self.assertEqual(result.get_one(), ("Title 40",)) def test_reference_assign_composed_remote_key(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int() title = Unicode() foo = Reference((foo_id, title), (Foo.id, Foo.title)) bar = Bar() bar.id = 400 bar.foo = (20, "Title 20") self.store.add(bar) self.assertEqual(bar.foo_id, 20) self.assertEqual(bar.foo.id, 20) self.assertEqual(bar.title, "Title 20") self.assertEqual(bar.foo.title, "Title 20") def test_reference_on_added_unlink_on_flush(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.foo = foo bar.title = "Title 400" self.store.add(bar) foo.id = 40 self.assertEqual(bar.foo_id, 40) foo.id = 50 self.assertEqual(bar.foo_id, 50) foo.id = 60 self.assertEqual(bar.foo_id, 60) self.store.flush() foo.id = 70 self.assertEqual(bar.foo_id, 60) def test_reference_on_added_unsets_original_key(self): foo = Foo() self.store.add(foo) bar = Bar() bar.id = 400 bar.foo_id = 40 bar.foo = foo self.assertEqual(bar.foo_id, None) def test_reference_on_two_added(self): foo1 = Foo() foo1.title = "Title 40" foo2 = Foo() foo2.title = "Title 40" self.store.add(foo1) self.store.add(foo2) bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo1 bar.foo = foo2 self.store.add(bar) foo1.id = 40 self.assertEqual(bar.foo_id, None) foo2.id = 50 self.assertEqual(bar.foo_id, 50) def test_reference_on_added_and_changed_manually(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.store.add(bar) bar.foo_id = 40 foo.id = 50 self.assertEqual(bar.foo_id, 40) def test_reference_on_added_composed_key_changed_manually(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int() title = Unicode() foo = Reference((foo_id, title), (Foo.id, Foo.title)) foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.foo = foo self.store.add(bar) bar.title = "Title 50" self.assertEqual(bar.foo, None) foo.id = 40 self.assertEqual(bar.foo_id, None) def test_reference_on_added_no_local_store(self): foo = Foo() foo.title = "Title 40" self.store.add(foo) bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) def test_reference_on_added_no_remote_store(self): foo = Foo() foo.title = "Title 40" bar = Bar() bar.id = 400 bar.title = "Title 400" self.store.add(bar) bar.foo = foo self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) def test_reference_on_added_no_store(self): foo = Foo() foo.title = "Title 40" bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.store.add(bar) self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_reference_on_added_no_store_2(self): foo = Foo() foo.title = "Title 40" bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo self.store.add(foo) self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_reference_on_added_wrong_store(self): store = self.create_store() foo = Foo() foo.title = "Title 40" store.add(foo) bar = Bar() bar.id = 400 bar.title = "Title 400" self.store.add(bar) self.assertRaises(WrongStoreError, setattr, bar, "foo", foo) def test_reference_on_added_no_store_unlink_before_adding(self): foo1 = Foo() foo1.title = "Title 40" bar = Bar() bar.id = 400 bar.title = "Title 400" bar.foo = foo1 bar.foo = None self.store.add(bar) store = self.create_store() store.add(foo1) self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo1), store) def test_reference_on_removed_wont_add_back(self): bar = self.store.get(Bar, 200) foo = self.store.get(Foo, bar.foo_id) self.store.remove(bar) self.assertEqual(bar.foo, foo) self.store.flush() self.assertEqual(Store.of(bar), None) self.assertEqual(Store.of(foo), self.store) def test_reference_equals(self): foo = self.store.get(Foo, 10) bar = self.store.find(Bar, foo=foo).one() self.assertTrue(bar) self.assertEqual(bar.foo, foo) bar = self.store.find(Bar, foo=foo.id).one() self.assertTrue(bar) self.assertEqual(bar.foo, foo) def test_reference_equals_none(self): result = list(self.store.find(SelfRef, selfref=None)) self.assertEqual(len(result), 2) self.assertEqual(result[0].selfref, None) self.assertEqual(result[1].selfref, None) def test_reference_equals_with_composed_key(self): # Interesting case of self-reference. class LinkWithRef(Link): myself = Reference((Link.foo_id, Link.bar_id), (Link.foo_id, Link.bar_id)) link = self.store.find(LinkWithRef, foo_id=10, bar_id=100).one() myself = self.store.find(LinkWithRef, myself=link).one() self.assertEqual(link, myself) myself = self.store.find(LinkWithRef, myself=(link.foo_id, link.bar_id)).one() self.assertEqual(link, myself) def test_reference_equals_with_wrapped(self): foo = self.store.get(Foo, 10) bar = self.store.find(Bar, foo=Wrapper(foo)).one() self.assertTrue(bar) self.assertEqual(bar.foo, foo) def test_reference_not_equals(self): foo = self.store.get(Foo, 10) result = self.store.find(Bar, Bar.foo != foo) self.assertEqual([200, 300], sorted(bar.id for bar in result)) def test_reference_not_equals_none(self): obj = self.store.find(SelfRef, SelfRef.selfref != None).one() self.assertTrue(obj) self.assertNotEqual(obj.selfref, None) def test_reference_not_equals_with_composed_key(self): class LinkWithRef(Link): myself = Reference((Link.foo_id, Link.bar_id), (Link.foo_id, Link.bar_id)) link = self.store.find(LinkWithRef, foo_id=10, bar_id=100).one() result = list(self.store.find(LinkWithRef, LinkWithRef.myself != link)) self.assertTrue(link not in result, "%r not in %r" % (link, result)) result = list(self.store.find( LinkWithRef, LinkWithRef.myself != (link.foo_id, link.bar_id))) self.assertTrue(link not in result, "%r not in %r" % (link, result)) def test_reference_self(self): selfref = self.store.add(SelfRef()) selfref.id = 400 selfref.title = "Title 400" selfref.selfref_id = 25 self.assertEqual(selfref.selfref.id, 25) self.assertEqual(selfref.selfref.title, "SelfRef 25") def get_bar_200_title(self): connection = self.store._connection result = connection.execute("SELECT title FROM bar WHERE id=200") return result.get_one()[0] def test_reference_wont_touch_store_when_key_is_none(self): bar = self.store.get(Bar, 200) bar.foo_id = None bar.title = "Don't flush this!" self.assertEqual(bar.foo, None) # Bypass the store to prevent flushing. self.assertEqual(self.get_bar_200_title(), "Title 200") def test_reference_wont_touch_store_when_key_is_unset(self): bar = self.store.get(Bar, 200) del bar.foo_id bar.title = "Don't flush this!" self.assertEqual(bar.foo, None) # Bypass the store to prevent flushing. connection = self.store._connection result = connection.execute("SELECT title FROM bar WHERE id=200") self.assertEqual(result.get_one()[0], "Title 200") def test_reference_wont_touch_store_with_composed_key_none(self): class Bar: __storm_table__ = "bar" id = Int(primary=True) foo_id = Int() title = Unicode() foo = Reference((foo_id, title), (Foo.id, Foo.title)) bar = self.store.get(Bar, 200) bar.foo_id = None bar.title = None self.assertEqual(bar.foo, None) # Bypass the store to prevent flushing. self.assertEqual(self.get_bar_200_title(), "Title 200") def test_reference_will_resolve_auto_reload(self): bar = self.store.get(Bar, 200) bar.foo_id = AutoReload self.assertTrue(bar.foo) def test_back_reference(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) foo = self.store.get(MyFoo, 10) self.assertTrue(foo.bar) self.assertEqual(foo.bar.title, "Title 300") def test_back_reference_setting(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" self.store.add(bar) foo = MyFoo() foo.bar = bar foo.title = "Title 40" self.store.add(foo) self.store.flush() self.assertTrue(foo.id) self.assertEqual(bar.foo_id, foo.id) result = self.store.execute("SELECT bar.title " "FROM foo, bar " "WHERE foo.id = bar.foo_id AND " "foo.title = 'Title 40'") self.assertEqual(result.get_one(), ("Title 400",)) def test_back_reference_setting_changed_manually(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" self.store.add(bar) foo = MyFoo() foo.bar = bar foo.title = "Title 40" self.store.add(foo) foo.id = 40 self.assertEqual(foo.bar, bar) self.store.flush() self.assertEqual(foo.id, 40) self.assertEqual(bar.foo_id, 40) result = self.store.execute("SELECT bar.title " "FROM foo, bar " "WHERE foo.id = bar.foo_id AND " "foo.title = 'Title 40'") self.assertEqual(result.get_one(), ("Title 400",)) def test_back_reference_assign_none_with_unseen(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) foo = self.store.get(MyFoo, 20) foo.bar = None self.assertEqual(foo.bar, None) def test_back_reference_assign_none_from_none(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) self.store.execute("INSERT INTO foo (id, title)" " VALUES (40, 'Title 40')") self.store.commit() foo = self.store.get(MyFoo, 40) foo.bar = None self.assertEqual(foo.bar, None) def test_back_reference_on_added_unsets_original_key(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) foo = MyFoo() bar = Bar() bar.id = 400 bar.foo_id = 40 foo.bar = bar self.assertEqual(bar.foo_id, None) def test_back_reference_on_added_no_store(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" foo = MyFoo() foo.bar = bar foo.title = "Title 40" self.store.add(bar) self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_back_reference_on_added_no_store_2(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" foo = MyFoo() foo.bar = bar foo.title = "Title 40" self.store.add(foo) self.assertEqual(Store.of(bar), self.store) self.assertEqual(Store.of(foo), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_back_reference_remove_remote(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" foo = MyFoo() foo.title = "Title 40" foo.bar = bar self.store.add(foo) self.store.flush() self.assertEqual(foo.bar, bar) self.store.remove(bar) self.assertEqual(foo.bar, None) def test_back_reference_remove_remote_pending_add(self): class MyFoo(Foo): bar = Reference(Foo.id, Bar.foo_id, on_remote=True) bar = Bar() bar.title = "Title 400" foo = MyFoo() foo.title = "Title 40" foo.bar = bar self.store.add(foo) self.assertEqual(foo.bar, bar) self.store.remove(bar) self.assertEqual(foo.bar, None) def test_reference_loop_with_undefined_keys_fails(self): """A loop of references with undefined keys raises OrderLoopError.""" ref1 = SelfRef() self.store.add(ref1) ref2 = SelfRef() ref2.selfref = ref1 ref1.selfref = ref2 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_dirty_keys_fails(self): ref1 = SelfRef() self.store.add(ref1) ref1.id = 42 ref2 = SelfRef() ref2.id = 43 ref2.selfref = ref1 ref1.selfref = ref2 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_dirty_keys_changed_later_fails(self): ref1 = SelfRef() ref2 = SelfRef() self.store.add(ref1) self.store.add(ref2) self.store.flush() ref2.selfref = ref1 ref1.selfref = ref2 ref1.id = 42 ref2.id = 43 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_dirty_keys_on_remote_fails(self): ref1 = SelfRef() self.store.add(ref1) ref1.id = 42 ref2 = SelfRef() ref2.id = 43 ref2.selfref_on_remote = ref1 ref1.selfref_on_remote = ref2 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_dirty_keys_on_remote_changed_later_fails(self): ref1 = SelfRef() ref2 = SelfRef() self.store.add(ref1) self.store.flush() ref2.selfref_on_remote = ref1 ref1.selfref_on_remote = ref2 ref1.id = 42 ref2.id = 43 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_unchanged_keys_succeeds(self): ref1 = SelfRef() self.store.add(ref1) ref1.id = 42 ref2 = SelfRef() self.store.add(ref2) ref1.id = 43 self.store.flush() # As ref1 and ref2 have been flushed to the database, so these # changes can be flushed. ref2.selfref = ref1 ref1.selfref = ref2 self.store.flush() def test_reference_loop_with_one_unchanged_key_succeeds(self): ref1 = SelfRef() self.store.add(ref1) self.store.flush() ref2 = SelfRef() ref2.selfref = ref1 ref1.selfref = ref2 # As ref1 and ref2 have been flushed to the database, so these # changes can be flushed. self.store.flush() def test_reference_loop_with_key_changed_later_succeeds(self): ref1 = SelfRef() self.store.add(ref1) self.store.flush() ref2 = SelfRef() ref1.selfref = ref2 ref2.id = 42 self.store.flush() def test_reference_loop_with_key_changed_later_on_remote_succeeds(self): ref1 = SelfRef() self.store.add(ref1) self.store.flush() ref2 = SelfRef() ref2.selfref_on_remote = ref1 ref2.id = 42 self.store.flush() def test_reference_loop_with_undefined_and_changed_keys_fails(self): ref1 = SelfRef() self.store.add(ref1) self.store.flush() ref1.id = 400 ref2 = SelfRef() ref2.selfref = ref1 ref1.selfref = ref2 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_with_undefined_and_changed_keys_fails2(self): ref1 = SelfRef() self.store.add(ref1) self.store.flush() ref2 = SelfRef() ref2.selfref = ref1 ref1.selfref = ref2 ref1.id = 400 self.assertRaises(OrderLoopError, self.store.flush) def test_reference_loop_broken_by_set(self): ref1 = SelfRef() ref2 = SelfRef() ref1.selfref = ref2 ref2.selfref = ref1 self.store.add(ref1) ref1.selfref = None self.store.flush() def test_reference_loop_set_only_removes_own_flush_order(self): ref1 = SelfRef() ref2 = SelfRef() self.store.add(ref2) self.store.flush() # The following does not create a loop since the keys are # dirty (as shown in another test). ref1.selfref = ref2 ref2.selfref = ref1 # Now add a flush order loop. self.store.add_flush_order(ref1, ref2) self.store.add_flush_order(ref2, ref1) # Now break the reference. This should leave the flush # ordering loop we previously created in place.. ref1.selfref = None self.assertRaises(OrderLoopError, self.store.flush) def add_reference_set_bar_400(self): bar = Bar() bar.id = 400 bar.foo_id = 20 bar.title = "Title 100" self.store.add(bar) def test_reference_set(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.foo_id, bar.title)) items.sort() self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) def test_reference_set_assign_fails(self): foo = self.store.get(FooRefSet, 20) try: foo.bars = [] except FeatureError: pass else: self.fail("FeatureError not raised") def test_reference_set_explicitly_with_wrapper(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) items = [] for bar in FooRefSet.bars.__get__(Wrapper(foo)): items.append((bar.id, bar.foo_id, bar.title)) items.sort() self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) def test_reference_set_with_added(self): bar1 = Bar() bar1.id = 400 bar1.title = "Title 400" bar2 = Bar() bar2.id = 500 bar2.title = "Title 500" foo = FooRefSet() foo.title = "Title 40" foo.bars.add(bar1) foo.bars.add(bar2) self.store.add(foo) self.assertEqual(foo.id, None) self.assertEqual(bar1.foo_id, None) self.assertEqual(bar2.foo_id, None) self.assertEqual(list(foo.bars.order_by(Bar.id)), [bar1, bar2]) self.assertEqual(type(foo.id), int) self.assertEqual(foo.id, bar1.foo_id) self.assertEqual(foo.id, bar2.foo_id) def test_reference_set_composed(self): self.add_reference_set_bar_400() bar = self.store.get(Bar, 400) bar.title = "Title 20" class FooRefSetComposed(Foo): bars = ReferenceSet((Foo.id, Foo.title), (Bar.foo_id, Bar.title)) foo = self.store.get(FooRefSetComposed, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (400, 20, "Title 20"), ]) bar = self.store.get(Bar, 200) bar.title = "Title 20" del items[:] for bar in foo.bars: items.append((bar.id, bar.foo_id, bar.title)) items.sort() self.assertEqual(items, [ (200, 20, "Title 20"), (400, 20, "Title 20"), ]) def test_reference_set_contains(self): def no_iter(self): raise RuntimeError() from storm.references import BoundReferenceSetBase orig_iter = BoundReferenceSetBase.__iter__ BoundReferenceSetBase.__iter__ = no_iter try: foo = self.store.get(FooRefSet, 20) bar = self.store.get(Bar, 200) self.assertEqual(bar in foo.bars, True) finally: BoundReferenceSetBase.__iter__ = orig_iter def test_reference_set_find(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) items = [] for bar in foo.bars.find(): items.append((bar.id, bar.foo_id, bar.title)) items.sort() self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) # Notice that there's another item with this title in the base, # which isn't part of the reference. objects = list(foo.bars.find(Bar.title == "Title 100")) self.assertEqual(len(objects), 1) self.assertTrue(objects[0] is bar) objects = list(foo.bars.find(title="Title 100")) self.assertEqual(len(objects), 1) self.assertTrue(objects[0] is bar) def test_reference_set_clear(self): foo = self.store.get(FooRefSet, 20) foo.bars.clear() self.assertEqual(list(foo.bars), []) # Object wasn't removed. self.assertTrue(self.store.get(Bar, 200)) def test_reference_set_clear_cached(self): foo = self.store.get(FooRefSet, 20) bar = self.store.get(Bar, 200) self.assertEqual(bar.foo_id, 20) foo.bars.clear() self.assertEqual(bar.foo_id, None) def test_reference_set_clear_where(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) foo.bars.clear(Bar.id > 200) items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars] self.assertEqual(items, [ (200, 20, "Title 200"), ]) bar = self.store.get(Bar, 400) bar.foo_id = 20 foo.bars.clear(id=200) items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars] self.assertEqual(items, [ (400, 20, "Title 100"), ]) def test_reference_set_is_empty(self): foo = self.store.get(FooRefSet, 20) self.assertFalse(foo.bars.is_empty()) foo.bars.clear() self.assertTrue(foo.bars.is_empty()) def test_reference_set_count(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) self.assertEqual(foo.bars.count(), 2) def test_reference_set_order_by(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) items = [] for bar in foo.bars.order_by(Bar.id): items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) del items[:] for bar in foo.bars.order_by(Bar.title): items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (400, 20, "Title 100"), (200, 20, "Title 200"), ]) def test_reference_set_default_order_by(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderID, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) items = [] for bar in foo.bars.find(): items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) foo = self.store.get(FooRefSetOrderTitle, 20) del items[:] for bar in foo.bars: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (400, 20, "Title 100"), (200, 20, "Title 200"), ]) del items[:] for bar in foo.bars.find(): items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (400, 20, "Title 100"), (200, 20, "Title 200"), ]) def test_reference_set_getitem(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderID, 20) self.assertEqual(foo.bars[0].id, 200) self.assertEqual(foo.bars[1].id, 400) self.assertRaises(IndexError, foo.bars.__getitem__, 2) items = [] for bar in foo.bars[:1]: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (200, 20, "Title 200"), ]) del items[:] for bar in foo.bars[1:]: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (400, 20, "Title 100"), ]) del items[:] for bar in foo.bars[:2]: items.append((bar.id, bar.foo_id, bar.title)) self.assertEqual(items, [ (200, 20, "Title 200"), (400, 20, "Title 100"), ]) def test_reference_set_first_last(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderID, 20) self.assertEqual(foo.bars.first().id, 200) self.assertEqual(foo.bars.last().id, 400) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.first().id, 400) self.assertEqual(foo.bars.last().id, 200) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(Bar.id > 400), None) self.assertEqual(foo.bars.last(Bar.id > 400), None) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(Bar.id < 400).id, 200) self.assertEqual(foo.bars.last(Bar.id < 400).id, 200) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(id=200).id, 200) self.assertEqual(foo.bars.last(id=200).id, 200) foo = self.store.get(FooRefSet, 20) self.assertRaises(UnorderedError, foo.bars.first) self.assertRaises(UnorderedError, foo.bars.last) def test_indirect_reference_set_any(self): """ L{BoundReferenceSet.any} returns an arbitrary object from the set of referenced objects. """ foo = self.store.get(FooRefSet, 20) self.assertNotEqual(None, foo.bars.any()) def test_indirect_reference_set_any_filtered(self): """ L{BoundReferenceSet.any} optionally takes a list of filtering criteria to narrow the set of objects to search. When provided, the criteria are used to filter the set before returning an arbitrary object. """ self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(Bar.id > 400), None) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(Bar.id < 400).id, 200) foo = self.store.get(FooRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(id=200).id, 200) def test_reference_set_one(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderID, 20) self.assertRaises(NotOneError, foo.bars.one) foo = self.store.get(FooRefSetOrderID, 30) self.assertEqual(foo.bars.one().id, 300) foo = self.store.get(FooRefSetOrderID, 20) self.assertEqual(foo.bars.one(Bar.id > 400), None) foo = self.store.get(FooRefSetOrderID, 20) self.assertEqual(foo.bars.one(Bar.id < 400).id, 200) foo = self.store.get(FooRefSetOrderID, 20) self.assertEqual(foo.bars.one(id=200).id, 200) def test_reference_set_remove(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSet, 20) for bar in foo.bars: foo.bars.remove(bar) self.assertEqual(bar.foo_id, None) self.assertEqual(list(foo.bars), []) def test_reference_set_after_object_removed(self): class MyBar(Bar): # Make sure that this works even with allow_none=False. foo_id = Int(allow_none=False) class MyFoo(Foo): bars = ReferenceSet(Foo.id, MyBar.foo_id) foo = self.store.get(MyFoo, 20) bar = foo.bars.any() self.store.remove(bar) self.assertTrue(bar not in list(foo.bars)) def test_reference_set_add(self): bar = Bar() bar.id = 400 bar.title = "Title 100" foo = self.store.get(FooRefSet, 20) foo.bars.add(bar) self.assertEqual(bar.foo_id, 20) self.assertEqual(Store.of(bar), self.store) def test_reference_set_add_no_store(self): bar = Bar() bar.id = 400 bar.title = "Title 400" foo = FooRefSet() foo.title = "Title 40" foo.bars.add(bar) self.store.add(foo) self.assertEqual(Store.of(foo), self.store) self.assertEqual(Store.of(bar), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_reference_set_add_no_store_2(self): bar = Bar() bar.id = 400 bar.title = "Title 400" foo = FooRefSet() foo.title = "Title 40" foo.bars.add(bar) self.store.add(bar) self.assertEqual(Store.of(foo), self.store) self.assertEqual(Store.of(bar), self.store) self.store.flush() self.assertEqual(type(bar.foo_id), int) def test_reference_set_add_no_store_unlink_after_adding(self): bar1 = Bar() bar1.id = 400 bar1.title = "Title 400" bar2 = Bar() bar2.id = 500 bar2.title = "Title 500" foo = FooRefSet() foo.title = "Title 40" foo.bars.add(bar1) foo.bars.add(bar2) foo.bars.remove(bar1) self.store.add(foo) store = self.create_store() store.add(bar1) self.assertEqual(Store.of(foo), self.store) self.assertEqual(Store.of(bar1), store) self.assertEqual(Store.of(bar2), self.store) def test_reference_set_values(self): self.add_reference_set_bar_400() foo = self.store.get(FooRefSetOrderID, 20) values = list(foo.bars.values(Bar.id, Bar.foo_id, Bar.title)) self.assertEqual(values, [(200, 20, "Title 200"), (400, 20, "Title 100")]) def test_reference_set_order_by_desc_id(self): self.add_reference_set_bar_400() class FooRefSetOrderByDescID(Foo): bars = ReferenceSet(Foo.id, Bar.foo_id, order_by=Desc(Bar.id)) foo = self.store.get(FooRefSetOrderByDescID, 20) values = list(foo.bars.values(Bar.id, Bar.foo_id, Bar.title)) self.assertEqual(values, [(400, 20, "Title 100"), (200, 20, "Title 200")]) self.assertEqual(foo.bars.first().id, 400) self.assertEqual(foo.bars.last().id, 200) def test_indirect_reference_set(self): foo = self.store.get(FooIndRefSet, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [(100, "Title 300"), (200, "Title 200")]) def test_indirect_reference_set_with_added(self): bar1 = Bar() bar1.id = 400 bar1.title = "Title 400" bar2 = Bar() bar2.id = 500 bar2.title = "Title 500" self.store.add(bar1) self.store.add(bar2) foo = FooIndRefSet() foo.title = "Title 40" foo.bars.add(bar1) foo.bars.add(bar2) self.assertEqual(foo.id, None) self.store.add(foo) self.assertEqual(foo.id, None) self.assertEqual(bar1.foo_id, None) self.assertEqual(bar2.foo_id, None) self.assertEqual(list(foo.bars.order_by(Bar.id)), [bar1, bar2]) self.assertEqual(type(foo.id), int) self.assertEqual(type(bar1.id), int) self.assertEqual(type(bar2.id), int) def test_indirect_reference_set_find(self): foo = self.store.get(FooIndRefSet, 20) items = [] for bar in foo.bars.find(Bar.title == "Title 300"): items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), ]) def test_indirect_reference_set_clear(self): foo = self.store.get(FooIndRefSet, 20) foo.bars.clear() self.assertEqual(list(foo.bars), []) def test_indirect_reference_set_clear_where(self): foo = self.store.get(FooIndRefSet, 20) items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars] self.assertEqual(items, [ (100, 10, "Title 300"), (200, 20, "Title 200"), ]) foo = self.store.get(FooIndRefSet, 30) foo.bars.clear(Bar.id < 300) foo.bars.clear(id=200) foo = self.store.get(FooIndRefSet, 20) foo.bars.clear(Bar.id < 200) items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars] self.assertEqual(items, [ (200, 20, "Title 200"), ]) foo.bars.clear(id=200) items = [(bar.id, bar.foo_id, bar.title) for bar in foo.bars] self.assertEqual(items, []) def test_indirect_reference_set_is_empty(self): foo = self.store.get(FooIndRefSet, 20) self.assertFalse(foo.bars.is_empty()) foo.bars.clear() self.assertTrue(foo.bars.is_empty()) def test_indirect_reference_set_count(self): foo = self.store.get(FooIndRefSet, 20) self.assertEqual(foo.bars.count(), 2) def test_indirect_reference_set_order_by(self): foo = self.store.get(FooIndRefSet, 20) items = [] for bar in foo.bars.order_by(Bar.title): items.append((bar.id, bar.title)) self.assertEqual(items, [ (200, "Title 200"), (100, "Title 300"), ]) del items[:] for bar in foo.bars.order_by(Bar.id): items.append((bar.id, bar.title)) self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) def test_indirect_reference_set_default_order_by(self): foo = self.store.get(FooIndRefSetOrderTitle, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) self.assertEqual(items, [ (200, "Title 200"), (100, "Title 300"), ]) del items[:] for bar in foo.bars.find(): items.append((bar.id, bar.title)) self.assertEqual(items, [ (200, "Title 200"), (100, "Title 300"), ]) foo = self.store.get(FooIndRefSetOrderID, 20) del items[:] for bar in foo.bars: items.append((bar.id, bar.title)) self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) del items[:] for bar in foo.bars.find(): items.append((bar.id, bar.title)) self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) def test_indirect_reference_set_getitem(self): foo = self.store.get(FooIndRefSetOrderID, 20) self.assertEqual(foo.bars[0].id, 100) self.assertEqual(foo.bars[1].id, 200) self.assertRaises(IndexError, foo.bars.__getitem__, 2) items = [] for bar in foo.bars[:1]: items.append((bar.id, bar.title)) self.assertEqual(items, [ (100, "Title 300"), ]) del items[:] for bar in foo.bars[1:]: items.append((bar.id, bar.title)) self.assertEqual(items, [ (200, "Title 200"), ]) del items[:] for bar in foo.bars[:2]: items.append((bar.id, bar.title)) self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) def test_indirect_reference_set_first_last(self): foo = self.store.get(FooIndRefSetOrderID, 20) self.assertEqual(foo.bars.first().id, 100) self.assertEqual(foo.bars.last().id, 200) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.first().id, 200) self.assertEqual(foo.bars.last().id, 100) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(Bar.id > 200), None) self.assertEqual(foo.bars.last(Bar.id > 200), None) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(Bar.id < 200).id, 100) self.assertEqual(foo.bars.last(Bar.id < 200).id, 100) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.first(id=200).id, 200) self.assertEqual(foo.bars.last(id=200).id, 200) foo = self.store.get(FooIndRefSet, 20) self.assertRaises(UnorderedError, foo.bars.first) self.assertRaises(UnorderedError, foo.bars.last) def test_indirect_reference_set_any(self): """ L{BoundIndirectReferenceSet.any} returns an arbitrary object from the set of referenced objects. """ foo = self.store.get(FooIndRefSet, 20) self.assertNotEqual(None, foo.bars.any()) def test_indirect_reference_set_any_filtered(self): """ L{BoundIndirectReferenceSet.any} optionally takes a list of filtering criteria to narrow the set of objects to search. When provided, the criteria are used to filter the set before returning an arbitrary object. """ foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(Bar.id > 200), None) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(Bar.id < 200).id, 100) foo = self.store.get(FooIndRefSetOrderTitle, 20) self.assertEqual(foo.bars.any(id=200).id, 200) def test_indirect_reference_set_one(self): foo = self.store.get(FooIndRefSetOrderID, 20) self.assertRaises(NotOneError, foo.bars.one) foo = self.store.get(FooIndRefSetOrderID, 30) self.assertEqual(foo.bars.one().id, 300) foo = self.store.get(FooIndRefSetOrderID, 20) self.assertEqual(foo.bars.one(Bar.id > 200), None) foo = self.store.get(FooIndRefSetOrderID, 20) self.assertEqual(foo.bars.one(Bar.id < 200).id, 100) foo = self.store.get(FooIndRefSetOrderID, 20) self.assertEqual(foo.bars.one(id=200).id, 200) def test_indirect_reference_set_add(self): foo = self.store.get(FooIndRefSet, 20) bar = self.store.get(Bar, 300) foo.bars.add(bar) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), (300, "Title 100"), ]) def test_indirect_reference_set_remove(self): foo = self.store.get(FooIndRefSet, 20) bar = self.store.get(Bar, 200) foo.bars.remove(bar) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), ]) def test_indirect_reference_set_add_remove(self): foo = self.store.get(FooIndRefSet, 20) bar = self.store.get(Bar, 300) foo.bars.add(bar) foo.bars.remove(bar) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) def test_indirect_reference_set_add_remove_with_wrapper(self): foo = self.store.get(FooIndRefSet, 20) bar300 = self.store.get(Bar, 300) bar200 = self.store.get(Bar, 200) foo.bars.add(Wrapper(bar300)) foo.bars.remove(Wrapper(bar200)) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), (300, "Title 100"), ]) def test_indirect_reference_set_add_remove_with_added(self): foo = FooIndRefSet() foo.id = 40 bar1 = Bar() bar1.id = 400 bar1.title = "Title 400" bar2 = Bar() bar2.id = 500 bar2.title = "Title 500" self.store.add(foo) self.store.add(bar1) self.store.add(bar2) foo.bars.add(bar1) foo.bars.add(bar2) foo.bars.remove(bar1) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (500, "Title 500"), ]) def test_indirect_reference_set_with_added_no_store(self): bar1 = Bar() bar1.id = 400 bar1.title = "Title 400" bar2 = Bar() bar2.id = 500 bar2.title = "Title 500" foo = FooIndRefSet() foo.title = "Title 40" foo.bars.add(bar1) foo.bars.add(bar2) self.store.add(bar1) self.assertEqual(Store.of(foo), self.store) self.assertEqual(Store.of(bar1), self.store) self.assertEqual(Store.of(bar2), self.store) self.assertEqual(foo.id, None) self.assertEqual(bar1.foo_id, None) self.assertEqual(bar2.foo_id, None) self.assertEqual(list(foo.bars.order_by(Bar.id)), [bar1, bar2]) def test_indirect_reference_set_values(self): foo = self.store.get(FooIndRefSetOrderID, 20) values = list(foo.bars.values(Bar.id, Bar.foo_id, Bar.title)) self.assertEqual(values, [ (100, 10, "Title 300"), (200, 20, "Title 200"), ]) def test_references_raise_nostore(self): foo1 = FooRefSet() foo2 = FooIndRefSet() self.assertRaises(NoStoreError, foo1.bars.__iter__) self.assertRaises(NoStoreError, foo2.bars.__iter__) self.assertRaises(NoStoreError, foo1.bars.find) self.assertRaises(NoStoreError, foo2.bars.find) self.assertRaises(NoStoreError, foo1.bars.order_by) self.assertRaises(NoStoreError, foo2.bars.order_by) self.assertRaises(NoStoreError, foo1.bars.count) self.assertRaises(NoStoreError, foo2.bars.count) self.assertRaises(NoStoreError, foo1.bars.clear) self.assertRaises(NoStoreError, foo2.bars.clear) self.assertRaises(NoStoreError, foo2.bars.remove, object()) def test_string_reference(self): class Base(metaclass=PropertyPublisherMeta): pass class MyBar(Base): __storm_table__ = "bar" id = Int(primary=True) title = Unicode() foo_id = Int() foo = Reference("foo_id", "MyFoo.id") class MyFoo(Base): __storm_table__ = "foo" id = Int(primary=True) title = Unicode() bar = self.store.get(MyBar, 100) self.assertTrue(bar.foo) self.assertEqual(bar.foo.title, "Title 30") self.assertEqual(type(bar.foo), MyFoo) def test_string_indirect_reference_set(self): """ A L{ReferenceSet} can have its reference keys specified as strings when the class its a member of uses the L{PropertyPublisherMeta} metaclass. This makes it possible to work around problems with circular dependencies by delaying property resolution. """ class Base(metaclass=PropertyPublisherMeta): pass class MyFoo(Base): __storm_table__ = "foo" id = Int(primary=True) title = Unicode() bars = ReferenceSet("id", "MyLink.foo_id", "MyLink.bar_id", "MyBar.id") class MyBar(Base): __storm_table__ = "bar" id = Int(primary=True) title = Unicode() class MyLink(Base): __storm_table__ = "link" __storm_primary__ = "foo_id", "bar_id" foo_id = Int() bar_id = Int() foo = self.store.get(MyFoo, 20) items = [] for bar in foo.bars: items.append((bar.id, bar.title)) items.sort() self.assertEqual(items, [ (100, "Title 300"), (200, "Title 200"), ]) def test_string_reference_set_order_by(self): """ A L{ReferenceSet} can have its default order by specified as a string when the class its a member of uses the L{PropertyPublisherMeta} metaclass. This makes it possible to work around problems with circular dependencies by delaying resolution of the order by column. """ class Base(metaclass=PropertyPublisherMeta): pass class MyFoo(Base): __storm_table__ = "foo" id = Int(primary=True) title = Unicode() bars = ReferenceSet("id", "MyLink.foo_id", "MyLink.bar_id", "MyBar.id", order_by="MyBar.title") class MyBar(Base): __storm_table__ = "bar" id = Int(primary=True) title = Unicode() class MyLink(Base): __storm_table__ = "link" __storm_primary__ = "foo_id", "bar_id" foo_id = Int() bar_id = Int() foo = self.store.get(MyFoo, 20) items = [(bar.id, bar.title) for bar in foo.bars] self.assertEqual(items, [(200, "Title 200"), (100, "Title 300")]) def test_flush_order(self): foo1 = Foo() foo2 = Foo() foo3 = Foo() foo4 = Foo() foo5 = Foo() for i, foo in enumerate([foo1, foo2, foo3, foo4, foo5]): foo.title = "Object %d" % (i+1) self.store.add(foo) self.store.add_flush_order(foo2, foo4) self.store.add_flush_order(foo4, foo1) self.store.add_flush_order(foo1, foo3) self.store.add_flush_order(foo3, foo5) self.store.add_flush_order(foo5, foo2) self.store.add_flush_order(foo5, foo2) self.assertRaises(OrderLoopError, self.store.flush) self.store.remove_flush_order(foo5, foo2) self.assertRaises(OrderLoopError, self.store.flush) self.store.remove_flush_order(foo5, foo2) self.store.flush() self.assertTrue(foo2.id < foo4.id) self.assertTrue(foo4.id < foo1.id) self.assertTrue(foo1.id < foo3.id) self.assertTrue(foo3.id < foo5.id) def test_variable_filter_on_load(self): foo = self.store.get(FooVariable, 20) self.assertEqual(foo.title, "to_py(from_db(Title 20))") def test_variable_filter_on_update(self): foo = self.store.get(FooVariable, 20) foo.title = "Title 20" self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "to_db(from_py(Title 20))"), (30, "Title 10"), ]) def test_variable_filter_on_update_unchanged(self): foo = self.store.get(FooVariable, 20) self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_variable_filter_on_insert(self): foo = FooVariable() foo.id = 40 foo.title = "Title 40" self.store.add(foo) self.store.flush() self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), (40, "to_db(from_py(Title 40))"), ]) def test_variable_filter_on_missing_values(self): foo = FooVariable() foo.id = 40 self.store.add(foo) self.store.flush() self.assertEqual(foo.title, "to_py(from_db(Default Title))") def test_variable_filter_on_set(self): foo = FooVariable() self.store.find(FooVariable, id=20).set(title="Title 20") self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "to_db(from_py(Title 20))"), (30, "Title 10"), ]) def test_variable_filter_on_set_expr(self): foo = FooVariable() result = self.store.find(FooVariable, id=20) result.set(FooVariable.title == "Title 20") self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "to_db(from_py(Title 20))"), (30, "Title 10"), ]) def test_wb_result_set_variable(self): Result = self.store._connection.result_factory class MyResult(Result): def set_variable(self, variable, value): if variable.__class__ is UnicodeVariable: variable.set("set_variable(%s)" % value) elif variable.__class__ is IntVariable: variable.set(value+1) else: variable.set(value) self.store._connection.result_factory = MyResult try: foo = self.store.get(Foo, 20) finally: self.store._connection.result_factory = Result self.assertEqual(foo.id, 21) self.assertEqual(foo.title, "set_variable(Title 20)") def test_default(self): class MyFoo(Foo): title = Unicode(default="Some default value") foo = MyFoo() self.store.add(foo) self.store.flush() result = self.store.execute("SELECT title FROM foo WHERE id=?", (foo.id,)) self.assertEqual(result.get_one(), ("Some default value",)) self.assertEqual(foo.title, "Some default value") def test_default_factory(self): class MyFoo(Foo): title = Unicode(default_factory=lambda:"Some default value") foo = MyFoo() self.store.add(foo) self.store.flush() result = self.store.execute("SELECT title FROM foo WHERE id=?", (foo.id,)) self.assertEqual(result.get_one(), ("Some default value",)) self.assertEqual(foo.title, "Some default value") def test_pickle_variable(self): class PickleBlob(Blob): bin = Pickle() blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) self.assertEqual(pickle_blob.bin["a"], 1) pickle_blob.bin["b"] = 2 self.store.flush() self.store.reload(blob) self.assertEqual(pickle.loads(blob.bin), {"a": 1, "b": 2}) def test_pickle_variable_remove(self): """ When an object is removed from a store, it should unhook from the "flush" event emitted by the store, and thus not emit a "changed" event if its content change and that the store is flushed. """ class PickleBlob(Blob): bin = Pickle() blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) self.store.remove(pickle_blob) self.store.flush() # Let's change the object pickle_blob.bin = "foobin" # And subscribe to its changed event obj_info = get_obj_info(pickle_blob) events = [] obj_info.event.hook("changed", lambda *args: events.append(args)) self.store.flush() self.assertEqual(events, []) def test_pickle_variable_unhook(self): """ A variable instance must unhook itself from the store event system when the store invalidates its objects. """ # I create a custom PickleVariable, with no __slots__ definition, to be # able to get a weakref of it, thing that I can't do with # PickleVariable that defines __slots__ *AND* those parent is the C # implementation of Variable class CustomPickleVariable(PickleVariable): pass class CustomPickle(Pickle): variable_class = CustomPickleVariable class PickleBlob(Blob): bin = CustomPickle() blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) self.store.flush() self.store.invalidate() obj_info = get_obj_info(pickle_blob) variable = obj_info.variables[PickleBlob.bin] var_ref = weakref.ref(variable) del variable, blob, pickle_blob, obj_info gc.collect() self.assertTrue(var_ref() is None) def test_pickle_variable_referenceset(self): """ A variable instance must unhook itself from the store event system explcitely when the store invalidates its objects: it's particulary important when a ReferenceSet is used, because it keeps strong references to objects involved. """ class CustomPickleVariable(PickleVariable): pass class CustomPickle(Pickle): variable_class = CustomPickleVariable class PickleBlob(Blob): bin = CustomPickle() foo_id = Int() class FooBlobRefSet(Foo): blobs = ReferenceSet(Foo.id, PickleBlob.foo_id) blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) foo = self.store.get(FooBlobRefSet, 10) foo.blobs.add(pickle_blob) self.store.flush() self.store.invalidate() obj_info = get_obj_info(pickle_blob) variable = obj_info.variables[PickleBlob.bin] var_ref = weakref.ref(variable) del variable, blob, pickle_blob, obj_info, foo gc.collect() self.assertTrue(var_ref() is None) def test_pickle_variable_referenceset_several_transactions(self): """ Check that a pickle variable fires the changed event when used among several transactions. """ class PickleBlob(Blob): bin = Pickle() foo_id = Int() class FooBlobRefSet(Foo): blobs = ReferenceSet(Foo.id, PickleBlob.foo_id) blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) foo = self.store.get(FooBlobRefSet, 10) foo.blobs.add(pickle_blob) self.store.flush() self.store.invalidate() self.store.reload(pickle_blob) pickle_blob.bin = "foo" obj_info = get_obj_info(pickle_blob) events = [] obj_info.event.hook("changed", lambda *args: events.append(args)) self.store.flush() self.assertEqual(len(events), 1) def test_undefined_variables_filled_on_find(self): """ Check that when data is fetched from the database on a find, it is used to fill up any undefined variables. """ # We do a first find to get the object_infos into the cache. foos = list(self.store.find(Foo, title="Title 20")) # Commit so that all foos are invalidated and variables are # set back to AutoReload. self.store.commit() # Another find which should reuse in-memory foos. for foo in self.store.find(Foo, title="Title 20"): # Make sure we have all variables defined, because # values were already retrieved by the find's select. obj_info = get_obj_info(foo) for column in obj_info.variables: self.assertTrue(obj_info.variables[column].is_defined()) def test_storm_loaded_after_define(self): """ C{__storm_loaded__} is only called once all the variables are correctly defined in the object. If the object is in the alive cache but disappeared, it used to be called without its variables defined. """ # Disable the cache, which holds strong references. self.get_cache(self.store).set_size(0) loaded = [] class MyFoo(Foo): def __storm_loaded__(oself): loaded.append(None) obj_info = get_obj_info(oself) for column in obj_info.variables: self.assertTrue(obj_info.variables[column].is_defined()) foo = self.store.get(MyFoo, 20) obj_info = get_obj_info(foo) del foo gc.collect() self.assertEqual(obj_info.get_obj(), None) # Commit so that all foos are invalidated and variables are # set back to AutoReload. self.store.commit() foo = self.store.find(MyFoo, title="Title 20").one() self.assertEqual(foo.id, 20) self.assertEqual(len(loaded), 2) def test_defined_variables_not_overridden_on_find(self): """ Check that the keep_defined=True setting in _load_object() is in place. In practice, it ensures that already defined values aren't replaced during a find, when new data comes from the database and is used whenever possible. """ blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." class PickleBlob: __storm_table__ = "bin" id = Int(primary=True) pickle = Pickle("bin") blob = self.store.get(PickleBlob, 20) value = blob.pickle # Now the find should not destroy our value pointer. blob = self.store.find(PickleBlob, id=20).one() self.assertTrue(value is blob.pickle) def test_pickle_variable_with_deleted_object(self): class PickleBlob(Blob): bin = Pickle() blob = self.store.get(Blob, 20) blob.bin = b"\x80\x02}q\x01U\x01aK\x01s." self.store.flush() pickle_blob = self.store.get(PickleBlob, 20) self.assertEqual(pickle_blob.bin["a"], 1) pickle_blob.bin["b"] = 2 del pickle_blob gc.collect() self.store.flush() self.store.reload(blob) self.assertEqual(pickle.loads(blob.bin), {"a": 1, "b": 2}) def test_unhashable_object(self): class DictFoo(Foo, dict): pass foo = self.store.get(DictFoo, 20) foo["a"] = 1 self.assertEqual(list(foo.items()), [("a", 1)]) new_obj = DictFoo() new_obj.id = 40 new_obj.title = "My Title" self.store.add(new_obj) self.store.commit() self.assertTrue(self.store.get(DictFoo, 40) is new_obj) def test_wrapper(self): foo = self.store.get(Foo, 20) wrapper = Wrapper(foo) self.store.remove(wrapper) self.store.flush() self.assertEqual(self.store.get(Foo, 20), None) def test_rollback_loaded_and_still_in_cached(self): # Explore problem found on interaction between caching, commits, # and rollbacks, when they still existed. foo1 = self.store.get(Foo, 20) self.store.commit() self.store.rollback() foo2 = self.store.get(Foo, 20) self.assertTrue(foo1 is foo2) def test_class_alias(self): FooAlias = ClassAlias(Foo) result = self.store.find(FooAlias, FooAlias.id < Foo.id) self.assertEqual([(foo.id, foo.title) for foo in result if type(foo) is Foo], [ (10, "Title 30"), (10, "Title 30"), (20, "Title 20"), ]) def test_expr_values(self): foo = self.store.get(Foo, 20) foo.title = SQL("'New title'") # No commits yet. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.store.flush() # Now it should be there. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "New title"), (30, "Title 10"), ]) self.assertEqual(foo.title, "New title") def test_expr_values_flush_on_demand(self): foo = self.store.get(Foo, 20) foo.title = SQL("'New title'") # No commits yet. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, "New title") # Now it should be there. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "New title"), (30, "Title 10"), ]) def test_expr_values_flush_and_load_in_separate_steps(self): foo = self.store.get(Foo, 20) foo.title = SQL("'New title'") self.store.flush() # It's already in the database. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "New title"), (30, "Title 10"), ]) # But our value is now an AutoReload. lazy_value = get_obj_info(foo).variables[Foo.title].get_lazy() self.assertTrue(lazy_value is AutoReload) # Which gets resolved once touched. self.assertEqual(foo.title, "New title") def test_expr_values_flush_on_demand_with_added(self): foo = Foo() foo.id = 40 foo.title = SQL("'New title'") self.store.add(foo) # No commits yet. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, "New title") # Now it should be there. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), (40, "New title"), ]) def test_expr_values_flush_on_demand_with_removed_and_added(self): foo = self.store.get(Foo, 20) foo.title = SQL("'New title'") self.store.remove(foo) self.store.add(foo) # No commits yet. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, "New title") # Now it should be there. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "New title"), (30, "Title 10"), ]) def test_expr_values_flush_on_demand_with_removed_and_rollbacked(self): foo = self.store.get(Foo, 20) self.store.remove(foo) self.store.rollback() foo.title = SQL("'New title'") # No commits yet. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, "New title") # Now it should be there. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "New title"), (30, "Title 10"), ]) def test_expr_values_flush_on_demand_with_added_and_removed(self): # This test tries to trigger a problem in a few different ways. # It uses the same id of an existing object, and add and remove # the object. This object should never get in the database, nor # update the object that is already there, nor flush any other # pending changes when the lazy value is accessed. foo = Foo() foo.id = 20 foo_dep = Foo() foo_dep.id = 50 self.store.add(foo) self.store.add(foo_dep) foo.title = SQL("'New title'") # Add ordering to see if it helps triggering a bug of # incorrect flushing. self.store.add_flush_order(foo_dep, foo) self.store.remove(foo) # No changes. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, None) # Still no changes. There's no reason why foo_dep would be flushed. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_expr_values_flush_on_demand_with_removed(self): # Similar case, but removing an existing object instead. foo = self.store.get(Foo, 20) foo_dep = Foo() foo_dep.id = 50 self.store.add(foo_dep) foo.title = SQL("'New title'") # Add ordering to see if it helps triggering a bug of # incorrect flushing. self.store.add_flush_order(foo_dep, foo) self.store.remove(foo) # No changes. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) self.assertEqual(foo.title, None) # Still no changes. There's no reason why foo_dep would be flushed. self.assertEqual(self.get_items(), [ (10, "Title 30"), (20, "Title 20"), (30, "Title 10"), ]) def test_lazy_value_preserved_with_subsequent_object_initialization(self): """ If a lazy value has been modified on an object that is subsequently initialized from the database the lazy value is correctly preserved and the object is initialized properly. This tests the fix for the problem reported in bug #620615. """ # Retrieve an object, fully loaded. foo = self.store.get(Foo, 20) # Build and retrieve a result set ahead of time, so that # flushes won't happen when actually loading the object. result = self.store.find(Foo, Foo.id == 20) # Now, set an unflushed lazy value on an attribute. foo.title = SQL("'New title'") # Finally, get the existing object. foo = result.one() # We don't really have to test anything here, since the # explosion happened above, but here it is anyway. self.assertEqual(foo.title, "New title") def test_lazy_value_discarded_on_reload(self): """ A counter-test to the above logic, also related to bug #620615. On an explicit reload, the lazy value must be discarded. """ # Retrieve an object, fully loaded. foo = self.store.get(Foo, 20) # Build and retrieve a result set ahead of time, so that # flushes won't happen when actually loading the object. result = self.store.find(Foo, Foo.id == 20) # Now, set an unflushed lazy value on an attribute. foo.title = SQL("'New title'") # Give up on this and reload the original object. self.store.reload(foo) # We don't really have to test anything here, since the # explosion happened above, but here it is anyway. self.assertEqual(foo.title, "Title 20") def test_expr_values_with_columns(self): bar = self.store.get(Bar, 200) bar.foo_id = Bar.id+1 self.assertEqual(bar.foo_id, 201) def test_autoreload_attribute(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "Title 20") foo.title = AutoReload self.assertEqual(foo.title, "New Title") self.assertFalse(get_obj_info(foo).variables[Foo.title].has_changed()) def test_autoreload_attribute_with_changed_primary_key(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "Title 20") foo.id = 40 foo.title = AutoReload self.assertEqual(foo.title, "New Title") self.assertEqual(foo.id, 40) def test_autoreload_object(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "Title 20") self.store.autoreload(foo) self.assertEqual(foo.title, "New Title") def test_autoreload_primary_key_of_unflushed_object(self): foo = Foo() self.store.add(foo) foo.id = AutoReload foo.title = "New Title" self.assertTrue(isinstance(foo.id, int)) self.assertEqual(foo.title, "New Title") def test_autoreload_primary_key_doesnt_reload_everything_else(self): foo = self.store.get(Foo, 20) self.store.autoreload(foo) obj_info = get_obj_info(foo) self.assertEqual(obj_info.variables[Foo.id].get_lazy(), None) self.assertEqual(obj_info.variables[Foo.title].get_lazy(), AutoReload) self.assertEqual(foo.id, 20) self.assertEqual(obj_info.variables[Foo.id].get_lazy(), None) self.assertEqual(obj_info.variables[Foo.title].get_lazy(), AutoReload) def test_autoreload_all_objects(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.assertEqual(foo.title, "Title 20") self.store.autoreload() self.assertEqual(foo.title, "New Title") def test_autoreload_and_get_will_not_reload(self): foo = self.store.get(Foo, 20) self.store.execute("UPDATE foo SET title='New Title' WHERE id=20") self.store.autoreload(foo) obj_info = get_obj_info(foo) self.assertEqual(obj_info.variables[Foo.title].get_lazy(), AutoReload) self.store.get(Foo, 20) self.assertEqual(obj_info.variables[Foo.title].get_lazy(), AutoReload) self.assertEqual(foo.title, "New Title") def test_autoreload_object_doesnt_tag_as_dirty(self): foo = self.store.get(Foo, 20) self.store.autoreload(foo) self.assertTrue(get_obj_info(foo) not in self.store._dirty) def test_autoreload_missing_columns_on_insertion(self): foo = Foo() self.store.add(foo) self.store.flush() lazy_value = get_obj_info(foo).variables[Foo.title].get_lazy() self.assertEqual(lazy_value, AutoReload) self.assertEqual(foo.title, "Default Title") def test_reference_break_on_local_diverged_doesnt_autoreload(self): foo = self.store.get(Foo, 10) self.store.autoreload(foo) bar = self.store.get(Bar, 100) self.assertTrue(bar.foo) bar.foo_id = 40 self.assertEqual(bar.foo, None) obj_info = get_obj_info(foo) self.assertEqual(obj_info.variables[Foo.title].get_lazy(), AutoReload) def test_primary_key_reference(self): """ When an object references another one using its primary key, it correctly checks for the invalidated state after the store has been committed, detecting if the referenced object has been removed behind its back. """ class BarOnRemote: __storm_table__ = "bar" foo_id = Int(primary=True) foo = Reference(foo_id, Foo.id, on_remote=True) foo = self.store.get(Foo, 10) bar = self.store.get(BarOnRemote, 10) self.assertEqual(bar.foo, foo) self.store.execute("DELETE FROM foo WHERE id = 10") self.store.commit() self.assertEqual(bar.foo, None) def test_invalidate_and_get_object(self): foo = self.store.get(Foo, 20) self.store.invalidate(foo) self.assertEqual(self.store.get(Foo, 20), foo) self.assertEqual(self.store.find(Foo, id=20).one(), foo) def test_invalidate_and_get_removed_object(self): foo = self.store.get(Foo, 20) self.store.execute("DELETE FROM foo WHERE id=20") self.store.invalidate(foo) self.assertEqual(self.store.get(Foo, 20), None) self.assertEqual(self.store.find(Foo, id=20).one(), None) def test_invalidate_and_validate_with_find(self): foo = self.store.get(Foo, 20) self.store.invalidate(foo) self.assertEqual(self.store.find(Foo, id=20).one(), foo) # Cache should be considered valid again at this point. self.store.execute("DELETE FROM foo WHERE id=20") self.assertEqual(self.store.get(Foo, 20), foo) def test_invalidate_object_gets_validated(self): foo = self.store.get(Foo, 20) self.store.invalidate(foo) self.assertEqual(self.store.get(Foo, 20), foo) # At this point the object is valid again, so deleting it # from the database directly shouldn't affect caching. self.store.execute("DELETE FROM foo WHERE id=20") self.assertEqual(self.store.get(Foo, 20), foo) def test_invalidate_object_with_only_primary_key(self): link = self.store.get(Link, (20, 200)) self.store.execute("DELETE FROM link WHERE foo_id=20 AND bar_id=200") self.store.invalidate(link) self.assertEqual(self.store.get(Link, (20, 200)), None) def test_invalidate_added_object(self): foo = Foo() self.store.add(foo) self.store.invalidate(foo) foo.id = 40 foo.title = "Title 40" self.store.flush() # Object must have a valid cache at this point, since it was # just added. self.store.execute("DELETE FROM foo WHERE id=40") self.assertEqual(self.store.get(Foo, 40), foo) def test_invalidate_and_update(self): foo = self.store.get(Foo, 20) self.store.execute("DELETE FROM foo WHERE id=20") self.store.invalidate(foo) self.assertRaises(LostObjectError, setattr, foo, "title", "Title 40") def test_invalidated_objects_reloaded_by_get(self): foo = self.store.get(Foo, 20) self.store.invalidate(foo) foo = self.store.get(Foo, 20) title_variable = get_obj_info(foo).variables[Foo.title] self.assertEqual(title_variable.get_lazy(), None) self.assertEqual(title_variable.get(), "Title 20") self.assertEqual(foo.title, "Title 20") def test_invalidated_hook(self): called = [] class MyFoo(Foo): def __storm_invalidated__(self): called.append(True) foo = self.store.get(MyFoo, 20) self.assertEqual(called, []) self.store.autoreload(foo) self.assertEqual(called, []) self.store.invalidate(foo) self.assertEqual(called, [True]) def test_invalidated_hook_called_after_all_invalidated(self): """ Ensure that invalidated hooks are called only when all objects have already been marked as invalidated. See comment in store.py:_mark_autoreload. """ called = [] class MyFoo(Foo): def __storm_invalidated__(self): if not called: called.append(get_obj_info(foo1).get("invalidated")) called.append(get_obj_info(foo2).get("invalidated")) foo1 = self.store.get(MyFoo, 10) foo2 = self.store.get(MyFoo, 20) self.store.invalidate() self.assertEqual(called, [True, True]) def test_reset_recreates_objects(self): """ After resetting the store, all queries return fresh objects, even if there are other objects representing the same database rows still in memory. """ foo1 = self.store.get(Foo, 10) foo1.dirty = True self.store.reset() new_foo1 = self.store.get(Foo, 10) self.assertFalse(hasattr(new_foo1, "dirty")) self.assertNotIdentical(new_foo1, foo1) def test_reset_unmarks_dirty(self): """ If an object was dirty when store.reset() is called, its changes will not be affected. """ foo1 = self.store.get(Foo, 10) foo1_title = foo1.title foo1.title = "radix wuz here" self.store.reset() self.store.flush() new_foo1 = self.store.get(Foo, 10) self.assertEqual(new_foo1.title, foo1_title) def test_reset_clears_cache(self): cache = self.get_cache(self.store) foo1 = self.store.get(Foo, 10) self.assertTrue(get_obj_info(foo1) in cache.get_cached()) self.store.reset() self.assertEqual(cache.get_cached(), []) def test_reset_breaks_store_reference(self): """ After resetting the store, all objects that were associated with that store will no longer be. """ foo1 = self.store.get(Foo, 10) self.store.reset() self.assertIdentical(Store.of(foo1), None) def test_result_find(self): result1 = self.store.find(Foo, Foo.id <= 20) result2 = result1.find(Foo.id > 10) foo = result2.one() self.assertTrue(foo) self.assertEqual(foo.id, 20) def test_result_find_kwargs(self): result1 = self.store.find(Foo, Foo.id <= 20) result2 = result1.find(id=20) foo = result2.one() self.assertTrue(foo) self.assertEqual(foo.id, 20) def test_result_find_introduce_join(self): result1 = self.store.find(Foo, Foo.id <= 20) result2 = result1.find(Foo.id == Bar.foo_id, Bar.title == "Title 300") foo = result2.one() self.assertTrue(foo) self.assertEqual(foo.id, 10) def test_result_find_tuple(self): result1 = self.store.find((Foo, Bar), Foo.id == Bar.foo_id) result2 = result1.find(Bar.title == "Title 100") foo_bar = result2.one() self.assertTrue(foo_bar) foo, bar = foo_bar self.assertEqual(foo.id, 30) self.assertEqual(bar.id, 300) def test_result_find_undef_where(self): result = self.store.find(Foo, Foo.id == 20).find() foo = result.one() self.assertTrue(foo) self.assertEqual(foo.id, 20) result = self.store.find(Foo).find(Foo.id == 20) foo = result.one() self.assertTrue(foo) self.assertEqual(foo.id, 20) def test_result_find_fails_on_set_expr(self): result1 = self.store.find(Foo) result2 = self.store.find(Foo) result = result1.union(result2) self.assertRaises(FeatureError, result.find, Foo.id == 20) def test_result_find_fails_on_slice(self): result = self.store.find(Foo)[1:2] self.assertRaises(FeatureError, result.find, Foo.id == 20) def test_result_find_fails_on_group_by(self): result = self.store.find(Foo) result.group_by(Foo) self.assertRaises(FeatureError, result.find, Foo.id == 20) def test_result_union(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=10) result3 = result1.union(result2) result3.order_by(Foo.title) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), (10, "Title 30"), ]) result3.order_by(Desc(Foo.title)) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (10, "Title 30"), (30, "Title 10"), ]) def test_result_union_duplicated(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=30) result3 = result1.union(result2) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), ]) def test_result_union_duplicated_with_all(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=30) result3 = result1.union(result2, all=True) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), (30, "Title 10"), ]) def test_result_union_with_empty(self): result1 = self.store.find(Foo, id=30) result2 = EmptyResultSet() result3 = result1.union(result2) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), ]) def test_result_union_class_columns(self): """ It's possible to do a union of two result sets on columns on different classes, as long as their variable classes are the same (e.g. both are IntVariables). """ result1 = self.store.find(Foo.id, Foo.id == 10) result2 = self.store.find(Bar.foo_id, Bar.id == 200) self.assertEqual([10, 20], sorted(result1.union(result2))) def test_result_union_incompatible(self): result1 = self.store.find(Foo, id=10) result2 = self.store.find(Bar, id=100) self.assertRaises(FeatureError, result1.union, result2) def test_result_union_unsupported_methods(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=10) result3 = result1.union(result2) self.assertRaises(FeatureError, result3.set, title="Title 40") self.assertRaises(FeatureError, result3.remove) def test_result_union_count(self): result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=30) result3 = result1.union(result2, all=True) self.assertEqual(result3.count(), 2) def test_result_union_limit_count(self): """ It's possible to count the result of a union that is limited. """ result1 = self.store.find(Foo, id=30) result2 = self.store.find(Foo, id=30) result3 = result1.union(result2, all=True) result3.order_by(Foo.id) result3.config(limit=1) self.assertEqual(result3.count(), 1) self.assertEqual(result3.count(Foo.id), 1) def test_result_union_limit_avg(self): """ It's possible to average the result of a union that is limited. """ result1 = self.store.find(Foo, id=10) result2 = self.store.find(Foo, id=30) result3 = result1.union(result2, all=True) result3.order_by(Foo.id) result3.config(limit=1) # Since 30 was left off because of the limit, the only result will be # 10, and the average of that is 10. self.assertEqual(result3.avg(Foo.id), 10) def test_result_difference(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.difference tests on MySQL") result1 = self.store.find(Foo) result2 = self.store.find(Foo, id=20) result3 = result1.difference(result2) result3.order_by(Foo.title) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), (10, "Title 30"), ]) result3.order_by(Desc(Foo.title)) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (10, "Title 30"), (30, "Title 10"), ]) def test_result_difference_with_empty(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.difference tests on MySQL") result1 = self.store.find(Foo, id=30) result2 = EmptyResultSet() result3 = result1.difference(result2) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), ]) def test_result_difference_incompatible(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.difference tests on MySQL") result1 = self.store.find(Foo, id=10) result2 = self.store.find(Bar, id=100) self.assertRaises(FeatureError, result1.difference, result2) def test_result_difference_count(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.difference tests on MySQL") result1 = self.store.find(Foo) result2 = self.store.find(Foo, id=20) result3 = result1.difference(result2) self.assertEqual(result3.count(), 2) def test_is_in_empty_result_set(self): result1 = self.store.find(Foo, Foo.id < 10) result2 = self.store.find(Foo, Or(Foo.id > 20, Foo.id.is_in(result1))) self.assertEqual(result2.count(), 1) def test_is_in_empty_list(self): result2 = self.store.find(Foo, Eq(False, And(True, Foo.id.is_in([])))) self.assertEqual(result2.count(), 3) def test_result_intersection(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.intersection tests on MySQL") result1 = self.store.find(Foo) result2 = self.store.find(Foo, Foo.id.is_in((10, 30))) result3 = result1.intersection(result2) result3.order_by(Foo.title) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (30, "Title 10"), (10, "Title 30"), ]) result3.order_by(Desc(Foo.title)) self.assertEqual([(foo.id, foo.title) for foo in result3], [ (10, "Title 30"), (30, "Title 10"), ]) def test_result_intersection_with_empty(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.intersection tests on MySQL") result1 = self.store.find(Foo, id=30) result2 = EmptyResultSet() result3 = result1.intersection(result2) self.assertEqual(len(list(result3)), 0) def test_result_intersection_incompatible(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.intersection tests on MySQL") result1 = self.store.find(Foo, id=10) result2 = self.store.find(Bar, id=100) self.assertRaises(FeatureError, result1.intersection, result2) def test_result_intersection_count(self): if self.__class__.__name__.startswith("MySQL"): self.skipTest("Skipping ResultSet.intersection tests on MySQL") result1 = self.store.find(Foo, Foo.id.is_in((10, 20))) result2 = self.store.find(Foo, Foo.id.is_in((10, 30))) result3 = result1.intersection(result2) self.assertEqual(result3.count(), 1) def test_proxy(self): bar = self.store.get(BarProxy, 200) self.assertEqual(bar.foo_title, "Title 20") def test_proxy_equals(self): bar = self.store.find(BarProxy, BarProxy.foo_title == "Title 20").one() self.assertTrue(bar) self.assertEqual(bar.id, 200) def test_proxy_as_column(self): result = self.store.find(BarProxy, BarProxy.id == 200) self.assertEqual(list(result.values(BarProxy.foo_title)), ["Title 20"]) def test_proxy_set(self): bar = self.store.get(BarProxy, 200) bar.foo_title = "New Title" foo = self.store.get(Foo, 20) self.assertEqual(foo.title, "New Title") def get_bar_proxy_with_string(self): class Base(metaclass=PropertyPublisherMeta): pass class MyBarProxy(Base): __storm_table__ = "bar" id = Int(primary=True) foo_id = Int() foo = Reference("foo_id", "MyFoo.id") foo_title = Proxy(foo, "MyFoo.title") class MyFoo(Base): __storm_table__ = "foo" id = Int(primary=True) title = Unicode() return MyBarProxy, MyFoo def test_proxy_with_string(self): MyBarProxy, MyFoo = self.get_bar_proxy_with_string() bar = self.store.get(MyBarProxy, 200) self.assertEqual(bar.foo_title, "Title 20") def test_proxy_with_string_variable_factory_attribute(self): MyBarProxy, MyFoo = self.get_bar_proxy_with_string() variable = MyBarProxy.foo_title.variable_factory(value="Hello") self.assertTrue(isinstance(variable, UnicodeVariable)) def test_proxy_with_extra_table(self): """ Proxies use a join on auto_tables. It should work even if we have more tables in the query. """ result = self.store.find((BarProxy, Link), BarProxy.foo_title == "Title 20", BarProxy.foo_id == Link.foo_id) results = list(result) self.assertEqual(len(results), 2) for bar, link in results: self.assertEqual(bar.id, 200) self.assertEqual(bar.foo_title, "Title 20") self.assertEqual(bar.foo_id, 20) self.assertEqual(link.foo_id, 20) def test_get_decimal_property(self): money = self.store.get(Money, 10) self.assertEqual(money.value, decimal.Decimal("12.3455")) def test_set_decimal_property(self): money = self.store.get(Money, 10) money.value = decimal.Decimal("12.3456") self.store.flush() result = self.store.find(Money, value=decimal.Decimal("12.3456")) self.assertEqual(result.one(), money) def test_fill_missing_primary_key_with_lazy_value(self): foo = self.store.get(Foo, 10) foo.id = SQL("40") self.store.flush() self.assertEqual(foo.id, 40) self.assertEqual(self.store.get(Foo, 10), None) self.assertEqual(self.store.get(Foo, 40), foo) def test_fill_missing_primary_key_with_lazy_value_on_creation(self): foo = Foo() foo.id = SQL("40") self.store.add(foo) self.store.flush() self.assertEqual(foo.id, 40) self.assertEqual(self.store.get(Foo, 40), foo) def test_preset_primary_key(self): check = [] def preset_primary_key(primary_columns, primary_variables): check.append([(variable.is_defined(), variable.get_lazy()) for variable in primary_variables]) check.append([column.name for column in primary_columns]) primary_variables[0].set(SQL("40")) class DatabaseWrapper: """Wrapper to inject our custom preset_primary_key hook.""" def __init__(self, database): self.database = database def connect(self, event=None): connection = self.database.connect(event) connection.preset_primary_key = preset_primary_key return connection store = Store(DatabaseWrapper(self.database)) foo = store.add(Foo()) store.flush() try: self.assertEqual(check, [[(False, None)], ["id"]]) self.assertEqual(foo.id, 40) finally: store.close() def test_strong_cache_used(self): """ Objects should be referenced in the cache if not referenced in application code. """ foo = self.store.get(Foo, 20) foo.tainted = True obj_info = get_obj_info(foo) del foo gc.collect() cached = self.store.find(Foo).cached() self.assertEqual(len(cached), 1) foo = self.store.get(Foo, 20) self.assertEqual(cached, [foo]) self.assertTrue(hasattr(foo, "tainted")) def test_strong_cache_cleared_on_invalidate_all(self): cache = self.get_cache(self.store) foo = self.store.get(Foo, 20) self.assertEqual(cache.get_cached(), [get_obj_info(foo)]) self.store.invalidate() self.assertEqual(cache.get_cached(), []) def test_strong_cache_loses_object_on_invalidate(self): cache = self.get_cache(self.store) foo = self.store.get(Foo, 20) self.assertEqual(cache.get_cached(), [get_obj_info(foo)]) self.store.invalidate(foo) self.assertEqual(cache.get_cached(), []) def test_strong_cache_loses_object_on_remove(self): """ Make sure an object gets removed from the strong reference cache when removed from the store. """ cache = self.get_cache(self.store) foo = self.store.get(Foo, 20) self.assertEqual(cache.get_cached(), [get_obj_info(foo)]) self.store.remove(foo) self.store.flush() self.assertEqual(cache.get_cached(), []) def test_strong_cache_renews_object_on_get(self): cache = self.get_cache(self.store) foo1 = self.store.get(Foo, 10) foo2 = self.store.get(Foo, 20) foo1 = self.store.get(Foo, 10) self.assertEqual(cache.get_cached(), [get_obj_info(foo1), get_obj_info(foo2)]) def test_strong_cache_renews_object_on_find(self): cache = self.get_cache(self.store) foo1 = self.store.find(Foo, id=10).one() foo2 = self.store.find(Foo, id=20).one() foo1 = self.store.find(Foo, id=10).one() self.assertEqual(cache.get_cached(), [get_obj_info(foo1), get_obj_info(foo2)]) def test_unicode(self): class MyFoo(Foo): pass foo = self.store.get(Foo, 20) myfoo = self.store.get(MyFoo, 20) for title in ['Cừơng', 'Đức', 'Hạnh']: foo.title = title self.store.commit() try: self.assertEqual(myfoo.title, title) except AssertionError as e: raise AssertionError(str(e, 'replace') + " (ensure your database was created with CREATE DATABASE" " ... CHARACTER SET utf8mb3)") def test_creation_order_is_preserved_when_possible(self): foos = [self.store.add(Foo()) for i in range(10)] self.store.flush() for i in range(len(foos)-1): self.assertTrue(foos[i].id < foos[i+1].id) def test_update_order_is_preserved_when_possible(self): class MyFoo(Foo): sequence = 0 def __storm_flushed__(self): self.flush_order = MyFoo.sequence MyFoo.sequence += 1 foos = [self.store.add(MyFoo()) for i in range(10)] self.store.flush() MyFoo.sequence = 0 for foo in foos: foo.title = "Changed Title" self.store.flush() for i, foo in enumerate(foos): self.assertEqual(foo.flush_order, i) def test_removal_order_is_preserved_when_possible(self): class MyFoo(Foo): sequence = 0 def __storm_flushed__(self): self.flush_order = MyFoo.sequence MyFoo.sequence += 1 foos = [self.store.add(MyFoo()) for i in range(10)] self.store.flush() MyFoo.sequence = 0 for foo in foos: self.store.remove(foo) self.store.flush() for i, foo in enumerate(foos): self.assertEqual(foo.flush_order, i) def test_cache_poisoning(self): """ When a object update a field value to the previous value, which is in the cache, it correctly updates the value in the database. Because of change detection, this has been broken in the past, see bug #277095 in launchpad. """ store = self.create_store() foo2 = store.get(Foo, 10) self.assertEqual(foo2.title, "Title 30") store.commit() foo1 = self.store.get(Foo, 10) foo1.title = "Title 40" self.store.commit() foo2.title = "Title 30" store.commit() self.assertEqual(foo2.title, "Title 30") def test_execute_sends_event(self): """Statement execution emits the register-transaction event.""" calls = [] def register_transaction(owner): calls.append(owner) self.store._event.hook("register-transaction", register_transaction) self.store.execute("SELECT 1") self.assertEqual(len(calls), 1) self.assertEqual(calls[0], self.store) def test_wb_event_before_check_connection(self): """ The register-transaction event is emitted before checking the state of the connection. """ calls = [] def register_transaction(owner): calls.append(owner) self.store._event.hook("register-transaction", register_transaction) self.store._connection._state = STATE_DISCONNECTED self.assertRaises(DisconnectionError, self.store.execute, "SELECT 1") self.assertEqual(len(calls), 1) self.assertEqual(calls[0], self.store) def test_add_sends_event(self): """Adding an object emits the register-transaction event.""" calls = [] def register_transaction(owner): calls.append(owner) self.store._event.hook("register-transaction", register_transaction) foo = Foo() foo.title = "Foo" self.store.add(foo) self.assertEqual(len(calls), 1) self.assertEqual(calls[0], self.store) def test_remove_sends_event(self): """Adding an object emits the register-transaction event.""" calls = [] def register_transaction(owner): calls.append(owner) self.store._event.hook("register-transaction", register_transaction) foo = self.store.get(Foo, 10) del calls[:] self.store.remove(foo) self.assertEqual(len(calls), 1) self.assertEqual(calls[0], self.store) def test_change_invalidated_object_sends_event(self): """Modifying an object retrieved in a previous transaction emits the register-transaction event.""" calls = [] def register_transaction(owner): calls.append(owner) self.store._event.hook("register-transaction", register_transaction) foo = self.store.get(Foo, 10) self.store.rollback() del calls[:] foo.title = "New title" self.assertEqual(len(calls), 1) self.assertEqual(calls[0], self.store) def test_rowcount_remove(self): # All supported backends support rowcount, so far. result_to_remove = self.store.find(Foo, Foo.id <= 30) self.assertEqual(result_to_remove.remove(), 3) class EmptyResultSetTest: def setUp(self): self.create_database() self.connection = self.database.connect() self.drop_tables() self.create_tables() self.create_store() # Most of the tests here exercise the same functionality using # self.empty and self.result to ensure that EmptyResultSet and # ResultSet behave the same way, in the same situations. self.empty = EmptyResultSet() self.result = self.store.find(Foo) def tearDown(self): self.drop_store() self.drop_tables() self.drop_database() self.connection.close() def create_database(self): raise NotImplementedError def create_tables(self): raise NotImplementedError def create_store(self): self.store = Store(self.database) def drop_database(self): pass def drop_tables(self): for table in ["foo", "bar", "bin", "link"]: try: self.connection.execute("DROP TABLE %s" % table) self.connection.commit() except: self.connection.rollback() def drop_store(self): self.store.rollback() # Closing the store is needed because testcase objects are all # instantiated at once, and thus connections are kept open. self.store.close() def test_iter(self): self.assertEqual(list(self.result), list(self.empty)) def test_copy(self): self.assertNotEqual(self.result.copy(), self.result) self.assertNotEqual(self.empty.copy(), self.empty) self.assertEqual(list(self.result.copy()), list(self.empty.copy())) def test_config(self): self.result.config(distinct=True, offset=1, limit=1) self.empty.config(distinct=True, offset=1, limit=1) self.assertEqual(list(self.result), list(self.empty)) def test_config_returns_self(self): self.assertIs(self.result, self.result.config()) self.assertIs(self.empty, self.empty.config()) def test_slice(self): self.assertEqual(list(self.result[:]), []) self.assertEqual(list(self.empty[:]), []) def test_contains(self): self.assertEqual(Foo() in self.empty, False) def test_is_empty(self): self.assertEqual(self.result.is_empty(), True) self.assertEqual(self.empty.is_empty(), True) def test_any(self): self.assertEqual(self.result.any(), None) self.assertEqual(self.empty.any(), None) def test_first_unordered(self): self.assertRaises(UnorderedError, self.result.first) self.assertRaises(UnorderedError, self.empty.first) def test_first_ordered(self): self.result.order_by(Foo.title) self.assertEqual(self.result.first(), None) self.empty.order_by(Foo.title) self.assertEqual(self.empty.first(), None) def test_last_unordered(self): self.assertRaises(UnorderedError, self.result.last) self.assertRaises(UnorderedError, self.empty.last) def test_last_ordered(self): self.result.order_by(Foo.title) self.assertEqual(self.result.last(), None) self.empty.order_by(Foo.title) self.assertEqual(self.empty.last(), None) def test_one(self): self.assertEqual(self.result.one(), None) self.assertEqual(self.empty.one(), None) def test_order_by(self): self.assertEqual(self.result.order_by(Foo.title), self.result) self.assertEqual(self.empty.order_by(Foo.title), self.empty) def test_group_by(self): self.assertEqual(self.result.group_by(Foo.title), self.result) self.assertEqual(self.empty.group_by(Foo.title), self.empty) def test_remove(self): self.assertEqual(self.result.remove(), 0) self.assertEqual(self.empty.remove(), 0) def test_count(self): self.assertEqual(self.result.count(), 0) self.assertEqual(self.empty.count(), 0) self.assertEqual(self.empty.count(expr="abc"), 0) self.assertEqual(self.empty.count(distinct=True), 0) def test_max(self): self.assertEqual(self.result.max(Foo.id), None) self.assertEqual(self.empty.max(Foo.id), None) def test_min(self): self.assertEqual(self.result.min(Foo.id), None) self.assertEqual(self.empty.min(Foo.id), None) def test_avg(self): self.assertEqual(self.result.avg(Foo.id), None) self.assertEqual(self.empty.avg(Foo.id), None) def test_sum(self): self.assertEqual(self.result.sum(Foo.id), None) self.assertEqual(self.empty.sum(Foo.id), None) def test_get_select_expr_without_columns(self): """ A L{FeatureError} is raised if L{EmptyResultSet.get_select_expr} is called without a list of L{Column}s. """ self.assertRaises(FeatureError, self.result.get_select_expr) self.assertRaises(FeatureError, self.empty.get_select_expr) def test_get_select_expr_(self): """ A L{FeatureError} is raised if L{EmptyResultSet.get_select_expr} is called without a list of L{Column}s. """ subselect = self.result.get_select_expr(Foo.id) self.assertEqual((Foo.id,), subselect.columns) result = self.store.find(Foo, Foo.id.is_in(subselect)) self.assertEqual(list(result), []) subselect = self.empty.get_select_expr(Foo.id) self.assertEqual((Foo.id,), subselect.columns) result = self.store.find(Foo, Foo.id.is_in(subselect)) self.assertEqual(list(result), []) def test_values_no_columns(self): self.assertRaises(FeatureError, list, self.result.values()) self.assertRaises(FeatureError, list, self.empty.values()) def test_values(self): self.assertEqual(list(self.result.values(Foo.title)), []) self.assertEqual(list(self.empty.values(Foo.title)), []) def test_set_no_args(self): self.assertEqual(self.result.set(), None) self.assertEqual(self.empty.set(), None) def test_cached(self): self.assertEqual(self.result.cached(), []) self.assertEqual(self.empty.cached(), []) def test_find(self): self.assertEqual(list(self.result.find(Foo.title == "foo")), []) self.assertEqual(list(self.empty.find(Foo.title == "foo")), []) def test_union(self): self.assertEqual(self.empty.union(self.empty), self.empty) self.assertEqual(type(self.empty.union(self.result)), type(self.result)) self.assertEqual(type(self.result.union(self.empty)), type(self.result)) def test_difference(self): self.assertEqual(self.empty.difference(self.empty), self.empty) self.assertEqual(self.empty.difference(self.result), self.empty) self.assertEqual(self.result.difference(self.empty), self.result) def test_intersection(self): self.assertEqual(self.empty.intersection(self.empty), self.empty) self.assertEqual(self.empty.intersection(self.result), self.empty) self.assertEqual(self.result.intersection(self.empty), self.empty) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/store/block.py0000644000175000017500000000444314645174376020154 0ustar00cjwatsoncjwatson# Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . from storm.databases.sqlite import SQLite from storm.exceptions import ConnectionBlockedError from storm.store import Store, block_access from storm.tests.helper import TestHelper, MakePath from storm.uri import URI class BlockAccessTest(TestHelper): """Tests for L{block_access}.""" helpers = [MakePath] def setUp(self): super().setUp() database = SQLite(URI("sqlite:")) self.store = Store(database) def test_block_access(self): """ The L{block_access} context manager blocks access to a L{Store}. A L{ConnectionBlockedError} exception is raised if an attempt to access the underlying database is made while a store is blocked. """ with block_access(self.store): self.assertRaises(ConnectionBlockedError, self.store.execute, "SELECT 1") result = self.store.execute("SELECT 1") self.assertEqual([(1,)], list(result)) def test_block_access_with_multiple_stores(self): """ If multiple L{Store}s are passed to L{block_access} they will all be blocked until the managed context is left. """ database = SQLite(URI("sqlite:%s" % self.make_path())) store = Store(database) with block_access(self.store, store): self.assertRaises(ConnectionBlockedError, self.store.execute, "SELECT 1") self.assertRaises(ConnectionBlockedError, store.execute, "SELECT 1") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/store/mysql.py0000644000175000017500000001003614571373456020220 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os from storm.database import create_database from storm.tests.store.base import StoreTest, EmptyResultSetTest from storm.tests.helper import TestHelper class MySQLStoreTest(TestHelper, StoreTest): def setUp(self): TestHelper.setUp(self) StoreTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) StoreTest.tearDown(self) def is_supported(self): return bool(os.environ.get("STORM_MYSQL_URI")) def create_database(self): self.database = create_database(os.environ["STORM_MYSQL_URI"]) def create_tables(self): connection = self.connection connection.execute("CREATE TABLE foo " "(id INT PRIMARY KEY AUTO_INCREMENT," " title VARCHAR(50) DEFAULT 'Default Title') " "ENGINE=InnoDB") connection.execute("CREATE TABLE bar " "(id INT PRIMARY KEY AUTO_INCREMENT," " foo_id INTEGER, title VARCHAR(50)) " "ENGINE=InnoDB") connection.execute("CREATE TABLE bin " "(id INT PRIMARY KEY AUTO_INCREMENT," " bin BLOB, foo_id INTEGER) " "ENGINE=InnoDB") connection.execute("CREATE TABLE link " "(foo_id INTEGER, bar_id INTEGER," " PRIMARY KEY (foo_id, bar_id)) " "ENGINE=InnoDB") connection.execute("CREATE TABLE money " "(id INT PRIMARY KEY AUTO_INCREMENT," " value NUMERIC(6,4)) " "ENGINE=InnoDB") connection.execute("CREATE TABLE selfref " "(id INT PRIMARY KEY AUTO_INCREMENT," " title VARCHAR(50)," " selfref_id INTEGER," " INDEX (selfref_id)," " FOREIGN KEY (selfref_id) REFERENCES selfref(id)) " "ENGINE=InnoDB") connection.execute("CREATE TABLE foovalue " "(id INT PRIMARY KEY AUTO_INCREMENT," " foo_id INTEGER," " value1 INTEGER, value2 INTEGER) " "ENGINE=InnoDB") connection.execute("CREATE TABLE unique_id " "(id VARCHAR(36) PRIMARY KEY) " "ENGINE=InnoDB") connection.commit() class MySQLEmptyResultSetTest(TestHelper, EmptyResultSetTest): def setUp(self): TestHelper.setUp(self) EmptyResultSetTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) EmptyResultSetTest.tearDown(self) def is_supported(self): return bool(os.environ.get("STORM_MYSQL_URI")) def create_database(self): self.database = create_database(os.environ["STORM_MYSQL_URI"]) def create_tables(self): self.connection.execute("CREATE TABLE foo " "(id INT PRIMARY KEY AUTO_INCREMENT," " title VARCHAR(50) DEFAULT 'Default Title') " "ENGINE=InnoDB") self.connection.commit() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/store/postgres.py0000644000175000017500000001705414645174376020732 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os import gc from storm.database import create_database from storm.properties import Enum, Int, List from storm.info import get_obj_info from storm.tests.store.base import StoreTest, EmptyResultSetTest, Foo from storm.tests.helper import TestHelper class Lst1: __storm_table__ = "lst1" id = Int(primary=True) ints = List(type=Int()) class LstEnum: __storm_table__ = "lst1" id = Int(primary=True) ints = List(type=Enum(map={"one": 1, "two": 2, "three": 3})) class Lst2: __storm_table__ = "lst2" id = Int(primary=True) ints = List(type=List(type=Int())) class FooWithSchema(Foo): __storm_table__ = "public.foo" class PostgresStoreTest(TestHelper, StoreTest): def setUp(self): TestHelper.setUp(self) StoreTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) StoreTest.tearDown(self) def is_supported(self): return bool(os.environ.get("STORM_POSTGRES_URI")) def create_database(self): self.database = create_database(os.environ["STORM_POSTGRES_URI"]) def create_tables(self): connection = self.connection connection.execute("CREATE TABLE foo " "(id SERIAL PRIMARY KEY," " title VARCHAR DEFAULT 'Default Title')") # Prevent dynamically created Foos from having conflicting ids. connection.execute("SELECT setval('foo_id_seq', 1000)") connection.execute("CREATE TABLE bar " "(id SERIAL PRIMARY KEY," " foo_id INTEGER, title VARCHAR)") connection.execute("CREATE TABLE bin " "(id SERIAL PRIMARY KEY, bin BYTEA, foo_id INTEGER)") connection.execute("CREATE TABLE link " "(foo_id INTEGER, bar_id INTEGER," " PRIMARY KEY (foo_id, bar_id))") connection.execute("CREATE TABLE money " "(id SERIAL PRIMARY KEY, value NUMERIC(6,4))") connection.execute("CREATE TABLE selfref " "(id SERIAL PRIMARY KEY, title VARCHAR," " selfref_id INTEGER REFERENCES selfref(id))") connection.execute("CREATE TABLE lst1 " "(id SERIAL PRIMARY KEY, ints INTEGER[])") connection.execute("CREATE TABLE lst2 " "(id SERIAL PRIMARY KEY, ints INTEGER[][])") connection.execute("CREATE TABLE foovalue " "(id SERIAL PRIMARY KEY, foo_id INTEGER," " value1 INTEGER, value2 INTEGER)") connection.execute("CREATE TABLE unique_id " "(id UUID PRIMARY KEY)") connection.commit() def drop_tables(self): StoreTest.drop_tables(self) for table in ["lst1", "lst2"]: try: self.connection.execute("DROP TABLE %s" % table) self.connection.commit() except: self.connection.rollback() def test_list_variable(self): lst = Lst1() lst.id = 1 lst.ints = [1,2,3,4] self.store.add(lst) result = self.store.execute("SELECT ints FROM lst1 WHERE id=1") self.assertEqual(result.get_one(), ([1,2,3,4],)) del lst gc.collect() lst = self.store.find(Lst1, Lst1.ints == [1,2,3,4]).one() self.assertTrue(lst) lst.ints.append(5) result = self.store.execute("SELECT ints FROM lst1 WHERE id=1") self.assertEqual(result.get_one(), ([1,2,3,4,5],)) def test_list_enum_variable(self): lst = LstEnum() lst.id = 1 lst.ints = ["one", "two"] self.store.add(lst) result = self.store.execute("SELECT ints FROM lst1 WHERE id=1") self.assertEqual(result.get_one(), ([1,2],)) del lst gc.collect() lst = self.store.find(LstEnum, LstEnum.ints == ["one", "two"]).one() self.assertTrue(lst) lst.ints.append("three") result = self.store.execute("SELECT ints FROM lst1 WHERE id=1") self.assertEqual(result.get_one(), ([1,2,3],)) def test_list_variable_nested(self): lst = Lst2() lst.id = 1 lst.ints = [[1, 2], [3, 4]] self.store.add(lst) result = self.store.execute("SELECT ints FROM lst2 WHERE id=1") self.assertEqual(result.get_one(), ([[1,2],[3,4]],)) del lst gc.collect() lst = self.store.find(Lst2, Lst2.ints == [[1,2],[3,4]]).one() self.assertTrue(lst) lst.ints.append([5, 6]) result = self.store.execute("SELECT ints FROM lst2 WHERE id=1") self.assertEqual(result.get_one(), ([[1,2],[3,4],[5,6]],)) def test_add_find_with_schema(self): foo = FooWithSchema() foo.title = "Title" self.store.add(foo) self.store.flush() # We use find() here to actually exercise the backend code. # get() would just pick the object from the cache. self.assertEqual(self.store.find(FooWithSchema, id=foo.id).one(), foo) def test_wb_currval_based_identity(self): """ Ensure that the currval()-based identity retrieval continues to work, even if we're currently running on a 8.2+ database. """ self.database._version = 80109 foo1 = self.store.add(Foo()) self.store.flush() foo2 = self.store.add(Foo()) self.store.flush() self.assertEqual(foo2.id-foo1.id, 1) def test_list_unnecessary_update(self): """ Flushing an object with a list variable doesn't create an unnecessary UPDATE statement. """ self.store.execute("INSERT INTO lst1 VALUES (1, '{}')", noresult=True) lst = self.store.find(Lst1, id=1).one() self.assertTrue(lst) self.store.invalidate() lst2 = self.store.find(Lst1, id=1).one() self.assertTrue(lst2) obj_info = get_obj_info(lst2) events = [] obj_info.event.hook("changed", lambda *args: events.append(args)) self.store.flush() self.assertEqual(events, []) class PostgresEmptyResultSetTest(TestHelper, EmptyResultSetTest): def setUp(self): TestHelper.setUp(self) EmptyResultSetTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) EmptyResultSetTest.tearDown(self) def is_supported(self): return bool(os.environ.get("STORM_POSTGRES_URI")) def create_database(self): self.database = create_database(os.environ["STORM_POSTGRES_URI"]) def create_tables(self): self.connection.execute("CREATE TABLE foo " "(id SERIAL PRIMARY KEY," " title VARCHAR DEFAULT 'Default Title')") self.connection.commit() ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/store/sqlite.py0000644000175000017500000000671414571373456020364 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.databases.sqlite import SQLite from storm.uri import URI from storm.tests.store.base import StoreTest, EmptyResultSetTest from storm.tests.helper import TestHelper, MakePath class SQLiteStoreTest(TestHelper, StoreTest): helpers = [MakePath] def setUp(self): TestHelper.setUp(self) StoreTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) StoreTest.tearDown(self) def create_database(self): self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % self.make_path())) def create_tables(self): connection = self.connection connection.execute("CREATE TABLE foo " "(id INTEGER PRIMARY KEY," " title VARCHAR DEFAULT 'Default Title')") connection.execute("CREATE TABLE bar " "(id INTEGER PRIMARY KEY," " foo_id INTEGER, title VARCHAR)") connection.execute("CREATE TABLE bin " "(id INTEGER PRIMARY KEY, bin BLOB, foo_id INTEGER)") connection.execute("CREATE TABLE link " "(foo_id INTEGER, bar_id INTEGER)") # We have to use TEXT here, since NUMERIC would cause SQLite # to interpret values as float, and thus lose precision. connection.execute("CREATE TABLE money " "(id INTEGER PRIMARY KEY, value TEXT)") connection.execute("CREATE TABLE selfref " "(id INTEGER PRIMARY KEY, title VARCHAR," " selfref_id INTEGER)") connection.execute("CREATE TABLE foovalue " "(id INTEGER PRIMARY KEY, foo_id INTEGER," " value1 INTEGER, value2 INTEGER)") connection.execute("CREATE TABLE unique_id " "(id VARCHAR PRIMARY KEY)") connection.commit() def drop_tables(self): pass class SQLiteEmptyResultSetTest(TestHelper, EmptyResultSetTest): helpers = [MakePath] def setUp(self): TestHelper.setUp(self) EmptyResultSetTest.setUp(self) def tearDown(self): TestHelper.tearDown(self) EmptyResultSetTest.tearDown(self) def create_database(self): self.database = SQLite(URI("sqlite:%s?synchronous=OFF" % self.make_path())) def create_tables(self): self.connection.execute("CREATE TABLE foo " "(id INTEGER PRIMARY KEY," " title VARCHAR DEFAULT 'Default Title')") self.connection.commit() def drop_tables(self): pass ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/tracer.py0000644000175000017500000005422114645174376017205 0ustar00cjwatsoncjwatsonimport datetime import os import sys from unittest import TestCase from storm.tests import has_fixtures # Optional dependency. If missing, Fixture tests are skipped. if has_fixtures: import fixtures.testcase TestWithFixtures = fixtures.testcase.TestWithFixtures from storm.testing import CaptureTracer else: TestWithFixtures = object try: # Optional dependency, if missing TimelineTracer tests are skipped. import timeline has_timeline = True except ImportError: has_timeline = False from storm.tracer import (trace, install_tracer, get_tracers, remove_tracer, remove_tracer_type, remove_all_tracers, debug, BaseStatementTracer, DebugTracer, TimeoutTracer, TimelineTracer, TimeoutError, _tracers) from storm.database import Connection, create_database from storm.expr import Variable from storm.tests.helper import TestHelper class TracerTest(TestHelper): def tearDown(self): super().tearDown() del _tracers[:] def test_install_tracer(self): c = object() d = object() install_tracer(c) install_tracer(d) self.assertEqual(get_tracers(), [c, d]) def test_remove_all_tracers(self): install_tracer(object()) remove_all_tracers() self.assertEqual(get_tracers(), []) def test_remove_tracer(self): """The C{remote_tracer} function removes a specific tracer.""" tracer1 = object() tracer2 = object() install_tracer(tracer1) install_tracer(tracer2) remove_tracer(tracer1) self.assertEqual(get_tracers(), [tracer2]) def test_remove_tracer_with_not_installed_tracer(self): """C{remote_tracer} exits gracefully if the tracer is not installed.""" tracer = object() remove_tracer(tracer) self.assertEqual(get_tracers(), []) def test_remove_tracer_type(self): class C: pass class D(C): pass c = C() d1 = D() d2 = D() install_tracer(d1) install_tracer(c) install_tracer(d2) remove_tracer_type(C) self.assertEqual(get_tracers(), [d1, d2]) remove_tracer_type(D) self.assertEqual(get_tracers(), []) def test_install_debug(self): debug(True) debug(True) self.assertEqual([type(x) for x in get_tracers()], [DebugTracer]) def test_wb_install_debug_with_custom_stream(self): marker = object() debug(True, marker) [tracer] = get_tracers() self.assertEqual(tracer._stream, marker) def test_remove_debug(self): debug(True) debug(True) debug(False) self.assertEqual(get_tracers(), []) def test_trace(self): stash = [] class Tracer: def m1(_, *args, **kwargs): stash.extend(["m1", args, kwargs]) def m2(_, *args, **kwargs): stash.extend(["m2", args, kwargs]) install_tracer(Tracer()) trace("m1", 1, 2, c=3) trace("m2") trace("m3") self.assertEqual(stash, ["m1", (1, 2), {"c": 3}, "m2", (), {}]) class MockVariable(Variable): def __init__(self, value): self._value = value def get(self, to_db=False): return self._value class DebugTracerTest(TestHelper): def setUp(self): super().setUp() self.stream = self.mocker.mock(type(sys.stderr)) self.tracer = DebugTracer(self.stream) datetime_mock = self.mocker.replace("datetime.datetime") datetime_mock.now() self.mocker.result(datetime.datetime(1, 2, 3, 4, 5, 6, 7)) self.mocker.count(0, 1) self.variable = MockVariable("PARAM") def tearDown(self): del _tracers[:] super().tearDown() def test_wb_debug_tracer_uses_stderr_by_default(self): self.mocker.replay() tracer = DebugTracer() self.assertEqual(tracer._stream, sys.stderr) def test_wb_debug_tracer_uses_first_arg_as_stream(self): self.mocker.replay() marker = object() tracer = DebugTracer(marker) self.assertEqual(tracer._stream, marker) def test_connection_raw_execute(self): self.stream.write( "[04:05:06.000007] EXECUTE: 'STATEMENT', ('PARAM',)\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" raw_cursor = "RAW_CURSOR" statement = "STATEMENT" params = [self.variable] self.tracer.connection_raw_execute(connection, raw_cursor, statement, params) def test_connection_raw_execute_with_non_variable(self): self.stream.write( "[04:05:06.000007] EXECUTE: 'STATEMENT', ('PARAM', 1)\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" raw_cursor = "RAW_CURSOR" statement = "STATEMENT" params = [self.variable, 1] self.tracer.connection_raw_execute(connection, raw_cursor, statement, params) def test_connection_raw_execute_error(self): self.stream.write("[04:05:06.000007] ERROR: ERROR\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" raw_cursor = "RAW_CURSOR" statement = "STATEMENT" params = "PARAMS" error = "ERROR" self.tracer.connection_raw_execute_error(connection, raw_cursor, statement, params, error) def test_connection_raw_execute_success(self): self.stream.write("[04:05:06.000007] DONE\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" raw_cursor = "RAW_CURSOR" statement = "STATEMENT" params = "PARAMS" self.tracer.connection_raw_execute_success(connection, raw_cursor, statement, params) def test_connection_commit(self): self.stream.write("[04:05:06.000007] COMMIT xid=None\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" self.tracer.connection_commit(connection) def test_connection_rollback(self): self.stream.write("[04:05:06.000007] ROLLBACK xid=None\n") self.stream.flush() self.mocker.replay() connection = "CONNECTION" self.tracer.connection_rollback(connection) class TimeoutTracerTestBase(TestHelper): tracer_class = TimeoutTracer def setUp(self): super().setUp() self.tracer = self.tracer_class() self.raw_cursor = self.mocker.mock() self.statement = self.mocker.mock() self.params = self.mocker.mock() # Some data is kept in the connection, so we use a proxy to # allow things we don't care about here to happen. class Connection: pass self.connection = self.mocker.proxy(Connection()) def tearDown(self): super().tearDown() del _tracers[:] def execute(self): self.tracer.connection_raw_execute(self.connection, self.raw_cursor, self.statement, self.params) def execute_raising(self): self.assertRaises(TimeoutError, self.tracer.connection_raw_execute, self.connection, self.raw_cursor, self.statement, self.params) class TimeoutTracerTest(TimeoutTracerTestBase): def test_raise_not_implemented(self): """ L{TimeoutTracer.connection_raw_execute_error}, L{TimeoutTracer.set_statement_timeout} and L{TimeoutTracer.get_remaining_time} must all be implemented by backend-specific subclasses. """ self.assertRaises(NotImplementedError, self.tracer.connection_raw_execute_error, None, None, None, None, None) self.assertRaises(NotImplementedError, self.tracer.set_statement_timeout, None, None) self.assertRaises(NotImplementedError, self.tracer.get_remaining_time) def test_raise_timeout_error_when_no_remaining_time(self): """ A L{TimeoutError} is raised if there isn't any time left when a statement is executed. """ tracer_mock = self.mocker.patch(self.tracer) tracer_mock.get_remaining_time() self.mocker.result(0) self.mocker.replay() try: self.execute() except TimeoutError as e: self.assertEqual("0 seconds remaining in time budget", e.message) self.assertEqual(self.statement, e.statement) self.assertEqual(self.params, e.params) else: self.fail("TimeoutError not raised") def test_raise_timeout_on_granularity(self): tracer_mock = self.mocker.patch(self.tracer) self.mocker.order() tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity) tracer_mock.get_remaining_time() self.mocker.result(0) self.mocker.replay() self.execute() self.execute_raising() def test_wont_raise_timeout_before_granularity(self): tracer_mock = self.mocker.patch(self.tracer) self.mocker.order() tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity) tracer_mock.get_remaining_time() self.mocker.result(1) self.mocker.replay() self.execute() self.execute() def test_always_set_when_remaining_time_increased(self): tracer_mock = self.mocker.patch(self.tracer) self.mocker.order() tracer_mock.get_remaining_time() self.mocker.result(1) tracer_mock.set_statement_timeout(self.raw_cursor, 1) tracer_mock.get_remaining_time() self.mocker.result(2) tracer_mock.set_statement_timeout(self.raw_cursor, 2) self.mocker.replay() self.execute() self.execute() def test_set_again_on_granularity(self): tracer_mock = self.mocker.patch(self.tracer) self.mocker.order() tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity * 2) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity * 2) tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity) self.mocker.replay() self.execute() self.execute() def test_set_again_after_granularity(self): tracer_mock = self.mocker.patch(self.tracer) self.mocker.order() tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity * 2) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity * 2) tracer_mock.get_remaining_time() self.mocker.result(self.tracer.granularity - 1) tracer_mock.set_statement_timeout(self.raw_cursor, self.tracer.granularity - 1) self.mocker.replay() self.execute() self.execute() class TimeoutTracerWithDBTest(TestHelper): def setUp(self): super().setUp() self.tracer = StuckInTimeTimeoutTracer(10) install_tracer(self.tracer) database = create_database(os.environ["STORM_POSTGRES_URI"]) self.connection = database.connect() def tearDown(self): super().tearDown() remove_tracer(self.tracer) self.connection.close() def is_supported(self): return bool(os.environ.get("STORM_POSTGRES_URI")) def test_timeout_set_on_beginning_of_new_transaction__commit(self): """Check that we set the statement timeout before the first query of a transaction regardless of the remaining time left by previous transactions. When we reuse a connection for a different transaction, the remaining time of a previous transaction (which is stored in the connection) could cause the first query in that transaction to run with no timeout. This test makes sure that doesn't happen. """ self.connection.execute('SELECT 1') self.assertEqual([10], self.tracer.set_statement_timeout_calls) self.connection.commit() self.connection.execute('SELECT 1') self.assertEqual([10, 10], self.tracer.set_statement_timeout_calls) def test_timeout_set_on_beginning_of_new_transaction__rollback(self): """Same as the test above, but here we rollback the first tx.""" self.connection.execute('SELECT 1') self.assertEqual([10], self.tracer.set_statement_timeout_calls) self.connection.rollback() self.connection.execute('SELECT 1') self.assertEqual([10, 10], self.tracer.set_statement_timeout_calls) class StuckInTimeTimeoutTracer(TimeoutTracer): def __init__(self, fixed_remaining_time): super().__init__() self.set_statement_timeout_calls = [] self.fixed_remaining_time = fixed_remaining_time def get_remaining_time(self): return self.fixed_remaining_time def set_statement_timeout(self, raw_cursor, remaining_time): self.set_statement_timeout_calls.append(remaining_time) class StubConnection(Connection): def __init__(self): self._database = None self._event = None self._raw_connection = None self.name = 'Foo' class BaseStatementTracerTest(TestCase): class LoggingBaseStatementTracer(BaseStatementTracer): def _expanded_raw_execute(self, connection, raw_cursor, statement): self.__dict__.setdefault('calls', []).append( (connection, raw_cursor, statement)) def test_no_params(self): """With no parameters the statement is passed through verbatim.""" tracer = self.LoggingBaseStatementTracer() tracer.connection_raw_execute('foo', 'bar', 'baz ? %s', ()) self.assertEqual([('foo', 'bar', 'baz ? %s')], tracer.calls) def test_params_substituted_pyformat(self): tracer = self.LoggingBaseStatementTracer() conn = StubConnection() conn.param_mark = '%s' var1 = MockVariable('VAR1') tracer.connection_raw_execute( conn, 'cursor', 'SELECT * FROM person where name = %s', [var1]) self.assertEqual( [(conn, 'cursor', "SELECT * FROM person where name = 'VAR1'")], tracer.calls) def test_params_substituted_single_string(self): """String parameters are formatted as a single quoted string.""" tracer = self.LoggingBaseStatementTracer() conn = StubConnection() var1 = MockVariable('VAR1') tracer.connection_raw_execute( conn, 'cursor', 'SELECT * FROM person where name = ?', [var1]) self.assertEqual( [(conn, 'cursor', "SELECT * FROM person where name = 'VAR1'")], tracer.calls) def test_qmark_percent_s_literal_preserved(self): """With ? parameters %s in the statement can be kept intact.""" tracer = self.LoggingBaseStatementTracer() conn = StubConnection() var1 = MockVariable(1) tracer.connection_raw_execute( conn, 'cursor', "SELECT * FROM person where id > ? AND name LIKE '%s'", [var1]) self.assertEqual( [(conn, 'cursor', "SELECT * FROM person where id > 1 AND name LIKE '%s'")], tracer.calls) def test_int_variable_as_int(self): """Int parameters are formatted as an int literal.""" tracer = self.LoggingBaseStatementTracer() conn = StubConnection() var1 = MockVariable(1) tracer.connection_raw_execute( conn, 'cursor', "SELECT * FROM person where id = ?", [var1]) self.assertEqual( [(conn, 'cursor', "SELECT * FROM person where id = 1")], tracer.calls) def test_like_clause_preserved(self): """% operators in LIKE statements are preserved.""" tracer = self.LoggingBaseStatementTracer() conn = StubConnection() var1 = MockVariable('substring') tracer.connection_raw_execute( conn, 'cursor', "SELECT * FROM person WHERE name LIKE '%%' || ? || '-suffix%%'", [var1]) self.assertEqual( [(conn, 'cursor', "SELECT * FROM person WHERE name " "LIKE '%%' || 'substring' || '-suffix%%'")], tracer.calls) def test_unformattable_statements_are_handled(self): tracer = self.LoggingBaseStatementTracer() conn = StubConnection() var1 = MockVariable('substring') tracer.connection_raw_execute( conn, 'cursor', "%s %s", [var1]) self.assertEqual( [(conn, 'cursor', "Unformattable query: '%%s %%s' with params [%r]." % 'substring')], tracer.calls) class TimelineTracerTest(TestHelper): def is_supported(self): return has_timeline def factory(self): self.timeline = timeline.Timeline() return self.timeline def test_separate_tracers_own_state(self): """Check that multiple TimelineTracer's could be used at once.""" tracer1 = TimelineTracer(self.factory) tracer2 = TimelineTracer(self.factory) tracer1.threadinfo.action = 'foo' self.assertEqual(None, getattr(tracer2.threadinfo, 'action', None)) def test_error_finishes_action(self): tracer = TimelineTracer(self.factory) action = timeline.Timeline().start('foo', 'bar') tracer.threadinfo.action = action tracer.connection_raw_execute_error( 'conn', 'cursor', 'statement', 'params', 'error') self.assertNotEqual(None, action.duration) def test_success_finishes_action(self): tracer = TimelineTracer(self.factory) action = timeline.Timeline().start('foo', 'bar') tracer.threadinfo.action = action tracer.connection_raw_execute_success( 'conn', 'cursor', 'statement', 'params') self.assertNotEqual(None, action.duration) def test_finds_timeline_from_factory(self): factory_result = timeline.Timeline() tracer = TimelineTracer(lambda: factory_result) tracer._expanded_raw_execute('conn', 'cursor', 'statement') self.assertEqual(1, len(factory_result.actions)) def test_action_details_are_statement(self): """The detail in the timeline action is the formatted SQL statement.""" tracer = TimelineTracer(self.factory) tracer._expanded_raw_execute('conn', 'cursor', 'statement') self.assertEqual('statement', self.timeline.actions[-1].detail) def test_category_from_prefix_and_connection_name(self): tracer = TimelineTracer(self.factory, prefix='bar-') tracer._expanded_raw_execute(StubConnection(), 'cursor', 'statement') self.assertEqual('bar-Foo', self.timeline.actions[-1].category) def test_unnamed_connection(self): """A connection with no name has put in as a placeholder.""" tracer = TimelineTracer(self.factory, prefix='bar-') tracer._expanded_raw_execute('conn', 'cursor', 'statement') self.assertEqual('bar-', self.timeline.actions[-1].category) def test_default_prefix(self): """By default the prefix "SQL-" is added to the action's category.""" tracer = TimelineTracer(self.factory) tracer._expanded_raw_execute('conn', 'cursor', 'statement') self.assertEqual('SQL-', self.timeline.actions[-1].category) class CaptureTracerTest(TestHelper, TestWithFixtures): def is_supported(self): return has_fixtures def tearDown(self): super().tearDown() del _tracers[:] def test_capture(self): """ Using the L{CaptureTracer} fixture starts capturing queries and stops removes the tracer upon cleanup. """ tracer = self.useFixture(CaptureTracer()) self.assertEqual([tracer], get_tracers()) conn = StubConnection() conn.param_mark = '%s' var = MockVariable("var") tracer.connection_raw_execute(conn, "cursor", "select %s", [var]) self.assertEqual(["select 'var'"], tracer.queries) def check(): self.assertEqual([], get_tracers()) self.addCleanup(check) def test_capture_as_context_manager(self): """{CaptureTracer}s can be used as context managers.""" conn = StubConnection() with CaptureTracer() as tracer: self.assertEqual([tracer], get_tracers()) tracer.connection_raw_execute(conn, "cursor", "select", []) self.assertEqual([], get_tracers()) self.assertEqual(["select"], tracer.queries) def test_capture_multiple(self): """L{CaptureTracer}s can be used as nested context managers.""" conn = StubConnection() def trace(statement): for tracer in get_tracers(): tracer.connection_raw_execute(conn, "cursor", statement, []) with CaptureTracer() as tracer1: trace("one") with CaptureTracer() as tracer2: trace("two") trace("three") self.assertEqual([], get_tracers()) self.assertEqual(["one", "two", "three"], tracer1.queries) self.assertEqual(["two"], tracer2.queries) def test_capture_with_exception(self): """ L{CaptureTracer}s re-raise any error when used as context managers. """ errors = [] try: with CaptureTracer(): raise RuntimeError("boom") except RuntimeError as error: errors.append(error) [error] = errors self.assertEqual("boom", str(error)) self.assertEqual([], get_tracers()) ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/tests/twisted/0000755000175000017500000000000014645532536017026 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/twisted/__init__.py0000644000175000017500000000020714571373456021140 0ustar00cjwatsoncjwatson__all__ = [ 'has_twisted', ] try: import twisted except ImportError: has_twisted = False else: has_twisted = True ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/twisted/transact.py0000644000175000017500000003014214571373456021221 0ustar00cjwatsoncjwatsonfrom storm.tests import has_psycopg from storm.tests.helper import TestHelper from storm.tests.zope import has_transaction, has_zope_component from storm.tests.twisted import has_twisted if has_transaction and has_zope_component and has_twisted: import transaction from twisted.trial.unittest import TestCase from zope.component import getUtility from storm.zope.interfaces import IZStorm from storm.exceptions import IntegrityError, DisconnectionError from storm.twisted.transact import Transactor, transact from storm.twisted.testing import FakeThreadPool else: # We can't use trial's TestCase as base TestCase = TestHelper TestHelper = object if has_psycopg: from psycopg2.extensions import TransactionRollbackError class TransactorTest(TestCase, TestHelper): def is_supported(self): return has_transaction and has_zope_component and has_twisted def setUp(self): TestCase.setUp(self) TestHelper.setUp(self) self.threadpool = FakeThreadPool() self.transaction = self.mocker.mock() self.transactor = Transactor(self.threadpool, self.transaction) self.function = self.mocker.mock() def test_run(self): """ L{Transactor.run} executes a function in a thread, commits the transaction and returns a deferred firing the function result. """ self.mocker.order() self.expect(self.function(1, arg=2)).result(3) self.expect(self.transaction.commit()) self.mocker.replay() deferred = self.transactor.run(self.function, 1, arg=2) deferred.addCallback(self.assertEqual, 3) return deferred def test_run_with_function_failure(self): """ If the given function raises an error, then L{Transactor.run} aborts the transaction and re-raises the same error. """ self.mocker.order() self.expect(self.function()).throw(ZeroDivisionError()) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, ZeroDivisionError) return deferred def test_run_with_disconnection_error(self): """ If the given function raises a L{DisconnectionError}, then a C{SELECT 1} will be executed in each registered store such that C{psycopg} actually detects the disconnection. """ self.transactor.retries = 0 self.mocker.order() zstorm = self.mocker.mock() store1 = self.mocker.mock() store2 = self.mocker.mock() gu = self.mocker.replace(getUtility) self.expect(self.function()).throw(DisconnectionError()) self.expect(gu(IZStorm)).result(zstorm) self.expect(zstorm.iterstores()).result(iter((("store1", store1), ("store2", store2)))) self.expect(store1.execute("SELECT 1")) self.expect(store2.execute("SELECT 1")) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, DisconnectionError) return deferred def test_run_with_disconnection_error_in_execute_is_ignored(self): """ If the given function raises a L{DisconnectionError}, then a C{SELECT 1} will be executed in each registered store such that C{psycopg} actually detects the disconnection. If another L{DisconnectionError} happens during C{execute}, then it is ignored. """ self.transactor.retries = 0 zstorm = self.mocker.mock() store1 = self.mocker.mock() store2 = self.mocker.mock() gu = self.mocker.replace(getUtility) self.mocker.order() self.expect(self.function()).throw(DisconnectionError()) self.expect(gu(IZStorm)).result(zstorm) self.expect(zstorm.iterstores()).result(iter((("store1", store1), ("store2", store2)))) self.expect(store1.execute("SELECT 1")).throw(DisconnectionError()) self.expect(store2.execute("SELECT 1")) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, DisconnectionError) return deferred def test_run_with_commit_failure(self): """ If the given function succeeds but the transaction fails to commit, then L{Transactor.run} aborts the transaction and re-raises the commit exception. """ self.mocker.order() self.expect(self.function()) self.expect(self.transaction.commit()).throw(ZeroDivisionError()) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, ZeroDivisionError) return deferred def test_wb_default_transaction(self): """ By default L{Transact} uses the global transaction manager. """ transactor = Transactor(self.threadpool) self.assertIdentical(transaction, transactor._transaction) def test_decorate(self): """ A L{transact} decorator can be used with methods of an object that contains a L{Transactor} instance as a C{transactor} instance variable, ensuring that the decorated function is called via L{Transactor.run}. """ self.mocker.order() self.expect(self.transaction.commit()) self.mocker.replay() @transact def function(self): """docstring""" return "result" # Function metadata is copied to the wrapper. self.assertEqual("docstring", function.__doc__) deferred = function(self) deferred.addCallback(self.assertEqual, "result") return deferred def test_run_with_integrity_error_retries(self): """ If the given function raises a L{IntegrityError}, then the function will be retried another two times before letting the exception bubble up. """ self.transactor.sleep = self.mocker.mock() self.transactor.uniform = self.mocker.mock() self.mocker.order() self.expect(self.function()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 1)).result(1) self.expect(self.transactor.sleep(1)) self.expect(self.function()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 2)).result(2) self.expect(self.transactor.sleep(2)) self.expect(self.function()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, IntegrityError) return deferred def test_run_with_transaction_rollback_error_retries(self): """ If the given function raises a L{TransactionRollbackError}, then the function will be retried another two times before letting the exception bubble up. """ if not has_psycopg: return self.transactor.sleep = self.mocker.mock() self.transactor.uniform = self.mocker.mock() self.mocker.order() self.expect(self.function()).throw(TransactionRollbackError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 1)).result(1) self.expect(self.transactor.sleep(1)) self.expect(self.function()).throw(TransactionRollbackError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 2)).result(2) self.expect(self.transactor.sleep(2)) self.expect(self.function()).throw(TransactionRollbackError()) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, TransactionRollbackError) return deferred def test_run_with_disconnection_error_retries(self): """ If the given function raises a L{DisconnectionError}, then the function will be retried another two times before letting the exception bubble up. """ zstorm = self.mocker.mock() gu = self.mocker.replace(getUtility) self.transactor.sleep = self.mocker.mock() self.transactor.uniform = self.mocker.mock() self.mocker.order() self.expect(self.function()).throw(DisconnectionError()) self.expect(gu(IZStorm)).result(zstorm) self.expect(zstorm.iterstores()).result(iter(())) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 1)).result(1) self.expect(self.transactor.sleep(1)) self.expect(self.function()).throw(DisconnectionError()) self.expect(gu(IZStorm)).result(zstorm) self.expect(zstorm.iterstores()).result(iter(())) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 2)).result(2) self.expect(self.transactor.sleep(2)) self.expect(self.function()).throw(DisconnectionError()) self.expect(gu(IZStorm)).result(zstorm) self.expect(zstorm.iterstores()).result(iter(())) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, DisconnectionError) return deferred def test_run_with_integrity_error_on_commit_retries(self): """ If the given function raises a L{IntegrityError}, then the function will be retried another two times before letting the exception bubble up. """ self.transactor.sleep = self.mocker.mock() self.transactor.uniform = self.mocker.mock() self.mocker.order() self.expect(self.function()) self.expect(self.transaction.commit()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 1)).result(1) self.expect(self.transactor.sleep(1)) self.expect(self.function()) self.expect(self.transaction.commit()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 2)).result(2) self.expect(self.transactor.sleep(2)) self.expect(self.function()) self.expect(self.transaction.commit()).throw(IntegrityError()) self.expect(self.transaction.abort()) self.mocker.replay() deferred = self.transactor.run(self.function) self.assertFailure(deferred, IntegrityError) return deferred def test_run_with_on_retry_callback(self): """ If a retry callback is passed with the C{on_retry} parameter, then it's invoked with the number of retries performed so far. """ calls = [] def on_retry(context): calls.append(context) self.transactor.on_retry = on_retry self.transactor.sleep = self.mocker.mock() self.transactor.uniform = self.mocker.mock() self.mocker.order() self.expect(self.function(1, a=2)) error = IntegrityError() self.expect(self.transaction.commit()).throw(error) self.expect(self.transaction.abort()) self.expect(self.transactor.uniform(1, 2 ** 1)).result(1) self.expect(self.transactor.sleep(1)) self.expect(self.function(1, a=2)) self.expect(self.transaction.commit()) self.mocker.replay() deferred = self.transactor.run(self.function, 1, a=2) def check(_): [context] = calls self.assertEqual(self.function, context.function) self.assertEqual((1,), context.args) self.assertEqual({"a": 2}, context.kwargs) self.assertEqual(1, context.retry) self.assertEqual(1, context.retry) self.assertIs(error, context.error) return deferred.addCallback(check) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/uri.py0000644000175000017500000001465614571373456016532 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.uri import URI, URIError from storm.tests.helper import TestHelper class URITest(TestHelper): def test_parse_defaults(self): uri = URI("scheme:") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.options, {}) self.assertEqual(uri.username, None) self.assertEqual(uri.password, None) self.assertEqual(uri.host, None) self.assertEqual(uri.port, None) self.assertEqual(uri.database, None) def test_parse_no_colon(self): self.assertRaises(URIError, URI, "scheme") def test_parse_just_colon(self): uri = URI("scheme:") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.database, None) def test_parse_just_relative_database(self): uri = URI("scheme:d%61ta/base") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.database, "data/base") def test_parse_just_absolute_database(self): uri = URI("scheme:/d%61ta/base") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.database, "/data/base") def test_parse_host(self): uri = URI("scheme://ho%73t") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.host, "host") def test_parse_username(self): uri = URI("scheme://user%6eame@") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.host, None) def test_parse_username_password(self): uri = URI("scheme://user%6eame:pass%77ord@") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, None) def test_parse_username_host(self): uri = URI("scheme://user%6eame@ho%73t") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.host, "host") def test_parse_username_password_host(self): uri = URI("scheme://user%6eame:pass%77ord@ho%73t") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, "host") def test_parse_username_password_host_port(self): uri = URI("scheme://user%6eame:pass%77ord@ho%73t:1234") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, "host") self.assertEqual(uri.port, 1234) def test_parse_username_password_host_empty_port(self): uri = URI("scheme://user%6eame:pass%77ord@ho%73t:") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, "host") self.assertEqual(uri.port, None) def test_parse_username_password_host_port_database(self): uri = URI("scheme://user%6eame:pass%77ord@ho%73t:1234/d%61tabase") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, "host") self.assertEqual(uri.port, 1234) self.assertEqual(uri.database, "database") def test_parse_username_password_database(self): uri = URI("scheme://user%6eame:pass%77ord@/d%61tabase") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.username, "username") self.assertEqual(uri.password, "password") self.assertEqual(uri.host, None) self.assertEqual(uri.port, None) self.assertEqual(uri.database, "database") def test_parse_options(self): uri = URI("scheme:?a%62c=d%65f&ghi=jkl") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.host, None) self.assertEqual(uri.database, None) self.assertEqual(uri.options, {"abc": "def", "ghi": "jkl"}) def test_parse_host_options(self): uri = URI("scheme://ho%73t?a%62c=d%65f&ghi=jkl") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.host, "host") self.assertEqual(uri.database, None) self.assertEqual(uri.options, {"abc": "def", "ghi": "jkl"}) def test_parse_host_database_options(self): uri = URI("scheme://ho%73t/d%61tabase?a%62c=d%65f&ghi=jkl") self.assertEqual(uri.scheme, "scheme") self.assertEqual(uri.host, "host") self.assertEqual(uri.database, "database") self.assertEqual(uri.options, {"abc": "def", "ghi": "jkl"}) def test_copy(self): uri = URI("scheme:///db?opt=value") uri_copy = uri.copy() self.assertTrue(uri_copy is not uri) self.assertTrue(uri_copy.__dict__ == uri.__dict__) self.assertTrue(uri_copy.options is not uri.options) def str(self, uri): self.assertEqual(str(URI(uri)), uri) def test_str_full_with_escaping(self): self.str("scheme://us%2Fer:pa%2Fss@ho%2Fst:0/d%3Fb?a%2Fb=c%2Fd&ghi=jkl") def test_str_no_path_escaping(self): self.str("scheme:/a/b/c") def test_str_scheme_only(self): self.str("scheme:") def test_str_username_only(self): self.str("scheme://username@/") def test_str_password_only(self): self.str("scheme://:password@/") def test_str_port_only(self): self.str("scheme://:0/") def test_str_host_only(self): self.str("scheme://host/") def test_str_database_only(self): self.str("scheme:db") def test_str_option_only(self): self.str("scheme:?a=b") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/variables.py0000644000175000017500000010454314645174376017700 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta from decimal import Decimal import gc import json import pickle import weakref import uuid from storm.exceptions import NoneError from storm.variables import * from storm.event import EventSystem from storm.expr import Column, SQLToken from storm.tz import tzutc, tzoffset from storm import Undef from storm.tests.helper import TestHelper class Marker: pass marker = Marker() class CustomVariable(Variable): def __init__(self, *args, **kwargs): self.gets = [] self.sets = [] Variable.__init__(self, *args, **kwargs) def parse_get(self, variable, to_db): self.gets.append((variable, to_db)) return "g", variable def parse_set(self, variable, from_db): self.sets.append((variable, from_db)) return "s", variable class VariableTest(TestHelper): def test_constructor_value(self): variable = CustomVariable(marker) self.assertEqual(variable.sets, [(marker, False)]) def test_constructor_value_from_db(self): variable = CustomVariable(marker, from_db=True) self.assertEqual(variable.sets, [(marker, True)]) def test_constructor_value_factory(self): variable = CustomVariable(value_factory=lambda:marker) self.assertEqual(variable.sets, [(marker, False)]) def test_constructor_value_factory_from_db(self): variable = CustomVariable(value_factory=lambda:marker, from_db=True) self.assertEqual(variable.sets, [(marker, True)]) def test_constructor_column(self): variable = CustomVariable(column=marker) self.assertEqual(variable.column, marker) def test_constructor_event(self): variable = CustomVariable(event=marker) self.assertEqual(variable.event, weakref.proxy(marker)) def test_get_default(self): variable = CustomVariable() self.assertEqual(variable.get(default=marker), marker) def test_set(self): variable = CustomVariable() variable.set(marker) self.assertEqual(variable.sets, [(marker, False)]) variable.set(marker, from_db=True) self.assertEqual(variable.sets, [(marker, False), (marker, True)]) def test_set_leak(self): """When a variable is checkpointed, the value must not leak.""" variable = Variable() m = Marker() m_ref = weakref.ref(m) variable.set(m) variable.checkpoint() variable.set(LazyValue()) del m gc.collect() self.assertIdentical(m_ref(), None) def test_get(self): variable = CustomVariable() variable.set(marker) self.assertEqual(variable.get(), ("g", ("s", marker))) self.assertEqual(variable.gets, [(("s", marker), False)]) variable = CustomVariable() variable.set(marker) self.assertEqual(variable.get(to_db=True), ("g", ("s", marker))) self.assertEqual(variable.gets, [(("s", marker), True)]) def test_is_defined(self): variable = CustomVariable() self.assertFalse(variable.is_defined()) variable.set(marker) self.assertTrue(variable.is_defined()) def test_set_get_none(self): variable = CustomVariable() variable.set(None) self.assertEqual(variable.get(marker), None) self.assertEqual(variable.sets, []) self.assertEqual(variable.gets, []) def test_set_none_with_allow_none(self): variable = CustomVariable(allow_none=False) self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a value$", variable.set, None) def test_set_none_with_allow_none_and_column(self): column = Column("column_name") variable = CustomVariable(allow_none=False, column=column) self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a value for column_name$", variable.set, None) def test_set_none_with_allow_none_and_column_with_table(self): column = Column("column_name", SQLToken("table_name")) variable = CustomVariable(allow_none=False, column=column) self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a value for table_name.column_name$", variable.set, None) def test_set_default_none_with_allow_none(self): self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a default value$", CustomVariable, allow_none=False, value=None) def test_set_default_none_with_allow_none_and_column(self): column = Column("column_name") self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a default value for column_name$", CustomVariable, allow_none=False, value=None, column=column) def test_set_default_none_with_allow_none_and_column_with_table(self): column = Column("column_name", SQLToken("table_name")) self.assertRaisesRegex( NoneError, r"^None isn't acceptable as a default value for " r"table_name.column_name$", CustomVariable, allow_none=False, value=None, column=column) def test_set_with_validator(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) return value variable = CustomVariable(validator=validator) variable.set(3) self.assertEqual(args, [(None, None, 3)]) def test_set_with_validator_and_validator_arguments(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) return value variable = CustomVariable(validator=validator, validator_object_factory=lambda: 1, validator_attribute=2) variable.set(3) self.assertEqual(args, [(1, 2, 3)]) def test_set_with_validator_raising_error(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) raise ZeroDivisionError() variable = CustomVariable(validator=validator) self.assertRaises(ZeroDivisionError, variable.set, marker) self.assertEqual(args, [(None, None, marker)]) self.assertEqual(variable.get(), None) def test_set_with_validator_changing_value(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) return 42 variable = CustomVariable(validator=validator) variable.set(marker) self.assertEqual(args, [(None, None, marker)]) self.assertEqual(variable.get(), ('g', ('s', 42))) def test_set_from_db_wont_call_validator(self): args = [] def validator(obj, attr, value): args.append((obj, attr, value)) return 42 variable = CustomVariable(validator=validator) variable.set(marker, from_db=True) self.assertEqual(args, []) self.assertEqual(variable.get(), ('g', ('s', marker))) def test_event_changed(self): event = EventSystem(marker) changed_values = [] def changed(owner, variable, old_value, new_value, fromdb): changed_values.append((owner, variable, old_value, new_value, fromdb)) event.hook("changed", changed) variable = CustomVariable(event=event) variable.set("value1") variable.set("value2") variable.set("value3", from_db=True) variable.set(None, from_db=True) variable.set("value4") variable.delete() variable.delete() self.assertEqual(changed_values[0], (marker, variable, Undef, "value1", False)) self.assertEqual(changed_values[1], (marker, variable, ("g", ("s", "value1")), "value2", False)) self.assertEqual(changed_values[2], (marker, variable, ("g", ("s", "value2")), ("g", ("s", "value3")), True)) self.assertEqual(changed_values[3], (marker, variable, ("g", ("s", "value3")), None, True)) self.assertEqual(changed_values[4], (marker, variable, None, "value4", False)) self.assertEqual(changed_values[5], (marker, variable, ("g", ("s", "value4")), Undef, False)) self.assertEqual(len(changed_values), 6) def test_get_state(self): variable = CustomVariable(marker) self.assertEqual(variable.get_state(), (Undef, ("s", marker))) def test_set_state(self): lazy_value = object() variable = CustomVariable() variable.set_state((lazy_value, marker)) self.assertEqual(variable.get(), ("g", marker)) self.assertEqual(variable.get_lazy(), lazy_value) def test_checkpoint_and_has_changed(self): variable = CustomVariable() self.assertTrue(variable.has_changed()) variable.set(marker) self.assertTrue(variable.has_changed()) variable.checkpoint() self.assertFalse(variable.has_changed()) variable.set(marker) self.assertFalse(variable.has_changed()) variable.set((marker, marker)) self.assertTrue(variable.has_changed()) variable.checkpoint() self.assertFalse(variable.has_changed()) variable.set((marker, marker)) self.assertFalse(variable.has_changed()) variable.set(marker) self.assertTrue(variable.has_changed()) variable.set((marker, marker)) self.assertFalse(variable.has_changed()) def test_copy(self): variable = CustomVariable() variable.set(marker) variable_copy = variable.copy() variable_copy.gets = [] self.assertTrue(variable is not variable_copy) self.assertVariablesEqual([variable], [variable_copy]) def test_lazy_value_setting(self): variable = CustomVariable() variable.set(LazyValue()) self.assertEqual(variable.sets, []) self.assertTrue(variable.has_changed()) def test_lazy_value_getting(self): variable = CustomVariable() variable.set(LazyValue()) self.assertEqual(variable.get(marker), marker) variable.set(1) variable.set(LazyValue()) self.assertEqual(variable.get(marker), marker) self.assertFalse(variable.is_defined()) def test_lazy_value_resolving(self): event = EventSystem(marker) resolve_values = [] def resolve(owner, variable, value): resolve_values.append((owner, variable, value)) lazy_value = LazyValue() variable = CustomVariable(lazy_value, event=event) event.hook("resolve-lazy-value", resolve) variable.get() self.assertEqual(resolve_values, [(marker, variable, lazy_value)]) def test_lazy_value_changed_event(self): event = EventSystem(marker) changed_values = [] def changed(owner, variable, old_value, new_value, fromdb): changed_values.append((owner, variable, old_value, new_value, fromdb)) event.hook("changed", changed) variable = CustomVariable(event=event) lazy_value = LazyValue() variable.set(lazy_value) self.assertEqual(changed_values, [(marker, variable, Undef, lazy_value, False)]) def test_lazy_value_setting_on_resolving(self): event = EventSystem(marker) def resolve(owner, variable, value): variable.set(marker) event.hook("resolve-lazy-value", resolve) lazy_value = LazyValue() variable = CustomVariable(lazy_value, event=event) self.assertEqual(variable.get(), ("g", ("s", marker))) def test_lazy_value_reset_after_changed(self): event = EventSystem(marker) resolve_called = [] def resolve(owner, variable, value): resolve_called.append(True) event.hook("resolve-lazy-value", resolve) variable = CustomVariable(event=event) variable.set(LazyValue()) variable.set(1) self.assertEqual(variable.get(), ("g", ("s", 1))) self.assertEqual(resolve_called, []) def test_get_lazy_value(self): lazy_value = LazyValue() variable = CustomVariable() self.assertEqual(variable.get_lazy(), None) self.assertEqual(variable.get_lazy(marker), marker) variable.set(lazy_value) self.assertEqual(variable.get_lazy(marker), lazy_value) class BoolVariableTest(TestHelper): def test_set_get(self): variable = BoolVariable() variable.set(1) self.assertTrue(variable.get() is True) variable.set(0) self.assertTrue(variable.get() is False) variable.set(1.1) self.assertTrue(variable.get() is True) variable.set(0.0) self.assertTrue(variable.get() is False) variable.set(Decimal(1)) self.assertTrue(variable.get() is True) variable.set(Decimal(0)) self.assertTrue(variable.get() is False) self.assertRaises(TypeError, variable.set, "string") class IntVariableTest(TestHelper): def test_set_get(self): variable = IntVariable() variable.set(1) self.assertEqual(variable.get(), 1) variable.set(1.1) self.assertEqual(variable.get(), 1) variable.set(Decimal(2)) self.assertEqual(variable.get(), 2) self.assertRaises(TypeError, variable.set, "1") class FloatVariableTest(TestHelper): def test_set_get(self): variable = FloatVariable() variable.set(1.1) self.assertEqual(variable.get(), 1.1) variable.set(1) self.assertEqual(variable.get(), 1) self.assertEqual(type(variable.get()), float) variable.set(Decimal("1.1")) self.assertEqual(variable.get(), 1.1) self.assertRaises(TypeError, variable.set, "1") class DecimalVariableTest(TestHelper): def test_set_get(self): variable = DecimalVariable() variable.set(Decimal("1.1")) self.assertEqual(variable.get(), Decimal("1.1")) variable.set(1) self.assertEqual(variable.get(), 1) self.assertEqual(type(variable.get()), Decimal) variable.set(Decimal("1.1")) self.assertEqual(variable.get(), Decimal("1.1")) self.assertRaises(TypeError, variable.set, "1") self.assertRaises(TypeError, variable.set, 1.1) def test_get_set_from_database(self): """Strings used to/from the database.""" variable = DecimalVariable() variable.set("1.1", from_db=True) self.assertEqual(variable.get(), Decimal("1.1")) self.assertEqual(variable.get(to_db=True), "1.1") class BytesVariableTest(TestHelper): def test_set_get(self): variable = BytesVariable() variable.set(b"str") self.assertEqual(variable.get(), b"str") variable.set(memoryview(b"buffer")) self.assertEqual(variable.get(), b"buffer") self.assertRaises(TypeError, variable.set, "unicode") class UnicodeVariableTest(TestHelper): def test_set_get(self): variable = UnicodeVariable() variable.set("unicode") self.assertEqual(variable.get(), "unicode") self.assertRaises(TypeError, variable.set, b"str") class DateTimeVariableTest(TestHelper): def test_get_set(self): epoch = datetime.utcfromtimestamp(0) variable = DateTimeVariable() variable.set(0) self.assertEqual(variable.get(), epoch) variable.set(0.0) self.assertEqual(variable.get(), epoch) variable.set(epoch) self.assertEqual(variable.get(), epoch) self.assertRaises(TypeError, variable.set, marker) def test_get_set_from_database(self): datetime_str = "1977-05-04 12:34:56.78" datetime_uni = str(datetime_str) datetime_obj = datetime(1977, 5, 4, 12, 34, 56, 780000) variable = DateTimeVariable() variable.set(datetime_str, from_db=True) self.assertEqual(variable.get(), datetime_obj) variable.set(datetime_uni, from_db=True) self.assertEqual(variable.get(), datetime_obj) variable.set(datetime_obj, from_db=True) self.assertEqual(variable.get(), datetime_obj) datetime_str = "1977-05-04 12:34:56" datetime_uni = str(datetime_str) datetime_obj = datetime(1977, 5, 4, 12, 34, 56) variable.set(datetime_str, from_db=True) self.assertEqual(variable.get(), datetime_obj) variable.set(datetime_uni, from_db=True) self.assertEqual(variable.get(), datetime_obj) variable.set(datetime_obj, from_db=True) self.assertEqual(variable.get(), datetime_obj) self.assertRaises(TypeError, variable.set, 0, from_db=True) self.assertRaises(TypeError, variable.set, marker, from_db=True) self.assertRaises(ValueError, variable.set, "foobar", from_db=True) self.assertRaises(ValueError, variable.set, "foo bar", from_db=True) def test_get_set_with_tzinfo(self): datetime_str = "1977-05-04 12:34:56.78" datetime_obj = datetime(1977, 5, 4, 12, 34, 56, 780000, tzinfo=tzutc()) variable = DateTimeVariable(tzinfo=tzutc()) # Naive timezone, from_db=True. variable.set(datetime_str, from_db=True) self.assertEqual(variable.get(), datetime_obj) variable.set(datetime_obj, from_db=True) self.assertEqual(variable.get(), datetime_obj) # Naive timezone, from_db=False (doesn't work). datetime_obj = datetime(1977, 5, 4, 12, 34, 56, 780000) self.assertRaises(ValueError, variable.set, datetime_obj) # Different timezone, from_db=False. datetime_obj = datetime(1977, 5, 4, 12, 34, 56, 780000, tzinfo=tzoffset("1h", 3600)) variable.set(datetime_obj, from_db=False) converted_obj = variable.get() self.assertEqual(converted_obj, datetime_obj) self.assertEqual(type(converted_obj.tzinfo), tzutc) # Different timezone, from_db=True. datetime_obj = datetime(1977, 5, 4, 12, 34, 56, 780000, tzinfo=tzoffset("1h", 3600)) variable.set(datetime_obj, from_db=True) converted_obj = variable.get() self.assertEqual(converted_obj, datetime_obj) self.assertEqual(type(converted_obj.tzinfo), tzutc) class DateVariableTest(TestHelper): def test_get_set(self): epoch = datetime.utcfromtimestamp(0) epoch_date = epoch.date() variable = DateVariable() variable.set(epoch) self.assertEqual(variable.get(), epoch_date) variable.set(epoch_date) self.assertEqual(variable.get(), epoch_date) self.assertRaises(TypeError, variable.set, marker) def test_get_set_from_database(self): date_str = "1977-05-04" date_uni = str(date_str) date_obj = date(1977, 5, 4) datetime_obj = datetime(1977, 5, 4, 0, 0, 0) variable = DateVariable() variable.set(date_str, from_db=True) self.assertEqual(variable.get(), date_obj) variable.set(date_uni, from_db=True) self.assertEqual(variable.get(), date_obj) variable.set(date_obj, from_db=True) self.assertEqual(variable.get(), date_obj) variable.set(datetime_obj, from_db=True) self.assertEqual(variable.get(), date_obj) self.assertRaises(TypeError, variable.set, 0, from_db=True) self.assertRaises(TypeError, variable.set, marker, from_db=True) self.assertRaises(ValueError, variable.set, "foobar", from_db=True) def test_set_with_datetime(self): datetime_str = "1977-05-04 12:34:56.78" date_obj = date(1977, 5, 4) variable = DateVariable() variable.set(datetime_str, from_db=True) self.assertEqual(variable.get(), date_obj) class TimeVariableTest(TestHelper): def test_get_set(self): epoch = datetime.utcfromtimestamp(0) epoch_time = epoch.time() variable = TimeVariable() variable.set(epoch) self.assertEqual(variable.get(), epoch_time) variable.set(epoch_time) self.assertEqual(variable.get(), epoch_time) self.assertRaises(TypeError, variable.set, marker) def test_get_set_from_database(self): time_str = "12:34:56.78" time_uni = str(time_str) time_obj = time(12, 34, 56, 780000) variable = TimeVariable() variable.set(time_str, from_db=True) self.assertEqual(variable.get(), time_obj) variable.set(time_uni, from_db=True) self.assertEqual(variable.get(), time_obj) variable.set(time_obj, from_db=True) self.assertEqual(variable.get(), time_obj) time_str = "12:34:56" time_uni = str(time_str) time_obj = time(12, 34, 56) variable.set(time_str, from_db=True) self.assertEqual(variable.get(), time_obj) variable.set(time_uni, from_db=True) self.assertEqual(variable.get(), time_obj) variable.set(time_obj, from_db=True) self.assertEqual(variable.get(), time_obj) self.assertRaises(TypeError, variable.set, 0, from_db=True) self.assertRaises(TypeError, variable.set, marker, from_db=True) self.assertRaises(ValueError, variable.set, "foobar", from_db=True) def test_set_with_datetime(self): datetime_str = "1977-05-04 12:34:56.78" time_obj = time(12, 34, 56, 780000) variable = TimeVariable() variable.set(datetime_str, from_db=True) self.assertEqual(variable.get(), time_obj) def test_microsecond_error(self): time_str = "15:14:18.598678" time_obj = time(15, 14, 18, 598678) variable = TimeVariable() variable.set(time_str, from_db=True) self.assertEqual(variable.get(), time_obj) def test_microsecond_error_less_digits(self): time_str = "15:14:18.5986" time_obj = time(15, 14, 18, 598600) variable = TimeVariable() variable.set(time_str, from_db=True) self.assertEqual(variable.get(), time_obj) def test_microsecond_error_more_digits(self): time_str = "15:14:18.5986789" time_obj = time(15, 14, 18, 598678) variable = TimeVariable() variable.set(time_str, from_db=True) self.assertEqual(variable.get(), time_obj) class TimeDeltaVariableTest(TestHelper): def test_get_set(self): delta = timedelta(days=42) variable = TimeDeltaVariable() variable.set(delta) self.assertEqual(variable.get(), delta) self.assertRaises(TypeError, variable.set, marker) def test_get_set_from_database(self): delta_str = "42 days 12:34:56.78" delta_uni = str(delta_str) delta_obj = timedelta(days=42, hours=12, minutes=34, seconds=56, microseconds=780000) variable = TimeDeltaVariable() variable.set(delta_str, from_db=True) self.assertEqual(variable.get(), delta_obj) variable.set(delta_uni, from_db=True) self.assertEqual(variable.get(), delta_obj) variable.set(delta_obj, from_db=True) self.assertEqual(variable.get(), delta_obj) delta_str = "1 day, 12:34:56" delta_uni = str(delta_str) delta_obj = timedelta(days=1, hours=12, minutes=34, seconds=56) variable.set(delta_str, from_db=True) self.assertEqual(variable.get(), delta_obj) variable.set(delta_uni, from_db=True) self.assertEqual(variable.get(), delta_obj) variable.set(delta_obj, from_db=True) self.assertEqual(variable.get(), delta_obj) self.assertRaises(TypeError, variable.set, 0, from_db=True) self.assertRaises(TypeError, variable.set, marker, from_db=True) self.assertRaises(ValueError, variable.set, "foobar", from_db=True) # Intervals of months or years can not be converted to a # Python timedelta, so a ValueError exception is raised: self.assertRaises(ValueError, variable.set, "42 months", from_db=True) self.assertRaises(ValueError, variable.set, "42 years", from_db=True) class ParseIntervalTest(TestHelper): def check(self, interval, td): self.assertEqual(TimeDeltaVariable(interval, from_db=True).get(), td) def test_zero(self): self.check("0:00:00", timedelta(0)) def test_one_microsecond(self): self.check("0:00:00.000001", timedelta(0, 0, 1)) def test_twelve_centiseconds(self): self.check("0:00:00.120000", timedelta(0, 0, 120000)) def test_one_second(self): self.check("0:00:01", timedelta(0, 1)) def test_twelve_seconds(self): self.check("0:00:12", timedelta(0, 12)) def test_one_minute(self): self.check("0:01:00", timedelta(0, 60)) def test_twelve_minutes(self): self.check("0:12:00", timedelta(0, 12*60)) def test_one_hour(self): self.check("1:00:00", timedelta(0, 60*60)) def test_twelve_hours(self): self.check("12:00:00", timedelta(0, 12*60*60)) def test_one_day(self): self.check("1 day, 0:00:00", timedelta(1)) def test_twelve_days(self): self.check("12 days, 0:00:00", timedelta(12)) def test_twelve_twelve_twelve_twelve_twelve(self): self.check("12 days, 12:12:12.120000", timedelta(12, 12*60*60 + 12*60 + 12, 120000)) def test_minus_twelve_centiseconds(self): self.check("-1 day, 23:59:59.880000", timedelta(0, 0, -120000)) def test_minus_twelve_days(self): self.check("-12 days, 0:00:00", timedelta(-12)) def test_minus_twelve_hours(self): self.check("-12:00:00", timedelta(hours=-12)) def test_one_day_and_a_half(self): self.check("1.5 days", timedelta(days=1, hours=12)) def test_seconds_without_unit(self): self.check("1h123", timedelta(hours=1, seconds=123)) def test_d_h_m_s_ms(self): self.check("1d1h1m1s1ms", timedelta(days=1, hours=1, minutes=1, seconds=1, microseconds=1000)) def test_days_without_unit(self): self.check("-12 1:02 3s", timedelta(days=-12, hours=1, minutes=2, seconds=3)) def test_unsupported_unit(self): try: self.check("1 month", None) except ValueError as e: self.assertEqual(str(e), "Unsupported interval unit 'month' " "in interval '1 month'") else: self.fail("ValueError not raised") def test_missing_value(self): try: self.check("day", None) except ValueError as e: self.assertEqual(str(e), "Expected an interval value rather than " "'day' in interval 'day'") else: self.fail("ValueError not raised") class UUIDVariableTest(TestHelper): def test_get_set(self): value = uuid.UUID("{0609f76b-878f-4546-baf5-c1b135e8de72}") variable = UUIDVariable() variable.set(value) self.assertEqual(variable.get(), value) self.assertEqual( variable.get(to_db=True), "0609f76b-878f-4546-baf5-c1b135e8de72") self.assertRaises(TypeError, variable.set, marker) self.assertRaises(TypeError, variable.set, "0609f76b-878f-4546-baf5-c1b135e8de72") self.assertRaises(TypeError, variable.set, "0609f76b-878f-4546-baf5-c1b135e8de72") def test_get_set_from_database(self): value = uuid.UUID("{0609f76b-878f-4546-baf5-c1b135e8de72}") variable = UUIDVariable() # Strings and UUID objects are accepted from the database. variable.set(value, from_db=True) self.assertEqual(variable.get(), value) variable.set("0609f76b-878f-4546-baf5-c1b135e8de72", from_db=True) self.assertEqual(variable.get(), value) variable.set("0609f76b-878f-4546-baf5-c1b135e8de72", from_db=True) self.assertEqual(variable.get(), value) # Some other representations for UUID values. variable.set("{0609f76b-878f-4546-baf5-c1b135e8de72}", from_db=True) self.assertEqual(variable.get(), value) variable.set("0609f76b878f4546baf5c1b135e8de72", from_db=True) self.assertEqual(variable.get(), value) class EncodedValueVariableTestMixin: encoding = None variable_type = None def test_get_set(self): d = {"a": 1} d_dump = self.encode(d) variable = self.variable_type() variable.set(d) self.assertEqual(variable.get(), d) self.assertEqual(variable.get(to_db=True), d_dump) variable.set(d_dump, from_db=True) self.assertEqual(variable.get(), d) self.assertEqual(variable.get(to_db=True), d_dump) self.assertEqual(variable.get_state(), (Undef, d_dump)) variable.set(marker) variable.set_state((Undef, d_dump)) self.assertEqual(variable.get(), d) variable.get()["b"] = 2 self.assertEqual(variable.get(), {"a": 1, "b": 2}) def test_pickle_events(self): event = EventSystem(marker) variable = self.variable_type(event=event, value_factory=list) changes = [] def changed(owner, variable, old_value, new_value, fromdb): changes.append((variable, old_value, new_value, fromdb)) event.emit("start-tracking-changes", event) event.hook("changed", changed) variable.checkpoint() event.emit("flush") self.assertEqual(changes, []) lst = variable.get() self.assertEqual(lst, []) self.assertEqual(changes, []) lst.append("a") self.assertEqual(changes, []) event.emit("flush") self.assertEqual(changes, [(variable, None, ["a"], False)]) del changes[:] event.emit("object-deleted") self.assertEqual(changes, [(variable, None, ["a"], False)]) class PickleVariableTest(EncodedValueVariableTestMixin, TestHelper): encode = staticmethod(lambda data: pickle.dumps(data, -1)) variable_type = PickleVariable class JSONVariableTest(EncodedValueVariableTestMixin, TestHelper): encode = staticmethod(lambda data: json.dumps(data)) variable_type = JSONVariable def is_supported(self): return json is not None def test_unicode_from_db_required(self): # JSONVariable._loads() complains loudly if it does not receive a # unicode string because it has no way of knowing its encoding. variable = self.variable_type() self.assertRaises(TypeError, variable.set, b'"abc"', from_db=True) def test_unicode_to_db(self): # JSONVariable._dumps() works around text/bytes handling issues in # json. variable = self.variable_type() variable.set({"a": 1}) self.assertTrue(isinstance(variable.get(to_db=True), str)) class ListVariableTest(TestHelper): def test_get_set(self): # Enumeration variables are used as items so that database # side and python side values can be distinguished. get_map = {1: "a", 2: "b", 3: "c"} set_map = {"a": 1, "b": 2, "c": 3} item_factory = VariableFactory( EnumVariable, get_map=get_map, set_map=set_map) l = ["a", "b"] l_dump = pickle.dumps(l, -1) l_vars = [item_factory(value=x) for x in l] variable = ListVariable(item_factory) variable.set(l) self.assertEqual(variable.get(), l) self.assertVariablesEqual(variable.get(to_db=True), l_vars) variable.set([1, 2], from_db=True) self.assertEqual(variable.get(), l) self.assertVariablesEqual(variable.get(to_db=True), l_vars) self.assertEqual(variable.get_state(), (Undef, l_dump)) variable.set([]) variable.set_state((Undef, l_dump)) self.assertEqual(variable.get(), l) variable.get().append("c") self.assertEqual(variable.get(), ["a", "b", "c"]) def test_list_events(self): event = EventSystem(marker) variable = ListVariable(BytesVariable, event=event, value_factory=list) changes = [] def changed(owner, variable, old_value, new_value, fromdb): changes.append((variable, old_value, new_value, fromdb)) event.emit("start-tracking-changes", event) event.hook("changed", changed) variable.checkpoint() event.emit("flush") self.assertEqual(changes, []) lst = variable.get() self.assertEqual(lst, []) self.assertEqual(changes, []) lst.append("a") self.assertEqual(changes, []) event.emit("flush") self.assertEqual(changes, [(variable, None, ["a"], False)]) del changes[:] event.emit("object-deleted") self.assertEqual(changes, [(variable, None, ["a"], False)]) class EnumVariableTest(TestHelper): def test_set_get(self): variable = EnumVariable({1: "foo", 2: "bar"}, {"foo": 1, "bar": 2}) variable.set("foo") self.assertEqual(variable.get(), "foo") self.assertEqual(variable.get(to_db=True), 1) variable.set(2, from_db=True) self.assertEqual(variable.get(), "bar") self.assertEqual(variable.get(to_db=True), 2) self.assertRaises(ValueError, variable.set, "foobar") self.assertRaises(ValueError, variable.set, 2) def test_in_map(self): variable = EnumVariable({1: "foo", 2: "bar"}, {"one": 1, "two": 2}) variable.set("one") self.assertEqual(variable.get(), "foo") self.assertEqual(variable.get(to_db=True), 1) variable.set(2, from_db=True) self.assertEqual(variable.get(), "bar") self.assertEqual(variable.get(to_db=True), 2) self.assertRaises(ValueError, variable.set, "foo") self.assertRaises(ValueError, variable.set, 2) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/wsgi.py0000644000175000017500000001032514645174376016673 0ustar00cjwatsoncjwatsonimport queue from unittest import TestCase import threading import time from storm.wsgi import make_app class TestMakeApp(TestCase): def stub_app(self, environ, start_response): if getattr(self, 'in_request', None): self.in_request() getattr(self, 'calls', []).append('stub_app') start_response('200 OK', []) yield '' if getattr(self, 'in_generator', None): self.in_generator() def stub_start_response(self, status, headers): pass def test_find_timeline_outside_request(self): app, find_timeline = make_app(self.stub_app) # outside a request, find_timeline returns nothing: self.assertEqual(None, find_timeline()) def test_find_timeline_in_request_not_set(self): # In a request, with no timeline object in the environ, find_timeline # returns None: app, find_timeline = make_app(self.stub_app) self.in_request = lambda:self.assertEqual(None, find_timeline()) self.calls = [] list(app({}, self.stub_start_response)) # And we definitely got into the call: self.assertEqual(['stub_app'], self.calls) def test_find_timeline_set_in_environ(self): # If a timeline object is known, find_timeline finds it: app, find_timeline = make_app(self.stub_app) timeline = FakeTimeline() self.in_request = lambda:self.assertEqual(timeline, find_timeline()) list(app({'timeline.timeline': timeline}, self.stub_start_response)) def test_find_timeline_set_in_environ_during_generator(self): # If a timeline object is known, find_timeline finds it: app, find_timeline = make_app(self.stub_app) timeline = FakeTimeline() self.in_generator = lambda:self.assertEqual(timeline, find_timeline()) list(app({'timeline.timeline': timeline}, self.stub_start_response)) def test_timeline_is_replaced_in_subsequent_request(self): app, find_timeline = make_app(self.stub_app) timeline = FakeTimeline() self.in_request = lambda:self.assertEqual(timeline, find_timeline()) list(app({'timeline.timeline': timeline}, self.stub_start_response)) # Having left the request, the timeline is left behind... self.assertEqual(timeline, find_timeline()) # ... but only until the next request comes through. timeline2 = FakeTimeline() self.in_request = lambda:self.assertEqual(timeline2, find_timeline()) list(app({'timeline.timeline': timeline2}, self.stub_start_response)) def test_lookups_are_threaded(self): # with two threads in a request at once, each only sees their own # timeline. app, find_timeline = make_app(self.stub_app) errors = queue.Queue() sync = threading.Condition() waiting = [] def check_timeline(): timeline = FakeTimeline() def start_response(status, headers): # Block on the condition, so all test threads are in # start_response when the test resumes. sync.acquire() waiting.append('x') sync.wait() sync.release() found_timeline = find_timeline() if found_timeline != timeline: errors.put((found_timeline, timeline)) list(app({'timeline.timeline': timeline}, start_response)) t1 = threading.Thread(target=check_timeline) t2 = threading.Thread(target=check_timeline) t1.start() try: t2.start() try: while True: sync.acquire() if len(waiting) == 2: break sync.release() time.sleep(0) sync.notify() sync.notify() sync.release() finally: t2.join() finally: t1.join() if errors.qsize(): found_timeline, timeline = errors.get(False) self.assertEqual(timeline, found_timeline) class FakeTimeline: """A fake Timeline. We need this because we can't use plain object instances as they can't be weakreferenced. """ ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/tests/zope/0000755000175000017500000000000014645532536016320 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/tests/zope/__init__.py0000644000175000017500000000253714571373456020442 0ustar00cjwatsoncjwatson# Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # __all__ = [ 'has_transaction', 'has_zope_component', 'has_zope_security', 'has_testresources', ] try: import transaction except ImportError: has_transaction = False else: has_transaction = True try: import zope.component except ImportError: has_zope_component = False else: has_zope_component = True try: import zope.security except ImportError: has_zope_security = False else: has_zope_security = True try: import testresources except ImportError: has_testresources = False else: has_testresources = True ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/zope/adapters.py0000644000175000017500000000364714645174376020513 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from storm.tests.helper import TestHelper from storm.tests.zope import has_zope_component if has_zope_component: from zope.interface import implementer from zope.component import getGlobalSiteManager from storm.store import EmptyResultSet from storm.zope.adapters import sqlobject_result_set_to_storm_result_set from storm.zope.interfaces import IResultSet, ISQLObjectResultSet @implementer(ISQLObjectResultSet) class TestSQLObjectResultSet: _result_set = EmptyResultSet() class AdaptersTest(TestHelper): def is_supported(self): return has_zope_component def setUp(self): getGlobalSiteManager().registerAdapter( sqlobject_result_set_to_storm_result_set) def tearDown(self): getGlobalSiteManager().unregisterAdapter( sqlobject_result_set_to_storm_result_set) def test_adapt_sqlobject_to_storm(self): so_result_set = TestSQLObjectResultSet() self.assertTrue( ISQLObjectResultSet.providedBy(so_result_set)) result_set = IResultSet(so_result_set) self.assertTrue( IResultSet.providedBy(result_set)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/zope/testing.py0000644000175000017500000003740014645174376020357 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os import sys from storm.tests.helper import TestHelper from storm.tests.zope import ( has_testresources, has_transaction, has_zope_component, ) from storm.locals import create_database, Store, Unicode, Int from storm.exceptions import IntegrityError from storm.testing import CaptureTracer from storm.schema.patch import PatchSet if has_transaction and has_zope_component and has_testresources: from zope.component import provideUtility, getUtility from storm.zope.zstorm import ZStorm, global_zstorm from storm.zope.interfaces import IZStorm from storm.zope.schema import ZSchema from storm.zope.testing import ZStormResourceManager PATCH = """ def apply(store): store.execute('ALTER TABLE test ADD COLUMN bar INT') """ class ZStormResourceManagerTest(TestHelper): def is_supported(self): return has_transaction and has_zope_component and has_testresources def setUp(self): super().setUp() package_dir = self.makeDir() sys.path.append(package_dir) self.patch_dir = os.path.join(package_dir, "patch_package") os.mkdir(self.patch_dir) self.makeFile(path=os.path.join(self.patch_dir, "__init__.py"), content="") self.makeFile(path=os.path.join(self.patch_dir, "patch_1.py"), content=PATCH) import patch_package create = ["CREATE TABLE test (foo TEXT UNIQUE, bar INT)"] drop = ["DROP TABLE test"] delete = ["DELETE FROM test"] uri = "sqlite:///%s" % self.makeFile() schema = ZSchema(create, drop, delete, PatchSet(patch_package)) self.databases = [{"name": "test", "uri": uri, "schema": schema}] self.resource = ZStormResourceManager(self.databases) self.resource.vertical_patching = False self.store = Store(create_database(uri)) def tearDown(self): global_zstorm._reset() del sys.modules["patch_package"] sys.modules.pop("patch_package.patch_1", None) super().tearDown() def test_make(self): """ L{ZStormResourceManager.make} returns a L{ZStorm} resource that can be used to get the registered L{Store}s. """ zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([], list(store.execute("SELECT foo, bar FROM test"))) def test_make_lazy(self): """ L{ZStormResourceManager.make} does not create all stores upfront, but only when they're actually used, likewise L{ZStorm.get}. """ zstorm = self.resource.make([]) self.assertEqual([], list(zstorm.iterstores())) store = zstorm.get("test") self.assertEqual([("test", store)], list(zstorm.iterstores())) def test_make_upgrade(self): """ L{ZStormResourceManager.make} upgrades the schema if needed. """ self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.store.execute("CREATE TABLE test (foo TEXT)") self.store.commit() zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([], list(store.execute("SELECT bar FROM test"))) def test_make_upgrade_unknown_patch(self): """ L{ZStormResourceManager.make} resets the schema if an unknown patch is found """ self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.store.execute("INSERT INTO patch VALUES (2)") self.store.execute("CREATE TABLE test (foo TEXT, egg BOOL)") self.store.commit() zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([], list(store.execute("SELECT foo, bar FROM test"))) self.assertEqual([(1,)], list(store.execute("SELECT version FROM patch"))) def test_make_delete(self): """ L{ZStormResourceManager.make} deletes the data from all tables to make sure that tests run against a clean database. """ self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.store.execute("CREATE TABLE test (foo TEXT)") self.store.execute("INSERT INTO test (foo) VALUES ('data')") self.store.commit() zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([], list(store.execute("SELECT foo FROM test"))) def test_make_commits_transaction_once(self): """ L{ZStormResourceManager.make} commits schema changes only once across all stores, after all patch and delete statements have been executed. """ database2 = {"name": "test2", "uri": "sqlite:///%s" % self.makeFile(), "schema": self.databases[0]["schema"]} self.databases.append(database2) other_store = Store(create_database(database2["uri"])) for store in [self.store, other_store]: store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") store.execute("CREATE TABLE test (foo TEXT)") store.execute("INSERT INTO test (foo) VALUES ('data')") store.commit() with CaptureTracer() as tracer: zstorm = self.resource.make([]) self.assertEqual(["COMMIT", "COMMIT"], tracer.queries[-2:]) store1 = zstorm.get("test") store2 = zstorm.get("test2") self.assertEqual([], list(store1.execute("SELECT foo FROM test"))) self.assertEqual([], list(store2.execute("SELECT foo FROM test"))) def test_make_zstorm_overwritten(self): """ L{ZStormResourceManager.make} registers its own ZStorm again if a test has registered a new ZStorm utility overwriting the resource one. """ zstorm = self.resource.make([]) provideUtility(ZStorm()) self.resource.make([]) self.assertIs(zstorm, getUtility(IZStorm)) def test_clean_flush(self): """ L{ZStormResourceManager.clean} tries to flush the stores to make sure that they are all in a consistent state. """ class Test: __storm_table__ = "test" foo = Unicode() bar = Int(primary=True) def __init__(self, foo, bar): self.foo = foo self.bar = bar zstorm = self.resource.make([]) store = zstorm.get("test") store.add(Test("data", 1)) store.add(Test("data", 2)) self.assertRaises(IntegrityError, self.resource.clean, zstorm) def test_clean_delete(self): """ L{ZStormResourceManager.clean} cleans the database tables from the data created by the tests. """ zstorm = self.resource.make([]) store = zstorm.get("test") store.execute("INSERT INTO test (foo, bar) VALUES ('data', 123)") store.commit() self.resource.clean(zstorm) self.assertEqual([], list(self.store.execute("SELECT * FROM test"))) def test_clean_with_force_delete(self): """ If L{ZStormResourceManager.force_delete} is C{True}, L{Schema.delete} is always invoked upon test cleanup. """ zstorm = self.resource.make([]) zstorm.get("test") # Force the creation of the store self.store.execute("INSERT INTO test (foo, bar) VALUES ('data', 123)") self.store.commit() self.resource.force_delete = True self.resource.clean(zstorm) self.assertEqual([], list(self.store.execute("SELECT * FROM test"))) def test_wb_clean_clears_alive_cache_before_abort(self): """ L{ZStormResourceManager.clean} clears the alive cache before aborting the transaction. """ class Test: __storm_table__ = "test" bar = Int(primary=True) def __init__(self, bar): self.bar = bar zstorm = self.resource.make([]) store = zstorm.get("test") store.add(Test(1)) store.add(Test(2)) real_invalidate = store.invalidate def invalidate_proxy(): self.assertEqual(0, len(list(store._alive.values()))) real_invalidate() store.invalidate = invalidate_proxy self.resource.clean(zstorm) def test_schema_uri(self): """ It's possible to specify an alternate URI for applying the schema and cleaning up tables after a test. """ schema_uri = "sqlite:///%s" % self.makeFile() self.databases[0]["schema-uri"] = schema_uri zstorm = self.resource.make([]) store = zstorm.get("test") schema_store = Store(create_database(schema_uri)) # The schema was applied using the alternate schema URI statement = "SELECT name FROM sqlite_master WHERE name='patch'" self.assertEqual([], list(store.execute(statement))) self.assertEqual([("patch",)], list(schema_store.execute(statement))) # The cleanup is performed with the alternate schema URI store.commit() schema_store.execute("INSERT INTO test (foo) VALUES ('data')") schema_store.commit() self.resource.clean(zstorm) self.assertEqual([], list(schema_store.execute("SELECT * FROM test"))) def test_schema_uri_with_schema_stamp_dir(self): """ If a schema stamp directory is set, and the stamp indicates there's no need to update the schema, the resource clean up code will still connect as schema user if it needs to run the schema delete statements because of a commit. """ self.resource.schema_stamp_dir = self.makeFile() self.databases[0]["schema-uri"] = self.databases[0]["uri"] self.resource.make([]) # Simulate a second test run that initializes the zstorm resource # from scratch, using the same schema stamp directory resource2 = ZStormResourceManager(self.databases) resource2.schema_stamp_dir = self.resource.schema_stamp_dir zstorm = resource2.make([]) store = zstorm.get("test") store.execute("INSERT INTO test (foo) VALUES ('data')") store.commit() # Committing will force a schema.delete() run resource2.clean(zstorm) self.assertEqual([], list(store.execute("SELECT * FROM test"))) def test_no_schema(self): """ A particular database may have no schema associated. """ self.databases[0]["schema"] = None zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([], list(store.execute("SELECT * FROM sqlite_master"))) def test_no_schema_clean(self): """ A particular database may have no schema associated. If it's committed during tests, it will just be skipped when cleaning up tables. """ self.databases[0]["schema"] = None zstorm = self.resource.make([]) store = zstorm.get("test") store.commit() with CaptureTracer() as tracer: self.resource.clean(zstorm) self.assertEqual([], tracer.queries) def test_deprecated_database_format(self): """ The old deprecated format of the 'database' constructor parameter is still supported. """ import patch_package uri = "sqlite:///%s" % self.makeFile() schema = ZSchema([], [], [], patch_package) resource = ZStormResourceManager({"test": (uri, schema)}) zstorm = resource.make([]) store = zstorm.get("test") self.assertIsNot(None, store) def test_use_global_zstorm(self): """ If the C{use_global_zstorm} attribute is C{True} then the global L{ZStorm} will be used. """ self.resource.use_global_zstorm = True zstorm = self.resource.make([]) self.assertIs(global_zstorm, zstorm) def test_provide_utility_before_patches(self): """ The L{IZStorm} utility is provided before patches are applied, in order to let them get it if they need. """ content = ("from zope.component import getUtility\n" "from storm.zope.interfaces import IZStorm\n" "def apply(store):\n" " getUtility(IZStorm)\n") self.makeFile(path=os.path.join(self.patch_dir, "patch_2.py"), content=content) self.store.execute("CREATE TABLE patch " "(version INTEGER NOT NULL PRIMARY KEY)") self.store.execute("CREATE TABLE test (foo TEXT)") self.store.commit() zstorm = self.resource.make([]) store = zstorm.get("test") self.assertEqual([(1,), (2,)], sorted(store.execute("SELECT version FROM patch"))) def test_create_schema_stamp_dir(self): """ If a schema stamp directory is set, it's created automatically if it doesn't exist yet. """ self.resource.schema_stamp_dir = self.makeFile() self.resource.make([]) self.assertTrue(os.path.exists(self.resource.schema_stamp_dir)) def test_use_schema_stamp(self): """ If a schema stamp directory is set, then it's used to decide whether to upgrade the schema or not. In case the patch directory hasn't been changed since the last known upgrade, no schema upgrade is run. """ self.resource.schema_stamp_dir = self.makeFile() self.resource.make([]) # Simulate a second test run that initializes the zstorm resource # from scratch, using the same schema stamp directory resource2 = ZStormResourceManager(self.databases) resource2.schema_stamp_dir = self.resource.schema_stamp_dir with CaptureTracer() as tracer: resource2.make([]) self.assertEqual([], tracer.queries) def test_use_schema_stamp_out_of_date(self): """ If a schema stamp directory is set, then it's used to decide whether to upgrade the schema or not. In case the patch directory has changed a schema upgrade is run. """ self.resource.schema_stamp_dir = self.makeFile() self.resource.make([]) # Simulate a second test run that initializes the zstorm resource # from scratch, using the same schema stamp directory resource2 = ZStormResourceManager(self.databases) resource2.schema_stamp_dir = self.resource.schema_stamp_dir self.makeFile(path=os.path.join(self.patch_dir, "patch_2.py"), content="def apply(store): pass") class FakeStat: st_mtime = os.stat(self.patch_dir).st_mtime + 1 stat_mock = self.mocker.replace(os.stat) stat_mock(self.patch_dir) self.mocker.result(FakeStat()) self.mocker.replay() resource2.make([]) result = self.store.execute("SELECT version FROM patch") self.assertEqual([(1,), (2,)], sorted(result.get_all())) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tests/zope/zstorm.py0000644000175000017500000003675514645174376020254 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import threading import weakref import gc from storm.tests.helper import TestHelper from storm.tests.zope import has_transaction, has_zope_component if has_transaction: import transaction from transaction import ThreadTransactionManager from storm.zope.interfaces import IZStorm, ZStormError from storm.zope.zstorm import ZStorm, StoreDataManager if has_zope_component: from zope.component import provideUtility, getUtility from storm.exceptions import OperationalError from storm.locals import Store class ZStormTest(TestHelper): def is_supported(self): return has_transaction def setUp(self): self.zstorm = ZStorm() def tearDown(self): # Reset the utility to cleanup the StoreSynchronizer's from the # transaction. self.zstorm._reset() # Free the transaction to avoid having errors that cross # test cases. # XXX cjwatson 2019-05-29: transaction 2.4.0 changed # ThreadTransactionManager to wrap TransactionManager rather than # inheriting from it. For now, cope with either. Simplify this # once transaction 2.4.0 is old enough that we can reasonably just # test-depend on it. manager = transaction.manager if isinstance(manager, ThreadTransactionManager): try: manager.free except AttributeError: # transaction >= 2.4.0 manager = manager.manager manager.free(transaction.get()) def test_create(self): store = self.zstorm.create(None, "sqlite:") self.assertTrue(isinstance(store, Store)) def test_create_twice_unnamed(self): store = self.zstorm.create(None, "sqlite:") store.execute("CREATE TABLE test (id INTEGER)") store.commit() store = self.zstorm.create(None, "sqlite:") self.assertRaises(OperationalError, store.execute, "SELECT * FROM test") def test_create_twice_same_name(self): store = self.zstorm.create("name", "sqlite:") self.assertRaises(ZStormError, self.zstorm.create, "name", "sqlite:") def test_create_and_get_named(self): store = self.zstorm.create("name", "sqlite:") self.assertTrue(self.zstorm.get("name") is store) def test_create_and_get_named_another_thread(self): store = self.zstorm.create("name", "sqlite:") raised = [] def f(): try: self.zstorm.get("name") except ZStormError: raised.append(True) thread = threading.Thread(target=f) thread.start() thread.join() self.assertTrue(raised) def test_get_unexistent(self): self.assertRaises(ZStormError, self.zstorm.get, "name") def test_get_with_uri(self): store = self.zstorm.get("name", "sqlite:") self.assertTrue(isinstance(store, Store)) self.assertTrue(self.zstorm.get("name") is store) self.assertTrue(self.zstorm.get("name", "sqlite:") is store) def test_set_default_uri(self): self.zstorm.set_default_uri("name", "sqlite:") store = self.zstorm.get("name") self.assertTrue(isinstance(store, Store)) def test_create_default(self): self.zstorm.set_default_uri("name", "sqlite:") store = self.zstorm.create("name") self.assertTrue(isinstance(store, Store)) def test_create_default_twice(self): self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.create("name") self.assertRaises(ZStormError, self.zstorm.create, "name") def test_iterstores(self): store1 = self.zstorm.create(None, "sqlite:") store2 = self.zstorm.create(None, "sqlite:") store3 = self.zstorm.create("name", "sqlite:") stores = [] for name, store in self.zstorm.iterstores(): stores.append((name, store)) self.assertEqual(len(stores), 3) self.assertEqual(set(stores), {(None, store1), (None, store2), ("name", store3)}) def test_get_name(self): store = self.zstorm.create("name", "sqlite:") self.assertEqual(self.zstorm.get_name(store), "name") def test_get_name_with_removed_store(self): store = self.zstorm.create("name", "sqlite:") self.assertEqual(self.zstorm.get_name(store), "name") self.zstorm.remove(store) self.assertEqual(self.zstorm.get_name(store), None) def test_default_databases(self): self.zstorm.set_default_uri("name1", "sqlite:1") self.zstorm.set_default_uri("name2", "sqlite:2") self.zstorm.set_default_uri("name3", "sqlite:3") default_uris = self.zstorm.get_default_uris() self.assertEqual(default_uris, {"name1": "sqlite:1", "name2": "sqlite:2", "name3": "sqlite:3"}) def test_register_store_for_tpc_transaction(self): """ Setting a store to use two-phase-commit mode, makes ZStorm call its begin() method when it joins the transaction. """ self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.set_default_tpc("name", True) store = self.zstorm.get("name") xids = [] store.begin = lambda xid: xids.append(xid) store.execute("SELECT 1") [xid] = xids self.assertEqual(0, xid.format_id) self.assertEqual("_storm", xid.global_transaction_id[:6]) self.assertEqual("name", xid.branch_qualifier) def test_register_store_for_tpc_transaction_uses_per_transaction_id(self): """ Two stores in two-phase-commit mode joining the same transaction share the same global transaction ID. """ self.zstorm.set_default_uri("name1", "sqlite:///%s" % self.makeFile()) self.zstorm.set_default_uri("name2", "sqlite:///%s" % self.makeFile()) self.zstorm.set_default_tpc("name1", True) self.zstorm.set_default_tpc("name2", True) store1 = self.zstorm.get("name1") store2 = self.zstorm.get("name2") xids = [] store1.begin = lambda xid: xids.append(xid) store2.begin = lambda xid: xids.append(xid) store1.execute("SELECT 1") store2.execute("SELECT 1") [xid1, xid2] = xids self.assertEqual(xid1.global_transaction_id, xid2.global_transaction_id) def test_register_store_for_tpc_transaction_uses_unique_global_ids(self): """ Each global transaction gets assigned a unique ID. """ self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.set_default_tpc("name", True) store = self.zstorm.get("name") xids = [] store.begin = lambda xid: xids.append(xid) store.execute("SELECT 1") transaction.abort() store.execute("SELECT 1") transaction.abort() [xid1, xid2] = xids self.assertNotEqual(xid1.global_transaction_id, xid2.global_transaction_id) def test_transaction_with_two_phase_commit(self): """ If a store is set to use TPC, than the associated data manager will call its prepare() and commit() methods when committing. """ self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.set_default_tpc("name", True) store = self.zstorm.get("name") calls = [] store.begin = lambda xid: calls.append("begin") store.prepare = lambda: calls.append("prepare") store.commit = lambda: calls.append("commit") store.execute("SELECT 1") transaction.commit() self.assertEqual(["begin", "prepare", "commit"], calls) def test_transaction_with_single_and_two_phase_commit_stores(self): """ When there are both stores in single-phase and two-phase mode, the ones in single-phase mode are committed first. This makes it possible to actually achieve two-phase commit behavior when only one store doesn't support TPC. """ self.zstorm.set_default_uri("name1", "sqlite:///%s" % self.makeFile()) self.zstorm.set_default_uri("name2", "sqlite:///%s" % self.makeFile()) self.zstorm.set_default_tpc("name1", True) self.zstorm.set_default_tpc("name2", False) store1 = self.zstorm.get("name1") store2 = self.zstorm.get("name2") commits = [] store1.begin = lambda xid: None store1.prepare = lambda: None store1.commit = lambda: commits.append("commit1") store2.commit = lambda: commits.append("commit2") store1.execute("SELECT 1") store2.execute("SELECT 1") transaction.commit() self.assertEqual(["commit2", "commit1"], commits) def _isInTransaction(self, store): """Check if a Store is part of the current transaction.""" for dm in transaction.get()._resources: if isinstance(dm, StoreDataManager) and dm._store is store: return True return False def assertInTransaction(self, store): """Check that the given store is joined to the transaction.""" self.assertTrue(self._isInTransaction(store), "%r should be joined to the transaction" % store) def assertNotInTransaction(self, store): """Check that the given store is not joined to the transaction.""" self.assertTrue(not self._isInTransaction(store), "%r should not be joined to the transaction" % store) def test_wb_store_joins_transaction_on_register_event(self): """The Store joins the transaction when register-transaction is emitted. The Store tests check the various operations that trigger this event. """ store = self.zstorm.get("name", "sqlite:") self.assertNotInTransaction(store) store._event.emit("register-transaction") self.assertInTransaction(store) def test_wb_store_joins_transaction_on_use_after_commit(self): store = self.zstorm.get("name", "sqlite:") store.execute("SELECT 1") transaction.commit() self.assertNotInTransaction(store) store.execute("SELECT 1") self.assertInTransaction(store) def test_wb_store_joins_transaction_on_use_after_abort(self): store = self.zstorm.get("name", "sqlite:") store.execute("SELECT 1") transaction.abort() self.assertNotInTransaction(store) store.execute("SELECT 1") self.assertInTransaction(store) def test_wb_store_joins_transaction_on_use_after_tpc_commit(self): """ A store used after a two-phase commit re-joins the new transaction. """ self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.set_default_tpc("name", True) store = self.zstorm.get("name") store.begin = lambda xid: None store.prepare = lambda: None store.commit = lambda: None store.execute("SELECT 1") transaction.commit() self.assertNotInTransaction(store) store.execute("SELECT 1") self.assertInTransaction(store) def test_wb_store_joins_transaction_on_use_after_tpc_abort(self): """ A store used after a rollback during a two-phase commit re-joins the new transaction. """ self.zstorm.set_default_uri("name", "sqlite:") self.zstorm.set_default_tpc("name", True) store = self.zstorm.get("name") store.begin = lambda xid: None store.prepare = lambda: None store.rollback = lambda: None store.execute("SELECT 1") transaction.abort() self.assertNotInTransaction(store) store.execute("SELECT 1") self.assertInTransaction(store) def test_remove(self): removed_store = self.zstorm.get("name", "sqlite:") self.zstorm.remove(removed_store) for name, store in self.zstorm.iterstores(): self.assertNotEqual(store, removed_store) self.assertRaises(ZStormError, self.zstorm.get, "name") def test_wb_removed_store_does_not_join_transaction(self): """If a store has been removed, it will not join the transaction.""" store = self.zstorm.get("name", "sqlite:") self.zstorm.remove(store) store.execute("SELECT 1") self.assertNotInTransaction(store) def test_wb_removed_store_does_not_join_future_transactions(self): """If a store has been removed after joining a transaction, it will not join new transactions.""" store = self.zstorm.get("name", "sqlite:") store.execute("SELECT 1") self.zstorm.remove(store) self.assertInTransaction(store) transaction.abort() store.execute("SELECT 1") self.assertNotInTransaction(store) def test_wb_cross_thread_store_does_not_join_transaction(self): """If a zstorm registered thread crosses over to another thread, it will not be usable.""" store = self.zstorm.get("name", "sqlite:") failures = [] def f(): # We perform this twice to show that ZStormError is raised # consistently (i.e. not just the first time). for i in range(2): try: store.execute("SELECT 1") except ZStormError: failures.append("ZStormError raised") except Exception as exc: failures.append("Expected ZStormError, got %r" % exc) else: failures.append("Expected ZStormError, nothing raised") if self._isInTransaction(store): failures.append("store was joined to transaction") thread = threading.Thread(target=f) thread.start() thread.join() self.assertEqual(failures, ["ZStormError raised"] * 2) def test_wb_reset(self): """_reset is used to reset the zstorm utility between zope test runs. """ store = self.zstorm.get("name", "sqlite:") self.zstorm._reset() self.assertEqual(list(self.zstorm.iterstores()), []) def test_store_strong_reference(self): """ The zstorm utility should be a strong reference to named stores so that it doesn't recreate stores uselessly. """ store = self.zstorm.get("name", "sqlite:") store_ref = weakref.ref(store) transaction.abort() del store gc.collect() self.assertNotIdentical(store_ref(), None) store = self.zstorm.get("name") self.assertIdentical(store_ref(), store) class ZStormUtilityTest(TestHelper): def is_supported(self): return has_transaction and has_zope_component def test_utility(self): provideUtility(ZStorm()) self.assertTrue(isinstance(getUtility(IZStorm), ZStorm)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tracer.py0000644000175000017500000002517514645174376016051 0ustar00cjwatsoncjwatsonfrom datetime import datetime import re import sys import threading # Circular import: imported at the end of the module. # from storm.database import convert_param_marks from storm.exceptions import TimeoutError from storm.expr import Variable class DebugTracer: def __init__(self, stream=None): if stream is None: stream = sys.stderr self._stream = stream def connection_raw_execute(self, connection, raw_cursor, statement, params): time = datetime.now().isoformat()[11:] raw_params = [] for param in params: if isinstance(param, Variable): raw_params.append(param.get()) else: raw_params.append(param) raw_params = tuple(raw_params) self._stream.write( "[%s] EXECUTE: %r, %r\n" % (time, statement, raw_params)) self._stream.flush() def connection_raw_execute_error(self, connection, raw_cursor, statement, params, error): time = datetime.now().isoformat()[11:] self._stream.write("[%s] ERROR: %s\n" % (time, error)) self._stream.flush() def connection_raw_execute_success(self, connection, raw_cursor, statement, params): time = datetime.now().isoformat()[11:] self._stream.write("[%s] DONE\n" % time) self._stream.flush() def connection_commit(self, connection, xid=None): time = datetime.now().isoformat()[11:] self._stream.write("[%s] COMMIT xid=%s\n" % (time, xid)) self._stream.flush() def connection_rollback(self, connection, xid=None): time = datetime.now().isoformat()[11:] self._stream.write("[%s] ROLLBACK xid=%s\n" % (time, xid)) self._stream.flush() class TimeoutTracer: """Provide a timeout facility for connections to prevent rogue operations. This tracer must be subclassed by backend-specific implementations that override C{connection_raw_execute_error}, C{set_statement_timeout} and C{get_remaining_time} methods. """ def __init__(self, granularity=5): self.granularity = granularity def connection_raw_execute(self, connection, raw_cursor, statement, params): """Check timeout conditions before a statement is executed. @param connection: The L{Connection} to the database. @param raw_cursor: A cursor object, specific to the backend being used. @param statement: The SQL statement to execute. @param params: The parameters to use with C{statement}. @raises TimeoutError: Raised if there isn't enough time left to execute C{statement}. """ remaining_time = self.get_remaining_time() if remaining_time <= 0: raise TimeoutError( statement, params, "%d seconds remaining in time budget" % remaining_time) last_remaining_time = getattr(connection, "_timeout_tracer_remaining_time", 0) if (remaining_time > last_remaining_time or last_remaining_time - remaining_time >= self.granularity): self.set_statement_timeout(raw_cursor, remaining_time) connection._timeout_tracer_remaining_time = remaining_time def connection_raw_execute_error(self, connection, raw_cursor, statement, params, error): """Raise L{TimeoutError} if the given error was a timeout issue. Must be specialized in the backend. """ raise NotImplementedError("%s.connection_raw_execute_error() must be " "implemented" % self.__class__.__name__) def connection_commit(self, connection, xid=None): """Reset C{Connection._timeout_tracer_remaining_time}. @param connection: The L{Connection} to the database. @param xid: Optionally the L{Xid} of a previously prepared transaction to commit. """ self._reset_timeout_tracer_remaining_time(connection) def connection_rollback(self, connection, xid=None): """Reset C{Connection._timeout_tracer_remaining_time}. @param connection: The L{Connection} to the database. @param xid: Optionally the L{Xid} of a previously prepared transaction to rollback. """ self._reset_timeout_tracer_remaining_time(connection) def _reset_timeout_tracer_remaining_time(self, connection): """Set connection._timeout_tracer_remaining_time to 0.""" connection._timeout_tracer_remaining_time = 0 def set_statement_timeout(self, raw_cursor, remaining_time): """Perform the timeout setup in the raw cursor. The database should raise an error if the next statement takes more than the number of seconds provided in C{remaining_time}. Must be specialized in the backend. """ raise NotImplementedError("%s.set_statement_timeout() must be " "implemented" % self.__class__.__name__) def get_remaining_time(self): """Tells how much time the current context (HTTP request, etc) has. Must be specialized with application logic. @return: Number of seconds allowed for the next statement. """ raise NotImplementedError("%s.get_remaining_time() must be implemented" % self.__class__.__name__) class BaseStatementTracer: """Storm tracer base class that does query interpolation.""" def connection_raw_execute(self, connection, raw_cursor, statement, params): statement_to_log = statement if params: # There are some bind parameters so we want to insert them into # the sql statement so we can log the statement. query_params = list(connection.to_database(params)) if connection.param_mark == '%s': # Double the %'s in the string so that python string formatting # can restore them to the correct number. Note that %s needs to # be preserved as that is where we are substituting values in. quoted_statement = re.sub( "%%%", "%%%%", re.sub("%([^s])", r"%%\1", statement)) else: # Double all the %'s in the statement so that python string # formatting can restore them to the correct number. Any %s in # the string should be preserved because the param_mark is not # %s. quoted_statement = re.sub("%", "%%", statement) quoted_statement = convert_param_marks( quoted_statement, connection.param_mark, "%s") # We need to massage the query parameters a little to deal with # string parameters which represent encoded binary data. render_params = [] for param in query_params: if isinstance(param, str): render_params.append(ascii(param)) else: render_params.append(repr(param)) try: statement_to_log = quoted_statement % tuple(render_params) except TypeError: statement_to_log = \ "Unformattable query: %r with params %r." % ( statement, query_params) self._expanded_raw_execute(connection, raw_cursor, statement_to_log) def _expanded_raw_execute(self, connection, raw_cursor, statement): """Called by connection_raw_execute after parameter substitution.""" raise NotImplementedError(self._expanded_raw_execute) class TimelineTracer(BaseStatementTracer): """Storm tracer class to insert executed statements into a L{Timeline}. For more information on timelines see the module at U{https://pypi.org/project/timeline/}. The timeline to use is obtained by calling the timeline_factory supplied to the constructor. This simple function takes no parameters and returns a timeline to use. If it returns None, the tracer is bypassed. """ def __init__(self, timeline_factory, prefix='SQL-'): """Create a TimelineTracer. @param timeline_factory: A factory function to produce the timeline to record a query against. @param prefix: A prefix to give the connection name when starting an action. Connection names are found by trying a getattr for 'name' on the connection object. If no name has been assigned, '' is used instead. """ super().__init__() self.timeline_factory = timeline_factory self.prefix = prefix # Stores the action in progress in a given thread. self.threadinfo = threading.local() def _expanded_raw_execute(self, connection, raw_cursor, statement): timeline = self.timeline_factory() if timeline is None: return connection_name = getattr(connection, 'name', '') action = timeline.start(self.prefix + connection_name, statement) self.threadinfo.action = action def connection_raw_execute_success(self, connection, raw_cursor, statement, params): # action may be None if the tracer was installed after the statement # was submitted. action = getattr(self.threadinfo, 'action', None) if action is not None: action.finish() def connection_raw_execute_error(self, connection, raw_cursor, statement, params, error): # Since we are just logging durations, we execute the same # hook code for errors as successes. self.connection_raw_execute_success( connection, raw_cursor, statement, params) _tracers = [] def trace(name, *args, **kwargs): for tracer in _tracers: attr = getattr(tracer, name, None) if attr: attr(*args, **kwargs) def install_tracer(tracer): _tracers.append(tracer) def get_tracers(): return _tracers[:] def remove_all_tracers(): del _tracers[:] def remove_tracer(tracer): try: _tracers.remove(tracer) except ValueError: pass # The tracer is not installed, succeed gracefully def remove_tracer_type(tracer_type): for i in range(len(_tracers) - 1, -1, -1): if type(_tracers[i]) is tracer_type: del _tracers[i] def debug(flag, stream=None): remove_tracer_type(DebugTracer) if flag: install_tracer(DebugTracer(stream=stream)) # Deal with circular import. from storm.database import convert_param_marks ././@PaxHeader0000000000000000000000000000003300000000000010211 xustar0027 mtime=1721152862.425125 storm-1.0/storm/twisted/0000755000175000017500000000000014645532536015664 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/twisted/__init__.py0000644000175000017500000000000011752263216017753 0ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/twisted/testing.py0000644000175000017500000000321314645174376017716 0ustar00cjwatsoncjwatsonfrom twisted.python.failure import Failure from twisted.internet.defer import execute from storm.twisted.transact import Transactor class FakeThreadPool: """ A fake L{twisted.python.threadpool.ThreadPool}, running functions inside the main thread instead for easing tests. """ def callInThreadWithCallback(self, onResult, func, *args, **kw): success = True try: result = func(*args, **kw) except: result = Failure() success = False onResult(success, result) class FakeTransaction: def commit(self): pass def abort(self): pass class FakeTransactor(Transactor): """ A fake C{Transactor} wrapper that runs the given function in the main thread and performs basic checks on its return value. If it has a C{__storm_table__} property a C{RuntimeError} is raised because Storm objects cannot be used outside the thread in which they were created. @seealso: L{Transactor}. """ retries = 0 on_retry = None sleep = lambda *args, **kwargs: None def __init__(self, transaction=None): if transaction is None: transaction = FakeTransaction() self._transaction = transaction def run(self, function, *args, **kwargs): deferred = execute(self._wrap, function, *args, **kwargs) return deferred.addCallback(self._check_result) def _check_result(self, result): if getattr(result, "__storm_table__", None) is not None: raise RuntimeError("Attempted to return a Storm object from a " "transaction") return result ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/twisted/transact.py0000644000175000017500000001317414645174376020067 0ustar00cjwatsoncjwatsonimport time import random import transaction as zope_transaction from functools import wraps from zope.component import getUtility from storm.zope.interfaces import IZStorm from storm.exceptions import IntegrityError, DisconnectionError from twisted.internet.threads import deferToThreadPool RETRIABLE_ERRORS = (DisconnectionError, IntegrityError) try: from psycopg2.extensions import TransactionRollbackError RETRIABLE_ERRORS = RETRIABLE_ERRORS + (TransactionRollbackError,) except ImportError: pass class Transactor: """Run in a thread code that needs to interact with the database. This class makes sure that code interacting with the database is run in a separate thread and that the associated transaction is aborted or committed in the same thread. @param threadpool: The C{ThreadPool} to get threads from. @param _transaction: The C{TransactionManager} to use, for test cases only. @ivar retries: Maximum number of retries upon retriable exceptions. The default is to retry a function up to 2 times upon possibly transient or spurious errors like L{IntegrityError} and L{DisconnectionError}. @ivar on_retry: If not C{None}, a callable that will be called before retrying to run a function, and passed a L{RetryContext} instance with the details about the retry. @see: C{twisted.python.threadpool.ThreadPool} """ retries = 2 on_retry = None sleep = time.sleep uniform = random.uniform def __init__(self, threadpool, transaction=None): self._threadpool = threadpool if transaction is None: transaction = zope_transaction self._transaction = transaction def run(self, function, *args, **kwargs): """Run C{function} in a thread. The function is run in a thread by a function wrapper, which commits the transaction if the function runs successfully. If it raises an exception the transaction is aborted. @param function: The function to run. @param args: Positional arguments to pass to C{function}. @param kwargs: Keyword arguments to pass to C{function}. @return: A C{Deferred} that will fire after the function has been run. """ # Inline the reactor import here for sake of safeness, in case a # custom reactor needs to be installed from twisted.internet import reactor return deferToThreadPool( reactor, self._threadpool, self._wrap, function, *args, **kwargs) def _wrap(self, function, *args, **kwargs): retries = 0 while True: try: result = function(*args, **kwargs) self._transaction.commit() except RETRIABLE_ERRORS as error: if isinstance(error, DisconnectionError): # If we got a disconnection, calling rollback may not be # enough because psycopg2 doesn't necessarily use the # connection, so we call a dummy query to be sure that all # the stores are correct. zstorm = getUtility(IZStorm) for name, store in zstorm.iterstores(): try: store.execute("SELECT 1") except DisconnectionError: pass self._transaction.abort() if retries < self.retries: retries += 1 if self.on_retry is not None: context = RetryContext(function, args, kwargs, retries, error) self.on_retry(context) self.sleep(self.uniform(1, 2 ** retries)) continue else: raise except: self._transaction.abort() raise else: return result class RetryContext: """Hold details about a function that is going to be retried. @ivar function: The function that is going to be retried. @ivar args: The positional arguments passed to the function. @ivar kwargs: The keyword arguments passed to the function. @ivar retry: The sequential number of the retry that is going to be performed. @ivar error: The Exception instance that caused a retry to be scheduled. """ def __init__(self, function, args, kwargs, retry, error): self.function = function self.args = args self.kwargs = kwargs self.retry = retry self.error = error def transact(method): """Decorate L{method} so that it is invoked via L{Transactor.run}. Example:: from twisted.python.threadpool import ThreadPool from storm.twisted.transact import Transactor, transact class Foo(object): def __init__(self, transactor): self.transactor = transactor @transact def bar(self): # code that uses Storm threadpool = ThreadPool(0, 10) threadpool.start() transactor = Transactor(threadpool) foo = Foo(transactor) deferred = foo.bar() deferred.addCallback(...) @param method: The method to decorate. @return: A decorated method. @note: The return value of the decorated method should *not* contain any reference to Storm objects, because they were retrieved in a different thread and cannot be used outside it. """ @wraps(method) def wrapper(self, *args, **kwargs): return self.transactor.run(method, self, *args, **kwargs) return wrapper ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/tz.py0000644000175000017500000007700614645174376015226 0ustar00cjwatsoncjwatson# Copyright (c) 2003-2005 Gustavo Niemeyer """ This module offers extensions to the standard python 2.3+ datetime module. """ __author__ = "Gustavo Niemeyer " __license__ = "PSF License" import datetime import struct import time import sys import os relativedelta = None parser = None rrule = None __all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz"] try: from dateutil.tzwin import tzwin, tzwinlocal except (ImportError, OSError): tzwin, tzwinlocal = None, None ZERO = datetime.timedelta(0) EPOCHORDINAL = datetime.datetime.utcfromtimestamp(0).toordinal() class tzutc(datetime.tzinfo): def utcoffset(self, dt): return ZERO def dst(self, dt): return ZERO def tzname(self, dt): return "UTC" def __eq__(self, other): return (isinstance(other, tzutc) or (isinstance(other, tzoffset) and other._offset == ZERO)) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "%s()" % self.__class__.__name__ __reduce__ = object.__reduce__ class tzoffset(datetime.tzinfo): def __init__(self, name, offset): self._name = name self._offset = datetime.timedelta(seconds=offset) def utcoffset(self, dt): return self._offset def dst(self, dt): return ZERO def tzname(self, dt): return self._name def __eq__(self, other): return (isinstance(other, tzoffset) and self._offset == other._offset) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "%s(%r, %s)" % (self.__class__.__name__, self._name, self._offset.days*86400+self._offset.seconds) __reduce__ = object.__reduce__ class tzlocal(datetime.tzinfo): _std_offset = datetime.timedelta(seconds=-time.timezone) if time.daylight: _dst_offset = datetime.timedelta(seconds=-time.altzone) else: _dst_offset = _std_offset def utcoffset(self, dt): if self._isdst(dt): return self._dst_offset else: return self._std_offset def dst(self, dt): if self._isdst(dt): return self._dst_offset-self._std_offset else: return ZERO def tzname(self, dt): return time.tzname[self._isdst(dt)] def _isdst(self, dt): # We can't use mktime here. It is unstable when deciding if # the hour near to a change is DST or not. # # timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour, # dt.minute, dt.second, dt.weekday(), 0, -1)) # return time.localtime(timestamp).tm_isdst # # The code above yields the following result: # #>>> import tz, datetime #>>> t = tz.tzlocal() #>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() #'BRDT' #>>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() #'BRST' #>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() #'BRST' #>>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() #'BRDT' #>>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() #'BRDT' # # Here is a more stable implementation: # timestamp = ((dt.toordinal() - EPOCHORDINAL) * 86400 + dt.hour * 3600 + dt.minute * 60 + dt.second) return time.localtime(timestamp+time.timezone).tm_isdst def __eq__(self, other): if not isinstance(other, tzlocal): return False return (self._std_offset == other._std_offset and self._dst_offset == other._dst_offset) return True def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "%s()" % self.__class__.__name__ __reduce__ = object.__reduce__ class _ttinfo: __slots__ = ["offset", "delta", "isdst", "abbr", "isstd", "isgmt"] def __init__(self): for attr in self.__slots__: setattr(self, attr, None) def __repr__(self): l = [] for attr in self.__slots__: value = getattr(self, attr) if value is not None: l.append("%s=%r" % (attr, value)) return "%s(%s)" % (self.__class__.__name__, ", ".join(l)) def __eq__(self, other): if not isinstance(other, _ttinfo): return False return (self.offset == other.offset and self.delta == other.delta and self.isdst == other.isdst and self.abbr == other.abbr and self.isstd == other.isstd and self.isgmt == other.isgmt) def __ne__(self, other): return not self.__eq__(other) def __getstate__(self): state = {} for name in self.__slots__: state[name] = getattr(self, name, None) return state def __setstate__(self, state): for name in self.__slots__: if name in state: setattr(self, name, state[name]) class tzfile(datetime.tzinfo): # http://www.twinsun.com/tz/tz-link.htm # ftp://elsie.nci.nih.gov/pub/tz*.tar.gz def __init__(self, fileobj): if isinstance(fileobj, str): self._filename = fileobj fileobj = open(fileobj) elif hasattr(fileobj, "name"): self._filename = fileobj.name else: self._filename = repr(fileobj) # From tzfile(5): # # The time zone information files used by tzset(3) # begin with the magic characters "TZif" to identify # them as time zone information files, followed by # sixteen bytes reserved for future use, followed by # six four-byte values of type long, written in a # ``standard'' byte order (the high-order byte # of the value is written first). if fileobj.read(4) != "TZif": raise ValueError("magic not found") fileobj.read(16) ( # The number of UTC/local indicators stored in the file. ttisgmtcnt, # The number of standard/wall indicators stored in the file. ttisstdcnt, # The number of leap seconds for which data is # stored in the file. leapcnt, # The number of "transition times" for which data # is stored in the file. timecnt, # The number of "local time types" for which data # is stored in the file (must not be zero). typecnt, # The number of characters of "time zone # abbreviation strings" stored in the file. charcnt, ) = struct.unpack(">6l", fileobj.read(24)) # The above header is followed by tzh_timecnt four-byte # values of type long, sorted in ascending order. # These values are written in ``standard'' byte order. # Each is used as a transition time (as returned by # time(2)) at which the rules for computing local time # change. if timecnt: self._trans_list = struct.unpack(">%dl" % timecnt, fileobj.read(timecnt*4)) else: self._trans_list = [] # Next come tzh_timecnt one-byte values of type unsigned # char; each one tells which of the different types of # ``local time'' types described in the file is associated # with the same-indexed transition time. These values # serve as indices into an array of ttinfo structures that # appears next in the file. if timecnt: self._trans_idx = struct.unpack(">%dB" % timecnt, fileobj.read(timecnt)) else: self._trans_idx = [] # Each ttinfo structure is written as a four-byte value # for tt_gmtoff of type long, in a standard byte # order, followed by a one-byte value for tt_isdst # and a one-byte value for tt_abbrind. In each # structure, tt_gmtoff gives the number of # seconds to be added to UTC, tt_isdst tells whether # tm_isdst should be set by localtime(3), and # tt_abbrind serves as an index into the array of # time zone abbreviation characters that follow the # ttinfo structure(s) in the file. ttinfo = [] for i in range(typecnt): ttinfo.append(struct.unpack(">lbb", fileobj.read(6))) abbr = fileobj.read(charcnt) # Then there are tzh_leapcnt pairs of four-byte # values, written in standard byte order; the # first value of each pair gives the time (as # returned by time(2)) at which a leap second # occurs; the second gives the total number of # leap seconds to be applied after the given time. # The pairs of values are sorted in ascending order # by time. # Not used, for now if leapcnt: leap = struct.unpack(">%dl" % leapcnt*2, fileobj.read(leapcnt*8)) # Then there are tzh_ttisstdcnt standard/wall # indicators, each stored as a one-byte value; # they tell whether the transition times associated # with local time types were specified as standard # time or wall clock time, and are used when # a time zone file is used in handling POSIX-style # time zone environment variables. if ttisstdcnt: isstd = struct.unpack(">%db" % ttisstdcnt, fileobj.read(ttisstdcnt)) # Finally, there are tzh_ttisgmtcnt UTC/local # indicators, each stored as a one-byte value; # they tell whether the transition times associated # with local time types were specified as UTC or # local time, and are used when a time zone file # is used in handling POSIX-style time zone envi- # ronment variables. if ttisgmtcnt: isgmt = struct.unpack(">%db" % ttisgmtcnt, fileobj.read(ttisgmtcnt)) # ** Everything has been read ** # Build ttinfo list self._ttinfo_list = [] for i in range(typecnt): gmtoff, isdst, abbrind = ttinfo[i] # Round to full-minutes if that's not the case. Python's # datetime doesn't accept sub-minute timezones. Check # http://python.org/sf/1447945 for some information. gmtoff = (gmtoff+30)//60*60 tti = _ttinfo() tti.offset = gmtoff tti.delta = datetime.timedelta(seconds=gmtoff) tti.isdst = isdst tti.abbr = abbr[abbrind:abbr.find('\x00', abbrind)] tti.isstd = (ttisstdcnt > i and isstd[i] != 0) tti.isgmt = (ttisgmtcnt > i and isgmt[i] != 0) self._ttinfo_list.append(tti) # Replace ttinfo indexes for ttinfo objects. trans_idx = [] for idx in self._trans_idx: trans_idx.append(self._ttinfo_list[idx]) self._trans_idx = tuple(trans_idx) # Set standard, dst, and before ttinfos. before will be # used when a given time is before any transitions, # and will be set to the first non-dst ttinfo, or to # the first dst, if all of them are dst. self._ttinfo_std = None self._ttinfo_dst = None self._ttinfo_before = None if self._ttinfo_list: if not self._trans_list: self._ttinfo_std = self._ttinfo_first = self._ttinfo_list[0] else: for i in range(timecnt-1,-1,-1): tti = self._trans_idx[i] if not self._ttinfo_std and not tti.isdst: self._ttinfo_std = tti elif not self._ttinfo_dst and tti.isdst: self._ttinfo_dst = tti if self._ttinfo_std and self._ttinfo_dst: break else: if self._ttinfo_dst and not self._ttinfo_std: self._ttinfo_std = self._ttinfo_dst for tti in self._ttinfo_list: if not tti.isdst: self._ttinfo_before = tti break else: self._ttinfo_before = self._ttinfo_list[0] # Now fix transition times to become relative to wall time. # # I'm not sure about this. In my tests, the tz source file # is setup to wall time, and in the binary file isstd and # isgmt are off, so it should be in wall time. OTOH, it's # always in gmt time. Let me know if you have comments # about this. laststdoffset = 0 self._trans_list = list(self._trans_list) for i in range(len(self._trans_list)): tti = self._trans_idx[i] if not tti.isdst: # This is std time. self._trans_list[i] += tti.offset laststdoffset = tti.offset else: # This is dst time. Convert to std. self._trans_list[i] += laststdoffset self._trans_list = tuple(self._trans_list) def _find_ttinfo(self, dt, laststd=0): timestamp = ((dt.toordinal() - EPOCHORDINAL) * 86400 + dt.hour * 3600 + dt.minute * 60 + dt.second) idx = 0 for trans in self._trans_list: if timestamp < trans: break idx += 1 else: return self._ttinfo_std if idx == 0: return self._ttinfo_before if laststd: while idx > 0: tti = self._trans_idx[idx-1] if not tti.isdst: return tti idx -= 1 else: return self._ttinfo_std else: return self._trans_idx[idx-1] def utcoffset(self, dt): if not self._ttinfo_std: return ZERO return self._find_ttinfo(dt).delta def dst(self, dt): if not self._ttinfo_dst: return ZERO tti = self._find_ttinfo(dt) if not tti.isdst: return ZERO # The documentation says that utcoffset()-dst() must # be constant for every dt. return self._find_ttinfo(dt, laststd=1).delta-tti.delta # An alternative for that would be: # # return self._ttinfo_dst.offset-self._ttinfo_std.offset # # However, this class stores historical changes in the # dst offset, so I belive that this wouldn't be the right # way to implement this. def tzname(self, dt): if not self._ttinfo_std: return None return self._find_ttinfo(dt).abbr def __eq__(self, other): if not isinstance(other, tzfile): return False return (self._trans_list == other._trans_list and self._trans_idx == other._trans_idx and self._ttinfo_list == other._ttinfo_list) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._filename) def __reduce__(self): if not os.path.isfile(self._filename): raise ValueError("Unpickable %s class" % self.__class__.__name__) return (self.__class__, (self._filename,)) class tzrange(datetime.tzinfo): def __init__(self, stdabbr, stdoffset=None, dstabbr=None, dstoffset=None, start=None, end=None): global relativedelta if not relativedelta: from dateutil import relativedelta self._std_abbr = stdabbr self._dst_abbr = dstabbr if stdoffset is not None: self._std_offset = datetime.timedelta(seconds=stdoffset) else: self._std_offset = ZERO if dstoffset is not None: self._dst_offset = datetime.timedelta(seconds=dstoffset) elif dstabbr and stdoffset is not None: self._dst_offset = self._std_offset+datetime.timedelta(hours=+1) else: self._dst_offset = ZERO if start is None: self._start_delta = relativedelta.relativedelta( hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) else: self._start_delta = start if end is None: self._end_delta = relativedelta.relativedelta( hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) else: self._end_delta = end def utcoffset(self, dt): if self._isdst(dt): return self._dst_offset else: return self._std_offset def dst(self, dt): if self._isdst(dt): return self._dst_offset-self._std_offset else: return ZERO def tzname(self, dt): if self._isdst(dt): return self._dst_abbr else: return self._std_abbr def _isdst(self, dt): if not self._start_delta: return False year = datetime.date(dt.year,1,1) start = year+self._start_delta end = year+self._end_delta dt = dt.replace(tzinfo=None) if start < end: return dt >= start and dt < end else: return dt >= start or dt < end def __eq__(self, other): if not isinstance(other, tzrange): return False return (self._std_abbr == other._std_abbr and self._dst_abbr == other._dst_abbr and self._std_offset == other._std_offset and self._dst_offset == other._dst_offset and self._start_delta == other._start_delta and self._end_delta == other._end_delta) def __ne__(self, other): return not self.__eq__(other) def __repr__(self): return "%s(...)" % self.__class__.__name__ __reduce__ = object.__reduce__ class tzstr(tzrange): def __init__(self, s): global parser if not parser: from dateutil import parser self._s = s res = parser._parsetz(s) if res is None: raise ValueError("unknown string format") # We must initialize it first, since _delta() needs # _std_offset and _dst_offset set. Use False in start/end # to avoid building it two times. tzrange.__init__(self, res.stdabbr, res.stdoffset, res.dstabbr, res.dstoffset, start=False, end=False) self._start_delta = self._delta(res.start) if self._start_delta: self._end_delta = self._delta(res.end, isend=1) def _delta(self, x, isend=0): kwargs = {} if x.month is not None: kwargs["month"] = x.month if x.weekday is not None: kwargs["weekday"] = relativedelta.weekday(x.weekday, x.week) if x.week > 0: kwargs["day"] = 1 else: kwargs["day"] = 31 elif x.day: kwargs["day"] = x.day elif x.yday is not None: kwargs["yearday"] = x.yday elif x.jyday is not None: kwargs["nlyearday"] = x.jyday if not kwargs: # Default is to start on first sunday of april, and end # on last sunday of october. if not isend: kwargs["month"] = 4 kwargs["day"] = 1 kwargs["weekday"] = relativedelta.SU(+1) else: kwargs["month"] = 10 kwargs["day"] = 31 kwargs["weekday"] = relativedelta.SU(-1) if x.time is not None: kwargs["seconds"] = x.time else: # Default is 2AM. kwargs["seconds"] = 7200 if isend: # Convert to standard time, to follow the documented way # of working with the extra hour. See the documentation # of the tzinfo class. delta = self._dst_offset-self._std_offset kwargs["seconds"] -= delta.seconds+delta.days*86400 return relativedelta.relativedelta(**kwargs) def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._s) class _tzicalvtzcomp: def __init__(self, tzoffsetfrom, tzoffsetto, isdst, tzname=None, rrule=None): self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) self.tzoffsetdiff = self.tzoffsetto-self.tzoffsetfrom self.isdst = isdst self.tzname = tzname self.rrule = rrule class _tzicalvtz(datetime.tzinfo): def __init__(self, tzid, comps=[]): self._tzid = tzid self._comps = comps self._cachedate = [] self._cachecomp = [] def _find_comp(self, dt): if len(self._comps) == 1: return self._comps[0] dt = dt.replace(tzinfo=None) try: return self._cachecomp[self._cachedate.index(dt)] except ValueError: pass lastcomp = None lastcompdt = None for comp in self._comps: if not comp.isdst: # Handle the extra hour in DST -> STD compdt = comp.rrule.before(dt-comp.tzoffsetdiff, inc=True) else: compdt = comp.rrule.before(dt, inc=True) if compdt and (not lastcompdt or lastcompdt < compdt): lastcompdt = compdt lastcomp = comp if not lastcomp: # RFC says nothing about what to do when a given # time is before the first onset date. We'll look for the # first standard component, or the first component, if # none is found. for comp in self._comps: if not comp.isdst: lastcomp = comp break else: lastcomp = comp[0] self._cachedate.insert(0, dt) self._cachecomp.insert(0, lastcomp) if len(self._cachedate) > 10: self._cachedate.pop() self._cachecomp.pop() return lastcomp def utcoffset(self, dt): return self._find_comp(dt).tzoffsetto def dst(self, dt): comp = self._find_comp(dt) if comp.isdst: return comp.tzoffsetdiff else: return ZERO def tzname(self, dt): return self._find_comp(dt).tzname def __repr__(self): return "" % self._tzid __reduce__ = object.__reduce__ class tzical: def __init__(self, fileobj): global rrule if not rrule: from dateutil import rrule if isinstance(fileobj, str): self._s = fileobj fileobj = open(fileobj) elif hasattr(fileobj, "name"): self._s = fileobj.name else: self._s = repr(fileobj) self._vtz = {} self._parse_rfc(fileobj.read()) def keys(self): return self._vtz.keys() def get(self, tzid=None): if tzid is None: keys = list(self._vtz) if len(keys) == 0: raise Exception("no timezones defined") elif len(keys) > 1: raise Exception("more than one timezone available") tzid = keys[0] return self._vtz.get(tzid) def _parse_offset(self, s): s = s.strip() if not s: raise ValueError("empty offset") if s[0] in ('+', '-'): signal = (-1,+1)[s[0]=='+'] s = s[1:] else: signal = +1 if len(s) == 4: return (int(s[:2])*3600+int(s[2:])*60)*signal elif len(s) == 6: return (int(s[:2])*3600+int(s[2:4])*60+int(s[4:]))*signal else: raise ValueError("invalid offset: "+s) def _parse_rfc(self, s): lines = s.splitlines() if not lines: raise ValueError("empty string") # Unfold i = 0 while i < len(lines): line = lines[i].rstrip() if not line: del lines[i] elif i > 0 and line[0] == " ": lines[i-1] += line[1:] del lines[i] else: i += 1 tzid = None comps = [] invtz = False comptype = None for line in lines: if not line: continue name, value = line.split(':', 1) parms = name.split(';') if not parms: raise ValueError("empty property name") name = parms[0].upper() parms = parms[1:] if invtz: if name == "BEGIN": if value in ("STANDARD", "DAYLIGHT"): # Process component pass else: raise ValueError("unknown component: "+value) comptype = value founddtstart = False tzoffsetfrom = None tzoffsetto = None rrulelines = [] tzname = None elif name == "END": if value == "VTIMEZONE": if comptype: raise ValueError( "component not closed: "+comptype) if not tzid: raise ValueError( "mandatory TZID not found") if not comps: raise ValueError( "at least one component is needed") # Process vtimezone self._vtz[tzid] = _tzicalvtz(tzid, comps) invtz = False elif value == comptype: if not founddtstart: raise ValueError( "mandatory DTSTART not found") if tzoffsetfrom is None: raise ValueError( "mandatory TZOFFSETFROM not found") if tzoffsetto is None: raise ValueError( "mandatory TZOFFSETFROM not found") # Process component rr = None if rrulelines: rr = rrule.rrulestr("\n".join(rrulelines), compatible=True, ignoretz=True, cache=True) comp = _tzicalvtzcomp(tzoffsetfrom, tzoffsetto, (comptype == "DAYLIGHT"), tzname, rr) comps.append(comp) comptype = None else: raise ValueError( "invalid component end: "+value) elif comptype: if name == "DTSTART": rrulelines.append(line) founddtstart = True elif name in ("RRULE", "RDATE", "EXRULE", "EXDATE"): rrulelines.append(line) elif name == "TZOFFSETFROM": if parms: raise ValueError( "unsupported %s parm: %s "%(name, parms[0])) tzoffsetfrom = self._parse_offset(value) elif name == "TZOFFSETTO": if parms: raise ValueError( "unsupported TZOFFSETTO parm: "+parms[0]) tzoffsetto = self._parse_offset(value) elif name == "TZNAME": if parms: raise ValueError( "unsupported TZNAME parm: "+parms[0]) tzname = value elif name == "COMMENT": pass else: raise ValueError("unsupported property: "+name) else: if name == "TZID": if parms: raise ValueError( "unsupported TZID parm: "+parms[0]) tzid = value elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): pass else: raise ValueError("unsupported property: "+name) elif name == "BEGIN" and value == "VTIMEZONE": tzid = None comps = [] invtz = True def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._s) if sys.platform != "win32": TZFILES = ["/etc/localtime", "localtime"] TZPATHS = ["/usr/share/zoneinfo", "/usr/lib/zoneinfo", "/etc/zoneinfo"] else: TZFILES = [] TZPATHS = [] def gettz(name=None): tz = None if not name: try: name = os.environ["TZ"] except KeyError: pass if name is None or name == ":": for filepath in TZFILES: if not os.path.isabs(filepath): filename = filepath for path in TZPATHS: filepath = os.path.join(path, filename) if os.path.isfile(filepath): break else: continue if os.path.isfile(filepath): try: tz = tzfile(filepath) break except (OSError, ValueError): pass else: tz = tzlocal() else: if name.startswith(":"): name = name[:-1] if os.path.isabs(name): if os.path.isfile(name): tz = tzfile(name) else: tz = None else: for path in TZPATHS: filepath = os.path.join(path, name) if not os.path.isfile(filepath): filepath = filepath.replace(' ','_') if not os.path.isfile(filepath): continue try: tz = tzfile(filepath) break except (OSError, ValueError): pass else: tz = None if tzwin: try: tz = tzwin(name) except OSError: pass if not tz: from dateutil.zoneinfo import gettz tz = gettz(name) if not tz: for c in name: # name must have at least one offset to be a tzstr if c in "0123456789": try: tz = tzstr(name) except ValueError: pass break else: if name in ("GMT", "UTC"): tz = tzutc() elif name in time.tzname: tz = tzlocal() return tz # vim:ts=4:sw=4:et ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/uri.py0000644000175000017500000001112714645174376015360 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from urllib.parse import quote from storm.exceptions import URIError class URI: """A representation of a Uniform Resource Identifier (URI). This is intended exclusively for database connection URIs. @ivar username: The username part of the URI, or C{None}. @ivar password: The password part of the URI, or C{None}. @ivar host: The host part of the URI, or C{None}. @type port: L{int} @ivar port: The port part of the URI, or C{None}. @ivar database: The part of the URI representing the database name, or C{None}. """ username = None password = None host = None port = None database = None def __init__(self, uri_str): try: self.scheme, rest = uri_str.split(":", 1) except ValueError: raise URIError("URI has no scheme: %s" % repr(uri_str)) self.options = {} if "?" in rest: rest, options = rest.split("?", 1) for pair in options.split("&"): key, value = pair.split("=", 1) self.options[unescape(key)] = unescape(value) if rest: if not rest.startswith("//"): self.database = unescape(rest) else: rest = rest[2:] if "/" in rest: rest, database = rest.split("/", 1) self.database = unescape(database) if "@" in rest: userpass, hostport = rest.split("@", 1) else: userpass = None hostport = rest if hostport: if ":" in hostport: host, port = hostport.rsplit(":", 1) self.host = unescape(host) if port: self.port = int(port) else: self.host = unescape(hostport) if userpass is not None: if ":" in userpass: username, password = userpass.rsplit(":", 1) self.username = unescape(username) self.password = unescape(password) else: self.username = unescape(userpass) def copy(self): uri = object.__new__(self.__class__) uri.__dict__.update(self.__dict__) uri.options = self.options.copy() return uri def __str__(self): tokens = [self.scheme, ":"] append = tokens.append if (self.username is not None or self.password is not None or self.host is not None or self.port is not None): append("//") if self.username is not None or self.password is not None: if self.username is not None: append(escape(self.username)) if self.password is not None: append(":") append(escape(self.password)) append("@") if self.host is not None: append(escape(self.host)) if self.port is not None: append(":") append(str(self.port)) append("/") if self.database is not None: append(escape(self.database, "/")) if self.options: options = ["%s=%s" % (escape(key), escape(value)) for key, value in sorted(self.options.items())] append("?") append("&".join(options)) return "".join(tokens) def escape(s, safe=""): return quote(s, safe) def unescape(s): if "%" not in s: return s i = 0 j = s.find("%") r = [] while j != -1: r.append(s[i:j]) i = j+3 r.append(chr(int(s[j+1:i], 16))) j = s.find("%", i) r.append(s[i:]) return "".join(r) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/variables.py0000644000175000017500000006352014645174376016535 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from datetime import datetime, date, time, timedelta from decimal import Decimal from functools import partial import json import pickle import re import uuid import weakref from storm.exceptions import NoneError from storm import Undef, has_cextensions __all__ = [ "VariableFactory", "Variable", "LazyValue", "BoolVariable", "IntVariable", "FloatVariable", "DecimalVariable", "BytesVariable", "RawStrVariable", "UnicodeVariable", "DateTimeVariable", "DateVariable", "TimeVariable", "TimeDeltaVariable", "EnumVariable", "UUIDVariable", "PickleVariable", "JSONVariable", "ListVariable", ] class LazyValue: """Marker to be used as a base class on lazily evaluated values.""" __slots__ = () def raise_none_error(column, default=False): description = "default value" if default else "value" if not column: raise NoneError("None isn't acceptable as a %s" % description) else: from storm.expr import compile, CompileError name = column.name if column.table is not Undef: try: table = compile(column.table) name = "%s.%s" % (table, name) except CompileError: pass raise NoneError( "None isn't acceptable as a %s for %s" % (description, name)) VariableFactory = partial class Variable: """Basic representation of a database value in Python. @type column: L{storm.expr.Column} @ivar column: The column this variable represents. @type event: L{storm.event.EventSystem} @ivar event: The event system on which to broadcast events. If None, no events will be emitted. """ _value = Undef _lazy_value = Undef _checkpoint_state = Undef _allow_none = True _validator = None _validator_object_factory = None _validator_attribute = None column = None event = None def __init__(self, value=Undef, value_factory=Undef, from_db=False, allow_none=True, column=None, event=None, validator=None, validator_object_factory=None, validator_attribute=None): """ @param value: The initial value of this variable. The default behavior is for the value to stay undefined until it is set with L{set}. @param value_factory: If specified, this will immediately be called to get the initial value. @param from_db: A boolean value indicating where the initial value comes from, if C{value} or C{value_factory} are specified. @param allow_none: A boolean indicating whether None should be allowed to be set as the value of this variable. @param validator: Validation function called whenever trying to set the variable to a non-db value. The function should look like validator(object, attr, value), where the first and second arguments are the result of validator_object_factory() (or None, if this parameter isn't provided) and the value of validator_attribute, respectively. When called, the function should raise an error if the value is unacceptable, or return the value to be used in place of the original value otherwise. @type column: L{storm.expr.Column} @param column: The column that this variable represents. It's used for reporting better error messages. @type event: L{storm.event.EventSystem} @param event: The event system to broadcast messages with. If not specified, then no events will be broadcast. """ if not allow_none: self._allow_none = False if value is None: raise_none_error(column, default=True) if value is not Undef: self.set(value, from_db) elif value_factory is not Undef: self.set(value_factory(), from_db) if validator is not None: self._validator = validator self._validator_object_factory = validator_object_factory self._validator_attribute = validator_attribute self.column = column self.event = weakref.proxy(event) if event is not None else None def get_lazy(self, default=None): """Get the current L{LazyValue} without resolving its value. @param default: If no L{LazyValue} was previously specified, return this value. Defaults to None. """ if self._lazy_value is Undef: return default return self._lazy_value def get(self, default=None, to_db=False): """Get the value, resolving it from a L{LazyValue} if necessary. If the current value is an instance of L{LazyValue}, then the C{resolve-lazy-value} event will be emitted, to give third parties the chance to resolve the lazy value to a real value. @param default: Returned if no value has been set. @param to_db: A boolean flag indicating whether this value is destined for the database. """ if self._lazy_value is not Undef and self.event is not None: self.event.emit("resolve-lazy-value", self, self._lazy_value) value = self._value if value is Undef: return default if value is None: return None return self.parse_get(value, to_db) def set(self, value, from_db=False): """Set a new value. Generally this will be called when an attribute was set in Python, or data is being loaded from the database. If the value is different from the previous value (or it is a L{LazyValue}), then the C{changed} event will be emitted. @param value: The value to set. If this is an instance of L{LazyValue}, then later calls to L{get} will try to resolve the value. @param from_db: A boolean indicating whether this value has come from the database. """ # FASTPATH This method is part of the fast path. Be careful when # changing it (try to profile any changes). if isinstance(value, LazyValue): self._lazy_value = value self._checkpoint_state = new_value = Undef else: if not from_db and self._validator is not None: # We use a factory rather than the object itself to prevent # the cycle object => obj_info => variable => object value = self._validator(self._validator_object_factory and self._validator_object_factory(), self._validator_attribute, value) self._lazy_value = Undef if value is None: if self._allow_none is False: raise_none_error(self.column) new_value = None else: new_value = self.parse_set(value, from_db) if from_db: # Prepare it for being used by the hook below. value = self.parse_get(new_value, False) old_value = self._value self._value = new_value if (self.event is not None and (self._lazy_value is not Undef or new_value != old_value)): if old_value is not None and old_value is not Undef: old_value = self.parse_get(old_value, False) self.event.emit("changed", self, old_value, value, from_db) def delete(self): """Delete the internal value. If there was a value set, then emit the C{changed} event. """ old_value = self._value if old_value is not Undef: self._value = Undef if self.event is not None: if old_value is not None and old_value is not Undef: old_value = self.parse_get(old_value, False) self.event.emit("changed", self, old_value, Undef, False) def is_defined(self): """Check whether there is currently a value. @return: boolean indicating whether there is currently a value for this variable. Note that if a L{LazyValue} was previously set, this returns False; it only returns True if there is currently a real value set. """ return self._value is not Undef def has_changed(self): """Check whether the value has changed. @return: boolean indicating whether the value has changed since the last call to L{checkpoint}. """ return (self._lazy_value is not Undef or self.get_state() != self._checkpoint_state) def get_state(self): """Get the internal state of this object. @return: A value which can later be passed to L{set_state}. """ return (self._lazy_value, self._value) def set_state(self, state): """Set the internal state of this object. @param state: A result from a previous call to L{get_state}. The internal state of this variable will be set to the state of the variable which get_state was called on. """ self._lazy_value, self._value = state def checkpoint(self): """"Checkpoint" the internal state. See L{has_changed}. """ self._checkpoint_state = self.get_state() def copy(self): """Make a new copy of this Variable with the same internal state.""" variable = self.__class__.__new__(self.__class__) variable.set_state(self.get_state()) return variable def parse_get(self, value, to_db): """Convert the internal value to an external value. Get a representation of this value either for Python or for the database. This method is only intended to be overridden in subclasses, not called from external code. @param value: The value to be converted. @param to_db: Whether or not this value is destined for the database. """ return value def parse_set(self, value, from_db): """Convert an external value to an internal value. A value is being set either from Python code or from the database. Parse it into its internal representation. This method is only intended to be overridden in subclasses, not called from external code. @param value: The value, either from Python code setting an attribute or from a column in a database. @param from_db: A boolean flag indicating whether this value is from the database. """ return value if has_cextensions: from storm.cextensions import Variable class BoolVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if not isinstance(value, (int, float, Decimal)): raise TypeError("Expected bool, found %r: %r" % (type(value), value)) return bool(value) class IntVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if not isinstance(value, (int, float, Decimal)): raise TypeError("Expected int, found %r: %r" % (type(value), value)) return int(value) class FloatVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if not isinstance(value, (int, float, Decimal)): raise TypeError("Expected float, found %r: %r" % (type(value), value)) return float(value) class DecimalVariable(Variable): __slots__ = () @staticmethod def parse_set(value, from_db): if (from_db and isinstance(value, str)) or isinstance(value, int): value = Decimal(value) elif not isinstance(value, Decimal): raise TypeError("Expected Decimal, found %r: %r" % (type(value), value)) return value @staticmethod def parse_get(value, to_db): if to_db: return str(value) return value class BytesVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if isinstance(value, memoryview): value = bytes(value) elif not isinstance(value, bytes): raise TypeError("Expected bytes, found %r: %r" % (type(value), value)) return value # DEPRECATED: BytesVariable was RawStrVariable until 0.22. RawStrVariable = BytesVariable class UnicodeVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if not isinstance(value, str): raise TypeError("Expected text, found %r: %r" % (type(value), value)) return value class DateTimeVariable(Variable): __slots__ = ("_tzinfo",) def __init__(self, *args, **kwargs): self._tzinfo = kwargs.pop("tzinfo", None) super().__init__(*args, **kwargs) def parse_set(self, value, from_db): if from_db: if isinstance(value, datetime): pass elif isinstance(value, str): if " " not in value: raise ValueError("Unknown date/time format: %r" % value) date_str, time_str = value.split(" ") value = datetime(*(_parse_date(date_str) + _parse_time(time_str))) else: raise TypeError("Expected datetime, found %s" % repr(value)) if self._tzinfo is not None: if value.tzinfo is None: value = value.replace(tzinfo=self._tzinfo) else: value = value.astimezone(self._tzinfo) else: if type(value) in (int, float): value = datetime.utcfromtimestamp(value) elif not isinstance(value, datetime): raise TypeError("Expected datetime, found %s" % repr(value)) if self._tzinfo is not None: # Python 3.6 gained support for calling the astimezone # method on naive datetime objects, presuming them to # represent system local time. This is probably # inappropriate for most uses of Storm, since depending on # what the system local time happens to be is usually a # mistake, so forbid this explicitly for now; we can always # open it up later if there's a good reason. if (value.tzinfo is None or value.tzinfo.utcoffset(value) is None): raise ValueError( "Expected aware datetime, found naive: %r" % value) value = value.astimezone(self._tzinfo) return value class DateVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if from_db: if value is None: return None if isinstance(value, datetime): return value.date() if isinstance(value, date): return value if not isinstance(value, str): raise TypeError("Expected date, found %s" % repr(value)) if " " in value: value, time_str = value.split(" ") return date(*_parse_date(value)) else: if isinstance(value, datetime): return value.date() if not isinstance(value, date): raise TypeError("Expected date, found %s" % repr(value)) return value class TimeVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if from_db: # XXX Can None ever get here, considering that set() checks for it? if value is None: return None if isinstance(value, time): return value if not isinstance(value, str): raise TypeError("Expected time, found %s" % repr(value)) if " " in value: date_str, value = value.split(" ") return time(*_parse_time(value)) else: if isinstance(value, datetime): return value.time() if not isinstance(value, time): raise TypeError("Expected time, found %s" % repr(value)) return value class TimeDeltaVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if from_db: # XXX Can None ever get here, considering that set() checks for it? if value is None: return None if isinstance(value, timedelta): return value if not isinstance(value, str): raise TypeError("Expected timedelta, found %s" % repr(value)) return _parse_interval(value) else: if not isinstance(value, timedelta): raise TypeError("Expected timedelta, found %s" % repr(value)) return value class UUIDVariable(Variable): __slots__ = () def parse_set(self, value, from_db): if from_db and isinstance(value, str): value = uuid.UUID(value) elif not isinstance(value, uuid.UUID): raise TypeError("Expected UUID, found %r: %r" % (type(value), value)) return value def parse_get(self, value, to_db): if to_db: return str(value) return value class EnumVariable(Variable): __slots__ = ("_get_map", "_set_map") def __init__(self, get_map, set_map, *args, **kwargs): self._get_map = get_map self._set_map = set_map Variable.__init__(self, *args, **kwargs) def parse_set(self, value, from_db): if from_db: return value try: return self._set_map[value] except KeyError: raise ValueError("Invalid enum value: %s" % repr(value)) def parse_get(self, value, to_db): if to_db: return value try: return self._get_map[value] except KeyError: raise ValueError("Invalid enum value: %s" % repr(value)) class MutableValueVariable(Variable): """ A variable which contains a reference to mutable content. For this kind of variable, we can't simply detect when a modification has been made, so we have to synchronize the content of the variable when the store is flushing current objects, to check if the state has changed. """ __slots__ = ("_event_system") def __init__(self, *args, **kwargs): self._event_system = None Variable.__init__(self, *args, **kwargs) if self.event is not None: self.event.hook("start-tracking-changes", self._start_tracking) self.event.hook("object-deleted", self._detect_changes_and_stop) def _start_tracking(self, obj_info, event_system): self._event_system = weakref.proxy(event_system) self.event.hook("stop-tracking-changes", self._stop_tracking) def _stop_tracking(self, obj_info, event_system): event_system.unhook("flush", self._detect_changes) self._event_system = None def _detect_changes(self, obj_info): if (self._checkpoint_state is not Undef and self.get_state() != self._checkpoint_state): self.event.emit("changed", self, None, self._value, False) def _detect_changes_and_stop(self, obj_info): self._detect_changes(obj_info) if self._event_system is not None: self._stop_tracking(obj_info, self._event_system) def get(self, default=None, to_db=False): if self._event_system is not None: self._event_system.hook("flush", self._detect_changes) return super().get(default, to_db) def set(self, value, from_db=False): if self._event_system is not None: if isinstance(value, LazyValue): self._event_system.unhook("flush", self._detect_changes) else: self._event_system.hook("flush", self._detect_changes) super().set(value, from_db) class EncodedValueVariable(MutableValueVariable): __slots__ = () def parse_set(self, value, from_db): if from_db: if isinstance(value, memoryview): value = bytes(value) return self._loads(value) else: return value def parse_get(self, value, to_db): if to_db: return self._dumps(value) else: return value def get_state(self): return (self._lazy_value, self._dumps(self._value)) def set_state(self, state): self._lazy_value = state[0] self._value = self._loads(state[1]) class PickleVariable(EncodedValueVariable): def _loads(self, value): return pickle.loads(value) def _dumps(self, value): return pickle.dumps(value, -1) class JSONVariable(EncodedValueVariable): __slots__ = () def _loads(self, value): if not isinstance(value, str): raise TypeError( "Cannot safely assume encoding of byte string %r." % value) return json.loads(value) def _dumps(self, value): # http://www.ietf.org/rfc/rfc4627.txt states that JSON is text-based # and so we treat it as such here. In other words, this method returns # Unicode text and never bytes. dump = json.dumps(value, ensure_ascii=False) if not isinstance(dump, str): # json.dumps() does not always return unicode. See # http://code.google.com/p/simplejson/issues/detail?id=40 for one # of many discussions of str/unicode handling in simplejson. dump = dump.decode("utf-8") return dump class ListVariable(MutableValueVariable): __slots__ = ("_item_factory",) def __init__(self, item_factory, *args, **kwargs): self._item_factory = item_factory MutableValueVariable.__init__(self, *args, **kwargs) def parse_set(self, value, from_db): if from_db: item_factory = self._item_factory return [item_factory(value=val, from_db=from_db).get() for val in value] else: return value def parse_get(self, value, to_db): if to_db: item_factory = self._item_factory return [item_factory(value=val, from_db=False) for val in value] else: return value def get_state(self): return (self._lazy_value, pickle.dumps(self._value, -1)) def set_state(self, state): self._lazy_value = state[0] self._value = pickle.loads(state[1]) def _parse_time(time_str): # TODO Add support for timezones. colons = time_str.count(":") if not 1 <= colons <= 2: raise ValueError("Unknown time format: %r" % time_str) if colons == 2: hour, minute, second = time_str.split(":") else: hour, minute = time_str.split(":") second = "0" if "." in second: second, microsecond = second.split(".") second = int(second) microsecond = int(int(microsecond) * 10 ** (6 - len(microsecond))) return int(hour), int(minute), second, microsecond return int(hour), int(minute), int(second), 0 def _parse_date(date_str): if "-" not in date_str: raise ValueError("Unknown date format: %r" % date_str) year, month, day = date_str.split("-") return int(year), int(month), int(day) def _parse_interval_table(): table = {} for units, delta in ( ("d day days", timedelta), ("h hour hours", lambda x: timedelta(hours=x)), ("m min minute minutes", lambda x: timedelta(minutes=x)), ("s sec second seconds", lambda x: timedelta(seconds=x)), ("ms millisecond milliseconds", lambda x: timedelta(milliseconds=x)), ("microsecond microseconds", lambda x: timedelta(microseconds=x)) ): for unit in units.split(): table[unit] = delta return table _parse_interval_table = _parse_interval_table() _parse_interval_re = re.compile(r"[\s,]*" r"([-+]?(?:\d\d?:\d\d?(?::\d\d?)?(?:\.\d+)?" r"|\d+(?:\.\d+)?))" r"[\s,]*") def _parse_interval(interval): result = timedelta(0) value = None for token in _parse_interval_re.split(interval): if not token: pass elif ":" in token: if value is not None: result += timedelta(days=value) value = None h, m, s, ms = _parse_time(token) result += timedelta(hours=h, minutes=m, seconds=s, microseconds=ms) elif value is None: try: value = float(token) except ValueError: raise ValueError("Expected an interval value rather than " "%r in interval %r" % (token, interval)) else: unit = _parse_interval_table.get(token) if unit is None: raise ValueError("Unsupported interval unit %r in interval %r" % (token, interval)) result += unit(value) value = None if value is not None: result += timedelta(seconds=value) return result ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/wsgi.py0000644000175000017500000000456214571373456015535 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Robert Collins # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """Glue to wire a storm timeline tracer up to a WSGI app.""" import threading import weakref __all__ = ['make_app'] def make_app(app): """Capture the per-request timeline object needed for Storm tracing. To use firstly make your app and then wrap it with this C{make_app}:: >>> app, find_timeline = make_app(app) Then wrap the returned app with the C{timeline} app (or anything that sets C{environ['timeline.timeline']}):: >>> app = timeline.wsgi.make_app(app) Finally install a timeline tracer to capture Storm queries:: >>> install_tracer(TimelineTracer(find_timeline)) @return: A wrapped WSGI app and a timeline factory function for use with L{TimelineTracer }. """ timeline_map = threading.local() def wrapper(environ, start_response): timeline = environ.get('timeline.timeline') timeline_map.timeline = None if timeline is not None: timeline_map.timeline = weakref.ref(timeline) # We could clean up timeline_map.timeline after we're done with the # request, but for that we'd have to consume all the data from the # underlying app and it wouldn't play well with some non-standard # tricks (e.g. let the reactor consume IBodyProducers asynchronously # when returning large files) that some people may want to do. return app(environ, start_response) def get_timeline(): timeline = getattr(timeline_map, 'timeline', None) if timeline is not None: return timeline() return wrapper, get_timeline ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/xid.py0000644000175000017500000000214214645174376015342 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2012 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # class Xid: """ Represent a transaction identifier compliant with the XA specification. """ def __init__(self, format_id, global_transaction_id, branch_qualifier): self.format_id = format_id self.global_transaction_id = global_transaction_id self.branch_qualifier = branch_qualifier ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4291248 storm-1.0/storm/zope/0000755000175000017500000000000014645532536015156 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/zope/__init__.py0000644000175000017500000000333014571373456017270 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from zope.interface import classImplements from storm.info import ObjectInfo from storm.store import EmptyResultSet, ResultSet from storm.zope.interfaces import IResultSet, ISQLObjectResultSet from storm import sqlobject as storm_sqlobject classImplements(storm_sqlobject.SQLObjectResultSet, ISQLObjectResultSet) classImplements(ResultSet, IResultSet) classImplements(EmptyResultSet, IResultSet) try: from zope.security.checker import NoProxy, BasicTypes, _available_by_default except ImportError: # We don't have zope.security installed. pass else: # The following is required for storm.info.get_obj_info() to have # access to a proxied object which is already in the store (IOW, has # the object info set already). With this, Storm is able to # gracefully handle situations when a proxied object is passed to a # Store. _available_by_default.append("__storm_object_info__") BasicTypes[ObjectInfo] = NoProxy ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/zope/adapters.py0000644000175000017500000000210714571373456017335 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from zope.component import adapter from zope.interface import implementer from storm.zope.interfaces import IResultSet, ISQLObjectResultSet @adapter(ISQLObjectResultSet) @implementer(IResultSet) def sqlobject_result_set_to_storm_result_set(so_result_set): return so_result_set._result_set ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1589984001.0 storm-1.0/storm/zope/configure.zcml0000644000175000017500000000314113661235401020011 0ustar00cjwatsoncjwatson ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/zope/interfaces.py0000644000175000017500000001375414571373456017667 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from zope.interface import Interface from storm.expr import Undef class ZStormError(Exception): """ Raised when situations such as duplicate store creation, unknown store name, etc., arise. """ class IZStorm(Interface): """A flag interface used to lookup the ZStorm utility.""" class IResultSet(Interface): """The interface for a L{ResultSet}. The rationale behind the exposed attributes is: 1. Model code constructs a L{ResultSet} and returns a security proxied object to the view code. 2. View code treats the L{ResultSet} as an immutable sequence/iterable and presents the data to the user. Therefore several attributes of L{ResultSet} are not included here: - Both C{set()} and C{remove()} can be used to modify the contained objects, which will bypass the security proxies on those objects. - C{get_select_expr()} will return a L{Select} object, which has no security declarations (and it isn't clear that any would be desirable). - C{find()}, C{group_by()} and C{having()} are really used to configure result sets, so are mostly intended for use on the model side. - There may be an argument for exposing C{difference()}, C{intersection()} and C{union()} as a way for view code to combine multiple results, but it isn't clear how often it makes sense to do this on the view side rather than model side. """ def copy(): """ Return a copy of this result set object, with the same configuration. """ def config(distinct=None, offset=None, limit=None): """Configure the result set. @param distinct: Optionally, when true, only return distinct rows. @param offset: Optionally, the offset to start retrieving records from. @param limit: Optionally, the maximum number of rows to return. """ def __iter__(): """Iterate the result set.""" def __getitem__(index): """ Get the value at C{index} in the result set if C{index} is a single interger. If C{index} is a slice a new C{ResultSet} will be returned. """ def __contains__(item): """Check if C{item} is contained in the result set.""" def any(): """ Get a random object from the result set or C{None} if the result set is empty. """ def first(): """Return the first item from an ordered result set. @raises UnorderedError: Raised if the result set isn't ordered. """ def last(): """Return the last item from an ordered result set. @raises UnorderedError: Raised if the result set isn't ordered. """ def one(): """ Return one item from a result set containing at most one item or None if the result set is empty. @raises NotOneError: Raised if the result set contains more than one item. """ def order_by(*args): """Order the result set based on expressions in C{args}.""" def count(column=Undef, distinct=False): """Returns the number of rows in the result set. @param column: Optionally, the column to count. @param distinct: Optionally, when true, count only distinct rows. """ def max(column): """Returns the maximum C{column} value in the result set.""" def min(): """Returns the minimum C{column} value in the result set.""" def avg(): """Returns the average of C{column} values in the result set.""" def sum(): """Returns the sum of C{column} values in the result set.""" def values(*args): """Generator yields values for the columns specified in C{args}.""" def cached(): """Return matching objects from the cache for the current query.""" def is_empty(): """Return true if the result set contains no results.""" class ISQLObjectResultSet(Interface): def __getitem__(item): """List emulation.""" def __getslice__(slice): """Slice support.""" def __iter__(): """List emulation.""" def count(): """Return the number of items in the result set.""" def __bool__(): """Return C{True} if this result set contains any results. @note: This method is provided for compatibility with SQL Object. For new code, prefer L{is_empty}. It's compatible with L{ResultSet} which doesn't have a C{__bool__} implementation. """ def __contains__(item): """Support C{if FooObject in Foo.select(query)}.""" def intersect(otherSelect, intersectAll=False, orderBy=None): """Return the intersection of this result and C{otherSelect} @param otherSelect: the other L{ISQLObjectResultSet} @param intersectAll: whether to use INTERSECT ALL behaviour @param orderBy: the order the result set should use. """ def is_empty(): """Return C{True} if this result set doesn't contain any results.""" def prejoin(prejoins): """Return a new L{SelectResults} with the list of attributes prejoined. @param prejoins: The list of attribute names to prejoin. """ ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1336501902.0 storm-1.0/storm/zope/meta.zcml0000644000175000017500000000040611752263216016763 0ustar00cjwatsoncjwatson ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1709569838.0 storm-1.0/storm/zope/metaconfigure.py0000644000175000017500000000231114571373456020357 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from zope import component from storm.zope.interfaces import IZStorm def set_default_uri(name, uri): """Register C{uri} as the default URI for stores called C{name}.""" zstorm = component.getUtility(IZStorm) zstorm.set_default_uri(name, uri) def store(_context, name, uri): _context.action(discriminator=("store", name), callable=set_default_uri, args=(name, uri)) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/zope/metadirectives.py0000644000175000017500000000200114645174376020535 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # from zope.interface import Interface from zope.schema import TextLine class IStoreDirective(Interface): name = TextLine(title="Name", description="Store name") uri = TextLine(title="URI", description="Database URI") ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/zope/schema.py0000644000175000017500000000250714645174376017000 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # """ZStorm-aware schema manager.""" import transaction from storm.schema import Schema class ZCommitter: """A L{Schema} committer that uses Zope's transaction manager.""" def commit(self): transaction.commit() def rollback(self): transaction.abort() class ZSchema(Schema): """Convenience for creating L{Schema}s that use a L{ZCommitter}.""" def __init__(self, creates, drops, deletes, patch_package): committer = ZCommitter() super().__init__(creates, drops, deletes, patch_package, committer) ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/zope/testing.py0000644000175000017500000002672014645174376017220 0ustar00cjwatsoncjwatson# # Copyright (c) 2006, 2007 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import os import shutil import transaction from testresources import TestResourceManager from zope.component import provideUtility, getUtility from storm.schema.patch import UnknownPatchError from storm.schema.sharding import Sharding from storm.zope.zstorm import ZStorm, global_zstorm from storm.zope.interfaces import IZStorm class ZStormResourceManager(TestResourceManager): """Provide a L{ZStorm} resource to be used in test cases. The constructor is passed the details of the L{Store}s to be registered in the provided L{ZStore} resource. Then the C{make} and C{clean} methods make sure that such L{Store}s are properly setup and cleaned for each test. @param databases: A C{list} of C{dict}s holding the following keys: - 'name', the name of the store to be registered. - 'uri', the database URI to use to create the store. - 'schema', optionally, the L{Schema} for the tables in the store, if not given no schema will be applied. - 'schema-uri', optionally an alternate URI to use for applying the schema, if not given it defaults to 'uri'. @ivar force_delete: If C{True} for running L{Schema.delete} on a L{Store} even if no commit was performed by the test. Useful when running a test in a subprocess that might commit behind our back. @ivar use_global_zstorm: If C{True} then the C{global_zstorm} object from C{storm.zope.zstorm} will be used, instead of creating a new one. This is useful for code loading the zcml directives of C{storm.zope}. @ivar schema_stamp_dir: Optionally, a path to a directory that will be used to save timestamps of the schema's patch packages, so schema upgrades will be performed only when needed. This is just an optimisation to let the resource setup a bit faster. @ivar vertical_patching: If C{True}, patches will be applied "vertically", meaning that all patches for the first store will be applied, then all patches for the second store etc. Otherwise, if set to C{False} patches will be applied "horizontally" (see L{Sharding.upgrade}). The default is C{True} just because of backward-compatibility, but normally you should set it to C{False}. """ force_delete = False use_global_zstorm = False schema_stamp_dir = None vertical_patching = True def __init__(self, databases): super().__init__() self._databases = databases self._zstorm = None self._schema_zstorm = None self._commits = {} self._schemas = {} self._sharding = [] def make(self, dependencies): """Create a L{ZStorm} resource to be used by tests. @return: A L{ZStorm} object that will be shared among all tests using this resource manager. """ if self._zstorm is None: if self.use_global_zstorm: self._zstorm = global_zstorm else: self._zstorm = ZStorm() self._schema_zstorm = ZStorm() databases = self._databases # Adapt the old databases format to the new one, for backward # compatibility. This should be eventually dropped. if isinstance(databases, dict): databases = [ {"name": name, "uri": uri, "schema": schema} for name, (uri, schema) in databases.items()] # Provide the global IZStorm utility before applying patches, so # patch code can get the ztorm object if needed (e.g. looking up # other stores). provideUtility(self._zstorm) self._set_create_hook() enforce_schema = False for database in databases: name = database["name"] uri = database["uri"] schema = database.get("schema") schema_uri = database.get("schema-uri", uri) self._zstorm.set_default_uri(name, uri) if schema is not None: # The configuration for this database does not include a # schema definition, so we just setup the store (the user # code should apply the schema elsewhere, if any) self._schemas[name] = schema self._schema_zstorm.set_default_uri(name, schema_uri) schema.autocommit(False) store = self._schema_zstorm.get(name) if not self._sharding or self.vertical_patching: self._sharding.append(Sharding()) sharding = self._sharding[-1] sharding.add(store, schema) if self._has_patch_package_changed(name, schema): enforce_schema = True if enforce_schema: for sharding in self._sharding: try: sharding.upgrade() except UnknownPatchError: sharding.drop() sharding.create() except: # An unknown error occured, let's drop all timestamps # so subsequent runs won't assume that everything is # fine self._purge_schema_stamp_dir() raise else: sharding.delete() # Commit all schema changes across all stores transaction.commit() elif getUtility(IZStorm) is not self._zstorm: # This probably means that the test code has overwritten our # utility, let's re-register it. provideUtility(self._zstorm) return self._zstorm def _set_create_hook(self): """ Set a hook in ZStorm.create, so we can lazily set commit proxies. """ self._zstorm.__real_create__ = self._zstorm.create def create_hook(name, uri=None): store = self._zstorm.__real_create__(name, uri=uri) if self._schemas.get(name) is not None: # Only set commit proxies for databases that have a schema # that we can use for cleanup self._set_commit_proxy(store) return store self._zstorm.create = create_hook def _set_commit_proxy(self, store): """Set a commit proxy to keep track of commits and clean up the tables. @param store: The L{Store} to set the commit proxy on. Any commit on this store will result in the associated tables to be cleaned upon tear down. """ store.__real_commit__ = store.commit def commit_proxy(): self._commits[store] = True store.__real_commit__() store.commit = commit_proxy def _has_patch_package_changed(self, name, schema): """Whether the schema for the given database is up-to-date. As an optimisation, if the C{schema_stamp_dir} attribute is set, then this method performs a fast check based on the patch directory timestamp rather than the database patch table, so connections and upgrade queries can be skipped if there's no need. @param name: The name of the database to check. @param schema: The schema to be ensured. @return: C{True} if the patch directory has changed and the schema needs to be updated, C{False} otherwise. """ # If a schema stamp directory is set, then figure out whether there's # need to upgrade the schema by looking at timestamps. if self.schema_stamp_dir is not None: schema_mtime = self._get_schema_mtime(schema) schema_stamp_mtime = self._get_schema_stamp_mtime(name) # The modification time of the schema's patch directory matches our # timestamp, so the schema is already up-to-date if schema_mtime == schema_stamp_mtime: return False # Save the modification time of the schema's patch directory so in # subsequent runs we'll know if we're already up-to-date self._set_schema_stamp_mtime(name, schema_mtime) return True def _get_schema_mtime(self, schema): """ Return the modification time of the C{schema}'s patch directory. """ patch_directory = os.path.dirname(schema._patch_set._package.__file__) schema_stat = os.stat(patch_directory) return int(schema_stat.st_mtime) def _get_schema_stamp_mtime(self, name): """ Return the modification time of schemas's patch directory, as saved in the stamp directory. """ # Let's create the stamp directory if it doesn't exist if not os.path.exists(self.schema_stamp_dir): os.makedirs(self.schema_stamp_dir) schema_stamp_path = os.path.join(self.schema_stamp_dir, name) # Get the last schema modification time we ran the upgrade for, or -1 # if this is our first run if os.path.exists(schema_stamp_path): with open(schema_stamp_path) as fd: schema_stamp_mtime = int(fd.read()) else: schema_stamp_mtime = -1 return schema_stamp_mtime def _set_schema_stamp_mtime(self, name, schema_mtime): """ Save the schema's modification time in the stamp directory. """ schema_stamp_path = os.path.join(self.schema_stamp_dir, name) with open(schema_stamp_path, "w") as fd: fd.write("%d" % schema_mtime) def _purge_schema_stamp_dir(self): """Remove the stamp directory.""" if self.schema_stamp_dir and os.path.exists(self.schema_stamp_dir): shutil.rmtree(self.schema_stamp_dir) def clean(self, resource): """Clean up the stores after a test.""" try: for name, store in self._zstorm.iterstores(): # Ensure that the store is in a consistent state store.flush() # Clear the alive cache *before* abort is called, # to prevent a useless loop in Store.invalidate # over the alive objects store._alive.clear() finally: transaction.abort() # Clean up tables after each test if a commit was made needs_commit = False for name, store in self._zstorm.iterstores(): if self.force_delete or store in self._commits: schema_store = self._schema_zstorm.get(name) schema = self._schemas[name] schema.delete(schema_store) needs_commit = True if needs_commit: transaction.commit() self._commits = {} ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039102.0 storm-1.0/storm/zope/zstorm.py0000644000175000017500000002631014645174376017074 0ustar00cjwatsoncjwatson"""ZStorm integrates Storm with Zope 3. @var global_zstorm: A global L{ZStorm} instance. It used the L{IZStorm} utility registered in C{configure.zcml}. """ # # Copyright (c) 2006-2009 Canonical # # Written by Gustavo Niemeyer # # This file is part of Storm Object Relational Mapper. # # Storm is free software; you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation; either version 2.1 of # the License, or (at your option) any later version. # # Storm is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see . # import threading import weakref from uuid import uuid4 from zope.interface import implementer import transaction from transaction.interfaces import IDataManager try: from transaction.interfaces import TransactionFailedError except ImportError: from ZODB.POSException import TransactionFailedError from storm.zope.interfaces import IZStorm, ZStormError from storm.database import create_database from storm.store import Store from storm.xid import Xid @implementer(IZStorm) class ZStorm: """A utility which integrates Storm with Zope. Typically, applications will register stores using ZCML similar to:: Application code can then acquire the store by name using code similar to:: from zope.component import getUtility from storm.zope.interfaces import IZStorm store = getUtility(IZStorm).get('main') """ transaction_manager = transaction.manager _databases = {} def __init__(self): self._local = threading.local() self._default_databases = {} self._default_uris = {} self._default_tpcs = {} def _reset(self): for name, store in list(self.iterstores()): self.remove(store) store.close() self._local = threading.local() self._databases.clear() self._default_databases.clear() self._default_uris.clear() self._default_tpcs.clear() @property def _stores(self): try: return self._local.stores except AttributeError: stores = weakref.WeakValueDictionary() return self._local.__dict__.setdefault("stores", stores) @property def _named(self): try: return self._local.named except AttributeError: return self._local.__dict__.setdefault("named", {}) @property def _name_index(self): try: return self._local.name_index except AttributeError: return self._local.__dict__.setdefault( "name_index", weakref.WeakKeyDictionary()) @property def _txn_ids(self): """ A thread-local weak-key dict used to keep track of transaction IDs. """ try: return self._local.txn_ids except AttributeError: txn_ids = weakref.WeakKeyDictionary() return self._local.__dict__.setdefault("txn_ids", txn_ids) def _get_database(self, uri): database = self._databases.get(uri) if database is None: return self._databases.setdefault(uri, create_database(uri)) return database def set_default_uri(self, name, default_uri): """Set C{default_uri} as the default URI for stores called C{name}.""" self._default_databases[name] = self._get_database(default_uri) self._default_uris[name] = default_uri def set_default_tpc(self, name, default_flag): """Set the default two-phase mode for stores with the given C{name}.""" self._default_tpcs[name] = default_flag def create(self, name, uri=None): """Create a new store called C{name}. @param uri: Optionally, the URI to use. @raises ZStormError: Raised if C{uri} is None and no default URI exists for C{name}. Also raised if a store with C{name} already exists. """ if uri is None: database = self._default_databases.get(name) if database is None: raise ZStormError("Store named '%s' not found" % name) else: database = self._get_database(uri) if name is not None and self._named.get(name) is not None: raise ZStormError("Store named '%s' already exists" % name) store = Store(database) store._register_for_txn = True store._tpc = self._default_tpcs.get(name, False) store._event.hook( "register-transaction", register_store_with_transaction, weakref.ref(self)) self._stores[id(store)] = store if name is not None: self._named[name] = store self._name_index[store] = name return store def get(self, name, default_uri=None): """Get the store called C{name}, creating it first if necessary. @param default_uri: Optionally, the URI to use to create a store called C{name} when one doesn't already exist. @raises ZStormError: Raised if C{uri} is None and no default URI exists for C{name}. """ store = self._named.get(name) if not store: return self.create(name, default_uri) return store def remove(self, store): """Remove the given store from ZStorm. This removes any management of the store from ZStorm. Notice that if the store was used inside the current transaction, it's probably joined the transaction system as a resource already, and thus it will commit/rollback when the transaction system requests so. This method will unlink the *synchronizer* from the transaction system, so that once the current transaction is over it won't link back to it in future transactions. """ del self._stores[id(store)] name = self._name_index[store] del self._name_index[store] if name in self._named: del self._named[name] # Make sure the store isn't hooked up to future transactions. store._register_for_txn = False store._event.unhook( "register-transaction", register_store_with_transaction, weakref.ref(self)) def iterstores(self): """Iterate C{name, store} 2-tuples.""" # We explicitly copy the list of items before iterating over # it to avoid the problem where a store is deallocated during # iteration causing RuntimeError: dictionary changed size # during iteration. for store, name in list(self._name_index.items()): yield name, store def get_name(self, store): """Returns the name for C{store} or None if one isn't available.""" return self._name_index.get(store) def get_default_uris(self): """ Return a list of name, uri tuples that are named as the default databases for those names. """ return self._default_uris.copy() def register_store_with_transaction(store, zstorm_ref): zstorm = zstorm_ref() if zstorm is None: # zstorm object does not exist any more. return False # Check if the store is known. This could indicate a store being # used outside of its thread. if id(store) not in zstorm._stores: raise ZStormError("Store not registered with ZStorm, or registered " "with another thread.") txn = zstorm.transaction_manager.get() if store._tpc: global_transaction_id = zstorm._txn_ids.get(txn) if global_transaction_id is None: # The the global transaction doesn't have an ID yet, let's create # one in a way that it will be unique global_transaction_id = "_storm_%s" % str(uuid4()) zstorm._txn_ids[txn] = global_transaction_id xid = Xid(0, global_transaction_id, zstorm.get_name(store)) store.begin(xid) data_manager = StoreDataManager(store, zstorm) txn.join(data_manager) # Unhook the event handler. It will be rehooked for the next transaction. return False @implementer(IDataManager) class StoreDataManager: """An L{IDataManager} implementation for C{ZStorm}.""" def __init__(self, store, zstorm): self._store = store self._zstorm = zstorm self.transaction_manager = zstorm.transaction_manager def _hook_register_transaction_event(self): if self._store._register_for_txn: self._store._event.hook( "register-transaction", register_store_with_transaction, weakref.ref(self._zstorm)) def abort(self, txn): try: self._store.rollback() finally: if self._store._register_for_txn: self._store._event.hook( "register-transaction", register_store_with_transaction, weakref.ref(self._zstorm)) def tpc_begin(self, txn): # Zope's transaction system will call tpc_begin() on all # managers before calling commit, so flushing here may help # in cases where there are two stores with changes, and one # of them will fail. In such cases, flushing earlier will # ensure that both transactions will be rolled back, instead # of one committed and one rolled back. # # If TPC support is on, we still want to perform this flush for a # couple of reasons. Firstly the queries flush() runs couldn't # be run after calling prepare(), because the transaction is frozen # waiting for the final commit(), and secondly because if the flush # fails the entire transaction will be aborted with a normal rollback # as opposed to a TPC rollback, that would happen after prepare(). self._store.flush() def commit(self, txn): if self._store._tpc: self._store.prepare() def tpc_vote(self, txn): pass def tpc_finish(self, txn): # If commit raises an exception, we let the exception propagate, as # the transaction manager will then call tcp_abort, and we will # register the hook there self._store.commit() self._hook_register_transaction_event() def tpc_abort(self, txn): if self._store._tpc: try: self._store.rollback() finally: self._hook_register_transaction_event() def sortKey(self): # Stores in TPC mode should be the last to be committed, this makes # it possible to have TPC behavior when there's only a single store # not in TPC mode. if self._store._tpc: prefix = "zz" else: prefix = "aa" return "%s_store_%d" % (prefix, id(self)) global_zstorm = ZStorm() try: from zope.testing.cleanup import addCleanUp except ImportError: # We don't have zope.testing installed. pass else: addCleanUp(global_zstorm._reset) del addCleanUp ././@PaxHeader0000000000000000000000000000003400000000000010212 xustar0028 mtime=1721152862.4291248 storm-1.0/storm.egg-info/0000755000175000017500000000000014645532536015673 5ustar00cjwatsoncjwatson././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152862.0 storm-1.0/storm.egg-info/PKG-INFO0000644000175000017500000001727414645532536017003 0ustar00cjwatsoncjwatsonMetadata-Version: 2.1 Name: storm Version: 1.0 Summary: Storm is an object-relational mapper (ORM) for Python developed at Canonical. Home-page: https://storm.canonical.com Download-URL: https://launchpad.net/storm/+download Author: Gustavo Niemeyer Author-email: gustavo@niemeyer.net Maintainer: Storm Developers Maintainer-email: storm@lists.canonical.com License: LGPL Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL) Classifier: Programming Language :: Python Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: 3.9 Classifier: Programming Language :: Python :: 3.10 Classifier: Programming Language :: Python :: 3.11 Classifier: Programming Language :: Python :: 3.12 Classifier: Topic :: Database Classifier: Topic :: Database :: Front-Ends Classifier: Topic :: Software Development :: Libraries :: Python Modules Requires-Python: >=3.6 Description-Content-Type: text/x-rst License-File: LICENSE Requires-Dist: packaging>=14.1 Provides-Extra: doc Requires-Dist: fixtures; extra == "doc" Requires-Dist: sphinx; extra == "doc" Requires-Dist: sphinx-epytext; extra == "doc" Provides-Extra: test Requires-Dist: fixtures>=1.3.0; extra == "test" Requires-Dist: mysqlclient; extra == "test" Requires-Dist: pgbouncer>=0.0.7; extra == "test" Requires-Dist: postgresfixture; extra == "test" Requires-Dist: psycopg2>=2.3.0; extra == "test" Requires-Dist: testresources>=0.2.4; extra == "test" Requires-Dist: testtools>=0.9.8; extra == "test" Requires-Dist: timeline>=0.0.2; extra == "test" Requires-Dist: transaction>=1.0.0; extra == "test" Requires-Dist: Twisted>=10.0.0; extra == "test" Requires-Dist: zope.component>=3.8.0; extra == "test" Requires-Dist: zope.configuration; extra == "test" Requires-Dist: zope.interface>=4.0.0; extra == "test" Requires-Dist: zope.security>=3.7.2; extra == "test" Storm is an Object Relational Mapper for Python developed at Canonical. API docs, a manual, and a tutorial are available from: https://storm.canonical.com/ Introduction ============ The project was in development for more than a year for use in Canonical projects such as Launchpad and Landscape before being released as free software on July 9th, 2007. Design: * Clean and lightweight API offers a short learning curve and long-term maintainability. * Storm is developed in a test-driven manner. An untested line of code is considered a bug. * Storm needs no special class constructors, nor imperative base classes. * Storm is well designed (different classes have very clear boundaries, with small and clean public APIs). * Designed from day one to work both with thin relational databases, such as SQLite, and big iron systems like PostgreSQL and MySQL. * Storm is easy to debug, since its code is written with a KISS principle, and thus is easy to understand. * Designed from day one to work both at the low end, with trivial small databases, and the high end, with applications accessing billion row tables and committing to multiple database backends. * It's very easy to write and support backends for Storm (current backends have around 100 lines of code). Features: * Storm is fast. * Storm lets you efficiently access and update large datasets by allowing you to formulate complex queries spanning multiple tables using Python. * Storm allows you to fallback to SQL if needed (or if you just prefer), allowing you to mix "old school" code and ORM code * Storm handles composed primary keys with ease (no need for surrogate keys). * Storm doesn't do schema management, and as a result you're free to manage the schema as wanted, and creating classes that work with Storm is clean and simple. * Storm works very well connecting to several databases and using the same Python types (or different ones) with all of them. * Storm can handle obj.attr = assignments, when that's really needed (the expression is executed at INSERT/UPDATE time). * Storm handles relationships between objects even before they were added to a database. * Storm works well with existing database schemas. * Storm will flush changes to the database automatically when needed, so that queries made affect recently modified objects. License ======= Copyright (C) 2006-2020 Canonical, Ltd. All contributions must have copyright assigned to Canonical. This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. This library is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with this library; if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA On Ubuntu systems, the complete text of the GNU Lesser General Public Version 2.1 License is in /usr/share/common-licenses/LGPL-2.1 Developing Storm ================ SHORT VERSION: If you are running ubuntu, or probably debian, the following should work. If not, and for reference, the long version is below. $ dev/ubuntu-deps $ echo "$PWD/** rwk," | sudo tee /etc/apparmor.d/local/usr.sbin.mysqld >/dev/null $ sudo aa-enforce /usr/sbin/mysqld $ make develop $ make check LONG VERSION: The following instructions describe the procedure for setting up a development environment and running the test suite. Installing dependencies ----------------------- The following instructions assume that you're using Ubuntu. The same procedure will probably work without changes on a Debian system and with minimal changes on a non-Debian-based linux distribution. In order to run the test suite, and exercise all supported backends, you will need to install MySQL and PostgreSQL, along with the related Python database drivers: $ sudo apt-get install \ mysql-server \ postgresql pgbouncer \ build-essential These will take a few minutes to download. The Python dependencies for running tests can be installed with apt-get: $ apt-get install \ python3-fixtures \ python3-pgbouncer \ python3-psycopg2 \ python3-testresources \ python3-timeline \ python3-transaction \ python3-twisted \ python3-zope.component \ python3-zope.security Alternatively, dependencies can be downloaded as eggs into the current directory with: $ make develop This ensures that all dependencies are available, downloading from PyPI as appropriate. Database setup -------------- Most database setup is done automatically by the test suite. However, Ubuntu's default MySQL packaging ships an AppArmor profile that prevents it from writing to a local data directory. To allow the test suite to do this, you will need to grant it access, which is most easily done by adding a line such as this to /etc/apparmor.d/local/usr.sbin.mysqld: /path/to/storm/** rwk, Then reload the profile: $ sudo aa-enforce /usr/sbin/mysqld Running the tests ----------------- Finally, its time to run the tests! Go into the base directory of the storm branch you want to test, and run: $ make check They'll take a while to run. All tests should pass: failures mean there's a problem with your environment or a bug in Storm. ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152862.0 storm-1.0/storm.egg-info/SOURCES.txt0000644000175000017500000000440214645532536017557 0ustar00cjwatsoncjwatsonLICENSE MANIFEST.in Makefile NEWS README TODO setup.cfg setup.py tox.ini dev/test storm/__init__.py storm/base.py storm/cache.py storm/cextensions.c storm/database.py storm/event.py storm/exceptions.py storm/expr.py storm/info.py storm/locals.py storm/properties.py storm/references.py storm/sqlobject.py storm/store.py storm/testing.py storm/tracer.py storm/tz.py storm/uri.py storm/variables.py storm/wsgi.py storm/xid.py storm.egg-info/PKG-INFO storm.egg-info/SOURCES.txt storm.egg-info/dependency_links.txt storm.egg-info/not-zip-safe storm.egg-info/requires.txt storm.egg-info/top_level.txt storm/databases/__init__.py storm/databases/mysql.py storm/databases/postgres.py storm/databases/sqlite.py storm/docs/Makefile storm/docs/__init__.py storm/docs/api.rst storm/docs/conf.py storm/docs/index.rst storm/docs/infoheritance.rst storm/docs/tutorial.rst storm/docs/zope.rst storm/schema/__init__.py storm/schema/patch.py storm/schema/schema.py storm/schema/sharding.py storm/tests/__init__.py storm/tests/base.py storm/tests/cache.py storm/tests/database.py storm/tests/event.py storm/tests/expr.py storm/tests/helper.py storm/tests/info.py storm/tests/mocker.py storm/tests/properties.py storm/tests/sqlobject.py storm/tests/tracer.py storm/tests/uri.py storm/tests/variables.py storm/tests/wsgi.py storm/tests/databases/__init__.py storm/tests/databases/base.py storm/tests/databases/mysql.py storm/tests/databases/postgres.py storm/tests/databases/proxy.py storm/tests/databases/sqlite.py storm/tests/django/__init__.py storm/tests/schema/__init__.py storm/tests/schema/patch.py storm/tests/schema/schema.py storm/tests/schema/sharding.py storm/tests/store/__init__.py storm/tests/store/base.py storm/tests/store/block.py storm/tests/store/mysql.py storm/tests/store/postgres.py storm/tests/store/sqlite.py storm/tests/twisted/__init__.py storm/tests/twisted/transact.py storm/tests/zope/__init__.py storm/tests/zope/adapters.py storm/tests/zope/testing.py storm/tests/zope/zstorm.py storm/twisted/__init__.py storm/twisted/testing.py storm/twisted/transact.py storm/zope/__init__.py storm/zope/adapters.py storm/zope/configure.zcml storm/zope/interfaces.py storm/zope/meta.zcml storm/zope/metaconfigure.py storm/zope/metadirectives.py storm/zope/schema.py storm/zope/testing.py storm/zope/zstorm.py././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152862.0 storm-1.0/storm.egg-info/dependency_links.txt0000644000175000017500000000000114645532536021741 0ustar00cjwatsoncjwatson ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1586716248.0 storm-1.0/storm.egg-info/not-zip-safe0000644000175000017500000000000113644657130020115 0ustar00cjwatsoncjwatson ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152862.0 storm-1.0/storm.egg-info/requires.txt0000644000175000017500000000047014645532536020274 0ustar00cjwatsoncjwatsonpackaging>=14.1 [doc] fixtures sphinx sphinx-epytext [test] fixtures>=1.3.0 mysqlclient pgbouncer>=0.0.7 postgresfixture psycopg2>=2.3.0 testresources>=0.2.4 testtools>=0.9.8 timeline>=0.0.2 transaction>=1.0.0 Twisted>=10.0.0 zope.component>=3.8.0 zope.configuration zope.interface>=4.0.0 zope.security>=3.7.2 ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721152862.0 storm-1.0/storm.egg-info/top_level.txt0000644000175000017500000000000614645532536020421 0ustar00cjwatsoncjwatsonstorm ././@PaxHeader0000000000000000000000000000002600000000000010213 xustar0022 mtime=1721039171.0 storm-1.0/tox.ini0000644000175000017500000000121214645174503014340 0ustar00cjwatsoncjwatson[tox] envlist = py36-{cextensions,nocextensions} py37-{cextensions,nocextensions} py38-{cextensions,nocextensions} py39-{cextensions,nocextensions} py310-{cextensions,nocextensions} py311-{cextensions,nocextensions} py312-{cextensions,nocextensions} docs [testenv] deps = .[test] passenv = STORM_TEST_RUNNER USER setenv = cextensions: STORM_CEXTENSIONS = 1 nocextensions: STORM_CEXTENSIONS = 0 commands = python dev/test {posargs} [testenv:docs] basepython = python3.12 commands = sphinx-build -b html -d storm/docs/_build/doctrees storm/docs storm/docs/_build/html deps = .[doc]