pax_global_header00006660000000000000000000000064137425623700014523gustar00rootroot0000000000000052 comment=84d0fcf94740dbcd5ae7eff40cc0d1f4d2d3aead pymongo-3.11.0/000077500000000000000000000000001374256237000132755ustar00rootroot00000000000000pymongo-3.11.0/LICENSE000066400000000000000000000261351374256237000143110ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. pymongo-3.11.0/MANIFEST.in000066400000000000000000000006041374256237000150330ustar00rootroot00000000000000include README.rst include LICENSE include THIRD-PARTY-NOTICES include ez_setup.py recursive-include doc *.rst recursive-include doc *.py recursive-include doc *.conf recursive-include doc *.css recursive-include doc *.js recursive-include doc *.png recursive-include tools *.py include tools/README.rst recursive-include test *.pem recursive-include test *.py recursive-include bson *.h pymongo-3.11.0/PKG-INFO000066400000000000000000000251201374256237000143720ustar00rootroot00000000000000Metadata-Version: 2.1 Name: pymongo Version: 3.11.0 Summary: Python driver for MongoDB Home-page: http://github.com/mongodb/mongo-python-driver Author: Mike Dirolf Author-email: mongodb-user@googlegroups.com Maintainer: Bernie Hackett Maintainer-email: bernie@mongodb.com License: Apache License, Version 2.0 Description: ======= PyMongo ======= :Info: See `the mongo site `_ for more information. See `GitHub `_ for the latest source. :Documentation: Available at `pymongo.readthedocs.io `_ :Author: Mike Dirolf :Maintainer: Bernie Hackett About ===== The PyMongo distribution contains tools for interacting with MongoDB database from Python. The ``bson`` package is an implementation of the `BSON format `_ for Python. The ``pymongo`` package is a native Python driver for MongoDB. The ``gridfs`` package is a `gridfs `_ implementation on top of ``pymongo``. PyMongo supports MongoDB 2.6, 3.0, 3.2, 3.4, 3.6, 4.0, 4.2, and 4.4. Support / Feedback ================== For issues with, questions about, or feedback for PyMongo, please look into our `support channels `_. Please do not email any of the PyMongo developers directly with issues or questions - you're more likely to get an answer on the `MongoDB Community Forums `_. Bugs / Feature Requests ======================= Think you’ve found a bug? Want to see a new feature in PyMongo? Please open a case in our issue management tool, JIRA: - `Create an account and login `_. - Navigate to `the PYTHON project `_. - Click **Create Issue** - Please provide as much information as possible about the issue type and how to reproduce it. Bug reports in JIRA for all driver projects (i.e. PYTHON, CSHARP, JAVA) and the Core Server (i.e. SERVER) project are **public**. How To Ask For Help ------------------- Please include all of the following information when opening an issue: - Detailed steps to reproduce the problem, including full traceback, if possible. - The exact python version used, with patch level:: $ python -c "import sys; print(sys.version)" - The exact version of PyMongo used, with patch level:: $ python -c "import pymongo; print(pymongo.version); print(pymongo.has_c())" - The operating system and version (e.g. Windows 7, OSX 10.8, ...) - Web framework or asynchronous network library used, if any, with version (e.g. Django 1.7, mod_wsgi 4.3.0, gevent 1.0.1, Tornado 4.0.2, ...) Security Vulnerabilities ------------------------ If you’ve identified a security vulnerability in a driver or any other MongoDB project, please report it according to the `instructions here `_. Installation ============ PyMongo can be installed with `pip `_:: $ python -m pip install pymongo Or ``easy_install`` from `setuptools `_:: $ python -m easy_install pymongo You can also download the project source and do:: $ python setup.py install Do **not** install the "bson" package from pypi. PyMongo comes with its own bson package; doing "easy_install bson" installs a third-party package that is incompatible with PyMongo. Dependencies ============ PyMongo supports CPython 2.7, 3.4+, PyPy, and PyPy3.5+. Optional dependencies: GSSAPI authentication requires `pykerberos `_ on Unix or `WinKerberos `_ on Windows. The correct dependency can be installed automatically along with PyMongo:: $ python -m pip install pymongo[gssapi] MONGODB-AWS authentication requires `pymongo-auth-aws `_:: $ python -m pip install pymongo[aws] Support for mongodb+srv:// URIs requires `dnspython `_:: $ python -m pip install pymongo[srv] TLS / SSL support may require `ipaddress `_ and `certifi `_ or `wincertstore `_ depending on the Python version in use. The necessary dependencies can be installed along with PyMongo:: $ python -m pip install pymongo[tls] .. note:: Users of Python versions older than 2.7.9 will also receive the dependencies for OCSP when using the tls extra. OCSP (Online Certificate Status Protocol) requires `PyOpenSSL `_, `requests `_ and `service_identity `_:: $ python -m pip install pymongo[ocsp] Wire protocol compression with snappy requires `python-snappy `_:: $ python -m pip install pymongo[snappy] Wire protocol compression with zstandard requires `zstandard `_:: $ python -m pip install pymongo[zstd] Client-Side Field Level Encryption requires `pymongocrypt `_:: $ python -m pip install pymongo[encryption] You can install all dependencies automatically with the following command:: $ python -m pip install pymongo[gssapi,aws,ocsp,snappy,srv,tls,zstd,encryption] Other optional packages: - `backports.pbkdf2 `_, improves authentication performance with SCRAM-SHA-1 and SCRAM-SHA-256. It especially improves performance on Python versions older than 2.7.8. - `monotonic `_ adds support for a monotonic clock, which improves reliability in environments where clock adjustments are frequent. Not needed in Python 3. Additional dependencies are: - (to generate documentation) sphinx_ Examples ======== Here's a basic example (for more see the *examples* section of the docs): .. code-block:: python >>> import pymongo >>> client = pymongo.MongoClient("localhost", 27017) >>> db = client.test >>> db.name u'test' >>> db.my_collection Collection(Database(MongoClient('localhost', 27017), u'test'), u'my_collection') >>> db.my_collection.insert_one({"x": 10}).inserted_id ObjectId('4aba15ebe23f6b53b0000000') >>> db.my_collection.insert_one({"x": 8}).inserted_id ObjectId('4aba160ee23f6b543e000000') >>> db.my_collection.insert_one({"x": 11}).inserted_id ObjectId('4aba160ee23f6b543e000002') >>> db.my_collection.find_one() {u'x': 10, u'_id': ObjectId('4aba15ebe23f6b53b0000000')} >>> for item in db.my_collection.find(): ... print(item["x"]) ... 10 8 11 >>> db.my_collection.create_index("x") u'x_1' >>> for item in db.my_collection.find().sort("x", pymongo.ASCENDING): ... print(item["x"]) ... 8 10 11 >>> [item["x"] for item in db.my_collection.find().limit(2).skip(1)] [8, 11] Documentation ============= Documentation is available at `pymongo.readthedocs.io `_. To build the documentation, you will need to install sphinx_. Documentation can be generated by running **python setup.py doc**. Generated documentation can be found in the *doc/build/html/* directory. Testing ======= The easiest way to run the tests is to run **python setup.py test** in the root of the distribution. To verify that PyMongo works with Gevent's monkey-patching:: $ python green_framework_test.py gevent Or with Eventlet's:: $ python green_framework_test.py eventlet .. _sphinx: http://sphinx.pocoo.org/ Keywords: mongo,mongodb,pymongo,gridfs,bson Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: POSIX Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Database Requires-Python: >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* Provides-Extra: tls Provides-Extra: encryption Provides-Extra: aws Provides-Extra: gssapi Provides-Extra: snappy Provides-Extra: srv Provides-Extra: zstd Provides-Extra: ocsp pymongo-3.11.0/README.rst000066400000000000000000000165121374256237000147710ustar00rootroot00000000000000======= PyMongo ======= :Info: See `the mongo site `_ for more information. See `GitHub `_ for the latest source. :Documentation: Available at `pymongo.readthedocs.io `_ :Author: Mike Dirolf :Maintainer: Bernie Hackett About ===== The PyMongo distribution contains tools for interacting with MongoDB database from Python. The ``bson`` package is an implementation of the `BSON format `_ for Python. The ``pymongo`` package is a native Python driver for MongoDB. The ``gridfs`` package is a `gridfs `_ implementation on top of ``pymongo``. PyMongo supports MongoDB 2.6, 3.0, 3.2, 3.4, 3.6, 4.0, 4.2, and 4.4. Support / Feedback ================== For issues with, questions about, or feedback for PyMongo, please look into our `support channels `_. Please do not email any of the PyMongo developers directly with issues or questions - you're more likely to get an answer on the `MongoDB Community Forums `_. Bugs / Feature Requests ======================= Think you’ve found a bug? Want to see a new feature in PyMongo? Please open a case in our issue management tool, JIRA: - `Create an account and login `_. - Navigate to `the PYTHON project `_. - Click **Create Issue** - Please provide as much information as possible about the issue type and how to reproduce it. Bug reports in JIRA for all driver projects (i.e. PYTHON, CSHARP, JAVA) and the Core Server (i.e. SERVER) project are **public**. How To Ask For Help ------------------- Please include all of the following information when opening an issue: - Detailed steps to reproduce the problem, including full traceback, if possible. - The exact python version used, with patch level:: $ python -c "import sys; print(sys.version)" - The exact version of PyMongo used, with patch level:: $ python -c "import pymongo; print(pymongo.version); print(pymongo.has_c())" - The operating system and version (e.g. Windows 7, OSX 10.8, ...) - Web framework or asynchronous network library used, if any, with version (e.g. Django 1.7, mod_wsgi 4.3.0, gevent 1.0.1, Tornado 4.0.2, ...) Security Vulnerabilities ------------------------ If you’ve identified a security vulnerability in a driver or any other MongoDB project, please report it according to the `instructions here `_. Installation ============ PyMongo can be installed with `pip `_:: $ python -m pip install pymongo Or ``easy_install`` from `setuptools `_:: $ python -m easy_install pymongo You can also download the project source and do:: $ python setup.py install Do **not** install the "bson" package from pypi. PyMongo comes with its own bson package; doing "easy_install bson" installs a third-party package that is incompatible with PyMongo. Dependencies ============ PyMongo supports CPython 2.7, 3.4+, PyPy, and PyPy3.5+. Optional dependencies: GSSAPI authentication requires `pykerberos `_ on Unix or `WinKerberos `_ on Windows. The correct dependency can be installed automatically along with PyMongo:: $ python -m pip install pymongo[gssapi] MONGODB-AWS authentication requires `pymongo-auth-aws `_:: $ python -m pip install pymongo[aws] Support for mongodb+srv:// URIs requires `dnspython `_:: $ python -m pip install pymongo[srv] TLS / SSL support may require `ipaddress `_ and `certifi `_ or `wincertstore `_ depending on the Python version in use. The necessary dependencies can be installed along with PyMongo:: $ python -m pip install pymongo[tls] .. note:: Users of Python versions older than 2.7.9 will also receive the dependencies for OCSP when using the tls extra. OCSP (Online Certificate Status Protocol) requires `PyOpenSSL `_, `requests `_ and `service_identity `_:: $ python -m pip install pymongo[ocsp] Wire protocol compression with snappy requires `python-snappy `_:: $ python -m pip install pymongo[snappy] Wire protocol compression with zstandard requires `zstandard `_:: $ python -m pip install pymongo[zstd] Client-Side Field Level Encryption requires `pymongocrypt `_:: $ python -m pip install pymongo[encryption] You can install all dependencies automatically with the following command:: $ python -m pip install pymongo[gssapi,aws,ocsp,snappy,srv,tls,zstd,encryption] Other optional packages: - `backports.pbkdf2 `_, improves authentication performance with SCRAM-SHA-1 and SCRAM-SHA-256. It especially improves performance on Python versions older than 2.7.8. - `monotonic `_ adds support for a monotonic clock, which improves reliability in environments where clock adjustments are frequent. Not needed in Python 3. Additional dependencies are: - (to generate documentation) sphinx_ Examples ======== Here's a basic example (for more see the *examples* section of the docs): .. code-block:: python >>> import pymongo >>> client = pymongo.MongoClient("localhost", 27017) >>> db = client.test >>> db.name u'test' >>> db.my_collection Collection(Database(MongoClient('localhost', 27017), u'test'), u'my_collection') >>> db.my_collection.insert_one({"x": 10}).inserted_id ObjectId('4aba15ebe23f6b53b0000000') >>> db.my_collection.insert_one({"x": 8}).inserted_id ObjectId('4aba160ee23f6b543e000000') >>> db.my_collection.insert_one({"x": 11}).inserted_id ObjectId('4aba160ee23f6b543e000002') >>> db.my_collection.find_one() {u'x': 10, u'_id': ObjectId('4aba15ebe23f6b53b0000000')} >>> for item in db.my_collection.find(): ... print(item["x"]) ... 10 8 11 >>> db.my_collection.create_index("x") u'x_1' >>> for item in db.my_collection.find().sort("x", pymongo.ASCENDING): ... print(item["x"]) ... 8 10 11 >>> [item["x"] for item in db.my_collection.find().limit(2).skip(1)] [8, 11] Documentation ============= Documentation is available at `pymongo.readthedocs.io `_. To build the documentation, you will need to install sphinx_. Documentation can be generated by running **python setup.py doc**. Generated documentation can be found in the *doc/build/html/* directory. Testing ======= The easiest way to run the tests is to run **python setup.py test** in the root of the distribution. To verify that PyMongo works with Gevent's monkey-patching:: $ python green_framework_test.py gevent Or with Eventlet's:: $ python green_framework_test.py eventlet .. _sphinx: http://sphinx.pocoo.org/ pymongo-3.11.0/THIRD-PARTY-NOTICES000066400000000000000000000151561374256237000161010ustar00rootroot00000000000000PyMongo uses third-party libraries or other resources that may be distributed under licenses different than the PyMongo software. In the event that we accidentally failed to list a required notice, please bring it to our attention through any of the ways detailed here: mongodb-dev@googlegroups.com The attached notices are provided for information only. For any licenses that require disclosure of source, sources are available at https://github.com/mongodb/mongo-python-driver. 1) License Notice for time64.c ------------------------------ Copyright (c) 2007-2010 Michael G Schwern This software originally derived from Paul Sheer's pivotal_gmtime_r.c. The MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 2) License Notice for bson-stdint-win32.h ----------------------------------------- ISO C9x compliant stdint.h for Microsoft Visual Studio Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 Copyright (c) 2006-2013 Alexander Chemeris Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the product nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 3) License Notice for encoding_helpers.c ---------------------------------------- Portions Copyright 2001 Unicode, Inc. Disclaimer This source code is provided as is by Unicode, Inc. No claims are made as to fitness for any particular purpose. No warranties of any kind are expressed or implied. The recipient agrees to determine applicability of information provided. If this file has been purchased on magnetic or optical media from Unicode, Inc., the sole remedy for any claim will be exchange of defective media within 90 days of receipt. Limitations on Rights to Redistribute This Code Unicode, Inc. hereby grants the right to freely use the information supplied in this file in the creation of products supporting the Unicode Standard, and to make copies of this file in any form for internal or external distribution as long as this notice remains attached. 4) License Notice for ssl_match_hostname.py ------------------------------------------- Python License (Python-2.0) Python License, Version 2 (Python-2.0) PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2 -------------------------------------------- 1. This LICENSE AGREEMENT is between the Python Software Foundation ("PSF"), and the Individual or Organization ("Licensee") accessing and otherwise using this software ("Python") in source or binary form and its associated documentation. 2. Subject to the terms and conditions of this License Agreement, PSF hereby grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, analyze, test, perform and/or display publicly, prepare derivative works, distribute, and otherwise use Python alone or in any derivative version, provided, however, that PSF's License Agreement and PSF's notice of copyright, i.e., "Copyright (c) 2001-2013 Python Software Foundation; All Rights Reserved" are retained in Python alone or in any derivative version prepared by Licensee. 3. In the event Licensee prepares a derivative work that is based on or incorporates Python or any part thereof, and wants to make the derivative work available to others as provided herein, then Licensee hereby agrees to include in any such work a brief summary of the changes made to Python. 4. PSF is making Python available to Licensee on an "AS IS" basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT INFRINGE ANY THIRD PARTY RIGHTS. 5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. 6. This License Agreement will automatically terminate upon a material breach of its terms and conditions. 7. Nothing in this License Agreement shall be deemed to create any relationship of agency, partnership, or joint venture between PSF and Licensee. This License Agreement does not grant permission to use PSF trademarks or trade name in a trademark sense to endorse or promote products or services of Licensee, or any third party. 8. By copying, installing or otherwise using Python, Licensee agrees to be bound by the terms and conditions of this License Agreement. pymongo-3.11.0/bson/000077500000000000000000000000001374256237000142365ustar00rootroot00000000000000pymongo-3.11.0/bson/__init__.py000066400000000000000000001314561374256237000163610ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """BSON (Binary JSON) encoding and decoding. The mapping from Python types to BSON types is as follows: ======================================= ============= =================== Python Type BSON Type Supported Direction ======================================= ============= =================== None null both bool boolean both int [#int]_ int32 / int64 py -> bson long int64 py -> bson `bson.int64.Int64` int64 both float number (real) both string string py -> bson unicode string both list array both dict / `SON` object both datetime.datetime [#dt]_ [#dt2]_ date both `bson.regex.Regex` regex both compiled re [#re]_ regex py -> bson `bson.binary.Binary` binary both `bson.objectid.ObjectId` oid both `bson.dbref.DBRef` dbref both None undefined bson -> py unicode code bson -> py `bson.code.Code` code py -> bson unicode symbol bson -> py bytes (Python 3) [#bytes]_ binary both ======================================= ============= =================== Note that, when using Python 2.x, to save binary data it must be wrapped as an instance of `bson.binary.Binary`. Otherwise it will be saved as a BSON string and retrieved as unicode. Users of Python 3.x can use the Python bytes type. .. [#int] A Python int will be saved as a BSON int32 or BSON int64 depending on its size. A BSON int32 will always decode to a Python int. A BSON int64 will always decode to a :class:`~bson.int64.Int64`. .. [#dt] datetime.datetime instances will be rounded to the nearest millisecond when saved .. [#dt2] all datetime.datetime instances are treated as *naive*. clients should always use UTC. .. [#re] :class:`~bson.regex.Regex` instances and regular expression objects from ``re.compile()`` are both saved as BSON regular expressions. BSON regular expressions are decoded as :class:`~bson.regex.Regex` instances. .. [#bytes] The bytes type from Python 3.x is encoded as BSON binary with subtype 0. In Python 3.x it will be decoded back to bytes. In Python 2.x it will be decoded to an instance of :class:`~bson.binary.Binary` with subtype 0. """ import calendar import datetime import itertools import platform import re import struct import sys import uuid from codecs import (utf_8_decode as _utf_8_decode, utf_8_encode as _utf_8_encode) from bson.binary import (Binary, UuidRepresentation, ALL_UUID_SUBTYPES, OLD_UUID_SUBTYPE, JAVA_LEGACY, CSHARP_LEGACY, UUIDLegacy, UUID_SUBTYPE) from bson.code import Code from bson.codec_options import ( CodecOptions, DEFAULT_CODEC_OPTIONS, _raw_document_class) from bson.dbref import DBRef from bson.decimal128 import Decimal128 from bson.errors import (InvalidBSON, InvalidDocument, InvalidStringData) from bson.int64 import Int64 from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.py3compat import (abc, b, PY3, iteritems, text_type, string_type, reraise) from bson.regex import Regex from bson.son import SON, RE_TYPE from bson.timestamp import Timestamp from bson.tz_util import utc try: from bson import _cbson _USE_C = True except ImportError: _USE_C = False EPOCH_AWARE = datetime.datetime.fromtimestamp(0, utc) EPOCH_NAIVE = datetime.datetime.utcfromtimestamp(0) BSONNUM = b"\x01" # Floating point BSONSTR = b"\x02" # UTF-8 string BSONOBJ = b"\x03" # Embedded document BSONARR = b"\x04" # Array BSONBIN = b"\x05" # Binary BSONUND = b"\x06" # Undefined BSONOID = b"\x07" # ObjectId BSONBOO = b"\x08" # Boolean BSONDAT = b"\x09" # UTC Datetime BSONNUL = b"\x0A" # Null BSONRGX = b"\x0B" # Regex BSONREF = b"\x0C" # DBRef BSONCOD = b"\x0D" # Javascript code BSONSYM = b"\x0E" # Symbol BSONCWS = b"\x0F" # Javascript code with scope BSONINT = b"\x10" # 32bit int BSONTIM = b"\x11" # Timestamp BSONLON = b"\x12" # 64bit int BSONDEC = b"\x13" # Decimal128 BSONMIN = b"\xFF" # Min key BSONMAX = b"\x7F" # Max key _UNPACK_FLOAT_FROM = struct.Struct("= obj_end: raise InvalidBSON("invalid object length") # If this is the top-level document, validate the total size too. if position == 0 and obj_size != obj_end: raise InvalidBSON("invalid object length") return obj_size, end def _get_object(data, view, position, obj_end, opts, dummy): """Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef.""" obj_size, end = _get_object_size(data, position, obj_end) if _raw_document_class(opts.document_class): return (opts.document_class(data[position:end + 1], opts), position + obj_size) obj = _elements_to_dict(data, view, position + 4, end, opts) position += obj_size if "$ref" in obj: return (DBRef(obj.pop("$ref"), obj.pop("$id", None), obj.pop("$db", None), obj), position) return obj, position def _get_array(data, view, position, obj_end, opts, element_name): """Decode a BSON array to python list.""" size = _UNPACK_INT_FROM(data, position)[0] end = position + size - 1 if data[end] != _OBJEND: raise InvalidBSON("bad eoo") position += 4 end -= 1 result = [] # Avoid doing global and attribute lookups in the loop. append = result.append index = data.index getter = _ELEMENT_GETTER decoder_map = opts.type_registry._decoder_map while position < end: element_type = data[position] # Just skip the keys. position = index(b'\x00', position) + 1 try: value, position = getter[element_type]( data, view, position, obj_end, opts, element_name) except KeyError: _raise_unknown_type(element_type, element_name) if decoder_map: custom_decoder = decoder_map.get(type(value)) if custom_decoder is not None: value = custom_decoder(value) append(value) if position != end + 1: raise InvalidBSON('bad array length') return result, position + 1 def _get_binary(data, view, position, obj_end, opts, dummy1): """Decode a BSON binary to bson.binary.Binary or python UUID.""" length, subtype = _UNPACK_LENGTH_SUBTYPE_FROM(data, position) position += 5 if subtype == 2: length2 = _UNPACK_INT_FROM(data, position)[0] position += 4 if length2 != length - 4: raise InvalidBSON("invalid binary (st 2) - lengths don't match!") length = length2 end = position + length if length < 0 or end > obj_end: raise InvalidBSON('bad binary object length') # Convert UUID subtypes to native UUIDs. # TODO: PYTHON-2245 Decoding should follow UUID spec in PyMongo 4.0+ if subtype in ALL_UUID_SUBTYPES: uuid_representation = opts.uuid_representation binary_value = Binary(data[position:end], subtype) if uuid_representation == UuidRepresentation.UNSPECIFIED: return binary_value, end if subtype == UUID_SUBTYPE: # Legacy behavior: use STANDARD with binary subtype 4. uuid_representation = UuidRepresentation.STANDARD elif uuid_representation == UuidRepresentation.STANDARD: # subtype == OLD_UUID_SUBTYPE # Legacy behavior: STANDARD is the same as PYTHON_LEGACY. uuid_representation = UuidRepresentation.PYTHON_LEGACY return binary_value.as_uuid(uuid_representation), end # Python3 special case. Decode subtype 0 to 'bytes'. if PY3 and subtype == 0: value = data[position:end] else: value = Binary(data[position:end], subtype) return value, end def _get_oid(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON ObjectId to bson.objectid.ObjectId.""" end = position + 12 return ObjectId(data[position:end]), end def _get_boolean(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON true/false to python True/False.""" end = position + 1 boolean_byte = data[position:end] if boolean_byte == b'\x00': return False, end elif boolean_byte == b'\x01': return True, end raise InvalidBSON('invalid boolean value: %r' % boolean_byte) def _get_date(data, view, position, dummy0, opts, dummy1): """Decode a BSON datetime to python datetime.datetime.""" return _millis_to_datetime( _UNPACK_LONG_FROM(data, position)[0], opts), position + 8 def _get_code(data, view, position, obj_end, opts, element_name): """Decode a BSON code to bson.code.Code.""" code, position = _get_string(data, view, position, obj_end, opts, element_name) return Code(code), position def _get_code_w_scope(data, view, position, obj_end, opts, element_name): """Decode a BSON code_w_scope to bson.code.Code.""" code_end = position + _UNPACK_INT_FROM(data, position)[0] code, position = _get_string( data, view, position + 4, code_end, opts, element_name) scope, position = _get_object(data, view, position, code_end, opts, element_name) if position != code_end: raise InvalidBSON('scope outside of javascript code boundaries') return Code(code, scope), position def _get_regex(data, view, position, dummy0, opts, dummy1): """Decode a BSON regex to bson.regex.Regex or a python pattern object.""" pattern, position = _get_c_string(data, view, position, opts) bson_flags, position = _get_c_string(data, view, position, opts) bson_re = Regex(pattern, bson_flags) return bson_re, position def _get_ref(data, view, position, obj_end, opts, element_name): """Decode (deprecated) BSON DBPointer to bson.dbref.DBRef.""" collection, position = _get_string( data, view, position, obj_end, opts, element_name) oid, position = _get_oid(data, view, position, obj_end, opts, element_name) return DBRef(collection, oid), position def _get_timestamp(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON timestamp to bson.timestamp.Timestamp.""" inc, timestamp = _UNPACK_TIMESTAMP_FROM(data, position) return Timestamp(timestamp, inc), position + 8 def _get_int64(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON int64 to bson.int64.Int64.""" return Int64(_UNPACK_LONG_FROM(data, position)[0]), position + 8 def _get_decimal128(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON decimal128 to bson.decimal128.Decimal128.""" end = position + 16 return Decimal128.from_bid(data[position:end]), end # Each decoder function's signature is: # - data: bytes # - view: memoryview that references `data` # - position: int, beginning of object in 'data' to decode # - obj_end: int, end of object to decode in 'data' if variable-length type # - opts: a CodecOptions _ELEMENT_GETTER = { _maybe_ord(BSONNUM): _get_float, _maybe_ord(BSONSTR): _get_string, _maybe_ord(BSONOBJ): _get_object, _maybe_ord(BSONARR): _get_array, _maybe_ord(BSONBIN): _get_binary, _maybe_ord(BSONUND): lambda u, v, w, x, y, z: (None, w), # Deprecated undefined _maybe_ord(BSONOID): _get_oid, _maybe_ord(BSONBOO): _get_boolean, _maybe_ord(BSONDAT): _get_date, _maybe_ord(BSONNUL): lambda u, v, w, x, y, z: (None, w), _maybe_ord(BSONRGX): _get_regex, _maybe_ord(BSONREF): _get_ref, # Deprecated DBPointer _maybe_ord(BSONCOD): _get_code, _maybe_ord(BSONSYM): _get_string, # Deprecated symbol _maybe_ord(BSONCWS): _get_code_w_scope, _maybe_ord(BSONINT): _get_int, _maybe_ord(BSONTIM): _get_timestamp, _maybe_ord(BSONLON): _get_int64, _maybe_ord(BSONDEC): _get_decimal128, _maybe_ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), _maybe_ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w)} if _USE_C: def _element_to_dict(data, view, position, obj_end, opts): return _cbson._element_to_dict(data, position, obj_end, opts) else: def _element_to_dict(data, view, position, obj_end, opts): """Decode a single key, value pair.""" element_type = data[position] position += 1 element_name, position = _get_c_string(data, view, position, opts) try: value, position = _ELEMENT_GETTER[element_type](data, view, position, obj_end, opts, element_name) except KeyError: _raise_unknown_type(element_type, element_name) if opts.type_registry._decoder_map: custom_decoder = opts.type_registry._decoder_map.get(type(value)) if custom_decoder is not None: value = custom_decoder(value) return element_name, value, position def _raw_to_dict(data, position, obj_end, opts, result): data, view = get_data_and_view(data) return _elements_to_dict(data, view, position, obj_end, opts, result) def _elements_to_dict(data, view, position, obj_end, opts, result=None): """Decode a BSON document into result.""" if result is None: result = opts.document_class() end = obj_end - 1 while position < end: key, value, position = _element_to_dict(data, view, position, obj_end, opts) result[key] = value if position != obj_end: raise InvalidBSON('bad object or element length') return result def _bson_to_dict(data, opts): """Decode a BSON string to document_class.""" data, view = get_data_and_view(data) try: if _raw_document_class(opts.document_class): return opts.document_class(data, opts) _, end = _get_object_size(data, 0, len(data)) return _elements_to_dict(data, view, 4, end, opts) except InvalidBSON: raise except Exception: # Change exception type to InvalidBSON but preserve traceback. _, exc_value, exc_tb = sys.exc_info() reraise(InvalidBSON, exc_value, exc_tb) if _USE_C: _bson_to_dict = _cbson._bson_to_dict _PACK_FLOAT = struct.Struct(">> import collections # From Python standard library. >>> import bson >>> from bson.codec_options import CodecOptions >>> data = bson.encode({'a': 1}) >>> decoded_doc = bson.decode(data) >>> options = CodecOptions(document_class=collections.OrderedDict) >>> decoded_doc = bson.decode(data, codec_options=options) >>> type(decoded_doc) :Parameters: - `data`: the BSON to decode. Any bytes-like object that implements the buffer protocol. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionadded:: 3.9 """ if not isinstance(codec_options, CodecOptions): raise _CODEC_OPTIONS_TYPE_ERROR return _bson_to_dict(data, codec_options) def decode_all(data, codec_options=DEFAULT_CODEC_OPTIONS): """Decode BSON data to multiple documents. `data` must be a bytes-like object implementing the buffer protocol that provides concatenated, valid, BSON-encoded documents. :Parameters: - `data`: BSON data - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionchanged:: 3.9 Supports bytes-like objects that implement the buffer protocol. .. versionchanged:: 3.0 Removed `compile_re` option: PyMongo now always represents BSON regular expressions as :class:`~bson.regex.Regex` objects. Use :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a BSON regular expression to a Python regular expression object. Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. .. versionchanged:: 2.7 Added `compile_re` option. If set to False, PyMongo represented BSON regular expressions as :class:`~bson.regex.Regex` objects instead of attempting to compile BSON regular expressions as Python native regular expressions, thus preventing errors for some incompatible patterns, see `PYTHON-500`_. .. _PYTHON-500: https://jira.mongodb.org/browse/PYTHON-500 """ data, view = get_data_and_view(data) if not isinstance(codec_options, CodecOptions): raise _CODEC_OPTIONS_TYPE_ERROR data_len = len(data) docs = [] position = 0 end = data_len - 1 use_raw = _raw_document_class(codec_options.document_class) try: while position < end: obj_size = _UNPACK_INT_FROM(data, position)[0] if data_len - position < obj_size: raise InvalidBSON("invalid object size") obj_end = position + obj_size - 1 if data[obj_end] != _OBJEND: raise InvalidBSON("bad eoo") if use_raw: docs.append( codec_options.document_class( data[position:obj_end + 1], codec_options)) else: docs.append(_elements_to_dict(data, view, position + 4, obj_end, codec_options)) position += obj_size return docs except InvalidBSON: raise except Exception: # Change exception type to InvalidBSON but preserve traceback. _, exc_value, exc_tb = sys.exc_info() reraise(InvalidBSON, exc_value, exc_tb) if _USE_C: decode_all = _cbson.decode_all def _decode_selective(rawdoc, fields, codec_options): if _raw_document_class(codec_options.document_class): # If document_class is RawBSONDocument, use vanilla dictionary for # decoding command response. doc = {} else: # Else, use the specified document_class. doc = codec_options.document_class() for key, value in iteritems(rawdoc): if key in fields: if fields[key] == 1: doc[key] = _bson_to_dict(rawdoc.raw, codec_options)[key] else: doc[key] = _decode_selective(value, fields[key], codec_options) else: doc[key] = value return doc def _decode_all_selective(data, codec_options, fields): """Decode BSON data to a single document while using user-provided custom decoding logic. `data` must be a string representing a valid, BSON-encoded document. :Parameters: - `data`: BSON data - `codec_options`: An instance of :class:`~bson.codec_options.CodecOptions` with user-specified type decoders. If no decoders are found, this method is the same as ``decode_all``. - `fields`: Map of document namespaces where data that needs to be custom decoded lives or None. For example, to custom decode a list of objects in 'field1.subfield1', the specified value should be ``{'field1': {'subfield1': 1}}``. If ``fields`` is an empty map or None, this method is the same as ``decode_all``. :Returns: - `document_list`: Single-member list containing the decoded document. .. versionadded:: 3.8 """ if not codec_options.type_registry._decoder_map: return decode_all(data, codec_options) if not fields: return decode_all(data, codec_options.with_options(type_registry=None)) # Decode documents for internal use. from bson.raw_bson import RawBSONDocument internal_codec_options = codec_options.with_options( document_class=RawBSONDocument, type_registry=None) _doc = _bson_to_dict(data, internal_codec_options) return [_decode_selective(_doc, fields, codec_options,)] def decode_iter(data, codec_options=DEFAULT_CODEC_OPTIONS): """Decode BSON data to multiple documents as a generator. Works similarly to the decode_all function, but yields one document at a time. `data` must be a string of concatenated, valid, BSON-encoded documents. :Parameters: - `data`: BSON data - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionchanged:: 3.0 Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. .. versionadded:: 2.8 """ if not isinstance(codec_options, CodecOptions): raise _CODEC_OPTIONS_TYPE_ERROR position = 0 end = len(data) - 1 while position < end: obj_size = _UNPACK_INT_FROM(data, position)[0] elements = data[position:position + obj_size] position += obj_size yield _bson_to_dict(elements, codec_options) def decode_file_iter(file_obj, codec_options=DEFAULT_CODEC_OPTIONS): """Decode bson data from a file to multiple documents as a generator. Works similarly to the decode_all function, but reads from the file object in chunks and parses bson in chunks, yielding one document at a time. :Parameters: - `file_obj`: A file object containing BSON data. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionchanged:: 3.0 Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. .. versionadded:: 2.8 """ while True: # Read size of next object. size_data = file_obj.read(4) if not size_data: break # Finished with file normaly. elif len(size_data) != 4: raise InvalidBSON("cut off in middle of objsize") obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4 elements = size_data + file_obj.read(max(0, obj_size)) yield _bson_to_dict(elements, codec_options) def is_valid(bson): """Check that the given string represents valid :class:`BSON` data. Raises :class:`TypeError` if `bson` is not an instance of :class:`str` (:class:`bytes` in python 3). Returns ``True`` if `bson` is valid :class:`BSON`, ``False`` otherwise. :Parameters: - `bson`: the data to be validated """ if not isinstance(bson, bytes): raise TypeError("BSON data must be an instance of a subclass of bytes") try: _bson_to_dict(bson, DEFAULT_CODEC_OPTIONS) return True except Exception: return False class BSON(bytes): """BSON (Binary JSON) data. .. warning:: Using this class to encode and decode BSON adds a performance cost. For better performance use the module level functions :func:`encode` and :func:`decode` instead. """ @classmethod def encode(cls, document, check_keys=False, codec_options=DEFAULT_CODEC_OPTIONS): """Encode a document to a new :class:`BSON` instance. A document can be any mapping type (like :class:`dict`). Raises :class:`TypeError` if `document` is not a mapping type, or contains keys that are not instances of :class:`basestring` (:class:`str` in python 3). Raises :class:`~bson.errors.InvalidDocument` if `document` cannot be converted to :class:`BSON`. :Parameters: - `document`: mapping type representing a document - `check_keys` (optional): check if keys start with '$' or contain '.', raising :class:`~bson.errors.InvalidDocument` in either case - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionchanged:: 3.0 Replaced `uuid_subtype` option with `codec_options`. """ return cls(encode(document, check_keys, codec_options)) def decode(self, codec_options=DEFAULT_CODEC_OPTIONS): """Decode this BSON data. By default, returns a BSON document represented as a Python :class:`dict`. To use a different :class:`MutableMapping` class, configure a :class:`~bson.codec_options.CodecOptions`:: >>> import collections # From Python standard library. >>> import bson >>> from bson.codec_options import CodecOptions >>> data = bson.BSON.encode({'a': 1}) >>> decoded_doc = bson.BSON(data).decode() >>> options = CodecOptions(document_class=collections.OrderedDict) >>> decoded_doc = bson.BSON(data).decode(codec_options=options) >>> type(decoded_doc) :Parameters: - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. .. versionchanged:: 3.0 Removed `compile_re` option: PyMongo now always represents BSON regular expressions as :class:`~bson.regex.Regex` objects. Use :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a BSON regular expression to a Python regular expression object. Replaced `as_class`, `tz_aware`, and `uuid_subtype` options with `codec_options`. .. versionchanged:: 2.7 Added `compile_re` option. If set to False, PyMongo represented BSON regular expressions as :class:`~bson.regex.Regex` objects instead of attempting to compile BSON regular expressions as Python native regular expressions, thus preventing errors for some incompatible patterns, see `PYTHON-500`_. .. _PYTHON-500: https://jira.mongodb.org/browse/PYTHON-500 """ return decode(self, codec_options) def has_c(): """Is the C extension installed? """ return _USE_C pymongo-3.11.0/bson/_cbsonmodule.c000066400000000000000000003017451374256237000170650ustar00rootroot00000000000000/* * Copyright 2009-present MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * This file contains C implementations of some of the functions * needed by the bson module. If possible, these implementations * should be used to speed up BSON encoding and decoding. */ #define PY_SSIZE_T_CLEAN #include "Python.h" #include "datetime.h" #include "buffer.h" #include "time64.h" #include "encoding_helpers.h" #define _CBSON_MODULE #include "_cbsonmodule.h" /* New module state and initialization code. * See the module-initialization-and-state * section in the following doc: * http://docs.python.org/release/3.1.3/howto/cporting.html * which references the following pep: * http://www.python.org/dev/peps/pep-3121/ * */ struct module_state { PyObject* Binary; PyObject* Code; PyObject* ObjectId; PyObject* DBRef; PyObject* Regex; PyObject* UUID; PyObject* Timestamp; PyObject* MinKey; PyObject* MaxKey; PyObject* UTC; PyTypeObject* REType; PyObject* BSONInt64; PyObject* Decimal128; PyObject* Mapping; PyObject* CodecOptions; }; /* The Py_TYPE macro was introduced in CPython 2.6 */ #ifndef Py_TYPE #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) #endif #if PY_MAJOR_VERSION >= 3 #define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) #else #define GETSTATE(m) (&_state) static struct module_state _state; #endif /* Maximum number of regex flags */ #define FLAGS_SIZE 7 /* Default UUID representation type code. */ #define PYTHON_LEGACY 3 /* Other UUID representations. */ #define STANDARD 4 #define JAVA_LEGACY 5 #define CSHARP_LEGACY 6 #define UNSPECIFIED 0 #define BSON_MAX_SIZE 2147483647 /* The smallest possible BSON document, i.e. "{}" */ #define BSON_MIN_SIZE 5 /* Get an error class from the bson.errors module. * * Returns a new ref */ static PyObject* _error(char* name) { PyObject* error; PyObject* errors = PyImport_ImportModule("bson.errors"); if (!errors) { return NULL; } error = PyObject_GetAttrString(errors, name); Py_DECREF(errors); return error; } /* Safely downcast from Py_ssize_t to int, setting an * exception and returning -1 on error. */ static int _downcast_and_check(Py_ssize_t size, uint8_t extra) { if (size > BSON_MAX_SIZE || ((BSON_MAX_SIZE - extra) < size)) { PyObject* InvalidStringData = _error("InvalidStringData"); if (InvalidStringData) { PyErr_SetString(InvalidStringData, "String length must be <= 2147483647"); Py_DECREF(InvalidStringData); } return -1; } return (int)size + extra; } static PyObject* elements_to_dict(PyObject* self, const char* string, unsigned max, const codec_options_t* options); static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char in_custom_call, unsigned char in_fallback_call); /* Write a RawBSONDocument to the buffer. * Returns the number of bytes written or 0 on failure. */ static int write_raw_doc(buffer_t buffer, PyObject* raw); /* Date stuff */ static PyObject* datetime_from_millis(long long millis) { /* To encode a datetime instance like datetime(9999, 12, 31, 23, 59, 59, 999999) * we follow these steps: * 1. Calculate a timestamp in seconds: 253402300799 * 2. Multiply that by 1000: 253402300799000 * 3. Add in microseconds divided by 1000 253402300799999 * * (Note: BSON doesn't support microsecond accuracy, hence the rounding.) * * To decode we could do: * 1. Get seconds: timestamp / 1000: 253402300799 * 2. Get micros: (timestamp % 1000) * 1000: 999000 * Resulting in datetime(9999, 12, 31, 23, 59, 59, 999000) -- the expected result * * Now what if the we encode (1, 1, 1, 1, 1, 1, 111111)? * 1. and 2. gives: -62135593139000 * 3. Gives us: -62135593138889 * * Now decode: * 1. Gives us: -62135593138 * 2. Gives us: -889000 * Resulting in datetime(1, 1, 1, 1, 1, 2, 15888216) -- an invalid result * * If instead to decode we do: * diff = ((millis % 1000) + 1000) % 1000: 111 * seconds = (millis - diff) / 1000: -62135593139 * micros = diff * 1000 111000 * Resulting in datetime(1, 1, 1, 1, 1, 1, 111000) -- the expected result */ int diff = (int)(((millis % 1000) + 1000) % 1000); int microseconds = diff * 1000; Time64_T seconds = (millis - diff) / 1000; struct TM timeinfo; gmtime64_r(&seconds, &timeinfo); return PyDateTime_FromDateAndTime(timeinfo.tm_year + 1900, timeinfo.tm_mon + 1, timeinfo.tm_mday, timeinfo.tm_hour, timeinfo.tm_min, timeinfo.tm_sec, microseconds); } static long long millis_from_datetime(PyObject* datetime) { struct TM timeinfo; long long millis; timeinfo.tm_year = PyDateTime_GET_YEAR(datetime) - 1900; timeinfo.tm_mon = PyDateTime_GET_MONTH(datetime) - 1; timeinfo.tm_mday = PyDateTime_GET_DAY(datetime); timeinfo.tm_hour = PyDateTime_DATE_GET_HOUR(datetime); timeinfo.tm_min = PyDateTime_DATE_GET_MINUTE(datetime); timeinfo.tm_sec = PyDateTime_DATE_GET_SECOND(datetime); millis = timegm64(&timeinfo) * 1000; millis += PyDateTime_DATE_GET_MICROSECOND(datetime) / 1000; return millis; } /* Just make this compatible w/ the old API. */ int buffer_write_bytes(buffer_t buffer, const char* data, int size) { if (buffer_write(buffer, data, size)) { return 0; } return 1; } int buffer_write_double(buffer_t buffer, double data) { double data_le = BSON_DOUBLE_TO_LE(data); return buffer_write_bytes(buffer, (const char*)&data_le, 8); } int buffer_write_int32(buffer_t buffer, int32_t data) { uint32_t data_le = BSON_UINT32_TO_LE(data); return buffer_write_bytes(buffer, (const char*)&data_le, 4); } int buffer_write_int64(buffer_t buffer, int64_t data) { uint64_t data_le = BSON_UINT64_TO_LE(data); return buffer_write_bytes(buffer, (const char*)&data_le, 8); } void buffer_write_int32_at_position(buffer_t buffer, int position, int32_t data) { uint32_t data_le = BSON_UINT32_TO_LE(data); memcpy(buffer_get_buffer(buffer) + position, &data_le, 4); } static int write_unicode(buffer_t buffer, PyObject* py_string) { int size; const char* data; PyObject* encoded = PyUnicode_AsUTF8String(py_string); if (!encoded) { return 0; } #if PY_MAJOR_VERSION >= 3 data = PyBytes_AS_STRING(encoded); #else data = PyString_AS_STRING(encoded); #endif if (!data) goto unicodefail; #if PY_MAJOR_VERSION >= 3 if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) #else if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) #endif goto unicodefail; if (!buffer_write_int32(buffer, (int32_t)size)) goto unicodefail; if (!buffer_write_bytes(buffer, data, size)) goto unicodefail; Py_DECREF(encoded); return 1; unicodefail: Py_DECREF(encoded); return 0; } /* returns 0 on failure */ static int write_string(buffer_t buffer, PyObject* py_string) { int size; const char* data; #if PY_MAJOR_VERSION >= 3 if (PyUnicode_Check(py_string)){ return write_unicode(buffer, py_string); } data = PyBytes_AsString(py_string); #else data = PyString_AsString(py_string); #endif if (!data) { return 0; } #if PY_MAJOR_VERSION >= 3 if ((size = _downcast_and_check(PyBytes_Size(py_string), 1)) == -1) #else if ((size = _downcast_and_check(PyString_Size(py_string), 1)) == -1) #endif return 0; if (!buffer_write_int32(buffer, (int32_t)size)) { return 0; } if (!buffer_write_bytes(buffer, data, size)) { return 0; } return 1; } /* * Are we in the main interpreter or a sub-interpreter? * Useful for deciding if we can use cached pure python * types in mod_wsgi. */ static int _in_main_interpreter(void) { static PyInterpreterState* main_interpreter = NULL; PyInterpreterState* interpreter; if (main_interpreter == NULL) { interpreter = PyInterpreterState_Head(); while (PyInterpreterState_Next(interpreter)) interpreter = PyInterpreterState_Next(interpreter); main_interpreter = interpreter; } return (main_interpreter == PyThreadState_Get()->interp); } /* * Get a reference to a pure python type. If we are in the * main interpreter return the cached object, otherwise import * the object we need and return it instead. */ static PyObject* _get_object(PyObject* object, char* module_name, char* object_name) { if (_in_main_interpreter()) { Py_XINCREF(object); return object; } else { PyObject* imported = NULL; PyObject* module = PyImport_ImportModule(module_name); if (!module) return NULL; imported = PyObject_GetAttrString(module, object_name); Py_DECREF(module); return imported; } } /* Load a Python object to cache. * * Returns non-zero on failure. */ static int _load_object(PyObject** object, char* module_name, char* object_name) { PyObject* module; module = PyImport_ImportModule(module_name); if (!module) { return 1; } *object = PyObject_GetAttrString(module, object_name); Py_DECREF(module); return (*object) ? 0 : 2; } /* Load all Python objects to cache. * * Returns non-zero on failure. */ static int _load_python_objects(PyObject* module) { PyObject* empty_string = NULL; PyObject* re_compile = NULL; PyObject* compiled = NULL; struct module_state *state = GETSTATE(module); if (_load_object(&state->Binary, "bson.binary", "Binary") || _load_object(&state->Code, "bson.code", "Code") || _load_object(&state->ObjectId, "bson.objectid", "ObjectId") || _load_object(&state->DBRef, "bson.dbref", "DBRef") || _load_object(&state->Timestamp, "bson.timestamp", "Timestamp") || _load_object(&state->MinKey, "bson.min_key", "MinKey") || _load_object(&state->MaxKey, "bson.max_key", "MaxKey") || _load_object(&state->UTC, "bson.tz_util", "utc") || _load_object(&state->Regex, "bson.regex", "Regex") || _load_object(&state->BSONInt64, "bson.int64", "Int64") || _load_object(&state->Decimal128, "bson.decimal128", "Decimal128") || _load_object(&state->UUID, "uuid", "UUID") || #if PY_MAJOR_VERSION >= 3 _load_object(&state->Mapping, "collections.abc", "Mapping") || #else _load_object(&state->Mapping, "collections", "Mapping") || #endif _load_object(&state->CodecOptions, "bson.codec_options", "CodecOptions")) { return 1; } /* Reload our REType hack too. */ #if PY_MAJOR_VERSION >= 3 empty_string = PyBytes_FromString(""); #else empty_string = PyString_FromString(""); #endif if (empty_string == NULL) { state->REType = NULL; return 1; } if (_load_object(&re_compile, "re", "compile")) { state->REType = NULL; Py_DECREF(empty_string); return 1; } compiled = PyObject_CallFunction(re_compile, "O", empty_string); Py_DECREF(re_compile); if (compiled == NULL) { state->REType = NULL; Py_DECREF(empty_string); return 1; } Py_INCREF(Py_TYPE(compiled)); state->REType = Py_TYPE(compiled); Py_DECREF(empty_string); Py_DECREF(compiled); return 0; } /* * Get the _type_marker from an Object. * * Return the type marker, 0 if there is no marker, or -1 on failure. */ static long _type_marker(PyObject* object) { PyObject* type_marker = NULL; long type = 0; if (PyObject_HasAttrString(object, "_type_marker")) { type_marker = PyObject_GetAttrString(object, "_type_marker"); if (type_marker == NULL) { return -1; } } /* * Python objects with broken __getattr__ implementations could return * arbitrary types for a call to PyObject_GetAttrString. For example * pymongo.database.Database returns a new Collection instance for * __getattr__ calls with names that don't match an existing attribute * or method. In some cases "value" could be a subtype of something * we know how to serialize. Make a best effort to encode these types. */ #if PY_MAJOR_VERSION >= 3 if (type_marker && PyLong_CheckExact(type_marker)) { type = PyLong_AsLong(type_marker); #else if (type_marker && PyInt_CheckExact(type_marker)) { type = PyInt_AsLong(type_marker); #endif Py_DECREF(type_marker); /* * Py(Long|Int)_AsLong returns -1 for error but -1 is a valid value * so we call PyErr_Occurred to differentiate. */ if (type == -1 && PyErr_Occurred()) { return -1; } } else { Py_XDECREF(type_marker); } return type; } /* Fill out a type_registry_t* from a TypeRegistry object. * * Return 1 on success. options->document_class is a new reference. * Return 0 on failure. */ int convert_type_registry(PyObject* registry_obj, type_registry_t* registry) { registry->encoder_map = NULL; registry->decoder_map = NULL; registry->fallback_encoder = NULL; registry->registry_obj = NULL; registry->encoder_map = PyObject_GetAttrString(registry_obj, "_encoder_map"); if (registry->encoder_map == NULL) { goto fail; } registry->is_encoder_empty = (PyDict_Size(registry->encoder_map) == 0); registry->decoder_map = PyObject_GetAttrString(registry_obj, "_decoder_map"); if (registry->decoder_map == NULL) { goto fail; } registry->is_decoder_empty = (PyDict_Size(registry->decoder_map) == 0); registry->fallback_encoder = PyObject_GetAttrString(registry_obj, "_fallback_encoder"); if (registry->fallback_encoder == NULL) { goto fail; } registry->has_fallback_encoder = (registry->fallback_encoder != Py_None); registry->registry_obj = registry_obj; Py_INCREF(registry->registry_obj); return 1; fail: Py_XDECREF(registry->encoder_map); Py_XDECREF(registry->decoder_map); Py_XDECREF(registry->fallback_encoder); return 0; } /* Fill out a codec_options_t* from a CodecOptions object. Use with the "O&" * format spec in PyArg_ParseTuple. * * Return 1 on success. options->document_class is a new reference. * Return 0 on failure. */ int convert_codec_options(PyObject* options_obj, void* p) { codec_options_t* options = (codec_options_t*)p; PyObject* type_registry_obj = NULL; long type_marker; options->unicode_decode_error_handler = NULL; if (!PyArg_ParseTuple(options_obj, "ObbzOO", &options->document_class, &options->tz_aware, &options->uuid_rep, &options->unicode_decode_error_handler, &options->tzinfo, &type_registry_obj)) return 0; type_marker = _type_marker(options->document_class); if (type_marker < 0) { return 0; } if (!convert_type_registry(type_registry_obj, &options->type_registry)) { return 0; } options->is_raw_bson = (101 == type_marker); options->options_obj = options_obj; Py_INCREF(options->options_obj); Py_INCREF(options->document_class); Py_INCREF(options->tzinfo); return 1; } /* Fill out a codec_options_t* with default options. * * Return 1 on success. * Return 0 on failure. */ int default_codec_options(struct module_state* state, codec_options_t* options) { PyObject* options_obj = NULL; PyObject* codec_options_func = _get_object( state->CodecOptions, "bson.codec_options", "CodecOptions"); if (codec_options_func == NULL) { return 0; } options_obj = PyObject_CallFunctionObjArgs(codec_options_func, NULL); Py_DECREF(codec_options_func); if (options_obj == NULL) { return 0; } return convert_codec_options(options_obj, options); } void destroy_codec_options(codec_options_t* options) { Py_CLEAR(options->document_class); Py_CLEAR(options->tzinfo); Py_CLEAR(options->options_obj); Py_CLEAR(options->type_registry.registry_obj); Py_CLEAR(options->type_registry.encoder_map); Py_CLEAR(options->type_registry.decoder_map); Py_CLEAR(options->type_registry.fallback_encoder); } static int write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char in_custom_call, unsigned char in_fallback_call) { int result = 0; if(Py_EnterRecursiveCall(" while encoding an object to BSON ")) { return 0; } result = _write_element_to_buffer(self, buffer, type_byte, value, check_keys, options, in_custom_call, in_fallback_call); Py_LeaveRecursiveCall(); return result; } static void _set_cannot_encode(PyObject* value) { PyObject* type = NULL; PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument == NULL) { goto error; } type = PyObject_Type(value); if (type == NULL) { goto error; } #if PY_MAJOR_VERSION >= 3 PyErr_Format(InvalidDocument, "cannot encode object: %R, of type: %R", value, type); #else else { PyObject* value_repr = NULL; PyObject* type_repr = NULL; char* value_str = NULL; char* type_str = NULL; value_repr = PyObject_Repr(value); if (value_repr == NULL) { goto py2error; } value_str = PyString_AsString(value_repr); if (value_str == NULL) { goto py2error; } type_repr = PyObject_Repr(type); if (type_repr == NULL) { goto py2error; } type_str = PyString_AsString(type_repr); if (type_str == NULL) { goto py2error; } PyErr_Format(InvalidDocument, "cannot encode object: %s, of type: %s", value_str, type_str); py2error: Py_XDECREF(type_repr); Py_XDECREF(value_repr); } #endif error: Py_XDECREF(type); Py_XDECREF(InvalidDocument); } /* * Encode a builtin Python regular expression or our custom Regex class. * * Sets exception and returns 0 on failure. */ static int _write_regex_to_buffer( buffer_t buffer, int type_byte, PyObject* value) { PyObject* py_flags; PyObject* py_pattern; PyObject* encoded_pattern; long int_flags; char flags[FLAGS_SIZE]; char check_utf8 = 0; const char* pattern_data; int pattern_length, flags_length; result_t status; /* * Both the builtin re type and our Regex class have attributes * "flags" and "pattern". */ py_flags = PyObject_GetAttrString(value, "flags"); if (!py_flags) { return 0; } #if PY_MAJOR_VERSION >= 3 int_flags = PyLong_AsLong(py_flags); #else int_flags = PyInt_AsLong(py_flags); #endif Py_DECREF(py_flags); if (int_flags == -1 && PyErr_Occurred()) { return 0; } py_pattern = PyObject_GetAttrString(value, "pattern"); if (!py_pattern) { return 0; } if (PyUnicode_Check(py_pattern)) { encoded_pattern = PyUnicode_AsUTF8String(py_pattern); Py_DECREF(py_pattern); if (!encoded_pattern) { return 0; } } else { encoded_pattern = py_pattern; check_utf8 = 1; } #if PY_MAJOR_VERSION >= 3 if (!(pattern_data = PyBytes_AsString(encoded_pattern))) { Py_DECREF(encoded_pattern); return 0; } if ((pattern_length = _downcast_and_check(PyBytes_Size(encoded_pattern), 0)) == -1) { Py_DECREF(encoded_pattern); return 0; } #else if (!(pattern_data = PyString_AsString(encoded_pattern))) { Py_DECREF(encoded_pattern); return 0; } if ((pattern_length = _downcast_and_check(PyString_Size(encoded_pattern), 0)) == -1) { Py_DECREF(encoded_pattern); return 0; } #endif status = check_string((const unsigned char*)pattern_data, pattern_length, check_utf8, 1); if (status == NOT_UTF_8) { PyObject* InvalidStringData = _error("InvalidStringData"); if (InvalidStringData) { PyErr_SetString(InvalidStringData, "regex patterns must be valid UTF-8"); Py_DECREF(InvalidStringData); } Py_DECREF(encoded_pattern); return 0; } else if (status == HAS_NULL) { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { PyErr_SetString(InvalidDocument, "regex patterns must not contain the NULL byte"); Py_DECREF(InvalidDocument); } Py_DECREF(encoded_pattern); return 0; } if (!buffer_write_bytes(buffer, pattern_data, pattern_length + 1)) { Py_DECREF(encoded_pattern); return 0; } Py_DECREF(encoded_pattern); flags[0] = 0; if (int_flags & 2) { STRCAT(flags, FLAGS_SIZE, "i"); } if (int_flags & 4) { STRCAT(flags, FLAGS_SIZE, "l"); } if (int_flags & 8) { STRCAT(flags, FLAGS_SIZE, "m"); } if (int_flags & 16) { STRCAT(flags, FLAGS_SIZE, "s"); } if (int_flags & 32) { STRCAT(flags, FLAGS_SIZE, "u"); } if (int_flags & 64) { STRCAT(flags, FLAGS_SIZE, "x"); } flags_length = (int)strlen(flags) + 1; if (!buffer_write_bytes(buffer, flags, flags_length)) { return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x0B; return 1; } /* Write a single value to the buffer (also write its type_byte, for which * space has already been reserved. * * returns 0 on failure */ static int _write_element_to_buffer(PyObject* self, buffer_t buffer, int type_byte, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char in_custom_call, unsigned char in_fallback_call) { struct module_state *state = GETSTATE(self); PyObject* mapping_type; PyObject* new_value = NULL; int retval; PyObject* uuid_type; /* * Don't use PyObject_IsInstance for our custom types. It causes * problems with python sub interpreters. Our custom types should * have a _type_marker attribute, which we can switch on instead. */ long type = _type_marker(value); if (type < 0) { return 0; } switch (type) { case 5: { /* Binary */ PyObject* subtype_object; char subtype; const char* data; int size; *(buffer_get_buffer(buffer) + type_byte) = 0x05; subtype_object = PyObject_GetAttrString(value, "subtype"); if (!subtype_object) { return 0; } #if PY_MAJOR_VERSION >= 3 subtype = (char)PyLong_AsLong(subtype_object); #else subtype = (char)PyInt_AsLong(subtype_object); #endif if (subtype == -1) { Py_DECREF(subtype_object); return 0; } #if PY_MAJOR_VERSION >= 3 size = _downcast_and_check(PyBytes_Size(value), 0); #else size = _downcast_and_check(PyString_Size(value), 0); #endif if (size == -1) { Py_DECREF(subtype_object); return 0; } Py_DECREF(subtype_object); if (subtype == 2) { #if PY_MAJOR_VERSION >= 3 int other_size = _downcast_and_check(PyBytes_Size(value), 4); #else int other_size = _downcast_and_check(PyString_Size(value), 4); #endif if (other_size == -1) return 0; if (!buffer_write_int32(buffer, other_size)) { return 0; } if (!buffer_write_bytes(buffer, &subtype, 1)) { return 0; } } if (!buffer_write_int32(buffer, size)) { return 0; } if (subtype != 2) { if (!buffer_write_bytes(buffer, &subtype, 1)) { return 0; } } #if PY_MAJOR_VERSION >= 3 data = PyBytes_AsString(value); #else data = PyString_AsString(value); #endif if (!data) { return 0; } if (!buffer_write_bytes(buffer, data, size)) { return 0; } return 1; } case 7: { /* ObjectId */ const char* data; PyObject* pystring = PyObject_GetAttrString(value, "binary"); if (!pystring) { return 0; } #if PY_MAJOR_VERSION >= 3 data = PyBytes_AsString(pystring); #else data = PyString_AsString(pystring); #endif if (!data) { Py_DECREF(pystring); return 0; } if (!buffer_write_bytes(buffer, data, 12)) { Py_DECREF(pystring); return 0; } Py_DECREF(pystring); *(buffer_get_buffer(buffer) + type_byte) = 0x07; return 1; } case 11: { /* Regex */ return _write_regex_to_buffer(buffer, type_byte, value); } case 13: { /* Code */ int start_position, length_location, length; PyObject* scope = PyObject_GetAttrString(value, "scope"); if (!scope) { return 0; } if (scope == Py_None) { Py_DECREF(scope); *(buffer_get_buffer(buffer) + type_byte) = 0x0D; return write_string(buffer, value); } *(buffer_get_buffer(buffer) + type_byte) = 0x0F; start_position = buffer_get_position(buffer); /* save space for length */ length_location = buffer_save_space(buffer, 4); if (length_location == -1) { Py_DECREF(scope); return 0; } if (!write_string(buffer, value)) { Py_DECREF(scope); return 0; } if (!write_dict(self, buffer, scope, 0, options, 0)) { Py_DECREF(scope); return 0; } Py_DECREF(scope); length = buffer_get_position(buffer) - start_position; buffer_write_int32_at_position( buffer, length_location, (int32_t)length); return 1; } case 17: { /* Timestamp */ PyObject* obj; unsigned long i; obj = PyObject_GetAttrString(value, "inc"); if (!obj) { return 0; } i = PyLong_AsUnsignedLong(obj); Py_DECREF(obj); if (i == (unsigned long)-1 && PyErr_Occurred()) { return 0; } if (!buffer_write_int32(buffer, (int32_t)i)) { return 0; } obj = PyObject_GetAttrString(value, "time"); if (!obj) { return 0; } i = PyLong_AsUnsignedLong(obj); Py_DECREF(obj); if (i == (unsigned long)-1 && PyErr_Occurred()) { return 0; } if (!buffer_write_int32(buffer, (int32_t)i)) { return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x11; return 1; } case 18: { /* Int64 */ const long long ll = PyLong_AsLongLong(value); if (PyErr_Occurred()) { /* Overflow */ PyErr_SetString(PyExc_OverflowError, "MongoDB can only handle up to 8-byte ints"); return 0; } if (!buffer_write_int64(buffer, (int64_t)ll)) { return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x12; return 1; } case 19: { /* Decimal128 */ const char* data; PyObject* pystring = PyObject_GetAttrString(value, "bid"); if (!pystring) { return 0; } #if PY_MAJOR_VERSION >= 3 data = PyBytes_AsString(pystring); #else data = PyString_AsString(pystring); #endif if (!data) { Py_DECREF(pystring); return 0; } if (!buffer_write_bytes(buffer, data, 16)) { Py_DECREF(pystring); return 0; } Py_DECREF(pystring); *(buffer_get_buffer(buffer) + type_byte) = 0x13; return 1; } case 100: { /* DBRef */ PyObject* as_doc = PyObject_CallMethod(value, "as_doc", NULL); if (!as_doc) { return 0; } if (!write_dict(self, buffer, as_doc, 0, options, 0)) { Py_DECREF(as_doc); return 0; } Py_DECREF(as_doc); *(buffer_get_buffer(buffer) + type_byte) = 0x03; return 1; } case 101: { /* RawBSONDocument */ if (!write_raw_doc(buffer, value)) { return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x03; return 1; } case 255: { /* MinKey */ *(buffer_get_buffer(buffer) + type_byte) = 0xFF; return 1; } case 127: { /* MaxKey */ *(buffer_get_buffer(buffer) + type_byte) = 0x7F; return 1; } } /* No _type_marker attibute or not one of our types. */ if (PyBool_Check(value)) { const char c = (value == Py_True) ? 0x01 : 0x00; *(buffer_get_buffer(buffer) + type_byte) = 0x08; return buffer_write_bytes(buffer, &c, 1); } #if PY_MAJOR_VERSION >= 3 else if (PyLong_Check(value)) { const long long_value = PyLong_AsLong(value); #else else if (PyInt_Check(value)) { const long long_value = PyInt_AsLong(value); #endif const int int_value = (int)long_value; if (PyErr_Occurred() || long_value != int_value) { /* Overflow */ long long long_long_value; PyErr_Clear(); long_long_value = PyLong_AsLongLong(value); if (PyErr_Occurred()) { /* Overflow AGAIN */ PyErr_SetString(PyExc_OverflowError, "MongoDB can only handle up to 8-byte ints"); return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x12; return buffer_write_int64(buffer, (int64_t)long_long_value); } *(buffer_get_buffer(buffer) + type_byte) = 0x10; return buffer_write_int32(buffer, (int32_t)int_value); #if PY_MAJOR_VERSION < 3 } else if (PyLong_Check(value)) { const long long long_long_value = PyLong_AsLongLong(value); if (PyErr_Occurred()) { /* Overflow */ PyErr_SetString(PyExc_OverflowError, "MongoDB can only handle up to 8-byte ints"); return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x12; return buffer_write_int64(buffer, (int64_t)long_long_value); #endif } else if (PyFloat_Check(value)) { const double d = PyFloat_AsDouble(value); *(buffer_get_buffer(buffer) + type_byte) = 0x01; return buffer_write_double(buffer, d); } else if (value == Py_None) { *(buffer_get_buffer(buffer) + type_byte) = 0x0A; return 1; } else if (PyDict_Check(value)) { *(buffer_get_buffer(buffer) + type_byte) = 0x03; return write_dict(self, buffer, value, check_keys, options, 0); } else if (PyList_Check(value) || PyTuple_Check(value)) { Py_ssize_t items, i; int start_position, length_location, length; char zero = 0; *(buffer_get_buffer(buffer) + type_byte) = 0x04; start_position = buffer_get_position(buffer); /* save space for length */ length_location = buffer_save_space(buffer, 4); if (length_location == -1) { return 0; } if ((items = PySequence_Size(value)) > BSON_MAX_SIZE) { PyObject* BSONError = _error("BSONError"); if (BSONError) { PyErr_SetString(BSONError, "Too many items to serialize."); Py_DECREF(BSONError); } return 0; } for(i = 0; i < items; i++) { int list_type_byte = buffer_save_space(buffer, 1); char name[16]; PyObject* item_value; if (list_type_byte == -1) { return 0; } INT2STRING(name, (int)i); if (!buffer_write_bytes(buffer, name, (int)strlen(name) + 1)) { return 0; } if (!(item_value = PySequence_GetItem(value, i))) return 0; if (!write_element_to_buffer(self, buffer, list_type_byte, item_value, check_keys, options, 0, 0)) { Py_DECREF(item_value); return 0; } Py_DECREF(item_value); } /* write null byte and fill in length */ if (!buffer_write_bytes(buffer, &zero, 1)) { return 0; } length = buffer_get_position(buffer) - start_position; buffer_write_int32_at_position( buffer, length_location, (int32_t)length); return 1; #if PY_MAJOR_VERSION >= 3 /* Python3 special case. Store bytes as BSON binary subtype 0. */ } else if (PyBytes_Check(value)) { char subtype = 0; int size; const char* data = PyBytes_AS_STRING(value); if (!data) return 0; if ((size = _downcast_and_check(PyBytes_GET_SIZE(value), 0)) == -1) return 0; *(buffer_get_buffer(buffer) + type_byte) = 0x05; if (!buffer_write_int32(buffer, (int32_t)size)) { return 0; } if (!buffer_write_bytes(buffer, &subtype, 1)) { return 0; } if (!buffer_write_bytes(buffer, data, size)) { return 0; } return 1; #else /* PyString_Check only works in Python 2.x. */ } else if (PyString_Check(value)) { result_t status; const char* data; int size; if (!(data = PyString_AS_STRING(value))) return 0; if ((size = _downcast_and_check(PyString_GET_SIZE(value), 1)) == -1) return 0; *(buffer_get_buffer(buffer) + type_byte) = 0x02; status = check_string((const unsigned char*)data, size - 1, 1, 0); if (status == NOT_UTF_8) { PyObject* InvalidStringData = _error("InvalidStringData"); if (InvalidStringData) { PyObject* repr = PyObject_Repr(value); char* repr_as_cstr = repr ? PyString_AsString(repr) : NULL; if (repr_as_cstr) { PyObject *message = PyString_FromFormat( "strings in documents must be valid UTF-8: %s", repr_as_cstr); if (message) { PyErr_SetObject(InvalidStringData, message); Py_DECREF(message); } } else { /* repr(value) failed, use a generic message. */ PyErr_SetString( InvalidStringData, "strings in documents must be valid UTF-8"); } Py_XDECREF(repr); Py_DECREF(InvalidStringData); } return 0; } if (!buffer_write_int32(buffer, (int32_t)size)) { return 0; } if (!buffer_write_bytes(buffer, data, size)) { return 0; } return 1; #endif } else if (PyUnicode_Check(value)) { *(buffer_get_buffer(buffer) + type_byte) = 0x02; return write_unicode(buffer, value); } else if (PyDateTime_Check(value)) { long long millis; PyObject* utcoffset = PyObject_CallMethod(value, "utcoffset", NULL); if (utcoffset == NULL) return 0; if (utcoffset != Py_None) { PyObject* result = PyNumber_Subtract(value, utcoffset); Py_DECREF(utcoffset); if (!result) { return 0; } millis = millis_from_datetime(result); Py_DECREF(result); } else { millis = millis_from_datetime(value); } *(buffer_get_buffer(buffer) + type_byte) = 0x09; return buffer_write_int64(buffer, (int64_t)millis); } else if (PyObject_TypeCheck(value, state->REType)) { return _write_regex_to_buffer(buffer, type_byte, value); } /* * Try Mapping and UUID last since we have to import * them if we're in a sub-interpreter. */ #if PY_MAJOR_VERSION >= 3 mapping_type = _get_object(state->Mapping, "collections.abc", "Mapping"); #else mapping_type = _get_object(state->Mapping, "collections", "Mapping"); #endif if (mapping_type && PyObject_IsInstance(value, mapping_type)) { Py_DECREF(mapping_type); /* PyObject_IsInstance returns -1 on error */ if (PyErr_Occurred()) { return 0; } *(buffer_get_buffer(buffer) + type_byte) = 0x03; return write_dict(self, buffer, value, check_keys, options, 0); } uuid_type = _get_object(state->UUID, "uuid", "UUID"); if (uuid_type && PyObject_IsInstance(value, uuid_type)) { PyObject* binary_type = NULL; PyObject* binary_value = NULL; int result; Py_DECREF(uuid_type); /* PyObject_IsInstance returns -1 on error */ if (PyErr_Occurred()) { return 0; } binary_type = _get_object(state->Binary, "bson", "Binary"); if (binary_type == NULL) { return 0; } binary_value = PyObject_CallMethod(binary_type, "from_uuid", "(Oi)", value, options->uuid_rep); if (binary_value == NULL) { Py_DECREF(binary_type); return 0; } result = _write_element_to_buffer(self, buffer, type_byte, binary_value, check_keys, options, in_custom_call, in_fallback_call); Py_DECREF(binary_type); Py_DECREF(binary_value); return result; } Py_XDECREF(mapping_type); Py_XDECREF(uuid_type); /* Try a custom encoder if one is provided and we have not already * attempted to use a type encoder. */ if (!in_custom_call && !options->type_registry.is_encoder_empty) { PyObject* value_type = NULL; PyObject* converter = NULL; value_type = PyObject_Type(value); if (value_type == NULL) { return 0; } converter = PyDict_GetItem(options->type_registry.encoder_map, value_type); Py_XDECREF(value_type); if (converter != NULL) { /* Transform types that have a registered converter. * A new reference is created upon transformation. */ new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); if (new_value == NULL) { return 0; } retval = write_element_to_buffer(self, buffer, type_byte, new_value, check_keys, options, 1, 0); Py_XDECREF(new_value); return retval; } } /* Try the fallback encoder if one is provided and we have not already * attempted to use the fallback encoder. */ if (!in_fallback_call && options->type_registry.has_fallback_encoder) { new_value = PyObject_CallFunctionObjArgs( options->type_registry.fallback_encoder, value, NULL); if (new_value == NULL) { // propagate any exception raised by the callback return 0; } retval = write_element_to_buffer(self, buffer, type_byte, new_value, check_keys, options, 0, 1); Py_XDECREF(new_value); return retval; } /* We can't determine value's type. Fail. */ _set_cannot_encode(value); return 0; } static int check_key_name(const char* name, int name_length) { if (name_length > 0 && name[0] == '$') { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { #if PY_MAJOR_VERSION >= 3 PyObject* errmsg = PyUnicode_FromFormat( "key '%s' must not start with '$'", name); #else PyObject* errmsg = PyString_FromFormat( "key '%s' must not start with '$'", name); #endif if (errmsg) { PyErr_SetObject(InvalidDocument, errmsg); Py_DECREF(errmsg); } Py_DECREF(InvalidDocument); } return 0; } if (strchr(name, '.')) { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { #if PY_MAJOR_VERSION >= 3 PyObject* errmsg = PyUnicode_FromFormat( "key '%s' must not contain '.'", name); #else PyObject* errmsg = PyString_FromFormat( "key '%s' must not contain '.'", name); #endif if (errmsg) { PyErr_SetObject(InvalidDocument, errmsg); Py_DECREF(errmsg); } Py_DECREF(InvalidDocument); } return 0; } return 1; } /* Write a (key, value) pair to the buffer. * * Returns 0 on failure */ int write_pair(PyObject* self, buffer_t buffer, const char* name, int name_length, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char allow_id) { int type_byte; /* Don't write any _id elements unless we're explicitly told to - * _id has to be written first so we do so, but don't bother * deleting it from the dictionary being written. */ if (!allow_id && strcmp(name, "_id") == 0) { return 1; } type_byte = buffer_save_space(buffer, 1); if (type_byte == -1) { return 0; } if (check_keys && !check_key_name(name, name_length)) { return 0; } if (!buffer_write_bytes(buffer, name, name_length + 1)) { return 0; } if (!write_element_to_buffer(self, buffer, type_byte, value, check_keys, options, 0, 0)) { return 0; } return 1; } int decode_and_write_pair(PyObject* self, buffer_t buffer, PyObject* key, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) { PyObject* encoded; const char* data; int size; if (PyUnicode_Check(key)) { encoded = PyUnicode_AsUTF8String(key); if (!encoded) { return 0; } #if PY_MAJOR_VERSION >= 3 if (!(data = PyBytes_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } if ((size = _downcast_and_check(PyBytes_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } #else if (!(data = PyString_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } #endif if (strlen(data) != (size_t)(size - 1)) { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { PyErr_SetString(InvalidDocument, "Key names must not contain the NULL byte"); Py_DECREF(InvalidDocument); } Py_DECREF(encoded); return 0; } #if PY_MAJOR_VERSION < 3 } else if (PyString_Check(key)) { result_t status; encoded = key; Py_INCREF(encoded); if (!(data = PyString_AS_STRING(encoded))) { Py_DECREF(encoded); return 0; } if ((size = _downcast_and_check(PyString_GET_SIZE(encoded), 1)) == -1) { Py_DECREF(encoded); return 0; } status = check_string((const unsigned char*)data, size - 1, 1, 1); if (status == NOT_UTF_8) { PyObject* InvalidStringData = _error("InvalidStringData"); if (InvalidStringData) { PyErr_SetString(InvalidStringData, "strings in documents must be valid UTF-8"); Py_DECREF(InvalidStringData); } Py_DECREF(encoded); return 0; } else if (status == HAS_NULL) { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { PyErr_SetString(InvalidDocument, "Key names must not contain the NULL byte"); Py_DECREF(InvalidDocument); } Py_DECREF(encoded); return 0; } #endif } else { PyObject* InvalidDocument = _error("InvalidDocument"); if (InvalidDocument) { PyObject* repr = PyObject_Repr(key); if (repr) { #if PY_MAJOR_VERSION >= 3 PyObject* errmsg = PyUnicode_FromString( "documents must have only string keys, key was "); #else PyObject* errmsg = PyString_FromString( "documents must have only string keys, key was "); #endif if (errmsg) { #if PY_MAJOR_VERSION >= 3 PyObject* error = PyUnicode_Concat(errmsg, repr); if (error) { PyErr_SetObject(InvalidDocument, error); Py_DECREF(error); } Py_DECREF(errmsg); Py_DECREF(repr); #else PyString_ConcatAndDel(&errmsg, repr); if (errmsg) { PyErr_SetObject(InvalidDocument, errmsg); Py_DECREF(errmsg); } #endif } else { Py_DECREF(repr); } } Py_DECREF(InvalidDocument); } return 0; } /* If top_level is True, don't allow writing _id here - it was already written. */ if (!write_pair(self, buffer, data, size - 1, value, check_keys, options, !top_level)) { Py_DECREF(encoded); return 0; } Py_DECREF(encoded); return 1; } /* Write a RawBSONDocument to the buffer. * Returns the number of bytes written or 0 on failure. */ static int write_raw_doc(buffer_t buffer, PyObject* raw) { char* bytes; Py_ssize_t len; int len_int; int bytes_written = 0; PyObject* bytes_obj = NULL; bytes_obj = PyObject_GetAttrString(raw, "raw"); if (!bytes_obj) { goto fail; } if (-1 == PyBytes_AsStringAndSize(bytes_obj, &bytes, &len)) { goto fail; } len_int = _downcast_and_check(len, 0); if (-1 == len_int) { goto fail; } if (!buffer_write_bytes(buffer, bytes, len_int)) { goto fail; } bytes_written = len_int; fail: Py_XDECREF(bytes_obj); return bytes_written; } /* returns the number of bytes written or 0 on failure */ int write_dict(PyObject* self, buffer_t buffer, PyObject* dict, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) { PyObject* key; PyObject* iter; char zero = 0; int length; int length_location; struct module_state *state = GETSTATE(self); PyObject* mapping_type; long type_marker; /* check for RawBSONDocument */ type_marker = _type_marker(dict); if (type_marker < 0) { return 0; } if (101 == type_marker) { return write_raw_doc(buffer, dict); } #if PY_MAJOR_VERSION >= 3 mapping_type = _get_object(state->Mapping, "collections.abc", "Mapping"); #else mapping_type = _get_object(state->Mapping, "collections", "Mapping"); #endif if (mapping_type) { if (!PyObject_IsInstance(dict, mapping_type)) { PyObject* repr; Py_DECREF(mapping_type); if ((repr = PyObject_Repr(dict))) { #if PY_MAJOR_VERSION >= 3 PyObject* errmsg = PyUnicode_FromString( "encoder expected a mapping type but got: "); if (errmsg) { PyObject* error = PyUnicode_Concat(errmsg, repr); if (error) { PyErr_SetObject(PyExc_TypeError, error); Py_DECREF(error); } Py_DECREF(errmsg); Py_DECREF(repr); } #else PyObject* errmsg = PyString_FromString( "encoder expected a mapping type but got: "); if (errmsg) { PyString_ConcatAndDel(&errmsg, repr); if (errmsg) { PyErr_SetObject(PyExc_TypeError, errmsg); Py_DECREF(errmsg); } } #endif else { Py_DECREF(repr); } } else { PyErr_SetString(PyExc_TypeError, "encoder expected a mapping type"); } return 0; } Py_DECREF(mapping_type); /* PyObject_IsInstance returns -1 on error */ if (PyErr_Occurred()) { return 0; } } length_location = buffer_save_space(buffer, 4); if (length_location == -1) { return 0; } /* Write _id first if this is a top level doc. */ if (top_level) { /* * If "dict" is a defaultdict we don't want to call * PyMapping_GetItemString on it. That would **create** * an _id where one didn't previously exist (PYTHON-871). */ if (PyDict_Check(dict)) { /* PyDict_GetItemString returns a borrowed reference. */ PyObject* _id = PyDict_GetItemString(dict, "_id"); if (_id) { if (!write_pair(self, buffer, "_id", 3, _id, check_keys, options, 1)) { return 0; } } } else if (PyMapping_HasKeyString(dict, "_id")) { PyObject* _id = PyMapping_GetItemString(dict, "_id"); if (!_id) { return 0; } if (!write_pair(self, buffer, "_id", 3, _id, check_keys, options, 1)) { Py_DECREF(_id); return 0; } /* PyMapping_GetItemString returns a new reference. */ Py_DECREF(_id); } } iter = PyObject_GetIter(dict); if (iter == NULL) { return 0; } while ((key = PyIter_Next(iter)) != NULL) { PyObject* value = PyObject_GetItem(dict, key); if (!value) { PyErr_SetObject(PyExc_KeyError, key); Py_DECREF(key); Py_DECREF(iter); return 0; } if (!decode_and_write_pair(self, buffer, key, value, check_keys, options, top_level)) { Py_DECREF(key); Py_DECREF(value); Py_DECREF(iter); return 0; } Py_DECREF(key); Py_DECREF(value); } Py_DECREF(iter); if (PyErr_Occurred()) { return 0; } /* write null byte and fill in length */ if (!buffer_write_bytes(buffer, &zero, 1)) { return 0; } length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)length); return length; } static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) { PyObject* dict; PyObject* result; unsigned char check_keys; unsigned char top_level = 1; codec_options_t options; buffer_t buffer; PyObject* raw_bson_document_bytes_obj; long type_marker; if (!PyArg_ParseTuple(args, "ObO&|b", &dict, &check_keys, convert_codec_options, &options, &top_level)) { return NULL; } /* check for RawBSONDocument */ type_marker = _type_marker(dict); if (type_marker < 0) { destroy_codec_options(&options); return NULL; } else if (101 == type_marker) { destroy_codec_options(&options); raw_bson_document_bytes_obj = PyObject_GetAttrString(dict, "raw"); if (NULL == raw_bson_document_bytes_obj) { return NULL; } return raw_bson_document_bytes_obj; } buffer = buffer_new(); if (!buffer) { destroy_codec_options(&options); return NULL; } if (!write_dict(self, buffer, dict, check_keys, &options, top_level)) { destroy_codec_options(&options); buffer_free(buffer); return NULL; } /* objectify buffer */ result = Py_BuildValue(BYTES_FORMAT_STRING, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer)); destroy_codec_options(&options); buffer_free(buffer); return result; } static PyObject* get_value(PyObject* self, PyObject* name, const char* buffer, unsigned* position, unsigned char type, unsigned max, const codec_options_t* options) { struct module_state *state = GETSTATE(self); PyObject* value = NULL; switch (type) { case 1: { double d; if (max < 8) { goto invalid; } memcpy(&d, buffer + *position, 8); value = PyFloat_FromDouble(BSON_DOUBLE_FROM_LE(d)); *position += 8; break; } case 2: case 14: { uint32_t value_length; if (max < 4) { goto invalid; } memcpy(&value_length, buffer + *position, 4); value_length = BSON_UINT32_FROM_LE(value_length); /* Encoded string length + string */ if (!value_length || max < value_length || max < 4 + value_length) { goto invalid; } *position += 4; /* Strings must end in \0 */ if (buffer[*position + value_length - 1]) { goto invalid; } value = PyUnicode_DecodeUTF8( buffer + *position, value_length - 1, options->unicode_decode_error_handler); if (!value) { goto invalid; } *position += value_length; break; } case 3: { PyObject* collection; uint32_t size; if (max < 4) { goto invalid; } memcpy(&size, buffer + *position, 4); size = BSON_UINT32_FROM_LE(size); if (size < BSON_MIN_SIZE || max < size) { goto invalid; } /* Check for bad eoo */ if (buffer[*position + size - 1]) { goto invalid; } if (options->is_raw_bson) { value = PyObject_CallFunction( options->document_class, BYTES_FORMAT_STRING "O", buffer + *position, (Py_ssize_t)size, options->options_obj); if (!value) { goto invalid; } *position += size; break; } value = elements_to_dict(self, buffer + *position + 4, size - 5, options); if (!value) { goto invalid; } /* Decoding for DBRefs */ if (PyMapping_HasKeyString(value, "$ref")) { /* DBRef */ PyObject* dbref = NULL; PyObject* dbref_type; PyObject* id; PyObject* database; collection = PyMapping_GetItemString(value, "$ref"); /* PyMapping_GetItemString returns NULL to indicate error. */ if (!collection) { goto invalid; } PyMapping_DelItemString(value, "$ref"); if (PyMapping_HasKeyString(value, "$id")) { id = PyMapping_GetItemString(value, "$id"); if (!id) { Py_DECREF(collection); goto invalid; } PyMapping_DelItemString(value, "$id"); } else { id = Py_None; Py_INCREF(id); } if (PyMapping_HasKeyString(value, "$db")) { database = PyMapping_GetItemString(value, "$db"); if (!database) { Py_DECREF(collection); Py_DECREF(id); goto invalid; } PyMapping_DelItemString(value, "$db"); } else { database = Py_None; Py_INCREF(database); } if ((dbref_type = _get_object(state->DBRef, "bson.dbref", "DBRef"))) { dbref = PyObject_CallFunctionObjArgs(dbref_type, collection, id, database, value, NULL); Py_DECREF(dbref_type); } Py_DECREF(value); value = dbref; Py_DECREF(id); Py_DECREF(collection); Py_DECREF(database); } *position += size; break; } case 4: { uint32_t size, end; if (max < 4) { goto invalid; } memcpy(&size, buffer + *position, 4); size = BSON_UINT32_FROM_LE(size); if (size < BSON_MIN_SIZE || max < size) { goto invalid; } end = *position + size - 1; /* Check for bad eoo */ if (buffer[end]) { goto invalid; } *position += 4; value = PyList_New(0); if (!value) { goto invalid; } while (*position < end) { PyObject* to_append; unsigned char bson_type = (unsigned char)buffer[(*position)++]; size_t key_size = strlen(buffer + *position); if (max < key_size) { Py_DECREF(value); goto invalid; } /* just skip the key, they're in order. */ *position += (unsigned)key_size + 1; if (Py_EnterRecursiveCall(" while decoding a list value")) { Py_DECREF(value); goto invalid; } to_append = get_value(self, name, buffer, position, bson_type, max - (unsigned)key_size, options); Py_LeaveRecursiveCall(); if (!to_append) { Py_DECREF(value); goto invalid; } if (PyList_Append(value, to_append) < 0) { Py_DECREF(value); Py_DECREF(to_append); goto invalid; } Py_DECREF(to_append); } if (*position != end) { goto invalid; } (*position)++; break; } case 5: { PyObject* data; PyObject* st; PyObject* type_to_create; uint32_t length, length2; unsigned char subtype; if (max < 5) { goto invalid; } memcpy(&length, buffer + *position, 4); length = BSON_UINT32_FROM_LE(length); if (max < length) { goto invalid; } subtype = (unsigned char)buffer[*position + 4]; *position += 5; if (subtype == 2) { if (length < 4) { goto invalid; } memcpy(&length2, buffer + *position, 4); length2 = BSON_UINT32_FROM_LE(length2); if (length2 != length - 4) { goto invalid; } } #if PY_MAJOR_VERSION >= 3 /* Python3 special case. Decode BSON binary subtype 0 to bytes. */ if (subtype == 0) { value = PyBytes_FromStringAndSize(buffer + *position, length); *position += length; break; } if (subtype == 2) { data = PyBytes_FromStringAndSize(buffer + *position + 4, length - 4); } else { data = PyBytes_FromStringAndSize(buffer + *position, length); } #else if (subtype == 2) { data = PyString_FromStringAndSize(buffer + *position + 4, length - 4); } else { data = PyString_FromStringAndSize(buffer + *position, length); } #endif if (!data) { goto invalid; } /* Encode as UUID or Binary based on options->uuid_rep * TODO: PYTHON-2245 Decoding should follow UUID spec in PyMongo 4.0 */ if (subtype == 3 || subtype == 4) { PyObject* binary_type = NULL; PyObject* binary_value = NULL; char uuid_rep = options->uuid_rep; /* UUID should always be 16 bytes */ if (length != 16) { goto uuiderror; } binary_type = _get_object(state->Binary, "bson", "Binary"); if (binary_type == NULL) { goto uuiderror; } binary_value = PyObject_CallFunction(binary_type, "(Oi)", data, subtype); if (binary_value == NULL) { goto uuiderror; } if (uuid_rep == UNSPECIFIED) { value = binary_value; Py_INCREF(value); } else { if (subtype == 4) { uuid_rep = STANDARD; } else if (uuid_rep == STANDARD) { uuid_rep = PYTHON_LEGACY; } value = PyObject_CallMethod(binary_value, "as_uuid", "(i)", uuid_rep); } uuiderror: Py_XDECREF(binary_type); Py_XDECREF(binary_value); Py_DECREF(data); if (!value) { goto invalid; } *position += length; break; } #if PY_MAJOR_VERSION >= 3 st = PyLong_FromLong(subtype); #else st = PyInt_FromLong(subtype); #endif if (!st) { Py_DECREF(data); goto invalid; } if ((type_to_create = _get_object(state->Binary, "bson.binary", "Binary"))) { value = PyObject_CallFunctionObjArgs(type_to_create, data, st, NULL); Py_DECREF(type_to_create); } Py_DECREF(st); Py_DECREF(data); if (!value) { goto invalid; } *position += length; break; } case 6: case 10: { value = Py_None; Py_INCREF(value); break; } case 7: { PyObject* objectid_type; if (max < 12) { goto invalid; } if ((objectid_type = _get_object(state->ObjectId, "bson.objectid", "ObjectId"))) { value = PyObject_CallFunction(objectid_type, BYTES_FORMAT_STRING, buffer + *position, (Py_ssize_t)12); Py_DECREF(objectid_type); } *position += 12; break; } case 8: { char boolean_raw = buffer[(*position)++]; if (0 == boolean_raw) { value = Py_False; } else if (1 == boolean_raw) { value = Py_True; } else { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_Format(InvalidBSON, "invalid boolean value: %x", boolean_raw); Py_DECREF(InvalidBSON); } return NULL; } Py_INCREF(value); break; } case 9: { PyObject* utc_type; PyObject* naive; PyObject* replace; PyObject* args; PyObject* kwargs; PyObject* astimezone; int64_t millis; if (max < 8) { goto invalid; } memcpy(&millis, buffer + *position, 8); millis = (int64_t)BSON_UINT64_FROM_LE(millis); naive = datetime_from_millis(millis); *position += 8; if (!options->tz_aware) { /* In the naive case, we're done here. */ value = naive; break; } if (!naive) { goto invalid; } replace = PyObject_GetAttrString(naive, "replace"); Py_DECREF(naive); if (!replace) { goto invalid; } args = PyTuple_New(0); if (!args) { Py_DECREF(replace); goto invalid; } kwargs = PyDict_New(); if (!kwargs) { Py_DECREF(replace); Py_DECREF(args); goto invalid; } utc_type = _get_object(state->UTC, "bson.tz_util", "utc"); if (!utc_type || PyDict_SetItemString(kwargs, "tzinfo", utc_type) == -1) { Py_DECREF(replace); Py_DECREF(args); Py_DECREF(kwargs); Py_XDECREF(utc_type); goto invalid; } Py_XDECREF(utc_type); value = PyObject_Call(replace, args, kwargs); if (!value) { Py_DECREF(replace); Py_DECREF(args); Py_DECREF(kwargs); goto invalid; } /* convert to local time */ if (options->tzinfo != Py_None) { astimezone = PyObject_GetAttrString(value, "astimezone"); Py_DECREF(value); if (!astimezone) { Py_DECREF(replace); Py_DECREF(args); Py_DECREF(kwargs); goto invalid; } value = PyObject_CallFunctionObjArgs(astimezone, options->tzinfo, NULL); Py_DECREF(astimezone); } Py_DECREF(replace); Py_DECREF(args); Py_DECREF(kwargs); break; } case 11: { PyObject* regex_class; PyObject* pattern; int flags; size_t flags_length, i; size_t pattern_length = strlen(buffer + *position); if (pattern_length > BSON_MAX_SIZE || max < pattern_length) { goto invalid; } pattern = PyUnicode_DecodeUTF8( buffer + *position, pattern_length, options->unicode_decode_error_handler); if (!pattern) { goto invalid; } *position += (unsigned)pattern_length + 1; flags_length = strlen(buffer + *position); if (flags_length > BSON_MAX_SIZE || (BSON_MAX_SIZE - pattern_length) < flags_length) { Py_DECREF(pattern); goto invalid; } if (max < pattern_length + flags_length) { Py_DECREF(pattern); goto invalid; } flags = 0; for (i = 0; i < flags_length; i++) { if (buffer[*position + i] == 'i') { flags |= 2; } else if (buffer[*position + i] == 'l') { flags |= 4; } else if (buffer[*position + i] == 'm') { flags |= 8; } else if (buffer[*position + i] == 's') { flags |= 16; } else if (buffer[*position + i] == 'u') { flags |= 32; } else if (buffer[*position + i] == 'x') { flags |= 64; } } *position += (unsigned)flags_length + 1; regex_class = _get_object(state->Regex, "bson.regex", "Regex"); if (regex_class) { value = PyObject_CallFunction(regex_class, "Oi", pattern, flags); Py_DECREF(regex_class); } Py_DECREF(pattern); break; } case 12: { uint32_t coll_length; PyObject* collection; PyObject* id = NULL; PyObject* objectid_type; PyObject* dbref_type; if (max < 4) { goto invalid; } memcpy(&coll_length, buffer + *position, 4); coll_length = BSON_UINT32_FROM_LE(coll_length); /* Encoded string length + string + 12 byte ObjectId */ if (!coll_length || max < coll_length || max < 4 + coll_length + 12) { goto invalid; } *position += 4; /* Strings must end in \0 */ if (buffer[*position + coll_length - 1]) { goto invalid; } collection = PyUnicode_DecodeUTF8( buffer + *position, coll_length - 1, options->unicode_decode_error_handler); if (!collection) { goto invalid; } *position += coll_length; if ((objectid_type = _get_object(state->ObjectId, "bson.objectid", "ObjectId"))) { id = PyObject_CallFunction(objectid_type, BYTES_FORMAT_STRING, buffer + *position, (Py_ssize_t)12); Py_DECREF(objectid_type); } if (!id) { Py_DECREF(collection); goto invalid; } *position += 12; if ((dbref_type = _get_object(state->DBRef, "bson.dbref", "DBRef"))) { value = PyObject_CallFunctionObjArgs(dbref_type, collection, id, NULL); Py_DECREF(dbref_type); } Py_DECREF(collection); Py_DECREF(id); break; } case 13: { PyObject* code; PyObject* code_type; uint32_t value_length; if (max < 4) { goto invalid; } memcpy(&value_length, buffer + *position, 4); value_length = BSON_UINT32_FROM_LE(value_length); /* Encoded string length + string */ if (!value_length || max < value_length || max < 4 + value_length) { goto invalid; } *position += 4; /* Strings must end in \0 */ if (buffer[*position + value_length - 1]) { goto invalid; } code = PyUnicode_DecodeUTF8( buffer + *position, value_length - 1, options->unicode_decode_error_handler); if (!code) { goto invalid; } *position += value_length; if ((code_type = _get_object(state->Code, "bson.code", "Code"))) { value = PyObject_CallFunctionObjArgs(code_type, code, NULL, NULL); Py_DECREF(code_type); } Py_DECREF(code); break; } case 15: { uint32_t c_w_s_size; uint32_t code_size; uint32_t scope_size; PyObject* code; PyObject* scope; PyObject* code_type; if (max < 8) { goto invalid; } memcpy(&c_w_s_size, buffer + *position, 4); c_w_s_size = BSON_UINT32_FROM_LE(c_w_s_size); *position += 4; if (max < c_w_s_size) { goto invalid; } memcpy(&code_size, buffer + *position, 4); code_size = BSON_UINT32_FROM_LE(code_size); /* code_w_scope length + code length + code + scope length */ if (!code_size || max < code_size || max < 4 + 4 + code_size + 4) { goto invalid; } *position += 4; /* Strings must end in \0 */ if (buffer[*position + code_size - 1]) { goto invalid; } code = PyUnicode_DecodeUTF8( buffer + *position, code_size - 1, options->unicode_decode_error_handler); if (!code) { goto invalid; } *position += code_size; memcpy(&scope_size, buffer + *position, 4); scope_size = BSON_UINT32_FROM_LE(scope_size); if (scope_size < BSON_MIN_SIZE) { Py_DECREF(code); goto invalid; } /* code length + code + scope length + scope */ if ((4 + code_size + 4 + scope_size) != c_w_s_size) { Py_DECREF(code); goto invalid; } /* Check for bad eoo */ if (buffer[*position + scope_size - 1]) { goto invalid; } scope = elements_to_dict(self, buffer + *position + 4, scope_size - 5, options); if (!scope) { Py_DECREF(code); goto invalid; } *position += scope_size; if ((code_type = _get_object(state->Code, "bson.code", "Code"))) { value = PyObject_CallFunctionObjArgs(code_type, code, scope, NULL); Py_DECREF(code_type); } Py_DECREF(code); Py_DECREF(scope); break; } case 16: { int32_t i; if (max < 4) { goto invalid; } memcpy(&i, buffer + *position, 4); i = (int32_t)BSON_UINT32_FROM_LE(i); #if PY_MAJOR_VERSION >= 3 value = PyLong_FromLong(i); #else value = PyInt_FromLong(i); #endif if (!value) { goto invalid; } *position += 4; break; } case 17: { uint32_t time, inc; PyObject* timestamp_type; if (max < 8) { goto invalid; } memcpy(&inc, buffer + *position, 4); memcpy(&time, buffer + *position + 4, 4); inc = BSON_UINT32_FROM_LE(inc); time = BSON_UINT32_FROM_LE(time); if ((timestamp_type = _get_object(state->Timestamp, "bson.timestamp", "Timestamp"))) { value = PyObject_CallFunction(timestamp_type, "II", time, inc); Py_DECREF(timestamp_type); } *position += 8; break; } case 18: { int64_t ll; PyObject* bson_int64_type = _get_object(state->BSONInt64, "bson.int64", "Int64"); if (!bson_int64_type) goto invalid; if (max < 8) { Py_DECREF(bson_int64_type); goto invalid; } memcpy(&ll, buffer + *position, 8); ll = (int64_t)BSON_UINT64_FROM_LE(ll); value = PyObject_CallFunction(bson_int64_type, "L", ll); *position += 8; Py_DECREF(bson_int64_type); break; } case 19: { PyObject* dec128; if (max < 16) { goto invalid; } if ((dec128 = _get_object(state->Decimal128, "bson.decimal128", "Decimal128"))) { value = PyObject_CallMethod(dec128, "from_bid", BYTES_FORMAT_STRING, buffer + *position, (Py_ssize_t)16); Py_DECREF(dec128); } *position += 16; break; } case 255: { PyObject* minkey_type = _get_object(state->MinKey, "bson.min_key", "MinKey"); if (!minkey_type) goto invalid; value = PyObject_CallFunctionObjArgs(minkey_type, NULL); Py_DECREF(minkey_type); break; } case 127: { PyObject* maxkey_type = _get_object(state->MaxKey, "bson.max_key", "MaxKey"); if (!maxkey_type) goto invalid; value = PyObject_CallFunctionObjArgs(maxkey_type, NULL); Py_DECREF(maxkey_type); break; } default: { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyObject* bobj = PyBytes_FromFormat("%c", type); if (bobj) { PyObject* repr = PyObject_Repr(bobj); Py_DECREF(bobj); /* * See http://bugs.python.org/issue22023 for why we can't * just use PyUnicode_FromFormat with %S or %R to do this * work. */ if (repr) { PyObject* left = PyUnicode_FromString( "Detected unknown BSON type "); if (left) { PyObject* lmsg = PyUnicode_Concat(left, repr); Py_DECREF(left); if (lmsg) { PyObject* errmsg = PyUnicode_FromFormat( "%U for fieldname '%U'. Are you using the " "latest driver version?", lmsg, name); if (errmsg) { PyErr_SetObject(InvalidBSON, errmsg); Py_DECREF(errmsg); } Py_DECREF(lmsg); } } Py_DECREF(repr); } } Py_DECREF(InvalidBSON); } goto invalid; } } if (value) { if (!options->type_registry.is_decoder_empty) { PyObject* value_type = NULL; PyObject* converter = NULL; value_type = PyObject_Type(value); if (value_type == NULL) { goto invalid; } converter = PyDict_GetItem(options->type_registry.decoder_map, value_type); if (converter != NULL) { PyObject* new_value = PyObject_CallFunctionObjArgs(converter, value, NULL); Py_DECREF(value_type); Py_DECREF(value); return new_value; } else { Py_DECREF(value_type); return value; } } return value; } invalid: /* * Wrap any non-InvalidBSON errors in InvalidBSON. */ if (PyErr_Occurred()) { PyObject *etype, *evalue, *etrace; PyObject *InvalidBSON; /* * Calling _error clears the error state, so fetch it first. */ PyErr_Fetch(&etype, &evalue, &etrace); /* Dont reraise anything but PyExc_Exceptions as InvalidBSON. */ if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { if (!PyErr_GivenExceptionMatches(etype, InvalidBSON)) { /* * Raise InvalidBSON(str(e)). */ Py_DECREF(etype); etype = InvalidBSON; if (evalue) { PyObject *msg = PyObject_Str(evalue); Py_DECREF(evalue); evalue = msg; } PyErr_NormalizeException(&etype, &evalue, &etrace); } else { /* * The current exception matches InvalidBSON, so we don't * need this reference after all. */ Py_DECREF(InvalidBSON); } } } /* Steals references to args. */ PyErr_Restore(etype, evalue, etrace); } else { PyObject *InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "invalid length or type code"); Py_DECREF(InvalidBSON); } } return NULL; } /* * Get the next 'name' and 'value' from a document in a string, whose position * is provided. * * Returns the position of the next element in the document, or -1 on error. */ static int _element_to_dict(PyObject* self, const char* string, unsigned position, unsigned max, const codec_options_t* options, PyObject** name, PyObject** value) { unsigned char type = (unsigned char)string[position++]; size_t name_length = strlen(string + position); if (name_length > BSON_MAX_SIZE || position + name_length >= max) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetNone(InvalidBSON); Py_DECREF(InvalidBSON); } return -1; } *name = PyUnicode_DecodeUTF8( string + position, name_length, options->unicode_decode_error_handler); if (!*name) { /* If NULL is returned then wrap the UnicodeDecodeError in an InvalidBSON error */ PyObject *etype, *evalue, *etrace; PyObject *InvalidBSON; PyErr_Fetch(&etype, &evalue, &etrace); if (PyErr_GivenExceptionMatches(etype, PyExc_Exception)) { InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { Py_DECREF(etype); etype = InvalidBSON; if (evalue) { PyObject *msg = PyObject_Str(evalue); Py_DECREF(evalue); evalue = msg; } PyErr_NormalizeException(&etype, &evalue, &etrace); } } PyErr_Restore(etype, evalue, etrace); return -1; } position += (unsigned)name_length + 1; *value = get_value(self, *name, string, &position, type, max - position, options); if (!*value) { Py_DECREF(*name); return -1; } return position; } static PyObject* _cbson_element_to_dict(PyObject* self, PyObject* args) { /* TODO: Support buffer protocol */ char* string; PyObject* bson; codec_options_t options; unsigned position; unsigned max; int new_position; PyObject* name; PyObject* value; PyObject* result_tuple; if (!PyArg_ParseTuple(args, "OII|O&", &bson, &position, &max, convert_codec_options, &options)) { return NULL; } if (PyTuple_GET_SIZE(args) < 4) { if (!default_codec_options(GETSTATE(self), &options)) { return NULL; } } #if PY_MAJOR_VERSION >= 3 if (!PyBytes_Check(bson)) { PyErr_SetString(PyExc_TypeError, "argument to _element_to_dict must be a bytes object"); #else if (!PyString_Check(bson)) { PyErr_SetString(PyExc_TypeError, "argument to _element_to_dict must be a string"); #endif return NULL; } #if PY_MAJOR_VERSION >= 3 string = PyBytes_AS_STRING(bson); #else string = PyString_AS_STRING(bson); #endif new_position = _element_to_dict(self, string, position, max, &options, &name, &value); if (new_position < 0) { return NULL; } result_tuple = Py_BuildValue("NNi", name, value, new_position); if (!result_tuple) { Py_DECREF(name); Py_DECREF(value); return NULL; } return result_tuple; } static PyObject* _elements_to_dict(PyObject* self, const char* string, unsigned max, const codec_options_t* options) { unsigned position = 0; PyObject* dict = PyObject_CallObject(options->document_class, NULL); if (!dict) { return NULL; } while (position < max) { PyObject* name = NULL; PyObject* value = NULL; int new_position; new_position = _element_to_dict( self, string, position, max, options, &name, &value); if (new_position < 0) { Py_DECREF(dict); return NULL; } else { position = (unsigned)new_position; } PyObject_SetItem(dict, name, value); Py_DECREF(name); Py_DECREF(value); } return dict; } static PyObject* elements_to_dict(PyObject* self, const char* string, unsigned max, const codec_options_t* options) { PyObject* result; if (Py_EnterRecursiveCall(" while decoding a BSON document")) return NULL; result = _elements_to_dict(self, string, max, options); Py_LeaveRecursiveCall(); return result; } static int _get_buffer(PyObject *exporter, Py_buffer *view) { if (PyObject_GetBuffer(exporter, view, PyBUF_SIMPLE) == -1) { return 0; } if (!PyBuffer_IsContiguous(view, 'C')) { PyErr_SetString(PyExc_ValueError, "must be a contiguous buffer"); goto fail; } if (!view->buf || view->len < 0) { PyErr_SetString(PyExc_ValueError, "invalid buffer"); goto fail; } if (view->itemsize != 1) { PyErr_SetString(PyExc_ValueError, "buffer data must be ascii or utf8"); goto fail; } return 1; fail: PyBuffer_Release(view); return 0; } static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* args) { int32_t size; Py_ssize_t total_size; const char* string; PyObject* bson; codec_options_t options; PyObject* result = NULL; PyObject* options_obj; Py_buffer view; if (! (PyArg_ParseTuple(args, "OO", &bson, &options_obj) && convert_codec_options(options_obj, &options))) { return result; } if (!_get_buffer(bson, &view)) { destroy_codec_options(&options); return result; } total_size = view.len; if (total_size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "not enough data for a BSON document"); Py_DECREF(InvalidBSON); } goto done;; } string = (char*)view.buf; memcpy(&size, string, 4); size = (int32_t)BSON_UINT32_FROM_LE(size); if (size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "invalid message size"); Py_DECREF(InvalidBSON); } goto done; } if (total_size < size || total_size > BSON_MAX_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "objsize too large"); Py_DECREF(InvalidBSON); } goto done; } if (size != total_size || string[size - 1]) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "bad eoo"); Py_DECREF(InvalidBSON); } goto done; } /* No need to decode fields if using RawBSONDocument */ if (options.is_raw_bson) { result = PyObject_CallFunction( options.document_class, BYTES_FORMAT_STRING "O", string, (Py_ssize_t)size, options_obj); } else { result = elements_to_dict(self, string + 4, (unsigned)size - 5, &options); } done: PyBuffer_Release(&view); destroy_codec_options(&options); return result; } static PyObject* _cbson_decode_all(PyObject* self, PyObject* args) { int32_t size; Py_ssize_t total_size; const char* string; PyObject* bson; PyObject* dict; PyObject* result = NULL; codec_options_t options; PyObject* options_obj; Py_buffer view; if (!PyArg_ParseTuple(args, "O|O", &bson, &options_obj)) { return NULL; } if (PyTuple_GET_SIZE(args) < 2) { if (!default_codec_options(GETSTATE(self), &options)) { return NULL; } } else if (!convert_codec_options(options_obj, &options)) { return NULL; } if (!_get_buffer(bson, &view)) { destroy_codec_options(&options); return NULL; } total_size = view.len; string = (char*)view.buf; if (!(result = PyList_New(0))) { goto fail; } while (total_size > 0) { if (total_size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "not enough data for a BSON document"); Py_DECREF(InvalidBSON); } Py_DECREF(result); goto fail; } memcpy(&size, string, 4); size = (int32_t)BSON_UINT32_FROM_LE(size); if (size < BSON_MIN_SIZE) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "invalid message size"); Py_DECREF(InvalidBSON); } Py_DECREF(result); goto fail; } if (total_size < size) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "objsize too large"); Py_DECREF(InvalidBSON); } Py_DECREF(result); goto fail; } if (string[size - 1]) { PyObject* InvalidBSON = _error("InvalidBSON"); if (InvalidBSON) { PyErr_SetString(InvalidBSON, "bad eoo"); Py_DECREF(InvalidBSON); } Py_DECREF(result); goto fail; } /* No need to decode fields if using RawBSONDocument. */ if (options.is_raw_bson) { dict = PyObject_CallFunction( options.document_class, BYTES_FORMAT_STRING "O", string, (Py_ssize_t)size, options_obj); } else { dict = elements_to_dict(self, string + 4, (unsigned)size - 5, &options); } if (!dict) { Py_DECREF(result); goto fail; } if (PyList_Append(result, dict) < 0) { Py_DECREF(dict); Py_DECREF(result); goto fail; } Py_DECREF(dict); string += size; total_size -= size; } goto done; fail: result = NULL; done: PyBuffer_Release(&view); destroy_codec_options(&options); return result; } static PyMethodDef _CBSONMethods[] = { {"_dict_to_bson", _cbson_dict_to_bson, METH_VARARGS, "convert a dictionary to a string containing its BSON representation."}, {"_bson_to_dict", _cbson_bson_to_dict, METH_VARARGS, "convert a BSON string to a SON object."}, {"decode_all", _cbson_decode_all, METH_VARARGS, "convert binary data to a sequence of documents."}, {"_element_to_dict", _cbson_element_to_dict, METH_VARARGS, "Decode a single key, value pair."}, {NULL, NULL, 0, NULL} }; #if PY_MAJOR_VERSION >= 3 #define INITERROR return NULL static int _cbson_traverse(PyObject *m, visitproc visit, void *arg) { Py_VISIT(GETSTATE(m)->Binary); Py_VISIT(GETSTATE(m)->Code); Py_VISIT(GETSTATE(m)->ObjectId); Py_VISIT(GETSTATE(m)->DBRef); Py_VISIT(GETSTATE(m)->Regex); Py_VISIT(GETSTATE(m)->UUID); Py_VISIT(GETSTATE(m)->Timestamp); Py_VISIT(GETSTATE(m)->MinKey); Py_VISIT(GETSTATE(m)->MaxKey); Py_VISIT(GETSTATE(m)->UTC); Py_VISIT(GETSTATE(m)->REType); return 0; } static int _cbson_clear(PyObject *m) { Py_CLEAR(GETSTATE(m)->Binary); Py_CLEAR(GETSTATE(m)->Code); Py_CLEAR(GETSTATE(m)->ObjectId); Py_CLEAR(GETSTATE(m)->DBRef); Py_CLEAR(GETSTATE(m)->Regex); Py_CLEAR(GETSTATE(m)->UUID); Py_CLEAR(GETSTATE(m)->Timestamp); Py_CLEAR(GETSTATE(m)->MinKey); Py_CLEAR(GETSTATE(m)->MaxKey); Py_CLEAR(GETSTATE(m)->UTC); Py_CLEAR(GETSTATE(m)->REType); return 0; } static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "_cbson", NULL, sizeof(struct module_state), _CBSONMethods, NULL, _cbson_traverse, _cbson_clear, NULL }; PyMODINIT_FUNC PyInit__cbson(void) #else #define INITERROR return PyMODINIT_FUNC init_cbson(void) #endif { PyObject *m; PyObject *c_api_object; static void *_cbson_API[_cbson_API_POINTER_COUNT]; PyDateTime_IMPORT; if (PyDateTimeAPI == NULL) { INITERROR; } /* Export C API */ _cbson_API[_cbson_buffer_write_bytes_INDEX] = (void *) buffer_write_bytes; _cbson_API[_cbson_write_dict_INDEX] = (void *) write_dict; _cbson_API[_cbson_write_pair_INDEX] = (void *) write_pair; _cbson_API[_cbson_decode_and_write_pair_INDEX] = (void *) decode_and_write_pair; _cbson_API[_cbson_convert_codec_options_INDEX] = (void *) convert_codec_options; _cbson_API[_cbson_destroy_codec_options_INDEX] = (void *) destroy_codec_options; _cbson_API[_cbson_buffer_write_double_INDEX] = (void *) buffer_write_double; _cbson_API[_cbson_buffer_write_int32_INDEX] = (void *) buffer_write_int32; _cbson_API[_cbson_buffer_write_int64_INDEX] = (void *) buffer_write_int64; _cbson_API[_cbson_buffer_write_int32_at_position_INDEX] = (void *) buffer_write_int32_at_position; _cbson_API[_cbson_downcast_and_check_INDEX] = (void *) _downcast_and_check; #if PY_VERSION_HEX >= 0x03010000 /* PyCapsule is new in python 3.1 */ c_api_object = PyCapsule_New((void *) _cbson_API, "_cbson._C_API", NULL); #else c_api_object = PyCObject_FromVoidPtr((void *) _cbson_API, NULL); #endif if (c_api_object == NULL) INITERROR; #if PY_MAJOR_VERSION >= 3 m = PyModule_Create(&moduledef); #else m = Py_InitModule("_cbson", _CBSONMethods); #endif if (m == NULL) { Py_DECREF(c_api_object); INITERROR; } /* Import several python objects */ if (_load_python_objects(m)) { Py_DECREF(c_api_object); #if PY_MAJOR_VERSION >= 3 Py_DECREF(m); #endif INITERROR; } if (PyModule_AddObject(m, "_C_API", c_api_object) < 0) { Py_DECREF(c_api_object); #if PY_MAJOR_VERSION >= 3 Py_DECREF(m); #endif INITERROR; } #if PY_MAJOR_VERSION >= 3 return m; #endif } pymongo-3.11.0/bson/_cbsonmodule.h000066400000000000000000000170571374256237000170720ustar00rootroot00000000000000/* * Copyright 2009-present MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "bson-endian.h" #ifndef _CBSONMODULE_H #define _CBSONMODULE_H #if defined(WIN32) || defined(_MSC_VER) /* * This macro is basically an implementation of asprintf for win32 * We print to the provided buffer to get the string value as an int. */ #if defined(_MSC_VER) && (_MSC_VER >= 1400) #define INT2STRING(buffer, i) \ _snprintf_s((buffer), \ _scprintf("%d", (i)) + 1, \ _scprintf("%d", (i)) + 1, \ "%d", \ (i)) #define STRCAT(dest, n, src) strcat_s((dest), (n), (src)) #else #define INT2STRING(buffer, i) \ _snprintf((buffer), \ _scprintf("%d", (i)) + 1, \ "%d", \ (i)) #define STRCAT(dest, n, src) strcat((dest), (src)) #endif #else #define INT2STRING(buffer, i) snprintf((buffer), sizeof((buffer)), "%d", (i)) #define STRCAT(dest, n, src) strcat((dest), (src)) #endif #if PY_MAJOR_VERSION >= 3 #define BYTES_FORMAT_STRING "y#" #else #define BYTES_FORMAT_STRING "s#" #endif typedef struct type_registry_t { PyObject* encoder_map; PyObject* decoder_map; PyObject* fallback_encoder; PyObject* registry_obj; unsigned char is_encoder_empty; unsigned char is_decoder_empty; unsigned char has_fallback_encoder; } type_registry_t; typedef struct codec_options_t { PyObject* document_class; unsigned char tz_aware; unsigned char uuid_rep; char* unicode_decode_error_handler; PyObject* tzinfo; type_registry_t type_registry; PyObject* options_obj; unsigned char is_raw_bson; } codec_options_t; /* C API functions */ #define _cbson_buffer_write_bytes_INDEX 0 #define _cbson_buffer_write_bytes_RETURN int #define _cbson_buffer_write_bytes_PROTO (buffer_t buffer, const char* data, int size) #define _cbson_write_dict_INDEX 1 #define _cbson_write_dict_RETURN int #define _cbson_write_dict_PROTO (PyObject* self, buffer_t buffer, PyObject* dict, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) #define _cbson_write_pair_INDEX 2 #define _cbson_write_pair_RETURN int #define _cbson_write_pair_PROTO (PyObject* self, buffer_t buffer, const char* name, int name_length, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char allow_id) #define _cbson_decode_and_write_pair_INDEX 3 #define _cbson_decode_and_write_pair_RETURN int #define _cbson_decode_and_write_pair_PROTO (PyObject* self, buffer_t buffer, PyObject* key, PyObject* value, unsigned char check_keys, const codec_options_t* options, unsigned char top_level) #define _cbson_convert_codec_options_INDEX 4 #define _cbson_convert_codec_options_RETURN int #define _cbson_convert_codec_options_PROTO (PyObject* options_obj, void* p) #define _cbson_destroy_codec_options_INDEX 5 #define _cbson_destroy_codec_options_RETURN void #define _cbson_destroy_codec_options_PROTO (codec_options_t* options) #define _cbson_buffer_write_double_INDEX 6 #define _cbson_buffer_write_double_RETURN int #define _cbson_buffer_write_double_PROTO (buffer_t buffer, double data) #define _cbson_buffer_write_int32_INDEX 7 #define _cbson_buffer_write_int32_RETURN int #define _cbson_buffer_write_int32_PROTO (buffer_t buffer, int32_t data) #define _cbson_buffer_write_int64_INDEX 8 #define _cbson_buffer_write_int64_RETURN int #define _cbson_buffer_write_int64_PROTO (buffer_t buffer, int64_t data) #define _cbson_buffer_write_int32_at_position_INDEX 9 #define _cbson_buffer_write_int32_at_position_RETURN void #define _cbson_buffer_write_int32_at_position_PROTO (buffer_t buffer, int position, int32_t data) #define _cbson_downcast_and_check_INDEX 10 #define _cbson_downcast_and_check_RETURN int #define _cbson_downcast_and_check_PROTO (Py_ssize_t size, uint8_t extra) /* Total number of C API pointers */ #define _cbson_API_POINTER_COUNT 11 #ifdef _CBSON_MODULE /* This section is used when compiling _cbsonmodule */ static _cbson_buffer_write_bytes_RETURN buffer_write_bytes _cbson_buffer_write_bytes_PROTO; static _cbson_write_dict_RETURN write_dict _cbson_write_dict_PROTO; static _cbson_write_pair_RETURN write_pair _cbson_write_pair_PROTO; static _cbson_decode_and_write_pair_RETURN decode_and_write_pair _cbson_decode_and_write_pair_PROTO; static _cbson_convert_codec_options_RETURN convert_codec_options _cbson_convert_codec_options_PROTO; static _cbson_destroy_codec_options_RETURN destroy_codec_options _cbson_destroy_codec_options_PROTO; static _cbson_buffer_write_double_RETURN buffer_write_double _cbson_buffer_write_double_PROTO; static _cbson_buffer_write_int32_RETURN buffer_write_int32 _cbson_buffer_write_int32_PROTO; static _cbson_buffer_write_int64_RETURN buffer_write_int64 _cbson_buffer_write_int64_PROTO; static _cbson_buffer_write_int32_at_position_RETURN buffer_write_int32_at_position _cbson_buffer_write_int32_at_position_PROTO; static _cbson_downcast_and_check_RETURN _downcast_and_check _cbson_downcast_and_check_PROTO; #else /* This section is used in modules that use _cbsonmodule's API */ static void **_cbson_API; #define buffer_write_bytes (*(_cbson_buffer_write_bytes_RETURN (*)_cbson_buffer_write_bytes_PROTO) _cbson_API[_cbson_buffer_write_bytes_INDEX]) #define write_dict (*(_cbson_write_dict_RETURN (*)_cbson_write_dict_PROTO) _cbson_API[_cbson_write_dict_INDEX]) #define write_pair (*(_cbson_write_pair_RETURN (*)_cbson_write_pair_PROTO) _cbson_API[_cbson_write_pair_INDEX]) #define decode_and_write_pair (*(_cbson_decode_and_write_pair_RETURN (*)_cbson_decode_and_write_pair_PROTO) _cbson_API[_cbson_decode_and_write_pair_INDEX]) #define convert_codec_options (*(_cbson_convert_codec_options_RETURN (*)_cbson_convert_codec_options_PROTO) _cbson_API[_cbson_convert_codec_options_INDEX]) #define destroy_codec_options (*(_cbson_destroy_codec_options_RETURN (*)_cbson_destroy_codec_options_PROTO) _cbson_API[_cbson_destroy_codec_options_INDEX]) #define buffer_write_double (*(_cbson_buffer_write_double_RETURN (*)_cbson_buffer_write_double_PROTO) _cbson_API[_cbson_buffer_write_double_INDEX]) #define buffer_write_int32 (*(_cbson_buffer_write_int32_RETURN (*)_cbson_buffer_write_int32_PROTO) _cbson_API[_cbson_buffer_write_int32_INDEX]) #define buffer_write_int64 (*(_cbson_buffer_write_int64_RETURN (*)_cbson_buffer_write_int64_PROTO) _cbson_API[_cbson_buffer_write_int64_INDEX]) #define buffer_write_int32_at_position (*(_cbson_buffer_write_int32_at_position_RETURN (*)_cbson_buffer_write_int32_at_position_PROTO) _cbson_API[_cbson_buffer_write_int32_at_position_INDEX]) #define _downcast_and_check (*(_cbson_downcast_and_check_RETURN (*)_cbson_downcast_and_check_PROTO) _cbson_API[_cbson_downcast_and_check_INDEX]) #define _cbson_IMPORT _cbson_API = (void **)PyCapsule_Import("_cbson._C_API", 0) #endif #endif // _CBSONMODULE_H pymongo-3.11.0/bson/binary.py000066400000000000000000000353071374256237000161040ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from uuid import UUID from warnings import warn from bson.py3compat import PY3 """Tools for representing BSON binary data. """ BINARY_SUBTYPE = 0 """BSON binary subtype for binary data. This is the default subtype for binary data. """ FUNCTION_SUBTYPE = 1 """BSON binary subtype for functions. """ OLD_BINARY_SUBTYPE = 2 """Old BSON binary subtype for binary data. This is the old default subtype, the current default is :data:`BINARY_SUBTYPE`. """ OLD_UUID_SUBTYPE = 3 """Old BSON binary subtype for a UUID. :class:`uuid.UUID` instances will automatically be encoded by :mod:`bson` using this subtype. .. versionadded:: 2.1 """ UUID_SUBTYPE = 4 """BSON binary subtype for a UUID. This is the new BSON binary subtype for UUIDs. The current default is :data:`OLD_UUID_SUBTYPE`. .. versionchanged:: 2.1 Changed to subtype 4. """ class UuidRepresentation: UNSPECIFIED = 0 """An unspecified UUID representation. When configured, :class:`uuid.UUID` instances will **not** be automatically encoded to or decoded from :class:`~bson.binary.Binary`. When encoding a :class:`uuid.UUID` instance, an error will be raised. To encode a :class:`uuid.UUID` instance with this configuration, it must be wrapped in the :class:`~bson.binary.Binary` class by the application code. When decoding a BSON binary field with a UUID subtype, a :class:`~bson.binary.Binary` instance will be returned instead of a :class:`uuid.UUID` instance. See :ref:`unspecified-representation-details` for details. .. versionadded:: 3.11 """ STANDARD = UUID_SUBTYPE """The standard UUID representation. :class:`uuid.UUID` instances will automatically be encoded to and decoded from BSON binary, using RFC-4122 byte order with binary subtype :data:`UUID_SUBTYPE`. See :ref:`standard-representation-details` for details. .. versionadded:: 3.11 """ PYTHON_LEGACY = OLD_UUID_SUBTYPE """The Python legacy UUID representation. :class:`uuid.UUID` instances will automatically be encoded to and decoded from BSON binary, using RFC-4122 byte order with binary subtype :data:`OLD_UUID_SUBTYPE`. See :ref:`python-legacy-representation-details` for details. .. versionadded:: 3.11 """ JAVA_LEGACY = 5 """The Java legacy UUID representation. :class:`uuid.UUID` instances will automatically be encoded to and decoded from BSON binary subtype :data:`OLD_UUID_SUBTYPE`, using the Java driver's legacy byte order. See :ref:`java-legacy-representation-details` for details. .. versionadded:: 3.11 """ CSHARP_LEGACY = 6 """The C#/.net legacy UUID representation. :class:`uuid.UUID` instances will automatically be encoded to and decoded from BSON binary subtype :data:`OLD_UUID_SUBTYPE`, using the C# driver's legacy byte order. See :ref:`csharp-legacy-representation-details` for details. .. versionadded:: 3.11 """ STANDARD = UuidRepresentation.STANDARD """An alias for :data:`UuidRepresentation.STANDARD`. .. versionadded:: 3.0 """ PYTHON_LEGACY = UuidRepresentation.PYTHON_LEGACY """An alias for :data:`UuidRepresentation.PYTHON_LEGACY`. .. versionadded:: 3.0 """ JAVA_LEGACY = UuidRepresentation.JAVA_LEGACY """An alias for :data:`UuidRepresentation.JAVA_LEGACY`. .. versionchanged:: 3.6 BSON binary subtype 4 is decoded using RFC-4122 byte order. .. versionadded:: 2.3 """ CSHARP_LEGACY = UuidRepresentation.CSHARP_LEGACY """An alias for :data:`UuidRepresentation.CSHARP_LEGACY`. .. versionchanged:: 3.6 BSON binary subtype 4 is decoded using RFC-4122 byte order. .. versionadded:: 2.3 """ ALL_UUID_SUBTYPES = (OLD_UUID_SUBTYPE, UUID_SUBTYPE) ALL_UUID_REPRESENTATIONS = (UuidRepresentation.UNSPECIFIED, UuidRepresentation.STANDARD, UuidRepresentation.PYTHON_LEGACY, UuidRepresentation.JAVA_LEGACY, UuidRepresentation.CSHARP_LEGACY) UUID_REPRESENTATION_NAMES = { UuidRepresentation.UNSPECIFIED: 'UuidRepresentation.UNSPECIFIED', UuidRepresentation.STANDARD: 'UuidRepresentation.STANDARD', UuidRepresentation.PYTHON_LEGACY: 'UuidRepresentation.PYTHON_LEGACY', UuidRepresentation.JAVA_LEGACY: 'UuidRepresentation.JAVA_LEGACY', UuidRepresentation.CSHARP_LEGACY: 'UuidRepresentation.CSHARP_LEGACY'} MD5_SUBTYPE = 5 """BSON binary subtype for an MD5 hash. """ USER_DEFINED_SUBTYPE = 128 """BSON binary subtype for any user defined structure. """ class Binary(bytes): """Representation of BSON binary data. This is necessary because we want to represent Python strings as the BSON string type. We need to wrap binary data so we can tell the difference between what should be considered binary data and what should be considered a string when we encode to BSON. Raises TypeError if `data` is not an instance of :class:`bytes` (:class:`str` in python 2) or `subtype` is not an instance of :class:`int`. Raises ValueError if `subtype` is not in [0, 256). .. note:: In python 3 instances of Binary with subtype 0 will be decoded directly to :class:`bytes`. :Parameters: - `data`: the binary data to represent. Can be any bytes-like type that implements the buffer protocol. - `subtype` (optional): the `binary subtype `_ to use .. versionchanged:: 3.9 Support any bytes-like type that implements the buffer protocol. """ _type_marker = 5 def __new__(cls, data, subtype=BINARY_SUBTYPE): if not isinstance(subtype, int): raise TypeError("subtype must be an instance of int") if subtype >= 256 or subtype < 0: raise ValueError("subtype must be contained in [0, 256)") # Support any type that implements the buffer protocol. self = bytes.__new__(cls, memoryview(data).tobytes()) self.__subtype = subtype return self @classmethod def from_uuid(cls, uuid, uuid_representation=UuidRepresentation.STANDARD): """Create a BSON Binary object from a Python UUID. Creates a :class:`~bson.binary.Binary` object from a :class:`uuid.UUID` instance. Assumes that the native :class:`uuid.UUID` instance uses the byte-order implied by the provided ``uuid_representation``. Raises :exc:`TypeError` if `uuid` is not an instance of :class:`~uuid.UUID`. :Parameters: - `uuid`: A :class:`uuid.UUID` instance. - `uuid_representation`: A member of :class:`~bson.binary.UuidRepresentation`. Default: :const:`~bson.binary.UuidRepresentation.STANDARD`. See :ref:`handling-uuid-data-example` for details. .. versionadded:: 3.11 """ if not isinstance(uuid, UUID): raise TypeError("uuid must be an instance of uuid.UUID") if uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError("uuid_representation must be a value " "from bson.binary.UuidRepresentation") if uuid_representation == UuidRepresentation.UNSPECIFIED: raise ValueError( "cannot encode native uuid.UUID with " "UuidRepresentation.UNSPECIFIED. UUIDs can be manually " "converted to bson.Binary instances using " "bson.Binary.from_uuid() or a different UuidRepresentation " "can be configured. See the documentation for " "UuidRepresentation for more information.") subtype = OLD_UUID_SUBTYPE if uuid_representation == UuidRepresentation.PYTHON_LEGACY: payload = uuid.bytes elif uuid_representation == UuidRepresentation.JAVA_LEGACY: from_uuid = uuid.bytes payload = from_uuid[0:8][::-1] + from_uuid[8:16][::-1] elif uuid_representation == UuidRepresentation.CSHARP_LEGACY: payload = uuid.bytes_le else: # uuid_representation == UuidRepresentation.STANDARD subtype = UUID_SUBTYPE payload = uuid.bytes return cls(payload, subtype) def as_uuid(self, uuid_representation=UuidRepresentation.STANDARD): """Create a Python UUID from this BSON Binary object. Decodes this binary object as a native :class:`uuid.UUID` instance with the provided ``uuid_representation``. Raises :exc:`ValueError` if this :class:`~bson.binary.Binary` instance does not contain a UUID. :Parameters: - `uuid_representation`: A member of :class:`~bson.binary.UuidRepresentation`. Default: :const:`~bson.binary.UuidRepresentation.STANDARD`. See :ref:`handling-uuid-data-example` for details. .. versionadded:: 3.11 """ if self.subtype not in ALL_UUID_SUBTYPES: raise ValueError("cannot decode subtype %s as a uuid" % ( self.subtype,)) if uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError("uuid_representation must be a value from " "bson.binary.UuidRepresentation") if uuid_representation == UuidRepresentation.UNSPECIFIED: raise ValueError("uuid_representation cannot be UNSPECIFIED") elif uuid_representation == UuidRepresentation.PYTHON_LEGACY: if self.subtype == OLD_UUID_SUBTYPE: return UUID(bytes=self) elif uuid_representation == UuidRepresentation.JAVA_LEGACY: if self.subtype == OLD_UUID_SUBTYPE: return UUID(bytes=self[0:8][::-1] + self[8:16][::-1]) elif uuid_representation == UuidRepresentation.CSHARP_LEGACY: if self.subtype == OLD_UUID_SUBTYPE: return UUID(bytes_le=self) else: # uuid_representation == UuidRepresentation.STANDARD if self.subtype == UUID_SUBTYPE: return UUID(bytes=self) raise ValueError("cannot decode subtype %s to %s" % ( self.subtype, UUID_REPRESENTATION_NAMES[uuid_representation])) @property def subtype(self): """Subtype of this binary data. """ return self.__subtype def __getnewargs__(self): # Work around http://bugs.python.org/issue7382 data = super(Binary, self).__getnewargs__()[0] if PY3 and not isinstance(data, bytes): data = data.encode('latin-1') return data, self.__subtype def __eq__(self, other): if isinstance(other, Binary): return ((self.__subtype, bytes(self)) == (other.subtype, bytes(other))) # We don't return NotImplemented here because if we did then # Binary("foo") == "foo" would return True, since Binary is a # subclass of str... return False def __hash__(self): return super(Binary, self).__hash__() ^ hash(self.__subtype) def __ne__(self, other): return not self == other def __repr__(self): return "Binary(%s, %s)" % (bytes.__repr__(self), self.__subtype) class UUIDLegacy(Binary): """**DEPRECATED** - UUID wrapper to support working with UUIDs stored as PYTHON_LEGACY. .. note:: This class has been deprecated and will be removed in PyMongo 4.0. Use :meth:`~bson.binary.Binary.from_uuid` and :meth:`~bson.binary.Binary.as_uuid` with the appropriate :class:`~bson.binary.UuidRepresentation` to handle legacy-formatted UUIDs instead.:: from bson import Binary, UUIDLegacy, UuidRepresentation import uuid my_uuid = uuid.uuid4() legacy_uuid = UUIDLegacy(my_uuid) binary_uuid = Binary.from_uuid( my_uuid, UuidRepresentation.PYTHON_LEGACY) assert legacy_uuid == binary_uuid assert legacy_uuid.uuid == binary_uuid.as_uuid( UuidRepresentation.PYTHON_LEGACY) .. doctest:: >>> import uuid >>> from bson.binary import Binary, UUIDLegacy, STANDARD >>> from bson.codec_options import CodecOptions >>> my_uuid = uuid.uuid4() >>> coll = db.get_collection('test', ... CodecOptions(uuid_representation=STANDARD)) >>> coll.insert_one({'uuid': Binary(my_uuid.bytes, 3)}).inserted_id ObjectId('...') >>> coll.count_documents({'uuid': my_uuid}) 0 >>> coll.count_documents({'uuid': UUIDLegacy(my_uuid)}) 1 >>> coll.find({'uuid': UUIDLegacy(my_uuid)})[0]['uuid'] UUID('...') >>> >>> # Convert from subtype 3 to subtype 4 >>> doc = coll.find_one({'uuid': UUIDLegacy(my_uuid)}) >>> coll.replace_one({"_id": doc["_id"]}, doc).matched_count 1 >>> coll.count_documents({'uuid': UUIDLegacy(my_uuid)}) 0 >>> coll.count_documents({'uuid': {'$in': [UUIDLegacy(my_uuid), my_uuid]}}) 1 >>> coll.find_one({'uuid': my_uuid})['uuid'] UUID('...') Raises :exc:`TypeError` if `obj` is not an instance of :class:`~uuid.UUID`. :Parameters: - `obj`: An instance of :class:`~uuid.UUID`. .. versionchanged:: 3.11 Deprecated. The same functionality can be replicated using the :meth:`~Binary.from_uuid` and :meth:`~Binary.to_uuid` methods with :data:`~UuidRepresentation.PYTHON_LEGACY`. .. versionadded:: 2.1 """ def __new__(cls, obj): warn( "The UUIDLegacy class has been deprecated and will be removed " "in PyMongo 4.0. Use the Binary.from_uuid() and Binary.to_uuid() " "with the appropriate UuidRepresentation to handle " "legacy-formatted UUIDs instead.", DeprecationWarning, stacklevel=2) if not isinstance(obj, UUID): raise TypeError("obj must be an instance of uuid.UUID") self = Binary.__new__(cls, obj.bytes, OLD_UUID_SUBTYPE) self.__uuid = obj return self def __getnewargs__(self): # Support copy and deepcopy return (self.__uuid,) @property def uuid(self): """UUID instance wrapped by this UUIDLegacy instance. """ return self.__uuid def __repr__(self): return "UUIDLegacy('%s')" % self.__uuid pymongo-3.11.0/bson/bson-endian.h000066400000000000000000000147151374256237000166140ustar00rootroot00000000000000/* * Copyright 2013-2016 MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef BSON_ENDIAN_H #define BSON_ENDIAN_H #if defined(__sun) # include #endif #ifdef _MSC_VER # include "bson-stdint-win32.h" # define BSON_INLINE __inline #else # include # define BSON_INLINE __inline__ #endif #define BSON_BIG_ENDIAN 4321 #define BSON_LITTLE_ENDIAN 1234 /* WORDS_BIGENDIAN from pyconfig.h / Python.h */ #ifdef WORDS_BIGENDIAN # define BSON_BYTE_ORDER BSON_BIG_ENDIAN #else # define BSON_BYTE_ORDER BSON_LITTLE_ENDIAN #endif #if defined(__sun) # define BSON_UINT16_SWAP_LE_BE(v) BSWAP_16((uint16_t)v) # define BSON_UINT32_SWAP_LE_BE(v) BSWAP_32((uint32_t)v) # define BSON_UINT64_SWAP_LE_BE(v) BSWAP_64((uint64_t)v) #elif defined(__clang__) && defined(__clang_major__) && defined(__clang_minor__) && \ (__clang_major__ >= 3) && (__clang_minor__ >= 1) # if __has_builtin(__builtin_bswap16) # define BSON_UINT16_SWAP_LE_BE(v) __builtin_bswap16(v) # endif # if __has_builtin(__builtin_bswap32) # define BSON_UINT32_SWAP_LE_BE(v) __builtin_bswap32(v) # endif # if __has_builtin(__builtin_bswap64) # define BSON_UINT64_SWAP_LE_BE(v) __builtin_bswap64(v) # endif #elif defined(__GNUC__) && (__GNUC__ >= 4) # if __GNUC__ >= 4 && defined (__GNUC_MINOR__) && __GNUC_MINOR__ >= 3 # define BSON_UINT32_SWAP_LE_BE(v) __builtin_bswap32 ((uint32_t)v) # define BSON_UINT64_SWAP_LE_BE(v) __builtin_bswap64 ((uint64_t)v) # endif # if __GNUC__ >= 4 && defined (__GNUC_MINOR__) && __GNUC_MINOR__ >= 8 # define BSON_UINT16_SWAP_LE_BE(v) __builtin_bswap16 ((uint32_t)v) # endif #endif #ifndef BSON_UINT16_SWAP_LE_BE # define BSON_UINT16_SWAP_LE_BE(v) __bson_uint16_swap_slow ((uint16_t)v) #endif #ifndef BSON_UINT32_SWAP_LE_BE # define BSON_UINT32_SWAP_LE_BE(v) __bson_uint32_swap_slow ((uint32_t)v) #endif #ifndef BSON_UINT64_SWAP_LE_BE # define BSON_UINT64_SWAP_LE_BE(v) __bson_uint64_swap_slow ((uint64_t)v) #endif #if BSON_BYTE_ORDER == BSON_LITTLE_ENDIAN # define BSON_UINT16_FROM_LE(v) ((uint16_t)v) # define BSON_UINT16_TO_LE(v) ((uint16_t)v) # define BSON_UINT16_FROM_BE(v) BSON_UINT16_SWAP_LE_BE (v) # define BSON_UINT16_TO_BE(v) BSON_UINT16_SWAP_LE_BE (v) # define BSON_UINT32_FROM_LE(v) ((uint32_t)v) # define BSON_UINT32_TO_LE(v) ((uint32_t)v) # define BSON_UINT32_FROM_BE(v) BSON_UINT32_SWAP_LE_BE (v) # define BSON_UINT32_TO_BE(v) BSON_UINT32_SWAP_LE_BE (v) # define BSON_UINT64_FROM_LE(v) ((uint64_t)v) # define BSON_UINT64_TO_LE(v) ((uint64_t)v) # define BSON_UINT64_FROM_BE(v) BSON_UINT64_SWAP_LE_BE (v) # define BSON_UINT64_TO_BE(v) BSON_UINT64_SWAP_LE_BE (v) # define BSON_DOUBLE_FROM_LE(v) ((double)v) # define BSON_DOUBLE_TO_LE(v) ((double)v) #elif BSON_BYTE_ORDER == BSON_BIG_ENDIAN # define BSON_UINT16_FROM_LE(v) BSON_UINT16_SWAP_LE_BE (v) # define BSON_UINT16_TO_LE(v) BSON_UINT16_SWAP_LE_BE (v) # define BSON_UINT16_FROM_BE(v) ((uint16_t)v) # define BSON_UINT16_TO_BE(v) ((uint16_t)v) # define BSON_UINT32_FROM_LE(v) BSON_UINT32_SWAP_LE_BE (v) # define BSON_UINT32_TO_LE(v) BSON_UINT32_SWAP_LE_BE (v) # define BSON_UINT32_FROM_BE(v) ((uint32_t)v) # define BSON_UINT32_TO_BE(v) ((uint32_t)v) # define BSON_UINT64_FROM_LE(v) BSON_UINT64_SWAP_LE_BE (v) # define BSON_UINT64_TO_LE(v) BSON_UINT64_SWAP_LE_BE (v) # define BSON_UINT64_FROM_BE(v) ((uint64_t)v) # define BSON_UINT64_TO_BE(v) ((uint64_t)v) # define BSON_DOUBLE_FROM_LE(v) (__bson_double_swap_slow (v)) # define BSON_DOUBLE_TO_LE(v) (__bson_double_swap_slow (v)) #else # error "The endianness of target architecture is unknown." #endif /* *-------------------------------------------------------------------------- * * __bson_uint16_swap_slow -- * * Fallback endianness conversion for 16-bit integers. * * Returns: * The endian swapped version. * * Side effects: * None. * *-------------------------------------------------------------------------- */ static BSON_INLINE uint16_t __bson_uint16_swap_slow (uint16_t v) /* IN */ { return ((v & 0x00FF) << 8) | ((v & 0xFF00) >> 8); } /* *-------------------------------------------------------------------------- * * __bson_uint32_swap_slow -- * * Fallback endianness conversion for 32-bit integers. * * Returns: * The endian swapped version. * * Side effects: * None. * *-------------------------------------------------------------------------- */ static BSON_INLINE uint32_t __bson_uint32_swap_slow (uint32_t v) /* IN */ { return ((v & 0x000000FFU) << 24) | ((v & 0x0000FF00U) << 8) | ((v & 0x00FF0000U) >> 8) | ((v & 0xFF000000U) >> 24); } /* *-------------------------------------------------------------------------- * * __bson_uint64_swap_slow -- * * Fallback endianness conversion for 64-bit integers. * * Returns: * The endian swapped version. * * Side effects: * None. * *-------------------------------------------------------------------------- */ static BSON_INLINE uint64_t __bson_uint64_swap_slow (uint64_t v) /* IN */ { return ((v & 0x00000000000000FFULL) << 56) | ((v & 0x000000000000FF00ULL) << 40) | ((v & 0x0000000000FF0000ULL) << 24) | ((v & 0x00000000FF000000ULL) << 8) | ((v & 0x000000FF00000000ULL) >> 8) | ((v & 0x0000FF0000000000ULL) >> 24) | ((v & 0x00FF000000000000ULL) >> 40) | ((v & 0xFF00000000000000ULL) >> 56); } /* *-------------------------------------------------------------------------- * * __bson_double_swap_slow -- * * Fallback endianness conversion for double floating point. * * Returns: * The endian swapped version. * * Side effects: * None. * *-------------------------------------------------------------------------- */ static BSON_INLINE double __bson_double_swap_slow (double v) /* IN */ { uint64_t uv; memcpy(&uv, &v, sizeof(v)); uv = BSON_UINT64_SWAP_LE_BE(uv); memcpy(&v, &uv, sizeof(v)); return v; } #endif /* BSON_ENDIAN_H */ pymongo-3.11.0/bson/bson-stdint-win32.h000066400000000000000000000176341374256237000176260ustar00rootroot00000000000000// ISO C9x compliant stdint.h for Microsoft Visual Studio // Based on ISO/IEC 9899:TC2 Committee draft (May 6, 2005) WG14/N1124 // // Copyright (c) 2006-2013 Alexander Chemeris // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, // this list of conditions and the following disclaimer. // // 2. Redistributions in binary form must reproduce the above copyright // notice, this list of conditions and the following disclaimer in the // documentation and/or other materials provided with the distribution. // // 3. Neither the name of the product nor the names of its contributors may // be used to endorse or promote products derived from this software // without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED // WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF // MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO // EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; // OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR // OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF // ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // /////////////////////////////////////////////////////////////////////////////// #ifndef _MSC_VER // [ #error "Use this header only with Microsoft Visual C++ compilers!" #endif // _MSC_VER ] #ifndef _MSC_STDINT_H_ // [ #define _MSC_STDINT_H_ #if _MSC_VER > 1000 #pragma once #endif #if _MSC_VER >= 1600 // [ #include #else // ] _MSC_VER >= 1600 [ #include // For Visual Studio 6 in C++ mode and for many Visual Studio versions when // compiling for ARM we should wrap include with 'extern "C++" {}' // or compiler give many errors like this: // error C2733: second C linkage of overloaded function 'wmemchr' not allowed #ifdef __cplusplus extern "C" { #endif # include #ifdef __cplusplus } #endif // Define _W64 macros to mark types changing their size, like intptr_t. #ifndef _W64 # if !defined(__midl) && (defined(_X86_) || defined(_M_IX86)) && _MSC_VER >= 1300 # define _W64 __w64 # else # define _W64 # endif #endif // 7.18.1 Integer types // 7.18.1.1 Exact-width integer types // Visual Studio 6 and Embedded Visual C++ 4 doesn't // realize that, e.g. char has the same size as __int8 // so we give up on __intX for them. #if (_MSC_VER < 1300) typedef signed char int8_t; typedef signed short int16_t; typedef signed int int32_t; typedef unsigned char uint8_t; typedef unsigned short uint16_t; typedef unsigned int uint32_t; #else typedef signed __int8 int8_t; typedef signed __int16 int16_t; typedef signed __int32 int32_t; typedef unsigned __int8 uint8_t; typedef unsigned __int16 uint16_t; typedef unsigned __int32 uint32_t; #endif typedef signed __int64 int64_t; typedef unsigned __int64 uint64_t; // 7.18.1.2 Minimum-width integer types typedef int8_t int_least8_t; typedef int16_t int_least16_t; typedef int32_t int_least32_t; typedef int64_t int_least64_t; typedef uint8_t uint_least8_t; typedef uint16_t uint_least16_t; typedef uint32_t uint_least32_t; typedef uint64_t uint_least64_t; // 7.18.1.3 Fastest minimum-width integer types typedef int8_t int_fast8_t; typedef int16_t int_fast16_t; typedef int32_t int_fast32_t; typedef int64_t int_fast64_t; typedef uint8_t uint_fast8_t; typedef uint16_t uint_fast16_t; typedef uint32_t uint_fast32_t; typedef uint64_t uint_fast64_t; // 7.18.1.4 Integer types capable of holding object pointers #ifdef _WIN64 // [ typedef signed __int64 intptr_t; typedef unsigned __int64 uintptr_t; #else // _WIN64 ][ typedef _W64 signed int intptr_t; typedef _W64 unsigned int uintptr_t; #endif // _WIN64 ] // 7.18.1.5 Greatest-width integer types typedef int64_t intmax_t; typedef uint64_t uintmax_t; // 7.18.2 Limits of specified-width integer types #if !defined(__cplusplus) || defined(__STDC_LIMIT_MACROS) // [ See footnote 220 at page 257 and footnote 221 at page 259 // 7.18.2.1 Limits of exact-width integer types #define INT8_MIN ((int8_t)_I8_MIN) #define INT8_MAX _I8_MAX #define INT16_MIN ((int16_t)_I16_MIN) #define INT16_MAX _I16_MAX #define INT32_MIN ((int32_t)_I32_MIN) #define INT32_MAX _I32_MAX #define INT64_MIN ((int64_t)_I64_MIN) #define INT64_MAX _I64_MAX #define UINT8_MAX _UI8_MAX #define UINT16_MAX _UI16_MAX #define UINT32_MAX _UI32_MAX #define UINT64_MAX _UI64_MAX // 7.18.2.2 Limits of minimum-width integer types #define INT_LEAST8_MIN INT8_MIN #define INT_LEAST8_MAX INT8_MAX #define INT_LEAST16_MIN INT16_MIN #define INT_LEAST16_MAX INT16_MAX #define INT_LEAST32_MIN INT32_MIN #define INT_LEAST32_MAX INT32_MAX #define INT_LEAST64_MIN INT64_MIN #define INT_LEAST64_MAX INT64_MAX #define UINT_LEAST8_MAX UINT8_MAX #define UINT_LEAST16_MAX UINT16_MAX #define UINT_LEAST32_MAX UINT32_MAX #define UINT_LEAST64_MAX UINT64_MAX // 7.18.2.3 Limits of fastest minimum-width integer types #define INT_FAST8_MIN INT8_MIN #define INT_FAST8_MAX INT8_MAX #define INT_FAST16_MIN INT16_MIN #define INT_FAST16_MAX INT16_MAX #define INT_FAST32_MIN INT32_MIN #define INT_FAST32_MAX INT32_MAX #define INT_FAST64_MIN INT64_MIN #define INT_FAST64_MAX INT64_MAX #define UINT_FAST8_MAX UINT8_MAX #define UINT_FAST16_MAX UINT16_MAX #define UINT_FAST32_MAX UINT32_MAX #define UINT_FAST64_MAX UINT64_MAX // 7.18.2.4 Limits of integer types capable of holding object pointers #ifdef _WIN64 // [ # define INTPTR_MIN INT64_MIN # define INTPTR_MAX INT64_MAX # define UINTPTR_MAX UINT64_MAX #else // _WIN64 ][ # define INTPTR_MIN INT32_MIN # define INTPTR_MAX INT32_MAX # define UINTPTR_MAX UINT32_MAX #endif // _WIN64 ] // 7.18.2.5 Limits of greatest-width integer types #define INTMAX_MIN INT64_MIN #define INTMAX_MAX INT64_MAX #define UINTMAX_MAX UINT64_MAX // 7.18.3 Limits of other integer types #ifdef _WIN64 // [ # define PTRDIFF_MIN _I64_MIN # define PTRDIFF_MAX _I64_MAX #else // _WIN64 ][ # define PTRDIFF_MIN _I32_MIN # define PTRDIFF_MAX _I32_MAX #endif // _WIN64 ] #define SIG_ATOMIC_MIN INT_MIN #define SIG_ATOMIC_MAX INT_MAX #ifndef SIZE_MAX // [ # ifdef _WIN64 // [ # define SIZE_MAX _UI64_MAX # else // _WIN64 ][ # define SIZE_MAX _UI32_MAX # endif // _WIN64 ] #endif // SIZE_MAX ] // WCHAR_MIN and WCHAR_MAX are also defined in #ifndef WCHAR_MIN // [ # define WCHAR_MIN 0 #endif // WCHAR_MIN ] #ifndef WCHAR_MAX // [ # define WCHAR_MAX _UI16_MAX #endif // WCHAR_MAX ] #define WINT_MIN 0 #define WINT_MAX _UI16_MAX #endif // __STDC_LIMIT_MACROS ] // 7.18.4 Limits of other integer types #if !defined(__cplusplus) || defined(__STDC_CONSTANT_MACROS) // [ See footnote 224 at page 260 // 7.18.4.1 Macros for minimum-width integer constants #define INT8_C(val) val##i8 #define INT16_C(val) val##i16 #define INT32_C(val) val##i32 #define INT64_C(val) val##i64 #define UINT8_C(val) val##ui8 #define UINT16_C(val) val##ui16 #define UINT32_C(val) val##ui32 #define UINT64_C(val) val##ui64 // 7.18.4.2 Macros for greatest-width integer constants // These #ifndef's are needed to prevent collisions with . // Check out Issue 9 for the details. #ifndef INTMAX_C // [ # define INTMAX_C INT64_C #endif // INTMAX_C ] #ifndef UINTMAX_C // [ # define UINTMAX_C UINT64_C #endif // UINTMAX_C ] #endif // __STDC_CONSTANT_MACROS ] #endif // _MSC_VER >= 1600 ] #endif // _MSC_STDINT_H_ ] pymongo-3.11.0/bson/buffer.c000066400000000000000000000104521374256237000156550ustar00rootroot00000000000000/* * Copyright 2009-2015 MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* Include Python.h so we can set Python's error indicator. */ #define PY_SSIZE_T_CLEAN #include "Python.h" #include #include #include "buffer.h" #define INITIAL_BUFFER_SIZE 256 struct buffer { char* buffer; int size; int position; }; /* Set Python's error indicator to MemoryError. * Called after allocation failures. */ static void set_memory_error(void) { PyErr_NoMemory(); } /* Allocate and return a new buffer. * Return NULL and sets MemoryError on allocation failure. */ buffer_t buffer_new(void) { buffer_t buffer; buffer = (buffer_t)malloc(sizeof(struct buffer)); if (buffer == NULL) { set_memory_error(); return NULL; } buffer->size = INITIAL_BUFFER_SIZE; buffer->position = 0; buffer->buffer = (char*)malloc(sizeof(char) * INITIAL_BUFFER_SIZE); if (buffer->buffer == NULL) { free(buffer); set_memory_error(); return NULL; } return buffer; } /* Free the memory allocated for `buffer`. * Return non-zero on failure. */ int buffer_free(buffer_t buffer) { if (buffer == NULL) { return 1; } /* Buffer will be NULL when buffer_grow fails. */ if (buffer->buffer != NULL) { free(buffer->buffer); } free(buffer); return 0; } /* Grow `buffer` to at least `min_length`. * Return non-zero and sets MemoryError on allocation failure. */ static int buffer_grow(buffer_t buffer, int min_length) { int old_size = 0; int size = buffer->size; char* old_buffer = buffer->buffer; if (size >= min_length) { return 0; } while (size < min_length) { old_size = size; size *= 2; if (size <= old_size) { /* Size did not increase. Could be an overflow * or size < 1. Just go with min_length. */ size = min_length; } } buffer->buffer = (char*)realloc(buffer->buffer, sizeof(char) * size); if (buffer->buffer == NULL) { free(old_buffer); set_memory_error(); return 1; } buffer->size = size; return 0; } /* Assure that `buffer` has at least `size` free bytes (and grow if needed). * Return non-zero and sets MemoryError on allocation failure. * Return non-zero and sets ValueError if `size` would exceed 2GiB. */ static int buffer_assure_space(buffer_t buffer, int size) { int new_size = buffer->position + size; /* Check for overflow. */ if (new_size < buffer->position) { PyErr_SetString(PyExc_ValueError, "Document would overflow BSON size limit"); return 1; } if (new_size <= buffer->size) { return 0; } return buffer_grow(buffer, new_size); } /* Save `size` bytes from the current position in `buffer` (and grow if needed). * Return offset for writing, or -1 on failure. * Sets MemoryError or ValueError on failure. */ buffer_position buffer_save_space(buffer_t buffer, int size) { int position = buffer->position; if (buffer_assure_space(buffer, size) != 0) { return -1; } buffer->position += size; return position; } /* Write `size` bytes from `data` to `buffer` (and grow if needed). * Return non-zero on failure. * Sets MemoryError or ValueError on failure. */ int buffer_write(buffer_t buffer, const char* data, int size) { if (buffer_assure_space(buffer, size) != 0) { return 1; } memcpy(buffer->buffer + buffer->position, data, size); buffer->position += size; return 0; } int buffer_get_position(buffer_t buffer) { return buffer->position; } char* buffer_get_buffer(buffer_t buffer) { return buffer->buffer; } void buffer_update_position(buffer_t buffer, buffer_position new_position) { buffer->position = new_position; } pymongo-3.11.0/bson/buffer.h000066400000000000000000000033541374256237000156650ustar00rootroot00000000000000/* * Copyright 2009-2015 MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef BUFFER_H #define BUFFER_H /* Note: if any of these functions return a failure condition then the buffer * has already been freed. */ /* A buffer */ typedef struct buffer* buffer_t; /* A position in the buffer */ typedef int buffer_position; /* Allocate and return a new buffer. * Return NULL on allocation failure. */ buffer_t buffer_new(void); /* Free the memory allocated for `buffer`. * Return non-zero on failure. */ int buffer_free(buffer_t buffer); /* Save `size` bytes from the current position in `buffer` (and grow if needed). * Return offset for writing, or -1 on allocation failure. */ buffer_position buffer_save_space(buffer_t buffer, int size); /* Write `size` bytes from `data` to `buffer` (and grow if needed). * Return non-zero on allocation failure. */ int buffer_write(buffer_t buffer, const char* data, int size); /* Getters for the internals of a buffer_t. * Should try to avoid using these as much as possible * since they break the abstraction. */ buffer_position buffer_get_position(buffer_t buffer); char* buffer_get_buffer(buffer_t buffer); void buffer_update_position(buffer_t buffer, buffer_position new_position); #endif pymongo-3.11.0/bson/code.py000066400000000000000000000064401374256237000155260ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for representing JavaScript code in BSON. """ from bson.py3compat import abc, string_type, PY3, text_type class Code(str): """BSON's JavaScript code type. Raises :class:`TypeError` if `code` is not an instance of :class:`basestring` (:class:`str` in python 3) or `scope` is not ``None`` or an instance of :class:`dict`. Scope variables can be set by passing a dictionary as the `scope` argument or by using keyword arguments. If a variable is set as a keyword argument it will override any setting for that variable in the `scope` dictionary. :Parameters: - `code`: A string containing JavaScript code to be evaluated or another instance of Code. In the latter case, the scope of `code` becomes this Code's :attr:`scope`. - `scope` (optional): dictionary representing the scope in which `code` should be evaluated - a mapping from identifiers (as strings) to values. Defaults to ``None``. This is applied after any scope associated with a given `code` above. - `**kwargs` (optional): scope variables can also be passed as keyword arguments. These are applied after `scope` and `code`. .. versionchanged:: 3.4 The default value for :attr:`scope` is ``None`` instead of ``{}``. """ _type_marker = 13 def __new__(cls, code, scope=None, **kwargs): if not isinstance(code, string_type): raise TypeError("code must be an " "instance of %s" % (string_type.__name__)) if not PY3 and isinstance(code, text_type): self = str.__new__(cls, code.encode('utf8')) else: self = str.__new__(cls, code) try: self.__scope = code.scope except AttributeError: self.__scope = None if scope is not None: if not isinstance(scope, abc.Mapping): raise TypeError("scope must be an instance of dict") if self.__scope is not None: self.__scope.update(scope) else: self.__scope = scope if kwargs: if self.__scope is not None: self.__scope.update(kwargs) else: self.__scope = kwargs return self @property def scope(self): """Scope dictionary for this instance or ``None``. """ return self.__scope def __repr__(self): return "Code(%s, %r)" % (str.__repr__(self), self.__scope) def __eq__(self, other): if isinstance(other, Code): return (self.__scope, str(self)) == (other.__scope, str(other)) return False __hash__ = None def __ne__(self, other): return not self == other pymongo-3.11.0/bson/codec_options.py000066400000000000000000000334361374256237000174510ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for specifying BSON codec options.""" import datetime import warnings from abc import abstractmethod from collections import namedtuple from bson.py3compat import ABC, abc, abstractproperty, string_type from bson.binary import (UuidRepresentation, ALL_UUID_REPRESENTATIONS, UUID_REPRESENTATION_NAMES) _RAW_BSON_DOCUMENT_MARKER = 101 def _raw_document_class(document_class): """Determine if a document_class is a RawBSONDocument class.""" marker = getattr(document_class, '_type_marker', None) return marker == _RAW_BSON_DOCUMENT_MARKER class TypeEncoder(ABC): """Base class for defining type codec classes which describe how a custom type can be transformed to one of the types BSON understands. Codec classes must implement the ``python_type`` attribute, and the ``transform_python`` method to support encoding. See :ref:`custom-type-type-codec` documentation for an example. """ @abstractproperty def python_type(self): """The Python type to be converted into something serializable.""" pass @abstractmethod def transform_python(self, value): """Convert the given Python object into something serializable.""" pass class TypeDecoder(ABC): """Base class for defining type codec classes which describe how a BSON type can be transformed to a custom type. Codec classes must implement the ``bson_type`` attribute, and the ``transform_bson`` method to support decoding. See :ref:`custom-type-type-codec` documentation for an example. """ @abstractproperty def bson_type(self): """The BSON type to be converted into our own type.""" pass @abstractmethod def transform_bson(self, value): """Convert the given BSON value into our own type.""" pass class TypeCodec(TypeEncoder, TypeDecoder): """Base class for defining type codec classes which describe how a custom type can be transformed to/from one of the types :mod:`bson` can already encode/decode. Codec classes must implement the ``python_type`` attribute, and the ``transform_python`` method to support encoding, as well as the ``bson_type`` attribute, and the ``transform_bson`` method to support decoding. See :ref:`custom-type-type-codec` documentation for an example. """ pass class TypeRegistry(object): """Encapsulates type codecs used in encoding and / or decoding BSON, as well as the fallback encoder. Type registries cannot be modified after instantiation. ``TypeRegistry`` can be initialized with an iterable of type codecs, and a callable for the fallback encoder:: >>> from bson.codec_options import TypeRegistry >>> type_registry = TypeRegistry([Codec1, Codec2, Codec3, ...], ... fallback_encoder) See :ref:`custom-type-type-registry` documentation for an example. :Parameters: - `type_codecs` (optional): iterable of type codec instances. If ``type_codecs`` contains multiple codecs that transform a single python or BSON type, the transformation specified by the type codec occurring last prevails. A TypeError will be raised if one or more type codecs modify the encoding behavior of a built-in :mod:`bson` type. - `fallback_encoder` (optional): callable that accepts a single, unencodable python value and transforms it into a type that :mod:`bson` can encode. See :ref:`fallback-encoder-callable` documentation for an example. """ def __init__(self, type_codecs=None, fallback_encoder=None): self.__type_codecs = list(type_codecs or []) self._fallback_encoder = fallback_encoder self._encoder_map = {} self._decoder_map = {} if self._fallback_encoder is not None: if not callable(fallback_encoder): raise TypeError("fallback_encoder %r is not a callable" % ( fallback_encoder)) for codec in self.__type_codecs: is_valid_codec = False if isinstance(codec, TypeEncoder): self._validate_type_encoder(codec) is_valid_codec = True self._encoder_map[codec.python_type] = codec.transform_python if isinstance(codec, TypeDecoder): is_valid_codec = True self._decoder_map[codec.bson_type] = codec.transform_bson if not is_valid_codec: raise TypeError( "Expected an instance of %s, %s, or %s, got %r instead" % ( TypeEncoder.__name__, TypeDecoder.__name__, TypeCodec.__name__, codec)) def _validate_type_encoder(self, codec): from bson import _BUILT_IN_TYPES for pytype in _BUILT_IN_TYPES: if issubclass(codec.python_type, pytype): err_msg = ("TypeEncoders cannot change how built-in types are " "encoded (encoder %s transforms type %s)" % (codec, pytype)) raise TypeError(err_msg) def __repr__(self): return ('%s(type_codecs=%r, fallback_encoder=%r)' % ( self.__class__.__name__, self.__type_codecs, self._fallback_encoder)) def __eq__(self, other): if not isinstance(other, type(self)): return NotImplemented return ((self._decoder_map == other._decoder_map) and (self._encoder_map == other._encoder_map) and (self._fallback_encoder == other._fallback_encoder)) _options_base = namedtuple( 'CodecOptions', ('document_class', 'tz_aware', 'uuid_representation', 'unicode_decode_error_handler', 'tzinfo', 'type_registry')) class CodecOptions(_options_base): """Encapsulates options used encoding and / or decoding BSON. The `document_class` option is used to define a custom type for use decoding BSON documents. Access to the underlying raw BSON bytes for a document is available using the :class:`~bson.raw_bson.RawBSONDocument` type:: >>> from bson.raw_bson import RawBSONDocument >>> from bson.codec_options import CodecOptions >>> codec_options = CodecOptions(document_class=RawBSONDocument) >>> coll = db.get_collection('test', codec_options=codec_options) >>> doc = coll.find_one() >>> doc.raw '\\x16\\x00\\x00\\x00\\x07_id\\x00[0\\x165\\x91\\x10\\xea\\x14\\xe8\\xc5\\x8b\\x93\\x00' The document class can be any type that inherits from :class:`~collections.MutableMapping`:: >>> class AttributeDict(dict): ... # A dict that supports attribute access. ... def __getattr__(self, key): ... return self[key] ... def __setattr__(self, key, value): ... self[key] = value ... >>> codec_options = CodecOptions(document_class=AttributeDict) >>> coll = db.get_collection('test', codec_options=codec_options) >>> doc = coll.find_one() >>> doc._id ObjectId('5b3016359110ea14e8c58b93') See :doc:`/examples/datetimes` for examples using the `tz_aware` and `tzinfo` options. See :class:`~bson.binary.UUIDLegacy` for examples using the `uuid_representation` option. :Parameters: - `document_class`: BSON documents returned in queries will be decoded to an instance of this class. Must be a subclass of :class:`~collections.MutableMapping`. Defaults to :class:`dict`. - `tz_aware`: If ``True``, BSON datetimes will be decoded to timezone aware instances of :class:`~datetime.datetime`. Otherwise they will be naive. Defaults to ``False``. - `uuid_representation`: The BSON representation to use when encoding and decoding instances of :class:`~uuid.UUID`. Defaults to :data:`~bson.binary.UuidRepresentation.PYTHON_LEGACY`. New applications should consider setting this to :data:`~bson.binary.UuidRepresentation.STANDARD` for cross language compatibility. See :ref:`handling-uuid-data-example` for details. - `unicode_decode_error_handler`: The error handler to apply when a Unicode-related error occurs during BSON decoding that would otherwise raise :exc:`UnicodeDecodeError`. Valid options include 'strict', 'replace', and 'ignore'. Defaults to 'strict'. - `tzinfo`: A :class:`~datetime.tzinfo` subclass that specifies the timezone to/from which :class:`~datetime.datetime` objects should be encoded/decoded. - `type_registry`: Instance of :class:`TypeRegistry` used to customize encoding and decoding behavior. .. versionadded:: 3.8 `type_registry` attribute. .. warning:: Care must be taken when changing `unicode_decode_error_handler` from its default value ('strict'). The 'replace' and 'ignore' modes should not be used when documents retrieved from the server will be modified in the client application and stored back to the server. """ def __new__(cls, document_class=dict, tz_aware=False, uuid_representation=None, unicode_decode_error_handler="strict", tzinfo=None, type_registry=None): if not (issubclass(document_class, abc.MutableMapping) or _raw_document_class(document_class)): raise TypeError("document_class must be dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or a " "sublass of collections.MutableMapping") if not isinstance(tz_aware, bool): raise TypeError("tz_aware must be True or False") if uuid_representation is None: uuid_representation = UuidRepresentation.PYTHON_LEGACY elif uuid_representation not in ALL_UUID_REPRESENTATIONS: raise ValueError("uuid_representation must be a value " "from bson.binary.UuidRepresentation") if not isinstance(unicode_decode_error_handler, (string_type, None)): raise ValueError("unicode_decode_error_handler must be a string " "or None") if tzinfo is not None: if not isinstance(tzinfo, datetime.tzinfo): raise TypeError( "tzinfo must be an instance of datetime.tzinfo") if not tz_aware: raise ValueError( "cannot specify tzinfo without also setting tz_aware=True") type_registry = type_registry or TypeRegistry() if not isinstance(type_registry, TypeRegistry): raise TypeError("type_registry must be an instance of TypeRegistry") return tuple.__new__( cls, (document_class, tz_aware, uuid_representation, unicode_decode_error_handler, tzinfo, type_registry)) def _arguments_repr(self): """Representation of the arguments used to create this object.""" document_class_repr = ( 'dict' if self.document_class is dict else repr(self.document_class)) uuid_rep_repr = UUID_REPRESENTATION_NAMES.get(self.uuid_representation, self.uuid_representation) return ('document_class=%s, tz_aware=%r, uuid_representation=%s, ' 'unicode_decode_error_handler=%r, tzinfo=%r, ' 'type_registry=%r' % (document_class_repr, self.tz_aware, uuid_rep_repr, self.unicode_decode_error_handler, self.tzinfo, self.type_registry)) def __repr__(self): return '%s(%s)' % (self.__class__.__name__, self._arguments_repr()) def with_options(self, **kwargs): """Make a copy of this CodecOptions, overriding some options:: >>> from bson.codec_options import DEFAULT_CODEC_OPTIONS >>> DEFAULT_CODEC_OPTIONS.tz_aware False >>> options = DEFAULT_CODEC_OPTIONS.with_options(tz_aware=True) >>> options.tz_aware True .. versionadded:: 3.5 """ return CodecOptions( kwargs.get('document_class', self.document_class), kwargs.get('tz_aware', self.tz_aware), kwargs.get('uuid_representation', self.uuid_representation), kwargs.get('unicode_decode_error_handler', self.unicode_decode_error_handler), kwargs.get('tzinfo', self.tzinfo), kwargs.get('type_registry', self.type_registry) ) DEFAULT_CODEC_OPTIONS = CodecOptions( uuid_representation=UuidRepresentation.PYTHON_LEGACY) def _parse_codec_options(options): """Parse BSON codec options.""" return CodecOptions( document_class=options.get( 'document_class', DEFAULT_CODEC_OPTIONS.document_class), tz_aware=options.get( 'tz_aware', DEFAULT_CODEC_OPTIONS.tz_aware), uuid_representation=options.get('uuidrepresentation'), unicode_decode_error_handler=options.get( 'unicode_decode_error_handler', DEFAULT_CODEC_OPTIONS.unicode_decode_error_handler), tzinfo=options.get('tzinfo', DEFAULT_CODEC_OPTIONS.tzinfo), type_registry=options.get( 'type_registry', DEFAULT_CODEC_OPTIONS.type_registry)) pymongo-3.11.0/bson/dbref.py000066400000000000000000000111751374256237000156770ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for manipulating DBRefs (references to MongoDB documents).""" from copy import deepcopy from bson.py3compat import iteritems, string_type from bson.son import SON class DBRef(object): """A reference to a document stored in MongoDB. """ # DBRef isn't actually a BSON "type" so this number was arbitrarily chosen. _type_marker = 100 def __init__(self, collection, id, database=None, _extra={}, **kwargs): """Initialize a new :class:`DBRef`. Raises :class:`TypeError` if `collection` or `database` is not an instance of :class:`basestring` (:class:`str` in python 3). `database` is optional and allows references to documents to work across databases. Any additional keyword arguments will create additional fields in the resultant embedded document. :Parameters: - `collection`: name of the collection the document is stored in - `id`: the value of the document's ``"_id"`` field - `database` (optional): name of the database to reference - `**kwargs` (optional): additional keyword arguments will create additional, custom fields .. mongodoc:: dbrefs """ if not isinstance(collection, string_type): raise TypeError("collection must be an " "instance of %s" % string_type.__name__) if database is not None and not isinstance(database, string_type): raise TypeError("database must be an " "instance of %s" % string_type.__name__) self.__collection = collection self.__id = id self.__database = database kwargs.update(_extra) self.__kwargs = kwargs @property def collection(self): """Get the name of this DBRef's collection as unicode. """ return self.__collection @property def id(self): """Get this DBRef's _id. """ return self.__id @property def database(self): """Get the name of this DBRef's database. Returns None if this DBRef doesn't specify a database. """ return self.__database def __getattr__(self, key): try: return self.__kwargs[key] except KeyError: raise AttributeError(key) # Have to provide __setstate__ to avoid # infinite recursion since we override # __getattr__. def __setstate__(self, state): self.__dict__.update(state) def as_doc(self): """Get the SON document representation of this DBRef. Generally not needed by application developers """ doc = SON([("$ref", self.collection), ("$id", self.id)]) if self.database is not None: doc["$db"] = self.database doc.update(self.__kwargs) return doc def __repr__(self): extra = "".join([", %s=%r" % (k, v) for k, v in iteritems(self.__kwargs)]) if self.database is None: return "DBRef(%r, %r%s)" % (self.collection, self.id, extra) return "DBRef(%r, %r, %r%s)" % (self.collection, self.id, self.database, extra) def __eq__(self, other): if isinstance(other, DBRef): us = (self.__database, self.__collection, self.__id, self.__kwargs) them = (other.__database, other.__collection, other.__id, other.__kwargs) return us == them return NotImplemented def __ne__(self, other): return not self == other def __hash__(self): """Get a hash value for this :class:`DBRef`.""" return hash((self.__collection, self.__id, self.__database, tuple(sorted(self.__kwargs.items())))) def __deepcopy__(self, memo): """Support function for `copy.deepcopy()`.""" return DBRef(deepcopy(self.__collection, memo), deepcopy(self.__id, memo), deepcopy(self.__database, memo), deepcopy(self.__kwargs, memo)) pymongo-3.11.0/bson/decimal128.py000066400000000000000000000242711374256237000164470ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for working with the BSON decimal128 type. .. versionadded:: 3.4 .. note:: The Decimal128 BSON type requires MongoDB 3.4+. """ import decimal import struct import sys from bson.py3compat import (PY3 as _PY3, string_type as _string_type) if _PY3: _from_bytes = int.from_bytes # pylint: disable=no-member, invalid-name else: import binascii def _from_bytes(value, dummy, _int=int, _hexlify=binascii.hexlify): "An implementation of int.from_bytes for python 2.x." return _int(_hexlify(value), 16) _PACK_64 = struct.Struct("= 3.3, cdecimal decimal.Context(clamp=1) # pylint: disable=unexpected-keyword-arg _CTX_OPTIONS['clamp'] = 1 except TypeError: # Python < 3.3 _CTX_OPTIONS['_clamp'] = 1 _DEC128_CTX = decimal.Context(**_CTX_OPTIONS.copy()) def create_decimal128_context(): """Returns an instance of :class:`decimal.Context` appropriate for working with IEEE-754 128-bit decimal floating point values. """ opts = _CTX_OPTIONS.copy() opts['traps'] = [] return decimal.Context(**opts) def _decimal_to_128(value): """Converts a decimal.Decimal to BID (high bits, low bits). :Parameters: - `value`: An instance of decimal.Decimal """ with decimal.localcontext(_DEC128_CTX) as ctx: value = ctx.create_decimal(value) if value.is_infinite(): return _NINF if value.is_signed() else _PINF sign, digits, exponent = value.as_tuple() if value.is_nan(): if digits: raise ValueError("NaN with debug payload is not supported") if value.is_snan(): return _NSNAN if value.is_signed() else _PSNAN return _NNAN if value.is_signed() else _PNAN significand = int("".join([str(digit) for digit in digits])) bit_length = significand.bit_length() high = 0 low = 0 for i in range(min(64, bit_length)): if significand & (1 << i): low |= 1 << i for i in range(64, bit_length): if significand & (1 << i): high |= 1 << (i - 64) biased_exponent = exponent + _EXPONENT_BIAS if high >> 49 == 1: high = high & 0x7fffffffffff high |= _EXPONENT_MASK high |= (biased_exponent & 0x3fff) << 47 else: high |= biased_exponent << 49 if sign: high |= _SIGN return high, low class Decimal128(object): """BSON Decimal128 type:: >>> Decimal128(Decimal("0.0005")) Decimal128('0.0005') >>> Decimal128("0.0005") Decimal128('0.0005') >>> Decimal128((3474527112516337664, 5)) Decimal128('0.0005') :Parameters: - `value`: An instance of :class:`decimal.Decimal`, string, or tuple of (high bits, low bits) from Binary Integer Decimal (BID) format. .. note:: :class:`~Decimal128` uses an instance of :class:`decimal.Context` configured for IEEE-754 Decimal128 when validating parameters. Signals like :class:`decimal.InvalidOperation`, :class:`decimal.Inexact`, and :class:`decimal.Overflow` are trapped and raised as exceptions:: >>> Decimal128(".13.1") Traceback (most recent call last): File "", line 1, in ... decimal.InvalidOperation: [] >>> >>> Decimal128("1E-6177") Traceback (most recent call last): File "", line 1, in ... decimal.Inexact: [] >>> >>> Decimal128("1E6145") Traceback (most recent call last): File "", line 1, in ... decimal.Overflow: [, ] To ensure the result of a calculation can always be stored as BSON Decimal128 use the context returned by :func:`create_decimal128_context`:: >>> import decimal >>> decimal128_ctx = create_decimal128_context() >>> with decimal.localcontext(decimal128_ctx) as ctx: ... Decimal128(ctx.create_decimal(".13.3")) ... Decimal128('NaN') >>> >>> with decimal.localcontext(decimal128_ctx) as ctx: ... Decimal128(ctx.create_decimal("1E-6177")) ... Decimal128('0E-6176') >>> >>> with decimal.localcontext(DECIMAL128_CTX) as ctx: ... Decimal128(ctx.create_decimal("1E6145")) ... Decimal128('Infinity') To match the behavior of MongoDB's Decimal128 implementation str(Decimal(value)) may not match str(Decimal128(value)) for NaN values:: >>> Decimal128(Decimal('NaN')) Decimal128('NaN') >>> Decimal128(Decimal('-NaN')) Decimal128('NaN') >>> Decimal128(Decimal('sNaN')) Decimal128('NaN') >>> Decimal128(Decimal('-sNaN')) Decimal128('NaN') However, :meth:`~Decimal128.to_decimal` will return the exact value:: >>> Decimal128(Decimal('NaN')).to_decimal() Decimal('NaN') >>> Decimal128(Decimal('-NaN')).to_decimal() Decimal('-NaN') >>> Decimal128(Decimal('sNaN')).to_decimal() Decimal('sNaN') >>> Decimal128(Decimal('-sNaN')).to_decimal() Decimal('-sNaN') Two instances of :class:`Decimal128` compare equal if their Binary Integer Decimal encodings are equal:: >>> Decimal128('NaN') == Decimal128('NaN') True >>> Decimal128('NaN').bid == Decimal128('NaN').bid True This differs from :class:`decimal.Decimal` comparisons for NaN:: >>> Decimal('NaN') == Decimal('NaN') False """ __slots__ = ('__high', '__low') _type_marker = 19 def __init__(self, value): if isinstance(value, (_string_type, decimal.Decimal)): self.__high, self.__low = _decimal_to_128(value) elif isinstance(value, (list, tuple)): if len(value) != 2: raise ValueError('Invalid size for creation of Decimal128 ' 'from list or tuple. Must have exactly 2 ' 'elements.') self.__high, self.__low = value else: raise TypeError("Cannot convert %r to Decimal128" % (value,)) def to_decimal(self): """Returns an instance of :class:`decimal.Decimal` for this :class:`Decimal128`. """ high = self.__high low = self.__low sign = 1 if (high & _SIGN) else 0 if (high & _SNAN) == _SNAN: return decimal.Decimal((sign, (), 'N')) elif (high & _NAN) == _NAN: return decimal.Decimal((sign, (), 'n')) elif (high & _INF) == _INF: return decimal.Decimal((sign, (), 'F')) if (high & _EXPONENT_MASK) == _EXPONENT_MASK: exponent = ((high & 0x1fffe00000000000) >> 47) - _EXPONENT_BIAS return decimal.Decimal((sign, (0,), exponent)) else: exponent = ((high & 0x7fff800000000000) >> 49) - _EXPONENT_BIAS arr = bytearray(15) mask = 0x00000000000000ff for i in range(14, 6, -1): arr[i] = (low & mask) >> ((14 - i) << 3) mask = mask << 8 mask = 0x00000000000000ff for i in range(6, 0, -1): arr[i] = (high & mask) >> ((6 - i) << 3) mask = mask << 8 mask = 0x0001000000000000 arr[0] = (high & mask) >> 48 # cdecimal only accepts a tuple for digits. digits = tuple( int(digit) for digit in str(_from_bytes(arr, 'big'))) with decimal.localcontext(_DEC128_CTX) as ctx: return ctx.create_decimal((sign, digits, exponent)) @classmethod def from_bid(cls, value): """Create an instance of :class:`Decimal128` from Binary Integer Decimal string. :Parameters: - `value`: 16 byte string (128-bit IEEE 754-2008 decimal floating point in Binary Integer Decimal (BID) format). """ if not isinstance(value, bytes): raise TypeError("value must be an instance of bytes") if len(value) != 16: raise ValueError("value must be exactly 16 bytes") return cls((_UNPACK_64(value[8:])[0], _UNPACK_64(value[:8])[0])) @property def bid(self): """The Binary Integer Decimal (BID) encoding of this instance.""" return _PACK_64(self.__low) + _PACK_64(self.__high) def __str__(self): dec = self.to_decimal() if dec.is_nan(): # Required by the drivers spec to match MongoDB behavior. return "NaN" return str(dec) def __repr__(self): return "Decimal128('%s')" % (str(self),) def __setstate__(self, value): self.__high, self.__low = value def __getstate__(self): return self.__high, self.__low def __eq__(self, other): if isinstance(other, Decimal128): return self.bid == other.bid return NotImplemented def __ne__(self, other): return not self == other pymongo-3.11.0/bson/encoding_helpers.c000066400000000000000000000106201374256237000177110ustar00rootroot00000000000000/* * Copyright 2009-2015 MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "encoding_helpers.h" /* * Portions Copyright 2001 Unicode, Inc. * * Disclaimer * * This source code is provided as is by Unicode, Inc. No claims are * made as to fitness for any particular purpose. No warranties of any * kind are expressed or implied. The recipient agrees to determine * applicability of information provided. If this file has been * purchased on magnetic or optical media from Unicode, Inc., the * sole remedy for any claim will be exchange of defective media * within 90 days of receipt. * * Limitations on Rights to Redistribute This Code * * Unicode, Inc. hereby grants the right to freely use the information * supplied in this file in the creation of products supporting the * Unicode Standard, and to make copies of this file in any form * for internal or external distribution as long as this notice * remains attached. */ /* * Index into the table below with the first byte of a UTF-8 sequence to * get the number of trailing bytes that are supposed to follow it. */ static const char trailingBytesForUTF8[256] = { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3,4,4,4,4,5,5,5,5 }; /* --------------------------------------------------------------------- */ /* * Utility routine to tell whether a sequence of bytes is legal UTF-8. * This must be called with the length pre-determined by the first byte. * The length can be set by: * length = trailingBytesForUTF8[*source]+1; * and the sequence is illegal right away if there aren't that many bytes * available. * If presented with a length > 4, this returns 0. The Unicode * definition of UTF-8 goes up to 4-byte sequences. */ static unsigned char isLegalUTF8(const unsigned char* source, int length) { unsigned char a; const unsigned char* srcptr = source + length; switch (length) { default: return 0; /* Everything else falls through when "true"... */ case 4: if ((a = (*--srcptr)) < 0x80 || a > 0xBF) return 0; case 3: if ((a = (*--srcptr)) < 0x80 || a > 0xBF) return 0; case 2: if ((a = (*--srcptr)) > 0xBF) return 0; switch (*source) { /* no fall-through in this inner switch */ case 0xE0: if (a < 0xA0) return 0; break; case 0xF0: if (a < 0x90) return 0; break; case 0xF4: if ((a > 0x8F) || (a < 0x80)) return 0; break; default: if (a < 0x80) return 0; } case 1: if (*source >= 0x80 && *source < 0xC2) return 0; if (*source > 0xF4) return 0; } return 1; } result_t check_string(const unsigned char* string, const int length, const char check_utf8, const char check_null) { int position = 0; /* By default we go character by character. Will be different for checking * UTF-8 */ int sequence_length = 1; if (!check_utf8 && !check_null) { return VALID; } while (position < length) { if (check_null && *(string + position) == 0) { return HAS_NULL; } if (check_utf8) { sequence_length = trailingBytesForUTF8[*(string + position)] + 1; if ((position + sequence_length) > length) { return NOT_UTF_8; } if (!isLegalUTF8(string + position, sequence_length)) { return NOT_UTF_8; } } position += sequence_length; } return VALID; } pymongo-3.11.0/bson/encoding_helpers.h000066400000000000000000000015431374256237000177220ustar00rootroot00000000000000/* * Copyright 2009-2015 MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ENCODING_HELPERS_H #define ENCODING_HELPERS_H typedef enum { VALID, NOT_UTF_8, HAS_NULL } result_t; result_t check_string(const unsigned char* string, const int length, const char check_utf8, const char check_null); #endif pymongo-3.11.0/bson/errors.py000066400000000000000000000022071374256237000161250ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Exceptions raised by the BSON package.""" class BSONError(Exception): """Base class for all BSON exceptions. """ class InvalidBSON(BSONError): """Raised when trying to create a BSON object from invalid data. """ class InvalidStringData(BSONError): """Raised when trying to encode a string containing non-UTF8 data. """ class InvalidDocument(BSONError): """Raised when trying to create a BSON object from an invalid document. """ class InvalidId(BSONError): """Raised when trying to create an ObjectId from invalid data. """ pymongo-3.11.0/bson/int64.py000066400000000000000000000020401374256237000155500ustar00rootroot00000000000000# Copyright 2014-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A BSON wrapper for long (int in python3)""" from bson.py3compat import PY3 if PY3: long = int class Int64(long): """Representation of the BSON int64 type. This is necessary because every integral number is an :class:`int` in Python 3. Small integral numbers are encoded to BSON int32 by default, but Int64 numbers will always be encoded to BSON int64. :Parameters: - `value`: the numeric value to represent """ _type_marker = 18 pymongo-3.11.0/bson/json_util.py000066400000000000000000000765031374256237000166310ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for using Python's :mod:`json` module with BSON documents. This module provides two helper methods `dumps` and `loads` that wrap the native :mod:`json` methods and provide explicit BSON conversion to and from JSON. :class:`~bson.json_util.JSONOptions` provides a way to control how JSON is emitted and parsed, with the default being the legacy PyMongo format. :mod:`~bson.json_util` can also generate Canonical or Relaxed `Extended JSON`_ when :const:`CANONICAL_JSON_OPTIONS` or :const:`RELAXED_JSON_OPTIONS` is provided, respectively. .. _Extended JSON: https://github.com/mongodb/specifications/blob/master/source/extended-json.rst Example usage (deserialization): .. doctest:: >>> from bson.json_util import loads >>> loads('[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$scope": {}, "$code": "function x() { return 1; }"}}, {"bin": {"$type": "80", "$binary": "AQIDBA=="}}]') [{u'foo': [1, 2]}, {u'bar': {u'hello': u'world'}}, {u'code': Code('function x() { return 1; }', {})}, {u'bin': Binary('...', 128)}] Example usage (serialization): .. doctest:: >>> from bson import Binary, Code >>> from bson.json_util import dumps >>> dumps([{'foo': [1, 2]}, ... {'bar': {'hello': 'world'}}, ... {'code': Code("function x() { return 1; }", {})}, ... {'bin': Binary(b"\x01\x02\x03\x04")}]) '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }", "$scope": {}}}, {"bin": {"$binary": "AQIDBA==", "$type": "00"}}]' Example usage (with :const:`CANONICAL_JSON_OPTIONS`): .. doctest:: >>> from bson import Binary, Code >>> from bson.json_util import dumps, CANONICAL_JSON_OPTIONS >>> dumps([{'foo': [1, 2]}, ... {'bar': {'hello': 'world'}}, ... {'code': Code("function x() { return 1; }")}, ... {'bin': Binary(b"\x01\x02\x03\x04")}], ... json_options=CANONICAL_JSON_OPTIONS) '[{"foo": [{"$numberInt": "1"}, {"$numberInt": "2"}]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }"}}, {"bin": {"$binary": {"base64": "AQIDBA==", "subType": "00"}}}]' Example usage (with :const:`RELAXED_JSON_OPTIONS`): .. doctest:: >>> from bson import Binary, Code >>> from bson.json_util import dumps, RELAXED_JSON_OPTIONS >>> dumps([{'foo': [1, 2]}, ... {'bar': {'hello': 'world'}}, ... {'code': Code("function x() { return 1; }")}, ... {'bin': Binary(b"\x01\x02\x03\x04")}], ... json_options=RELAXED_JSON_OPTIONS) '[{"foo": [1, 2]}, {"bar": {"hello": "world"}}, {"code": {"$code": "function x() { return 1; }"}}, {"bin": {"$binary": {"base64": "AQIDBA==", "subType": "00"}}}]' Alternatively, you can manually pass the `default` to :func:`json.dumps`. It won't handle :class:`~bson.binary.Binary` and :class:`~bson.code.Code` instances (as they are extended strings you can't provide custom defaults), but it will be faster as there is less recursion. .. note:: If your application does not need the flexibility offered by :class:`JSONOptions` and spends a large amount of time in the `json_util` module, look to `python-bsonjs `_ for a nice performance improvement. `python-bsonjs` is a fast BSON to MongoDB Extended JSON converter for Python built on top of `libbson `_. `python-bsonjs` works best with PyMongo when using :class:`~bson.raw_bson.RawBSONDocument`. .. versionchanged:: 2.8 The output format for :class:`~bson.timestamp.Timestamp` has changed from '{"t": , "i": }' to '{"$timestamp": {"t": , "i": }}'. This new format will be decoded to an instance of :class:`~bson.timestamp.Timestamp`. The old format will continue to be decoded to a python dict as before. Encoding to the old format is no longer supported as it was never correct and loses type information. Added support for $numberLong and $undefined - new in MongoDB 2.6 - and parsing $date in ISO-8601 format. .. versionchanged:: 2.7 Preserves order when rendering SON, Timestamp, Code, Binary, and DBRef instances. .. versionchanged:: 2.3 Added dumps and loads helpers to automatically handle conversion to and from json and supports :class:`~bson.binary.Binary` and :class:`~bson.code.Code` """ import base64 import datetime import json import math import re import uuid from pymongo.errors import ConfigurationError import bson from bson import EPOCH_AWARE, RE_TYPE, SON from bson.binary import (Binary, UuidRepresentation, ALL_UUID_SUBTYPES, UUID_SUBTYPE) from bson.code import Code from bson.codec_options import CodecOptions from bson.dbref import DBRef from bson.decimal128 import Decimal128 from bson.int64 import Int64 from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.py3compat import (PY3, iteritems, integer_types, string_type, text_type) from bson.regex import Regex from bson.timestamp import Timestamp from bson.tz_util import utc _RE_OPT_TABLE = { "i": re.I, "l": re.L, "m": re.M, "s": re.S, "u": re.U, "x": re.X, } # Dollar-prefixed keys which may appear in DBRefs. _DBREF_KEYS = frozenset(['$id', '$ref', '$db']) class DatetimeRepresentation: LEGACY = 0 """Legacy MongoDB Extended JSON datetime representation. :class:`datetime.datetime` instances will be encoded to JSON in the format `{"$date": }`, where `dateAsMilliseconds` is a 64-bit signed integer giving the number of milliseconds since the Unix epoch UTC. This was the default encoding before PyMongo version 3.4. .. versionadded:: 3.4 """ NUMBERLONG = 1 """NumberLong datetime representation. :class:`datetime.datetime` instances will be encoded to JSON in the format `{"$date": {"$numberLong": ""}}`, where `dateAsMilliseconds` is the string representation of a 64-bit signed integer giving the number of milliseconds since the Unix epoch UTC. .. versionadded:: 3.4 """ ISO8601 = 2 """ISO-8601 datetime representation. :class:`datetime.datetime` instances greater than or equal to the Unix epoch UTC will be encoded to JSON in the format `{"$date": ""}`. :class:`datetime.datetime` instances before the Unix epoch UTC will be encoded as if the datetime representation is :const:`~DatetimeRepresentation.NUMBERLONG`. .. versionadded:: 3.4 """ class JSONMode: LEGACY = 0 """Legacy Extended JSON representation. In this mode, :func:`~bson.json_util.dumps` produces PyMongo's legacy non-standard JSON output. Consider using :const:`~bson.json_util.JSONMode.RELAXED` or :const:`~bson.json_util.JSONMode.CANONICAL` instead. .. versionadded:: 3.5 """ RELAXED = 1 """Relaxed Extended JSON representation. In this mode, :func:`~bson.json_util.dumps` produces Relaxed Extended JSON, a mostly JSON-like format. Consider using this for things like a web API, where one is sending a document (or a projection of a document) that only uses ordinary JSON type primitives. In particular, the ``int``, :class:`~bson.int64.Int64`, and ``float`` numeric types are represented in the native JSON number format. This output is also the most human readable and is useful for debugging and documentation. .. seealso:: The specification for Relaxed `Extended JSON`_. .. versionadded:: 3.5 """ CANONICAL = 2 """Canonical Extended JSON representation. In this mode, :func:`~bson.json_util.dumps` produces Canonical Extended JSON, a type preserving format. Consider using this for things like testing, where one has to precisely specify expected types in JSON. In particular, the ``int``, :class:`~bson.int64.Int64`, and ``float`` numeric types are encoded with type wrappers. .. seealso:: The specification for Canonical `Extended JSON`_. .. versionadded:: 3.5 """ class JSONOptions(CodecOptions): """Encapsulates JSON options for :func:`dumps` and :func:`loads`. :Parameters: - `strict_number_long`: If ``True``, :class:`~bson.int64.Int64` objects are encoded to MongoDB Extended JSON's *Strict mode* type `NumberLong`, ie ``'{"$numberLong": "" }'``. Otherwise they will be encoded as an `int`. Defaults to ``False``. - `datetime_representation`: The representation to use when encoding instances of :class:`datetime.datetime`. Defaults to :const:`~DatetimeRepresentation.LEGACY`. - `strict_uuid`: If ``True``, :class:`uuid.UUID` object are encoded to MongoDB Extended JSON's *Strict mode* type `Binary`. Otherwise it will be encoded as ``'{"$uuid": "" }'``. Defaults to ``False``. - `json_mode`: The :class:`JSONMode` to use when encoding BSON types to Extended JSON. Defaults to :const:`~JSONMode.LEGACY`. - `document_class`: BSON documents returned by :func:`loads` will be decoded to an instance of this class. Must be a subclass of :class:`collections.MutableMapping`. Defaults to :class:`dict`. - `uuid_representation`: The :class:`~bson.binary.UuidRepresentation` to use when encoding and decoding instances of :class:`uuid.UUID`. Defaults to :const:`~bson.binary.UuidRepresentation.PYTHON_LEGACY`. - `tz_aware`: If ``True``, MongoDB Extended JSON's *Strict mode* type `Date` will be decoded to timezone aware instances of :class:`datetime.datetime`. Otherwise they will be naive. Defaults to ``True``. - `tzinfo`: A :class:`datetime.tzinfo` subclass that specifies the timezone from which :class:`~datetime.datetime` objects should be decoded. Defaults to :const:`~bson.tz_util.utc`. - `args`: arguments to :class:`~bson.codec_options.CodecOptions` - `kwargs`: arguments to :class:`~bson.codec_options.CodecOptions` .. seealso:: The specification for Relaxed and Canonical `Extended JSON`_. .. versionadded:: 3.4 .. versionchanged:: 3.5 Accepts the optional parameter `json_mode`. """ def __new__(cls, strict_number_long=False, datetime_representation=DatetimeRepresentation.LEGACY, strict_uuid=False, json_mode=JSONMode.LEGACY, *args, **kwargs): kwargs["tz_aware"] = kwargs.get("tz_aware", True) if kwargs["tz_aware"]: kwargs["tzinfo"] = kwargs.get("tzinfo", utc) if datetime_representation not in (DatetimeRepresentation.LEGACY, DatetimeRepresentation.NUMBERLONG, DatetimeRepresentation.ISO8601): raise ConfigurationError( "JSONOptions.datetime_representation must be one of LEGACY, " "NUMBERLONG, or ISO8601 from DatetimeRepresentation.") self = super(JSONOptions, cls).__new__(cls, *args, **kwargs) if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): raise ConfigurationError( "JSONOptions.json_mode must be one of LEGACY, RELAXED, " "or CANONICAL from JSONMode.") self.json_mode = json_mode if self.json_mode == JSONMode.RELAXED: self.strict_number_long = False self.datetime_representation = DatetimeRepresentation.ISO8601 self.strict_uuid = True elif self.json_mode == JSONMode.CANONICAL: self.strict_number_long = True self.datetime_representation = DatetimeRepresentation.NUMBERLONG self.strict_uuid = True else: self.strict_number_long = strict_number_long self.datetime_representation = datetime_representation self.strict_uuid = strict_uuid return self def _arguments_repr(self): return ('strict_number_long=%r, ' 'datetime_representation=%r, ' 'strict_uuid=%r, json_mode=%r, %s' % ( self.strict_number_long, self.datetime_representation, self.strict_uuid, self.json_mode, super(JSONOptions, self)._arguments_repr())) LEGACY_JSON_OPTIONS = JSONOptions(json_mode=JSONMode.LEGACY) """:class:`JSONOptions` for encoding to PyMongo's legacy JSON format. .. seealso:: The documentation for :const:`bson.json_util.JSONMode.LEGACY`. .. versionadded:: 3.5 """ DEFAULT_JSON_OPTIONS = LEGACY_JSON_OPTIONS """The default :class:`JSONOptions` for JSON encoding/decoding. The same as :const:`LEGACY_JSON_OPTIONS`. This will change to :const:`RELAXED_JSON_OPTIONS` in a future release. .. versionadded:: 3.4 """ CANONICAL_JSON_OPTIONS = JSONOptions(json_mode=JSONMode.CANONICAL) """:class:`JSONOptions` for Canonical Extended JSON. .. seealso:: The documentation for :const:`bson.json_util.JSONMode.CANONICAL`. .. versionadded:: 3.5 """ RELAXED_JSON_OPTIONS = JSONOptions(json_mode=JSONMode.RELAXED) """:class:`JSONOptions` for Relaxed Extended JSON. .. seealso:: The documentation for :const:`bson.json_util.JSONMode.RELAXED`. .. versionadded:: 3.5 """ STRICT_JSON_OPTIONS = JSONOptions( strict_number_long=True, datetime_representation=DatetimeRepresentation.ISO8601, strict_uuid=True) """**DEPRECATED** - :class:`JSONOptions` for MongoDB Extended JSON's *Strict mode* encoding. .. versionadded:: 3.4 .. versionchanged:: 3.5 Deprecated. Use :const:`RELAXED_JSON_OPTIONS` or :const:`CANONICAL_JSON_OPTIONS` instead. """ def dumps(obj, *args, **kwargs): """Helper function that wraps :func:`json.dumps`. Recursive function that handles all BSON types including :class:`~bson.binary.Binary` and :class:`~bson.code.Code`. :Parameters: - `json_options`: A :class:`JSONOptions` instance used to modify the encoding of MongoDB Extended JSON types. Defaults to :const:`DEFAULT_JSON_OPTIONS`. .. versionchanged:: 3.4 Accepts optional parameter `json_options`. See :class:`JSONOptions`. .. versionchanged:: 2.7 Preserves order when rendering SON, Timestamp, Code, Binary, and DBRef instances. """ json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS) return json.dumps(_json_convert(obj, json_options), *args, **kwargs) def loads(s, *args, **kwargs): """Helper function that wraps :func:`json.loads`. Automatically passes the object_hook for BSON type conversion. Raises ``TypeError``, ``ValueError``, ``KeyError``, or :exc:`~bson.errors.InvalidId` on invalid MongoDB Extended JSON. :Parameters: - `json_options`: A :class:`JSONOptions` instance used to modify the decoding of MongoDB Extended JSON types. Defaults to :const:`DEFAULT_JSON_OPTIONS`. .. versionchanged:: 3.5 Parses Relaxed and Canonical Extended JSON as well as PyMongo's legacy format. Now raises ``TypeError`` or ``ValueError`` when parsing JSON type wrappers with values of the wrong type or any extra keys. .. versionchanged:: 3.4 Accepts optional parameter `json_options`. See :class:`JSONOptions`. """ json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS) kwargs["object_pairs_hook"] = lambda pairs: object_pairs_hook( pairs, json_options) return json.loads(s, *args, **kwargs) def _json_convert(obj, json_options=DEFAULT_JSON_OPTIONS): """Recursive helper method that converts BSON types so they can be converted into json. """ if hasattr(obj, 'iteritems') or hasattr(obj, 'items'): # PY3 support return SON(((k, _json_convert(v, json_options)) for k, v in iteritems(obj))) elif hasattr(obj, '__iter__') and not isinstance(obj, (text_type, bytes)): return list((_json_convert(v, json_options) for v in obj)) try: return default(obj, json_options) except TypeError: return obj def object_pairs_hook(pairs, json_options=DEFAULT_JSON_OPTIONS): return object_hook(json_options.document_class(pairs), json_options) def object_hook(dct, json_options=DEFAULT_JSON_OPTIONS): if "$oid" in dct: return _parse_canonical_oid(dct) if "$ref" in dct: return _parse_canonical_dbref(dct) if "$date" in dct: return _parse_canonical_datetime(dct, json_options) if "$regex" in dct: return _parse_legacy_regex(dct) if "$minKey" in dct: return _parse_canonical_minkey(dct) if "$maxKey" in dct: return _parse_canonical_maxkey(dct) if "$binary" in dct: if "$type" in dct: return _parse_legacy_binary(dct, json_options) else: return _parse_canonical_binary(dct, json_options) if "$code" in dct: return _parse_canonical_code(dct) if "$uuid" in dct: return _parse_legacy_uuid(dct, json_options) if "$undefined" in dct: return None if "$numberLong" in dct: return _parse_canonical_int64(dct) if "$timestamp" in dct: tsp = dct["$timestamp"] return Timestamp(tsp["t"], tsp["i"]) if "$numberDecimal" in dct: return _parse_canonical_decimal128(dct) if "$dbPointer" in dct: return _parse_canonical_dbpointer(dct) if "$regularExpression" in dct: return _parse_canonical_regex(dct) if "$symbol" in dct: return _parse_canonical_symbol(dct) if "$numberInt" in dct: return _parse_canonical_int32(dct) if "$numberDouble" in dct: return _parse_canonical_double(dct) return dct def _parse_legacy_regex(doc): pattern = doc["$regex"] # Check if this is the $regex query operator. if isinstance(pattern, Regex): return doc flags = 0 # PyMongo always adds $options but some other tools may not. for opt in doc.get("$options", ""): flags |= _RE_OPT_TABLE.get(opt, 0) return Regex(pattern, flags) def _parse_legacy_uuid(doc, json_options): """Decode a JSON legacy $uuid to Python UUID.""" if len(doc) != 1: raise TypeError('Bad $uuid, extra field(s): %s' % (doc,)) if json_options.uuid_representation == UuidRepresentation.UNSPECIFIED: return Binary.from_uuid(uuid.UUID(doc["$uuid"])) else: return uuid.UUID(doc["$uuid"]) def _binary_or_uuid(data, subtype, json_options): # special handling for UUID if subtype in ALL_UUID_SUBTYPES: uuid_representation = json_options.uuid_representation binary_value = Binary(data, subtype) if uuid_representation == UuidRepresentation.UNSPECIFIED: return binary_value if subtype == UUID_SUBTYPE: # Legacy behavior: use STANDARD with binary subtype 4. uuid_representation = UuidRepresentation.STANDARD elif uuid_representation == UuidRepresentation.STANDARD: # subtype == OLD_UUID_SUBTYPE # Legacy behavior: STANDARD is the same as PYTHON_LEGACY. uuid_representation = UuidRepresentation.PYTHON_LEGACY return binary_value.as_uuid(uuid_representation) if PY3 and subtype == 0: return data return Binary(data, subtype) def _parse_legacy_binary(doc, json_options): if isinstance(doc["$type"], int): doc["$type"] = "%02x" % doc["$type"] subtype = int(doc["$type"], 16) if subtype >= 0xffffff80: # Handle mongoexport values subtype = int(doc["$type"][6:], 16) data = base64.b64decode(doc["$binary"].encode()) return _binary_or_uuid(data, subtype, json_options) def _parse_canonical_binary(doc, json_options): binary = doc["$binary"] b64 = binary["base64"] subtype = binary["subType"] if not isinstance(b64, string_type): raise TypeError('$binary base64 must be a string: %s' % (doc,)) if not isinstance(subtype, string_type) or len(subtype) > 2: raise TypeError('$binary subType must be a string at most 2 ' 'characters: %s' % (doc,)) if len(binary) != 2: raise TypeError('$binary must include only "base64" and "subType" ' 'components: %s' % (doc,)) data = base64.b64decode(b64.encode()) return _binary_or_uuid(data, int(subtype, 16), json_options) def _parse_canonical_datetime(doc, json_options): """Decode a JSON datetime to python datetime.datetime.""" dtm = doc["$date"] if len(doc) != 1: raise TypeError('Bad $date, extra field(s): %s' % (doc,)) # mongoexport 2.6 and newer if isinstance(dtm, string_type): # Parse offset if dtm[-1] == 'Z': dt = dtm[:-1] offset = 'Z' elif dtm[-6] in ('+', '-') and dtm[-3] == ':': # (+|-)HH:MM dt = dtm[:-6] offset = dtm[-6:] elif dtm[-5] in ('+', '-'): # (+|-)HHMM dt = dtm[:-5] offset = dtm[-5:] elif dtm[-3] in ('+', '-'): # (+|-)HH dt = dtm[:-3] offset = dtm[-3:] else: dt = dtm offset = '' # Parse the optional factional seconds portion. dot_index = dt.rfind('.') microsecond = 0 if dot_index != -1: microsecond = int(float(dt[dot_index:]) * 1000000) dt = dt[:dot_index] aware = datetime.datetime.strptime( dt, "%Y-%m-%dT%H:%M:%S").replace(microsecond=microsecond, tzinfo=utc) if offset and offset != 'Z': if len(offset) == 6: hours, minutes = offset[1:].split(':') secs = (int(hours) * 3600 + int(minutes) * 60) elif len(offset) == 5: secs = (int(offset[1:3]) * 3600 + int(offset[3:]) * 60) elif len(offset) == 3: secs = int(offset[1:3]) * 3600 if offset[0] == "-": secs *= -1 aware = aware - datetime.timedelta(seconds=secs) if json_options.tz_aware: if json_options.tzinfo: aware = aware.astimezone(json_options.tzinfo) return aware else: return aware.replace(tzinfo=None) return bson._millis_to_datetime(int(dtm), json_options) def _parse_canonical_oid(doc): """Decode a JSON ObjectId to bson.objectid.ObjectId.""" if len(doc) != 1: raise TypeError('Bad $oid, extra field(s): %s' % (doc,)) return ObjectId(doc['$oid']) def _parse_canonical_symbol(doc): """Decode a JSON symbol to Python string.""" symbol = doc['$symbol'] if len(doc) != 1: raise TypeError('Bad $symbol, extra field(s): %s' % (doc,)) return text_type(symbol) def _parse_canonical_code(doc): """Decode a JSON code to bson.code.Code.""" for key in doc: if key not in ('$code', '$scope'): raise TypeError('Bad $code, extra field(s): %s' % (doc,)) return Code(doc['$code'], scope=doc.get('$scope')) def _parse_canonical_regex(doc): """Decode a JSON regex to bson.regex.Regex.""" regex = doc['$regularExpression'] if len(doc) != 1: raise TypeError('Bad $regularExpression, extra field(s): %s' % (doc,)) if len(regex) != 2: raise TypeError('Bad $regularExpression must include only "pattern"' 'and "options" components: %s' % (doc,)) return Regex(regex['pattern'], regex['options']) def _parse_canonical_dbref(doc): """Decode a JSON DBRef to bson.dbref.DBRef.""" for key in doc: if key.startswith('$') and key not in _DBREF_KEYS: # Other keys start with $, so dct cannot be parsed as a DBRef. return doc return DBRef(doc.pop('$ref'), doc.pop('$id'), database=doc.pop('$db', None), **doc) def _parse_canonical_dbpointer(doc): """Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef.""" dbref = doc['$dbPointer'] if len(doc) != 1: raise TypeError('Bad $dbPointer, extra field(s): %s' % (doc,)) if isinstance(dbref, DBRef): dbref_doc = dbref.as_doc() # DBPointer must not contain $db in its value. if dbref.database is not None: raise TypeError( 'Bad $dbPointer, extra field $db: %s' % (dbref_doc,)) if not isinstance(dbref.id, ObjectId): raise TypeError( 'Bad $dbPointer, $id must be an ObjectId: %s' % (dbref_doc,)) if len(dbref_doc) != 2: raise TypeError( 'Bad $dbPointer, extra field(s) in DBRef: %s' % (dbref_doc,)) return dbref else: raise TypeError('Bad $dbPointer, expected a DBRef: %s' % (doc,)) def _parse_canonical_int32(doc): """Decode a JSON int32 to python int.""" i_str = doc['$numberInt'] if len(doc) != 1: raise TypeError('Bad $numberInt, extra field(s): %s' % (doc,)) if not isinstance(i_str, string_type): raise TypeError('$numberInt must be string: %s' % (doc,)) return int(i_str) def _parse_canonical_int64(doc): """Decode a JSON int64 to bson.int64.Int64.""" l_str = doc['$numberLong'] if len(doc) != 1: raise TypeError('Bad $numberLong, extra field(s): %s' % (doc,)) return Int64(l_str) def _parse_canonical_double(doc): """Decode a JSON double to python float.""" d_str = doc['$numberDouble'] if len(doc) != 1: raise TypeError('Bad $numberDouble, extra field(s): %s' % (doc,)) if not isinstance(d_str, string_type): raise TypeError('$numberDouble must be string: %s' % (doc,)) return float(d_str) def _parse_canonical_decimal128(doc): """Decode a JSON decimal128 to bson.decimal128.Decimal128.""" d_str = doc['$numberDecimal'] if len(doc) != 1: raise TypeError('Bad $numberDecimal, extra field(s): %s' % (doc,)) if not isinstance(d_str, string_type): raise TypeError('$numberDecimal must be string: %s' % (doc,)) return Decimal128(d_str) def _parse_canonical_minkey(doc): """Decode a JSON MinKey to bson.min_key.MinKey.""" if type(doc['$minKey']) is not int or doc['$minKey'] != 1: raise TypeError('$minKey value must be 1: %s' % (doc,)) if len(doc) != 1: raise TypeError('Bad $minKey, extra field(s): %s' % (doc,)) return MinKey() def _parse_canonical_maxkey(doc): """Decode a JSON MaxKey to bson.max_key.MaxKey.""" if type(doc['$maxKey']) is not int or doc['$maxKey'] != 1: raise TypeError('$maxKey value must be 1: %s', (doc,)) if len(doc) != 1: raise TypeError('Bad $minKey, extra field(s): %s' % (doc,)) return MaxKey() def _encode_binary(data, subtype, json_options): if json_options.json_mode == JSONMode.LEGACY: return SON([ ('$binary', base64.b64encode(data).decode()), ('$type', "%02x" % subtype)]) return {'$binary': SON([ ('base64', base64.b64encode(data).decode()), ('subType', "%02x" % subtype)])} def default(obj, json_options=DEFAULT_JSON_OPTIONS): # We preserve key order when rendering SON, DBRef, etc. as JSON by # returning a SON for those types instead of a dict. if isinstance(obj, ObjectId): return {"$oid": str(obj)} if isinstance(obj, DBRef): return _json_convert(obj.as_doc(), json_options=json_options) if isinstance(obj, datetime.datetime): if (json_options.datetime_representation == DatetimeRepresentation.ISO8601): if not obj.tzinfo: obj = obj.replace(tzinfo=utc) if obj >= EPOCH_AWARE: off = obj.tzinfo.utcoffset(obj) if (off.days, off.seconds, off.microseconds) == (0, 0, 0): tz_string = 'Z' else: tz_string = obj.strftime('%z') millis = int(obj.microsecond / 1000) fracsecs = ".%03d" % (millis,) if millis else "" return {"$date": "%s%s%s" % ( obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)} millis = bson._datetime_to_millis(obj) if (json_options.datetime_representation == DatetimeRepresentation.LEGACY): return {"$date": millis} return {"$date": {"$numberLong": str(millis)}} if json_options.strict_number_long and isinstance(obj, Int64): return {"$numberLong": str(obj)} if isinstance(obj, (RE_TYPE, Regex)): flags = "" if obj.flags & re.IGNORECASE: flags += "i" if obj.flags & re.LOCALE: flags += "l" if obj.flags & re.MULTILINE: flags += "m" if obj.flags & re.DOTALL: flags += "s" if obj.flags & re.UNICODE: flags += "u" if obj.flags & re.VERBOSE: flags += "x" if isinstance(obj.pattern, text_type): pattern = obj.pattern else: pattern = obj.pattern.decode('utf-8') if json_options.json_mode == JSONMode.LEGACY: return SON([("$regex", pattern), ("$options", flags)]) return {'$regularExpression': SON([("pattern", pattern), ("options", flags)])} if isinstance(obj, MinKey): return {"$minKey": 1} if isinstance(obj, MaxKey): return {"$maxKey": 1} if isinstance(obj, Timestamp): return {"$timestamp": SON([("t", obj.time), ("i", obj.inc)])} if isinstance(obj, Code): if obj.scope is None: return {'$code': str(obj)} return SON([ ('$code', str(obj)), ('$scope', _json_convert(obj.scope, json_options))]) if isinstance(obj, Binary): return _encode_binary(obj, obj.subtype, json_options) if PY3 and isinstance(obj, bytes): return _encode_binary(obj, 0, json_options) if isinstance(obj, uuid.UUID): if json_options.strict_uuid: binval = Binary.from_uuid( obj, uuid_representation=json_options.uuid_representation) return _encode_binary(binval, binval.subtype, json_options) else: return {"$uuid": obj.hex} if isinstance(obj, Decimal128): return {"$numberDecimal": str(obj)} if isinstance(obj, bool): return obj if (json_options.json_mode == JSONMode.CANONICAL and isinstance(obj, integer_types)): if -2 ** 31 <= obj < 2 ** 31: return {'$numberInt': text_type(obj)} return {'$numberLong': text_type(obj)} if json_options.json_mode != JSONMode.LEGACY and isinstance(obj, float): if math.isnan(obj): return {'$numberDouble': 'NaN'} elif math.isinf(obj): representation = 'Infinity' if obj > 0 else '-Infinity' return {'$numberDouble': representation} elif json_options.json_mode == JSONMode.CANONICAL: # repr() will return the shortest string guaranteed to produce the # original value, when float() is called on it. str produces a # shorter string in Python 2. return {'$numberDouble': text_type(repr(obj))} raise TypeError("%r is not JSON serializable" % obj) pymongo-3.11.0/bson/max_key.py000066400000000000000000000024431374256237000162500ustar00rootroot00000000000000# Copyright 2010-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Representation for the MongoDB internal MaxKey type. """ class MaxKey(object): """MongoDB internal MaxKey type. .. versionchanged:: 2.7 ``MaxKey`` now implements comparison operators. """ _type_marker = 127 def __eq__(self, other): return isinstance(other, MaxKey) def __hash__(self): return hash(self._type_marker) def __ne__(self, other): return not self == other def __le__(self, other): return isinstance(other, MaxKey) def __lt__(self, dummy): return False def __ge__(self, dummy): return True def __gt__(self, other): return not isinstance(other, MaxKey) def __repr__(self): return "MaxKey()" pymongo-3.11.0/bson/min_key.py000066400000000000000000000024431374256237000162460ustar00rootroot00000000000000# Copyright 2010-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Representation for the MongoDB internal MinKey type. """ class MinKey(object): """MongoDB internal MinKey type. .. versionchanged:: 2.7 ``MinKey`` now implements comparison operators. """ _type_marker = 255 def __eq__(self, other): return isinstance(other, MinKey) def __hash__(self): return hash(self._type_marker) def __ne__(self, other): return not self == other def __le__(self, dummy): return True def __lt__(self, other): return not isinstance(other, MinKey) def __ge__(self, other): return isinstance(other, MinKey) def __gt__(self, dummy): return False def __repr__(self): return "MinKey()" pymongo-3.11.0/bson/objectid.py000066400000000000000000000222411374256237000163740ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for working with MongoDB `ObjectIds `_. """ import binascii import calendar import datetime import os import struct import threading import time from random import SystemRandom from bson.errors import InvalidId from bson.py3compat import PY3, bytes_from_hex, string_type, text_type from bson.tz_util import utc _MAX_COUNTER_VALUE = 0xFFFFFF def _raise_invalid_id(oid): raise InvalidId( "%r is not a valid ObjectId, it must be a 12-byte input" " or a 24-character hex string" % oid) def _random_bytes(): """Get the 5-byte random field of an ObjectId.""" return os.urandom(5) class ObjectId(object): """A MongoDB ObjectId. """ _pid = os.getpid() _inc = SystemRandom().randint(0, _MAX_COUNTER_VALUE) _inc_lock = threading.Lock() __random = _random_bytes() __slots__ = ('__id',) _type_marker = 7 def __init__(self, oid=None): """Initialize a new ObjectId. An ObjectId is a 12-byte unique identifier consisting of: - a 4-byte value representing the seconds since the Unix epoch, - a 5-byte random value, - a 3-byte counter, starting with a random value. By default, ``ObjectId()`` creates a new unique identifier. The optional parameter `oid` can be an :class:`ObjectId`, or any 12 :class:`bytes` or, in Python 2, any 12-character :class:`str`. For example, the 12 bytes b'foo-bar-quux' do not follow the ObjectId specification but they are acceptable input:: >>> ObjectId(b'foo-bar-quux') ObjectId('666f6f2d6261722d71757578') `oid` can also be a :class:`unicode` or :class:`str` of 24 hex digits:: >>> ObjectId('0123456789ab0123456789ab') ObjectId('0123456789ab0123456789ab') >>> >>> # A u-prefixed unicode literal: >>> ObjectId(u'0123456789ab0123456789ab') ObjectId('0123456789ab0123456789ab') Raises :class:`~bson.errors.InvalidId` if `oid` is not 12 bytes nor 24 hex digits, or :class:`TypeError` if `oid` is not an accepted type. :Parameters: - `oid` (optional): a valid ObjectId. .. mongodoc:: objectids .. versionchanged:: 3.8 :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification version 0.2 `_. """ if oid is None: self.__generate() elif isinstance(oid, bytes) and len(oid) == 12: self.__id = oid else: self.__validate(oid) @classmethod def from_datetime(cls, generation_time): """Create a dummy ObjectId instance with a specific generation time. This method is useful for doing range queries on a field containing :class:`ObjectId` instances. .. warning:: It is not safe to insert a document containing an ObjectId generated using this method. This method deliberately eliminates the uniqueness guarantee that ObjectIds generally provide. ObjectIds generated with this method should be used exclusively in queries. `generation_time` will be converted to UTC. Naive datetime instances will be treated as though they already contain UTC. An example using this helper to get documents where ``"_id"`` was generated before January 1, 2010 would be: >>> gen_time = datetime.datetime(2010, 1, 1) >>> dummy_id = ObjectId.from_datetime(gen_time) >>> result = collection.find({"_id": {"$lt": dummy_id}}) :Parameters: - `generation_time`: :class:`~datetime.datetime` to be used as the generation time for the resulting ObjectId. """ if generation_time.utcoffset() is not None: generation_time = generation_time - generation_time.utcoffset() timestamp = calendar.timegm(generation_time.timetuple()) oid = struct.pack( ">I", int(timestamp)) + b"\x00\x00\x00\x00\x00\x00\x00\x00" return cls(oid) @classmethod def is_valid(cls, oid): """Checks if a `oid` string is valid or not. :Parameters: - `oid`: the object id to validate .. versionadded:: 2.3 """ if not oid: return False try: ObjectId(oid) return True except (InvalidId, TypeError): return False @classmethod def _random(cls): """Generate a 5-byte random number once per process. """ pid = os.getpid() if pid != cls._pid: cls._pid = pid cls.__random = _random_bytes() return cls.__random def __generate(self): """Generate a new value for this ObjectId. """ # 4 bytes current time oid = struct.pack(">I", int(time.time())) # 5 bytes random oid += ObjectId._random() # 3 bytes inc with ObjectId._inc_lock: oid += struct.pack(">I", ObjectId._inc)[1:4] ObjectId._inc = (ObjectId._inc + 1) % (_MAX_COUNTER_VALUE + 1) self.__id = oid def __validate(self, oid): """Validate and use the given id for this ObjectId. Raises TypeError if id is not an instance of (:class:`basestring` (:class:`str` or :class:`bytes` in python 3), ObjectId) and InvalidId if it is not a valid ObjectId. :Parameters: - `oid`: a valid ObjectId """ if isinstance(oid, ObjectId): self.__id = oid.binary # bytes or unicode in python 2, str in python 3 elif isinstance(oid, string_type): if len(oid) == 24: try: self.__id = bytes_from_hex(oid) except (TypeError, ValueError): _raise_invalid_id(oid) else: _raise_invalid_id(oid) else: raise TypeError("id must be an instance of (bytes, %s, ObjectId), " "not %s" % (text_type.__name__, type(oid))) @property def binary(self): """12-byte binary representation of this ObjectId. """ return self.__id @property def generation_time(self): """A :class:`datetime.datetime` instance representing the time of generation for this :class:`ObjectId`. The :class:`datetime.datetime` is timezone aware, and represents the generation time in UTC. It is precise to the second. """ timestamp = struct.unpack(">I", self.__id[0:4])[0] return datetime.datetime.fromtimestamp(timestamp, utc) def __getstate__(self): """return value of object for pickling. needed explicitly because __slots__() defined. """ return self.__id def __setstate__(self, value): """explicit state set from pickling """ # Provide backwards compatability with OIDs # pickled with pymongo-1.9 or older. if isinstance(value, dict): oid = value["_ObjectId__id"] else: oid = value # ObjectIds pickled in python 2.x used `str` for __id. # In python 3.x this has to be converted to `bytes` # by encoding latin-1. if PY3 and isinstance(oid, text_type): self.__id = oid.encode('latin-1') else: self.__id = oid def __str__(self): if PY3: return binascii.hexlify(self.__id).decode() return binascii.hexlify(self.__id) def __repr__(self): return "ObjectId('%s')" % (str(self),) def __eq__(self, other): if isinstance(other, ObjectId): return self.__id == other.binary return NotImplemented def __ne__(self, other): if isinstance(other, ObjectId): return self.__id != other.binary return NotImplemented def __lt__(self, other): if isinstance(other, ObjectId): return self.__id < other.binary return NotImplemented def __le__(self, other): if isinstance(other, ObjectId): return self.__id <= other.binary return NotImplemented def __gt__(self, other): if isinstance(other, ObjectId): return self.__id > other.binary return NotImplemented def __ge__(self, other): if isinstance(other, ObjectId): return self.__id >= other.binary return NotImplemented def __hash__(self): """Get a hash value for this :class:`ObjectId`.""" return hash(self.__id) pymongo-3.11.0/bson/py3compat.py000066400000000000000000000053771374256237000165430ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Utility functions and definitions for python3 compatibility.""" import sys PY3 = sys.version_info[0] == 3 if PY3: import codecs import collections.abc as abc import _thread as thread from abc import ABC, abstractmethod from io import BytesIO as StringIO def abstractproperty(func): return property(abstractmethod(func)) MAXSIZE = sys.maxsize imap = map def b(s): # BSON and socket operations deal in binary data. In # python 3 that means instances of `bytes`. In python # 2.7 you can create an alias for `bytes` using # the b prefix (e.g. b'foo'). # See http://python3porting.com/problems.html#nicer-solutions return codecs.latin_1_encode(s)[0] def bytes_from_hex(h): return bytes.fromhex(h) def iteritems(d): return iter(d.items()) def itervalues(d): return iter(d.values()) def reraise(exctype, value, trace=None): raise exctype(str(value)).with_traceback(trace) def reraise_instance(exc_instance, trace=None): raise exc_instance.with_traceback(trace) def _unicode(s): return s text_type = str string_type = str integer_types = int else: import collections as abc import thread from abc import ABCMeta, abstractproperty from itertools import imap try: from cStringIO import StringIO except ImportError: from StringIO import StringIO ABC = ABCMeta('ABC', (object,), {}) MAXSIZE = sys.maxint def b(s): # See comments above. In python 2.x b('foo') is just 'foo'. return s def bytes_from_hex(h): return h.decode('hex') def iteritems(d): return d.iteritems() def itervalues(d): return d.itervalues() def reraise(exctype, value, trace=None): _reraise(exctype, str(value), trace) def reraise_instance(exc_instance, trace=None): _reraise(exc_instance, None, trace) # "raise x, y, z" raises SyntaxError in Python 3 exec("""def _reraise(exc, value, trace): raise exc, value, trace """) _unicode = unicode string_type = basestring text_type = unicode integer_types = (int, long) pymongo-3.11.0/bson/raw_bson.py000066400000000000000000000117221374256237000164250ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for representing raw BSON documents. """ from bson import _raw_to_dict, _get_object_size from bson.py3compat import abc, iteritems from bson.codec_options import ( DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER) from bson.son import SON class RawBSONDocument(abc.Mapping): """Representation for a MongoDB document that provides access to the raw BSON bytes that compose it. Only when a field is accessed or modified within the document does RawBSONDocument decode its bytes. """ __slots__ = ('__raw', '__inflated_doc', '__codec_options') _type_marker = _RAW_BSON_DOCUMENT_MARKER def __init__(self, bson_bytes, codec_options=None): """Create a new :class:`RawBSONDocument` :class:`RawBSONDocument` is a representation of a BSON document that provides access to the underlying raw BSON bytes. Only when a field is accessed or modified within the document does RawBSONDocument decode its bytes. :class:`RawBSONDocument` implements the ``Mapping`` abstract base class from the standard library so it can be used like a read-only ``dict``:: >>> from bson import encode >>> raw_doc = RawBSONDocument(encode({'_id': 'my_doc'})) >>> raw_doc.raw b'...' >>> raw_doc['_id'] 'my_doc' :Parameters: - `bson_bytes`: the BSON bytes that compose this document - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions` whose ``document_class`` must be :class:`RawBSONDocument`. The default is :attr:`DEFAULT_RAW_BSON_OPTIONS`. .. versionchanged:: 3.8 :class:`RawBSONDocument` now validates that the ``bson_bytes`` passed in represent a single bson document. .. versionchanged:: 3.5 If a :class:`~bson.codec_options.CodecOptions` is passed in, its `document_class` must be :class:`RawBSONDocument`. """ self.__raw = bson_bytes self.__inflated_doc = None # Can't default codec_options to DEFAULT_RAW_BSON_OPTIONS in signature, # it refers to this class RawBSONDocument. if codec_options is None: codec_options = DEFAULT_RAW_BSON_OPTIONS elif codec_options.document_class is not RawBSONDocument: raise TypeError( "RawBSONDocument cannot use CodecOptions with document " "class %s" % (codec_options.document_class, )) self.__codec_options = codec_options # Validate the bson object size. _get_object_size(bson_bytes, 0, len(bson_bytes)) @property def raw(self): """The raw BSON bytes composing this document.""" return self.__raw def items(self): """Lazily decode and iterate elements in this document.""" return iteritems(self.__inflated) @property def __inflated(self): if self.__inflated_doc is None: # We already validated the object's size when this document was # created, so no need to do that again. # Use SON to preserve ordering of elements. self.__inflated_doc = _inflate_bson( self.__raw, self.__codec_options) return self.__inflated_doc def __getitem__(self, item): return self.__inflated[item] def __iter__(self): return iter(self.__inflated) def __len__(self): return len(self.__inflated) def __eq__(self, other): if isinstance(other, RawBSONDocument): return self.__raw == other.raw return NotImplemented def __repr__(self): return ("RawBSONDocument(%r, codec_options=%r)" % (self.raw, self.__codec_options)) def _inflate_bson(bson_bytes, codec_options): """Inflates the top level fields of a BSON document. :Parameters: - `bson_bytes`: the BSON bytes that compose this document - `codec_options`: An instance of :class:`~bson.codec_options.CodecOptions` whose ``document_class`` must be :class:`RawBSONDocument`. """ # Use SON to preserve ordering of elements. return _raw_to_dict( bson_bytes, 4, len(bson_bytes)-1, codec_options, SON()) DEFAULT_RAW_BSON_OPTIONS = DEFAULT.with_options(document_class=RawBSONDocument) """The default :class:`~bson.codec_options.CodecOptions` for :class:`RawBSONDocument`. """ pymongo-3.11.0/bson/regex.py000066400000000000000000000103031374256237000157170ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for representing MongoDB regular expressions. """ import re from bson.son import RE_TYPE from bson.py3compat import string_type, text_type def str_flags_to_int(str_flags): flags = 0 if "i" in str_flags: flags |= re.IGNORECASE if "l" in str_flags: flags |= re.LOCALE if "m" in str_flags: flags |= re.MULTILINE if "s" in str_flags: flags |= re.DOTALL if "u" in str_flags: flags |= re.UNICODE if "x" in str_flags: flags |= re.VERBOSE return flags class Regex(object): """BSON regular expression data.""" _type_marker = 11 @classmethod def from_native(cls, regex): """Convert a Python regular expression into a ``Regex`` instance. Note that in Python 3, a regular expression compiled from a :class:`str` has the ``re.UNICODE`` flag set. If it is undesirable to store this flag in a BSON regular expression, unset it first:: >>> pattern = re.compile('.*') >>> regex = Regex.from_native(pattern) >>> regex.flags ^= re.UNICODE >>> db.collection.insert({'pattern': regex}) :Parameters: - `regex`: A regular expression object from ``re.compile()``. .. warning:: Python regular expressions use a different syntax and different set of flags than MongoDB, which uses `PCRE`_. A regular expression retrieved from the server may not compile in Python, or may match a different set of strings in Python than when used in a MongoDB query. .. _PCRE: http://www.pcre.org/ """ if not isinstance(regex, RE_TYPE): raise TypeError( "regex must be a compiled regular expression, not %s" % type(regex)) return Regex(regex.pattern, regex.flags) def __init__(self, pattern, flags=0): """BSON regular expression data. This class is useful to store and retrieve regular expressions that are incompatible with Python's regular expression dialect. :Parameters: - `pattern`: string - `flags`: (optional) an integer bitmask, or a string of flag characters like "im" for IGNORECASE and MULTILINE """ if not isinstance(pattern, (text_type, bytes)): raise TypeError("pattern must be a string, not %s" % type(pattern)) self.pattern = pattern if isinstance(flags, string_type): self.flags = str_flags_to_int(flags) elif isinstance(flags, int): self.flags = flags else: raise TypeError( "flags must be a string or int, not %s" % type(flags)) def __eq__(self, other): if isinstance(other, Regex): return self.pattern == other.pattern and self.flags == other.flags else: return NotImplemented __hash__ = None def __ne__(self, other): return not self == other def __repr__(self): return "Regex(%r, %r)" % (self.pattern, self.flags) def try_compile(self): """Compile this :class:`Regex` as a Python regular expression. .. warning:: Python regular expressions use a different syntax and different set of flags than MongoDB, which uses `PCRE`_. A regular expression retrieved from the server may not compile in Python, or may match a different set of strings in Python than when used in a MongoDB query. :meth:`try_compile()` may raise :exc:`re.error`. .. _PCRE: http://www.pcre.org/ """ return re.compile(self.pattern, self.flags) pymongo-3.11.0/bson/son.py000066400000000000000000000132341374256237000154120ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for creating and manipulating SON, the Serialized Ocument Notation. Regular dictionaries can be used instead of SON objects, but not when the order of keys is important. A SON object can be used just like a normal Python dictionary.""" import copy import re from bson.py3compat import abc, iteritems # This sort of sucks, but seems to be as good as it gets... # This is essentially the same as re._pattern_type RE_TYPE = type(re.compile("")) class SON(dict): """SON data. A subclass of dict that maintains ordering of keys and provides a few extra niceties for dealing with SON. SON provides an API similar to collections.OrderedDict from Python 2.7+. """ def __init__(self, data=None, **kwargs): self.__keys = [] dict.__init__(self) self.update(data) self.update(kwargs) def __new__(cls, *args, **kwargs): instance = super(SON, cls).__new__(cls, *args, **kwargs) instance.__keys = [] return instance def __repr__(self): result = [] for key in self.__keys: result.append("(%r, %r)" % (key, self[key])) return "SON([%s])" % ", ".join(result) def __setitem__(self, key, value): if key not in self.__keys: self.__keys.append(key) dict.__setitem__(self, key, value) def __delitem__(self, key): self.__keys.remove(key) dict.__delitem__(self, key) def keys(self): return list(self.__keys) def copy(self): other = SON() other.update(self) return other # TODO this is all from UserDict.DictMixin. it could probably be made more # efficient. # second level definitions support higher levels def __iter__(self): for k in self.__keys: yield k def has_key(self, key): return key in self.__keys # third level takes advantage of second level definitions def iteritems(self): for k in self: yield (k, self[k]) def iterkeys(self): return self.__iter__() # fourth level uses definitions from lower levels def itervalues(self): for _, v in self.iteritems(): yield v def values(self): return [v for _, v in self.iteritems()] def items(self): return [(key, self[key]) for key in self] def clear(self): self.__keys = [] super(SON, self).clear() def setdefault(self, key, default=None): try: return self[key] except KeyError: self[key] = default return default def pop(self, key, *args): if len(args) > 1: raise TypeError("pop expected at most 2 arguments, got "\ + repr(1 + len(args))) try: value = self[key] except KeyError: if args: return args[0] raise del self[key] return value def popitem(self): try: k, v = next(self.iteritems()) except StopIteration: raise KeyError('container is empty') del self[k] return (k, v) def update(self, other=None, **kwargs): # Make progressively weaker assumptions about "other" if other is None: pass elif hasattr(other, 'iteritems'): # iteritems saves memory and lookups for k, v in other.iteritems(): self[k] = v elif hasattr(other, 'keys'): for k in other.keys(): self[k] = other[k] else: for k, v in other: self[k] = v if kwargs: self.update(kwargs) def get(self, key, default=None): try: return self[key] except KeyError: return default def __eq__(self, other): """Comparison to another SON is order-sensitive while comparison to a regular dictionary is order-insensitive. """ if isinstance(other, SON): return len(self) == len(other) and self.items() == other.items() return self.to_dict() == other def __ne__(self, other): return not self == other def __len__(self): return len(self.__keys) def to_dict(self): """Convert a SON document to a normal Python dictionary instance. This is trickier than just *dict(...)* because it needs to be recursive. """ def transform_value(value): if isinstance(value, list): return [transform_value(v) for v in value] elif isinstance(value, abc.Mapping): return dict([ (k, transform_value(v)) for k, v in iteritems(value)]) else: return value return transform_value(dict(self)) def __deepcopy__(self, memo): out = SON() val_id = id(self) if val_id in memo: return memo.get(val_id) memo[val_id] = out for k, v in self.iteritems(): if not isinstance(v, RE_TYPE): v = copy.deepcopy(v, memo) out[k] = v return out pymongo-3.11.0/bson/time64.c000066400000000000000000000515121374256237000155160ustar00rootroot00000000000000/* Copyright (c) 2007-2010 Michael G Schwern This software originally derived from Paul Sheer's pivotal_gmtime_r.c. The MIT License: Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ /* Programmers who have available to them 64-bit time values as a 'long long' type can use localtime64_r() and gmtime64_r() which correctly converts the time even on 32-bit systems. Whether you have 64-bit time values will depend on the operating system. localtime64_r() is a 64-bit equivalent of localtime_r(). gmtime64_r() is a 64-bit equivalent of gmtime_r(). */ #ifdef _MSC_VER #define _CRT_SECURE_NO_WARNINGS #endif /* Including Python.h fixes issues with interpreters built with -std=c99. */ #define PY_SSIZE_T_CLEAN #include "Python.h" #include #include "time64.h" #include "time64_limits.h" /* Spec says except for stftime() and the _r() functions, these all return static memory. Stabbings! */ static struct TM Static_Return_Date; static const int days_in_month[2][12] = { {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, {31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}, }; static const int julian_days_by_month[2][12] = { {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334}, {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335}, }; static const int length_of_year[2] = { 365, 366 }; /* Some numbers relating to the gregorian cycle */ static const Year years_in_gregorian_cycle = 400; #define days_in_gregorian_cycle ((365 * 400) + 100 - 4 + 1) static const Time64_T seconds_in_gregorian_cycle = days_in_gregorian_cycle * 60LL * 60LL * 24LL; /* Year range we can trust the time funcitons with */ #define MAX_SAFE_YEAR 2037 #define MIN_SAFE_YEAR 1971 /* 28 year Julian calendar cycle */ #define SOLAR_CYCLE_LENGTH 28 /* Year cycle from MAX_SAFE_YEAR down. */ static const int safe_years_high[SOLAR_CYCLE_LENGTH] = { 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025, 2026, 2027, 2028, 2029, 2030, 2031, 2032, 2033, 2034, 2035, 2036, 2037, 2010, 2011, 2012, 2013, 2014, 2015 }; /* Year cycle from MIN_SAFE_YEAR up */ static const int safe_years_low[SOLAR_CYCLE_LENGTH] = { 1996, 1997, 1998, 1971, 1972, 1973, 1974, 1975, 1976, 1977, 1978, 1979, 1980, 1981, 1982, 1983, 1984, 1985, 1986, 1987, 1988, 1989, 1990, 1991, 1992, 1993, 1994, 1995, }; /* Let's assume people are going to be looking for dates in the future. Let's provide some cheats so you can skip ahead. This has a 4x speed boost when near 2008. */ /* Number of days since epoch on Jan 1st, 2008 GMT */ #define CHEAT_DAYS (1199145600 / 24 / 60 / 60) #define CHEAT_YEARS 108 #define IS_LEAP(n) ((!(((n) + 1900) % 400) || (!(((n) + 1900) % 4) && (((n) + 1900) % 100))) != 0) #define _TIME64_WRAP(a,b,m) ((a) = ((a) < 0 ) ? ((b)--, (a) + (m)) : (a)) #ifdef USE_SYSTEM_LOCALTIME # define SHOULD_USE_SYSTEM_LOCALTIME(a) ( \ (a) <= SYSTEM_LOCALTIME_MAX && \ (a) >= SYSTEM_LOCALTIME_MIN \ ) #else # define SHOULD_USE_SYSTEM_LOCALTIME(a) (0) #endif #ifdef USE_SYSTEM_GMTIME # define SHOULD_USE_SYSTEM_GMTIME(a) ( \ (a) <= SYSTEM_GMTIME_MAX && \ (a) >= SYSTEM_GMTIME_MIN \ ) #else # define SHOULD_USE_SYSTEM_GMTIME(a) (0) #endif /* Multi varadic macros are a C99 thing, alas */ #ifdef TIME_64_DEBUG # define TIME64_TRACE(format) (fprintf(stderr, format)) # define TIME64_TRACE1(format, var1) (fprintf(stderr, format, var1)) # define TIME64_TRACE2(format, var1, var2) (fprintf(stderr, format, var1, var2)) # define TIME64_TRACE3(format, var1, var2, var3) (fprintf(stderr, format, var1, var2, var3)) #else # define TIME64_TRACE(format) ((void)0) # define TIME64_TRACE1(format, var1) ((void)0) # define TIME64_TRACE2(format, var1, var2) ((void)0) # define TIME64_TRACE3(format, var1, var2, var3) ((void)0) #endif static int is_exception_century(Year year) { int is_exception = ((year % 100 == 0) && !(year % 400 == 0)); TIME64_TRACE1("# is_exception_century: %s\n", is_exception ? "yes" : "no"); return(is_exception); } /* Compare two dates. The result is like cmp. Ignores things like gmtoffset and dst */ int cmp_date( const struct TM* left, const struct tm* right ) { if( left->tm_year > right->tm_year ) return 1; else if( left->tm_year < right->tm_year ) return -1; if( left->tm_mon > right->tm_mon ) return 1; else if( left->tm_mon < right->tm_mon ) return -1; if( left->tm_mday > right->tm_mday ) return 1; else if( left->tm_mday < right->tm_mday ) return -1; if( left->tm_hour > right->tm_hour ) return 1; else if( left->tm_hour < right->tm_hour ) return -1; if( left->tm_min > right->tm_min ) return 1; else if( left->tm_min < right->tm_min ) return -1; if( left->tm_sec > right->tm_sec ) return 1; else if( left->tm_sec < right->tm_sec ) return -1; return 0; } /* Check if a date is safely inside a range. The intention is to check if its a few days inside. */ int date_in_safe_range( const struct TM* date, const struct tm* min, const struct tm* max ) { if( cmp_date(date, min) == -1 ) return 0; if( cmp_date(date, max) == 1 ) return 0; return 1; } /* timegm() is not in the C or POSIX spec, but it is such a useful extension I would be remiss in leaving it out. Also I need it for localtime64() */ Time64_T timegm64(const struct TM *date) { Time64_T days = 0; Time64_T seconds = 0; Year year; Year orig_year = (Year)date->tm_year; int cycles = 0; if( orig_year > 100 ) { cycles = (int)((orig_year - 100) / 400); orig_year -= cycles * 400; days += (Time64_T)cycles * days_in_gregorian_cycle; } else if( orig_year < -300 ) { cycles = (int)((orig_year - 100) / 400); orig_year -= cycles * 400; days += (Time64_T)cycles * days_in_gregorian_cycle; } TIME64_TRACE3("# timegm/ cycles: %d, days: %lld, orig_year: %lld\n", cycles, days, orig_year); if( orig_year > 70 ) { year = 70; while( year < orig_year ) { days += length_of_year[IS_LEAP(year)]; year++; } } else if ( orig_year < 70 ) { year = 69; do { days -= length_of_year[IS_LEAP(year)]; year--; } while( year >= orig_year ); } days += julian_days_by_month[IS_LEAP(orig_year)][date->tm_mon]; days += date->tm_mday - 1; seconds = days * 60 * 60 * 24; seconds += date->tm_hour * 60 * 60; seconds += date->tm_min * 60; seconds += date->tm_sec; return(seconds); } #ifndef NDEBUG static int check_tm(struct TM *tm) { /* Don't forget leap seconds */ assert(tm->tm_sec >= 0); assert(tm->tm_sec <= 61); assert(tm->tm_min >= 0); assert(tm->tm_min <= 59); assert(tm->tm_hour >= 0); assert(tm->tm_hour <= 23); assert(tm->tm_mday >= 1); assert(tm->tm_mday <= days_in_month[IS_LEAP(tm->tm_year)][tm->tm_mon]); assert(tm->tm_mon >= 0); assert(tm->tm_mon <= 11); assert(tm->tm_wday >= 0); assert(tm->tm_wday <= 6); assert(tm->tm_yday >= 0); assert(tm->tm_yday <= length_of_year[IS_LEAP(tm->tm_year)]); #ifdef HAS_TM_TM_GMTOFF assert(tm->tm_gmtoff >= -24 * 60 * 60); assert(tm->tm_gmtoff <= 24 * 60 * 60); #endif return 1; } #endif /* The exceptional centuries without leap years cause the cycle to shift by 16 */ static Year cycle_offset(Year year) { const Year start_year = 2000; Year year_diff = year - start_year; Year exceptions; if( year > start_year ) year_diff--; exceptions = year_diff / 100; exceptions -= year_diff / 400; TIME64_TRACE3("# year: %lld, exceptions: %lld, year_diff: %lld\n", year, exceptions, year_diff); return exceptions * 16; } /* For a given year after 2038, pick the latest possible matching year in the 28 year calendar cycle. A matching year... 1) Starts on the same day of the week. 2) Has the same leap year status. This is so the calendars match up. Also the previous year must match. When doing Jan 1st you might wind up on Dec 31st the previous year when doing a -UTC time zone. Finally, the next year must have the same start day of week. This is for Dec 31st with a +UTC time zone. It doesn't need the same leap year status since we only care about January 1st. */ static int safe_year(const Year year) { int safe_year = 0; Year year_cycle; if( year >= MIN_SAFE_YEAR && year <= MAX_SAFE_YEAR ) { return (int)year; } year_cycle = year + cycle_offset(year); /* safe_years_low is off from safe_years_high by 8 years */ if( year < MIN_SAFE_YEAR ) year_cycle -= 8; /* Change non-leap xx00 years to an equivalent */ if( is_exception_century(year) ) year_cycle += 11; /* Also xx01 years, since the previous year will be wrong */ if( is_exception_century(year - 1) ) year_cycle += 17; year_cycle %= SOLAR_CYCLE_LENGTH; if( year_cycle < 0 ) year_cycle = SOLAR_CYCLE_LENGTH + year_cycle; assert( year_cycle >= 0 ); assert( year_cycle < SOLAR_CYCLE_LENGTH ); if( year < MIN_SAFE_YEAR ) safe_year = safe_years_low[year_cycle]; else if( year > MAX_SAFE_YEAR ) safe_year = safe_years_high[year_cycle]; else assert(0); TIME64_TRACE3("# year: %lld, year_cycle: %lld, safe_year: %d\n", year, year_cycle, safe_year); assert(safe_year <= MAX_SAFE_YEAR && safe_year >= MIN_SAFE_YEAR); return safe_year; } void copy_tm_to_TM64(const struct tm *src, struct TM *dest) { if( src == NULL ) { memset(dest, 0, sizeof(*dest)); } else { # ifdef USE_TM64 dest->tm_sec = src->tm_sec; dest->tm_min = src->tm_min; dest->tm_hour = src->tm_hour; dest->tm_mday = src->tm_mday; dest->tm_mon = src->tm_mon; dest->tm_year = (Year)src->tm_year; dest->tm_wday = src->tm_wday; dest->tm_yday = src->tm_yday; dest->tm_isdst = src->tm_isdst; # ifdef HAS_TM_TM_GMTOFF dest->tm_gmtoff = src->tm_gmtoff; # endif # ifdef HAS_TM_TM_ZONE dest->tm_zone = src->tm_zone; # endif # else /* They're the same type */ memcpy(dest, src, sizeof(*dest)); # endif } } void copy_TM64_to_tm(const struct TM *src, struct tm *dest) { if( src == NULL ) { memset(dest, 0, sizeof(*dest)); } else { # ifdef USE_TM64 dest->tm_sec = src->tm_sec; dest->tm_min = src->tm_min; dest->tm_hour = src->tm_hour; dest->tm_mday = src->tm_mday; dest->tm_mon = src->tm_mon; dest->tm_year = (int)src->tm_year; dest->tm_wday = src->tm_wday; dest->tm_yday = src->tm_yday; dest->tm_isdst = src->tm_isdst; # ifdef HAS_TM_TM_GMTOFF dest->tm_gmtoff = src->tm_gmtoff; # endif # ifdef HAS_TM_TM_ZONE dest->tm_zone = src->tm_zone; # endif # else /* They're the same type */ memcpy(dest, src, sizeof(*dest)); # endif } } /* Simulate localtime_r() to the best of our ability */ struct tm * fake_localtime_r(const time_t *time, struct tm *result) { const struct tm *static_result = localtime(time); assert(result != NULL); if( static_result == NULL ) { memset(result, 0, sizeof(*result)); return NULL; } else { memcpy(result, static_result, sizeof(*result)); return result; } } /* Simulate gmtime_r() to the best of our ability */ struct tm * fake_gmtime_r(const time_t *time, struct tm *result) { const struct tm *static_result = gmtime(time); assert(result != NULL); if( static_result == NULL ) { memset(result, 0, sizeof(*result)); return NULL; } else { memcpy(result, static_result, sizeof(*result)); return result; } } static Time64_T seconds_between_years(Year left_year, Year right_year) { int increment = (left_year > right_year) ? 1 : -1; Time64_T seconds = 0; int cycles; if( left_year > 2400 ) { cycles = (int)((left_year - 2400) / 400); left_year -= cycles * 400; seconds += cycles * seconds_in_gregorian_cycle; } else if( left_year < 1600 ) { cycles = (int)((left_year - 1600) / 400); left_year += cycles * 400; seconds += cycles * seconds_in_gregorian_cycle; } while( left_year != right_year ) { seconds += length_of_year[IS_LEAP(right_year - 1900)] * 60 * 60 * 24; right_year += increment; } return seconds * increment; } Time64_T mktime64(const struct TM *input_date) { struct tm safe_date; struct TM date; Time64_T time; Year year = input_date->tm_year + 1900; if( date_in_safe_range(input_date, &SYSTEM_MKTIME_MIN, &SYSTEM_MKTIME_MAX) ) { copy_TM64_to_tm(input_date, &safe_date); return (Time64_T)mktime(&safe_date); } /* Have to make the year safe in date else it won't fit in safe_date */ date = *input_date; date.tm_year = safe_year(year) - 1900; copy_TM64_to_tm(&date, &safe_date); time = (Time64_T)mktime(&safe_date); time += seconds_between_years(year, (Year)(safe_date.tm_year + 1900)); return time; } /* Because I think mktime() is a crappy name */ Time64_T timelocal64(const struct TM *date) { return mktime64(date); } struct TM *gmtime64_r (const Time64_T *in_time, struct TM *p) { int v_tm_sec, v_tm_min, v_tm_hour, v_tm_mon, v_tm_wday; Time64_T v_tm_tday; int leap; Time64_T m; Time64_T time = *in_time; Year year = 70; int cycles = 0; assert(p != NULL); #ifdef USE_SYSTEM_GMTIME /* Use the system gmtime() if time_t is small enough */ if( SHOULD_USE_SYSTEM_GMTIME(*in_time) ) { time_t safe_time = (time_t)*in_time; struct tm safe_date; GMTIME_R(&safe_time, &safe_date); copy_tm_to_TM64(&safe_date, p); assert(check_tm(p)); return p; } #endif #ifdef HAS_TM_TM_GMTOFF p->tm_gmtoff = 0; #endif p->tm_isdst = 0; #ifdef HAS_TM_TM_ZONE p->tm_zone = "UTC"; #endif v_tm_sec = (int)(time % 60); time /= 60; v_tm_min = (int)(time % 60); time /= 60; v_tm_hour = (int)(time % 24); time /= 24; v_tm_tday = time; _TIME64_WRAP (v_tm_sec, v_tm_min, 60); _TIME64_WRAP (v_tm_min, v_tm_hour, 60); _TIME64_WRAP (v_tm_hour, v_tm_tday, 24); v_tm_wday = (int)((v_tm_tday + 4) % 7); if (v_tm_wday < 0) v_tm_wday += 7; m = v_tm_tday; if (m >= CHEAT_DAYS) { year = CHEAT_YEARS; m -= CHEAT_DAYS; } if (m >= 0) { /* Gregorian cycles, this is huge optimization for distant times */ cycles = (int)(m / (Time64_T) days_in_gregorian_cycle); if( cycles ) { m -= (cycles * (Time64_T) days_in_gregorian_cycle); year += (cycles * years_in_gregorian_cycle); } /* Years */ leap = IS_LEAP (year); while (m >= (Time64_T) length_of_year[leap]) { m -= (Time64_T) length_of_year[leap]; year++; leap = IS_LEAP (year); } /* Months */ v_tm_mon = 0; while (m >= (Time64_T) days_in_month[leap][v_tm_mon]) { m -= (Time64_T) days_in_month[leap][v_tm_mon]; v_tm_mon++; } } else { year--; /* Gregorian cycles */ cycles = (int)((m / (Time64_T) days_in_gregorian_cycle) + 1); if( cycles ) { m -= (cycles * (Time64_T) days_in_gregorian_cycle); year += (cycles * years_in_gregorian_cycle); } /* Years */ leap = IS_LEAP (year); while (m < (Time64_T) -length_of_year[leap]) { m += (Time64_T) length_of_year[leap]; year--; leap = IS_LEAP (year); } /* Months */ v_tm_mon = 11; while (m < (Time64_T) -days_in_month[leap][v_tm_mon]) { m += (Time64_T) days_in_month[leap][v_tm_mon]; v_tm_mon--; } m += (Time64_T) days_in_month[leap][v_tm_mon]; } p->tm_year = (int)year; if( p->tm_year != year ) { #ifdef EOVERFLOW errno = EOVERFLOW; #endif return NULL; } /* At this point m is less than a year so casting to an int is safe */ p->tm_mday = (int) m + 1; p->tm_yday = julian_days_by_month[leap][v_tm_mon] + (int)m; p->tm_sec = v_tm_sec; p->tm_min = v_tm_min; p->tm_hour = v_tm_hour; p->tm_mon = v_tm_mon; p->tm_wday = v_tm_wday; assert(check_tm(p)); return p; } struct TM *localtime64_r (const Time64_T *time, struct TM *local_tm) { time_t safe_time; struct tm safe_date; struct TM gm_tm; Year orig_year; int month_diff; assert(local_tm != NULL); #ifdef USE_SYSTEM_LOCALTIME /* Use the system localtime() if time_t is small enough */ if( SHOULD_USE_SYSTEM_LOCALTIME(*time) ) { safe_time = (time_t)*time; TIME64_TRACE1("Using system localtime for %lld\n", *time); LOCALTIME_R(&safe_time, &safe_date); copy_tm_to_TM64(&safe_date, local_tm); assert(check_tm(local_tm)); return local_tm; } #endif if( gmtime64_r(time, &gm_tm) == NULL ) { TIME64_TRACE1("gmtime64_r returned null for %lld\n", *time); return NULL; } orig_year = gm_tm.tm_year; if (gm_tm.tm_year > (2037 - 1900) || gm_tm.tm_year < (1970 - 1900) ) { TIME64_TRACE1("Mapping tm_year %lld to safe_year\n", (Year)gm_tm.tm_year); gm_tm.tm_year = safe_year((Year)(gm_tm.tm_year + 1900)) - 1900; } safe_time = (time_t)timegm64(&gm_tm); if( LOCALTIME_R(&safe_time, &safe_date) == NULL ) { TIME64_TRACE1("localtime_r(%d) returned NULL\n", (int)safe_time); return NULL; } copy_tm_to_TM64(&safe_date, local_tm); local_tm->tm_year = (int)orig_year; if( local_tm->tm_year != orig_year ) { TIME64_TRACE2("tm_year overflow: tm_year %lld, orig_year %lld\n", (Year)local_tm->tm_year, (Year)orig_year); #ifdef EOVERFLOW errno = EOVERFLOW; #endif return NULL; } month_diff = local_tm->tm_mon - gm_tm.tm_mon; /* When localtime is Dec 31st previous year and gmtime is Jan 1st next year. */ if( month_diff == 11 ) { local_tm->tm_year--; } /* When localtime is Jan 1st, next year and gmtime is Dec 31st, previous year. */ if( month_diff == -11 ) { local_tm->tm_year++; } /* GMT is Jan 1st, xx01 year, but localtime is still Dec 31st in a non-leap xx00. There is one point in the cycle we can't account for which the safe xx00 year is a leap year. So we need to correct for Dec 31st comming out as the 366th day of the year. */ if( !IS_LEAP(local_tm->tm_year) && local_tm->tm_yday == 365 ) local_tm->tm_yday--; assert(check_tm(local_tm)); return local_tm; } int valid_tm_wday( const struct TM* date ) { if( 0 <= date->tm_wday && date->tm_wday <= 6 ) return 1; else return 0; } int valid_tm_mon( const struct TM* date ) { if( 0 <= date->tm_mon && date->tm_mon <= 11 ) return 1; else return 0; } /* Non-thread safe versions of the above */ struct TM *localtime64(const Time64_T *time) { #ifdef _MSC_VER _tzset(); #else tzset(); #endif return localtime64_r(time, &Static_Return_Date); } struct TM *gmtime64(const Time64_T *time) { return gmtime64_r(time, &Static_Return_Date); } pymongo-3.11.0/bson/time64.h000066400000000000000000000027511374256237000155240ustar00rootroot00000000000000#ifndef TIME64_H # define TIME64_H #include #include "time64_config.h" /* Set our custom types */ typedef INT_64_T Int64; typedef Int64 Time64_T; typedef Int64 Year; /* A copy of the tm struct but with a 64 bit year */ struct TM64 { int tm_sec; int tm_min; int tm_hour; int tm_mday; int tm_mon; Year tm_year; int tm_wday; int tm_yday; int tm_isdst; #ifdef HAS_TM_TM_GMTOFF long tm_gmtoff; #endif #ifdef HAS_TM_TM_ZONE char *tm_zone; #endif }; /* Decide which tm struct to use */ #ifdef USE_TM64 #define TM TM64 #else #define TM tm #endif /* Declare public functions */ struct TM *gmtime64_r (const Time64_T *, struct TM *); struct TM *localtime64_r (const Time64_T *, struct TM *); struct TM *gmtime64 (const Time64_T *); struct TM *localtime64 (const Time64_T *); Time64_T timegm64 (const struct TM *); Time64_T mktime64 (const struct TM *); Time64_T timelocal64 (const struct TM *); /* Not everyone has gm/localtime_r(), provide a replacement */ #ifdef HAS_LOCALTIME_R # define LOCALTIME_R(clock, result) localtime_r(clock, result) #else # define LOCALTIME_R(clock, result) fake_localtime_r(clock, result) #endif #ifdef HAS_GMTIME_R # define GMTIME_R(clock, result) gmtime_r(clock, result) #else # define GMTIME_R(clock, result) fake_gmtime_r(clock, result) #endif #endif pymongo-3.11.0/bson/time64_config.h000066400000000000000000000032221374256237000170430ustar00rootroot00000000000000/* Configuration ------------- Define as appropriate for your system. Sensible defaults provided. */ #ifndef TIME64_CONFIG_H # define TIME64_CONFIG_H /* Debugging TIME_64_DEBUG Define if you want debugging messages */ /* #define TIME_64_DEBUG */ /* INT_64_T A 64 bit integer type to use to store time and others. Must be defined. */ #define INT_64_T long long /* USE_TM64 Should we use a 64 bit safe replacement for tm? This will let you go past year 2 billion but the struct will be incompatible with tm. Conversion functions will be provided. */ /* #define USE_TM64 */ /* Availability of system functions. HAS_GMTIME_R Define if your system has gmtime_r() HAS_LOCALTIME_R Define if your system has localtime_r() HAS_TIMEGM Define if your system has timegm(), a GNU extension. */ #if !defined(WIN32) && !defined(_MSC_VER) #define HAS_GMTIME_R #define HAS_LOCALTIME_R #endif /* #define HAS_TIMEGM */ /* Details of non-standard tm struct elements. HAS_TM_TM_GMTOFF True if your tm struct has a "tm_gmtoff" element. A BSD extension. HAS_TM_TM_ZONE True if your tm struct has a "tm_zone" element. A BSD extension. */ /* #define HAS_TM_TM_GMTOFF */ /* #define HAS_TM_TM_ZONE */ /* USE_SYSTEM_LOCALTIME USE_SYSTEM_GMTIME USE_SYSTEM_MKTIME USE_SYSTEM_TIMEGM Should we use the system functions if the time is inside their range? Your system localtime() is probably more accurate, but our gmtime() is fast and safe. */ #define USE_SYSTEM_LOCALTIME /* #define USE_SYSTEM_GMTIME */ #define USE_SYSTEM_MKTIME /* #define USE_SYSTEM_TIMEGM */ #endif /* TIME64_CONFIG_H */ pymongo-3.11.0/bson/time64_limits.h000066400000000000000000000027241374256237000171050ustar00rootroot00000000000000/* Maximum and minimum inputs your system's respective time functions can correctly handle. time64.h will use your system functions if the input falls inside these ranges and corresponding USE_SYSTEM_* constant is defined. */ #ifndef TIME64_LIMITS_H #define TIME64_LIMITS_H /* Max/min for localtime() */ #define SYSTEM_LOCALTIME_MAX 2147483647 #define SYSTEM_LOCALTIME_MIN -2147483647-1 /* Max/min for gmtime() */ #define SYSTEM_GMTIME_MAX 2147483647 #define SYSTEM_GMTIME_MIN -2147483647-1 /* Max/min for mktime() */ static const struct tm SYSTEM_MKTIME_MAX = { 7, 14, 19, 18, 0, 138, 1, 17, 0 #ifdef HAS_TM_TM_GMTOFF ,-28800 #endif #ifdef HAS_TM_TM_ZONE ,"PST" #endif }; static const struct tm SYSTEM_MKTIME_MIN = { 52, 45, 12, 13, 11, 1, 5, 346, 0 #ifdef HAS_TM_TM_GMTOFF ,-28800 #endif #ifdef HAS_TM_TM_ZONE ,"PST" #endif }; /* Max/min for timegm() */ #ifdef HAS_TIMEGM static const struct tm SYSTEM_TIMEGM_MAX = { 7, 14, 3, 19, 0, 138, 2, 18, 0 #ifdef HAS_TM_TM_GMTOFF ,0 #endif #ifdef HAS_TM_TM_ZONE ,"UTC" #endif }; static const struct tm SYSTEM_TIMEGM_MIN = { 52, 45, 20, 13, 11, 1, 5, 346, 0 #ifdef HAS_TM_TM_GMTOFF ,0 #endif #ifdef HAS_TM_TM_ZONE ,"UTC" #endif }; #endif /* HAS_TIMEGM */ #endif /* TIME64_LIMITS_H */ pymongo-3.11.0/bson/timestamp.py000066400000000000000000000075341374256237000166240ustar00rootroot00000000000000# Copyright 2010-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for representing MongoDB internal Timestamps. """ import calendar import datetime from bson.py3compat import integer_types from bson.tz_util import utc UPPERBOUND = 4294967296 class Timestamp(object): """MongoDB internal timestamps used in the opLog. """ _type_marker = 17 def __init__(self, time, inc): """Create a new :class:`Timestamp`. This class is only for use with the MongoDB opLog. If you need to store a regular timestamp, please use a :class:`~datetime.datetime`. Raises :class:`TypeError` if `time` is not an instance of :class: `int` or :class:`~datetime.datetime`, or `inc` is not an instance of :class:`int`. Raises :class:`ValueError` if `time` or `inc` is not in [0, 2**32). :Parameters: - `time`: time in seconds since epoch UTC, or a naive UTC :class:`~datetime.datetime`, or an aware :class:`~datetime.datetime` - `inc`: the incrementing counter """ if isinstance(time, datetime.datetime): if time.utcoffset() is not None: time = time - time.utcoffset() time = int(calendar.timegm(time.timetuple())) if not isinstance(time, integer_types): raise TypeError("time must be an instance of int") if not isinstance(inc, integer_types): raise TypeError("inc must be an instance of int") if not 0 <= time < UPPERBOUND: raise ValueError("time must be contained in [0, 2**32)") if not 0 <= inc < UPPERBOUND: raise ValueError("inc must be contained in [0, 2**32)") self.__time = time self.__inc = inc @property def time(self): """Get the time portion of this :class:`Timestamp`. """ return self.__time @property def inc(self): """Get the inc portion of this :class:`Timestamp`. """ return self.__inc def __eq__(self, other): if isinstance(other, Timestamp): return (self.__time == other.time and self.__inc == other.inc) else: return NotImplemented def __hash__(self): return hash(self.time) ^ hash(self.inc) def __ne__(self, other): return not self == other def __lt__(self, other): if isinstance(other, Timestamp): return (self.time, self.inc) < (other.time, other.inc) return NotImplemented def __le__(self, other): if isinstance(other, Timestamp): return (self.time, self.inc) <= (other.time, other.inc) return NotImplemented def __gt__(self, other): if isinstance(other, Timestamp): return (self.time, self.inc) > (other.time, other.inc) return NotImplemented def __ge__(self, other): if isinstance(other, Timestamp): return (self.time, self.inc) >= (other.time, other.inc) return NotImplemented def __repr__(self): return "Timestamp(%s, %s)" % (self.__time, self.__inc) def as_datetime(self): """Return a :class:`~datetime.datetime` instance corresponding to the time portion of this :class:`Timestamp`. The returned datetime's timezone is UTC. """ return datetime.datetime.fromtimestamp(self.__time, utc) pymongo-3.11.0/bson/tz_util.py000066400000000000000000000027561374256237000163140ustar00rootroot00000000000000# Copyright 2010-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Timezone related utilities for BSON.""" from datetime import (timedelta, tzinfo) ZERO = timedelta(0) class FixedOffset(tzinfo): """Fixed offset timezone, in minutes east from UTC. Implementation based from the Python `standard library documentation `_. Defining __getinitargs__ enables pickling / copying. """ def __init__(self, offset, name): if isinstance(offset, timedelta): self.__offset = offset else: self.__offset = timedelta(minutes=offset) self.__name = name def __getinitargs__(self): return self.__offset, self.__name def utcoffset(self, dt): return self.__offset def tzname(self, dt): return self.__name def dst(self, dt): return ZERO utc = FixedOffset(0, "UTC") """Fixed offset timezone representing UTC.""" pymongo-3.11.0/doc/000077500000000000000000000000001374256237000140425ustar00rootroot00000000000000pymongo-3.11.0/doc/__init__.py000066400000000000000000000000001374256237000161410ustar00rootroot00000000000000pymongo-3.11.0/doc/api/000077500000000000000000000000001374256237000146135ustar00rootroot00000000000000pymongo-3.11.0/doc/api/bson/000077500000000000000000000000001374256237000155545ustar00rootroot00000000000000pymongo-3.11.0/doc/api/bson/binary.rst000066400000000000000000000015001374256237000175660ustar00rootroot00000000000000:mod:`binary` -- Tools for representing binary data to be stored in MongoDB =========================================================================== .. automodule:: bson.binary :synopsis: Tools for representing binary data to be stored in MongoDB .. autodata:: BINARY_SUBTYPE .. autodata:: FUNCTION_SUBTYPE .. autodata:: OLD_BINARY_SUBTYPE .. autodata:: OLD_UUID_SUBTYPE .. autodata:: UUID_SUBTYPE .. autodata:: STANDARD .. autodata:: PYTHON_LEGACY .. autodata:: JAVA_LEGACY .. autodata:: CSHARP_LEGACY .. autodata:: MD5_SUBTYPE .. autodata:: USER_DEFINED_SUBTYPE .. autoclass:: UuidRepresentation :members: .. autoclass:: Binary(data, subtype=BINARY_SUBTYPE) :members: :show-inheritance: .. autoclass:: UUIDLegacy(obj) :members: :show-inheritance: pymongo-3.11.0/doc/api/bson/code.rst000066400000000000000000000004311374256237000172160ustar00rootroot00000000000000:mod:`code` -- Tools for representing JavaScript code ===================================================== .. automodule:: bson.code :synopsis: Tools for representing JavaScript code .. autoclass:: Code(code, scope=None, **kwargs) :members: :show-inheritance: pymongo-3.11.0/doc/api/bson/codec_options.rst000066400000000000000000000003501374256237000211340ustar00rootroot00000000000000:mod:`codec_options` -- Tools for specifying BSON codec options =============================================================== .. automodule:: bson.codec_options :synopsis: Tools for specifying BSON codec options. :members: pymongo-3.11.0/doc/api/bson/dbref.rst000066400000000000000000000004651374256237000173750ustar00rootroot00000000000000:mod:`dbref` -- Tools for manipulating DBRefs (references to documents stored in MongoDB) ========================================================================================= .. automodule:: bson.dbref :synopsis: Tools for manipulating DBRefs (references to documents stored in MongoDB) :members: pymongo-3.11.0/doc/api/bson/decimal128.rst000066400000000000000000000002171374256237000201370ustar00rootroot00000000000000:mod:`decimal128` -- Support for BSON Decimal128 ================================================ .. automodule:: bson.decimal128 :members: pymongo-3.11.0/doc/api/bson/errors.rst000066400000000000000000000003321374256237000176200ustar00rootroot00000000000000:mod:`errors` -- Exceptions raised by the :mod:`bson` package ============================================================= .. automodule:: bson.errors :synopsis: Exceptions raised by the bson package :members: pymongo-3.11.0/doc/api/bson/index.rst000066400000000000000000000006471374256237000174240ustar00rootroot00000000000000:mod:`bson` -- BSON (Binary JSON) Encoding and Decoding ======================================================= .. automodule:: bson :synopsis: BSON (Binary JSON) Encoding and Decoding :members: Sub-modules: .. toctree:: :maxdepth: 2 binary code codec_options dbref decimal128 errors int64 json_util max_key min_key objectid raw_bson regex son timestamp tz_util pymongo-3.11.0/doc/api/bson/int64.rst000066400000000000000000000003231374256237000172500ustar00rootroot00000000000000:mod:`int64` -- Tools for representing BSON int64 ================================================= .. versionadded:: 3.0 .. automodule:: bson.int64 :synopsis: Tools for representing BSON int64 :members: pymongo-3.11.0/doc/api/bson/json_util.rst000066400000000000000000000005111374256237000203110ustar00rootroot00000000000000:mod:`json_util` -- Tools for using Python's :mod:`json` module with BSON documents =================================================================================== .. automodule:: bson.json_util :synopsis: Tools for using Python's json module with BSON documents :members: :undoc-members: :member-order: bysource pymongo-3.11.0/doc/api/bson/max_key.rst000066400000000000000000000003701374256237000177430ustar00rootroot00000000000000:mod:`max_key` -- Representation for the MongoDB internal MaxKey type ===================================================================== .. automodule:: bson.max_key :synopsis: Representation for the MongoDB internal MaxKey type :members: pymongo-3.11.0/doc/api/bson/min_key.rst000066400000000000000000000003701374256237000177410ustar00rootroot00000000000000:mod:`min_key` -- Representation for the MongoDB internal MinKey type ===================================================================== .. automodule:: bson.min_key :synopsis: Representation for the MongoDB internal MinKey type :members: pymongo-3.11.0/doc/api/bson/objectid.rst000066400000000000000000000012751374256237000200760ustar00rootroot00000000000000:mod:`objectid` -- Tools for working with MongoDB ObjectIds =========================================================== .. automodule:: bson.objectid :synopsis: Tools for working with MongoDB ObjectIds .. autoclass:: bson.objectid.ObjectId(oid=None) :members: .. describe:: str(o) Get a hex encoded version of :class:`ObjectId` `o`. The following property always holds: .. testsetup:: from bson.objectid import ObjectId .. doctest:: >>> o = ObjectId() >>> o == ObjectId(str(o)) True This representation is useful for urls or other places where ``o.binary`` is inappropriate. pymongo-3.11.0/doc/api/bson/raw_bson.rst000066400000000000000000000003401374256237000201150ustar00rootroot00000000000000:mod:`raw_bson` -- Tools for representing raw BSON documents. ============================================================= .. automodule:: bson.raw_bson :synopsis: Tools for representing raw BSON documents. :members: pymongo-3.11.0/doc/api/bson/regex.rst000066400000000000000000000004061374256237000174200ustar00rootroot00000000000000:mod:`regex` -- Tools for representing MongoDB regular expressions ================================================================== .. versionadded:: 2.7 .. automodule:: bson.regex :synopsis: Tools for representing MongoDB regular expressions :members: pymongo-3.11.0/doc/api/bson/son.rst000066400000000000000000000003361374256237000171070ustar00rootroot00000000000000:mod:`son` -- Tools for working with SON, an ordered mapping ============================================================ .. automodule:: bson.son :synopsis: Tools for working with SON, an ordered mapping :members: pymongo-3.11.0/doc/api/bson/timestamp.rst000066400000000000000000000003731374256237000203140ustar00rootroot00000000000000:mod:`timestamp` -- Tools for representing MongoDB internal Timestamps ====================================================================== .. automodule:: bson.timestamp :synopsis: Tools for representing MongoDB internal Timestamps :members: pymongo-3.11.0/doc/api/bson/tz_util.rst000066400000000000000000000003521374256237000200000ustar00rootroot00000000000000:mod:`tz_util` -- Utilities for dealing with timezones in Python ================================================================ .. automodule:: bson.tz_util :synopsis: Utilities for dealing with timezones in Python :members: pymongo-3.11.0/doc/api/gridfs/000077500000000000000000000000001374256237000160715ustar00rootroot00000000000000pymongo-3.11.0/doc/api/gridfs/errors.rst000066400000000000000000000003441374256237000201400ustar00rootroot00000000000000:mod:`errors` -- Exceptions raised by the :mod:`gridfs` package ================================================================= .. automodule:: gridfs.errors :synopsis: Exceptions raised by the gridfs package :members: pymongo-3.11.0/doc/api/gridfs/grid_file.rst000066400000000000000000000007031374256237000205470ustar00rootroot00000000000000:mod:`grid_file` -- Tools for representing files stored in GridFS ================================================================= .. automodule:: gridfs.grid_file :synopsis: Tools for representing files stored in GridFS .. autoclass:: GridIn :members: .. autoattribute:: _id .. autoclass:: GridOut :members: .. autoattribute:: _id .. automethod:: __iter__ .. autoclass:: GridOutCursor :members: pymongo-3.11.0/doc/api/gridfs/index.rst000066400000000000000000000003631374256237000177340ustar00rootroot00000000000000:mod:`gridfs` -- Tools for working with GridFS ============================================== .. automodule:: gridfs :synopsis: Tools for working with GridFS :members: Sub-modules: .. toctree:: :maxdepth: 2 errors grid_file pymongo-3.11.0/doc/api/index.rst000066400000000000000000000007451374256237000164620ustar00rootroot00000000000000API Documentation ================= The PyMongo distribution contains three top-level packages for interacting with MongoDB. :mod:`bson` is an implementation of the `BSON format `_, :mod:`pymongo` is a full-featured driver for MongoDB, and :mod:`gridfs` is a set of tools for working with the `GridFS `_ storage specification. .. toctree:: :maxdepth: 2 bson/index pymongo/index gridfs/index pymongo-3.11.0/doc/api/pymongo/000077500000000000000000000000001374256237000163035ustar00rootroot00000000000000pymongo-3.11.0/doc/api/pymongo/bulk.rst000066400000000000000000000003041374256237000177670ustar00rootroot00000000000000:mod:`bulk` -- The bulk write operations interface ================================================== .. automodule:: pymongo.bulk :synopsis: The bulk write operations interface. :members: pymongo-3.11.0/doc/api/pymongo/change_stream.rst000066400000000000000000000003141374256237000216330ustar00rootroot00000000000000:mod:`change_stream` -- Watch changes on a collection, database, or cluster =========================================================================== .. automodule:: pymongo.change_stream :members: pymongo-3.11.0/doc/api/pymongo/client_session.rst000066400000000000000000000002751374256237000220620ustar00rootroot00000000000000:mod:`client_session` -- Logical sessions for sequential operations =================================================================== .. automodule:: pymongo.client_session :members: pymongo-3.11.0/doc/api/pymongo/collation.rst000066400000000000000000000012061374256237000210200ustar00rootroot00000000000000:mod:`collation` -- Tools for working with collations. ====================================================== .. automodule:: pymongo.collation :synopsis: Tools for working with collations. .. autoclass:: pymongo.collation.Collation .. autoclass:: pymongo.collation.CollationStrength :members: :member-order: bysource .. autoclass:: pymongo.collation.CollationAlternate :members: :member-order: bysource .. autoclass:: pymongo.collation.CollationCaseFirst :members: :member-order: bysource .. autoclass:: pymongo.collation.CollationMaxVariable :members: :member-order: bysource pymongo-3.11.0/doc/api/pymongo/collection.rst000066400000000000000000000075661374256237000212060ustar00rootroot00000000000000:mod:`collection` -- Collection level operations ================================================ .. automodule:: pymongo.collection :synopsis: Collection level operations .. autodata:: pymongo.ASCENDING .. autodata:: pymongo.DESCENDING .. autodata:: pymongo.GEO2D .. autodata:: pymongo.GEOHAYSTACK .. autodata:: pymongo.GEOSPHERE .. autodata:: pymongo.HASHED .. autodata:: pymongo.TEXT .. autoclass:: pymongo.collection.ReturnDocument .. autoattribute:: BEFORE :annotation: .. autoattribute:: AFTER :annotation: .. autoclass:: pymongo.collection.Collection(database, name, create=False, **kwargs) .. describe:: c[name] || c.name Get the `name` sub-collection of :class:`Collection` `c`. Raises :class:`~pymongo.errors.InvalidName` if an invalid collection name is used. .. autoattribute:: full_name .. autoattribute:: name .. autoattribute:: database .. autoattribute:: codec_options .. autoattribute:: read_preference .. autoattribute:: write_concern .. autoattribute:: read_concern .. automethod:: with_options .. automethod:: bulk_write .. automethod:: insert_one .. automethod:: insert_many .. automethod:: replace_one .. automethod:: update_one .. automethod:: update_many .. automethod:: delete_one .. automethod:: delete_many .. automethod:: aggregate .. automethod:: aggregate_raw_batches .. automethod:: watch .. automethod:: find(filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, modifiers=None, batch_size=0, manipulate=True, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None, session=None) .. automethod:: find_raw_batches(filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, modifiers=None, batch_size=0, manipulate=True, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None) .. automethod:: find_one(filter=None, *args, **kwargs) .. automethod:: find_one_and_delete .. automethod:: find_one_and_replace(filter, replacement, projection=None, sort=None, return_document=ReturnDocument.BEFORE, hint=None, session=None, **kwargs) .. automethod:: find_one_and_update(filter, update, projection=None, sort=None, return_document=ReturnDocument.BEFORE, array_filters=None, hint=None, session=None, **kwargs) .. automethod:: count_documents .. automethod:: estimated_document_count .. automethod:: distinct .. automethod:: create_index .. automethod:: create_indexes .. automethod:: drop_index .. automethod:: drop_indexes .. automethod:: reindex .. automethod:: list_indexes .. automethod:: index_information .. automethod:: drop .. automethod:: rename .. automethod:: options .. automethod:: map_reduce .. automethod:: inline_map_reduce .. automethod:: parallel_scan .. automethod:: initialize_unordered_bulk_op .. automethod:: initialize_ordered_bulk_op .. automethod:: group .. automethod:: count .. automethod:: insert(doc_or_docs, manipulate=True, check_keys=True, continue_on_error=False, **kwargs) .. automethod:: save(to_save, manipulate=True, check_keys=True, **kwargs) .. automethod:: update(spec, document, upsert=False, manipulate=False, multi=False, check_keys=True, **kwargs) .. automethod:: remove(spec_or_id=None, multi=True, **kwargs) .. automethod:: find_and_modify .. automethod:: ensure_index pymongo-3.11.0/doc/api/pymongo/command_cursor.rst000066400000000000000000000004101374256237000220430ustar00rootroot00000000000000:mod:`command_cursor` -- Tools for iterating over MongoDB command results ========================================================================= .. automodule:: pymongo.command_cursor :synopsis: Tools for iterating over MongoDB command results :members: pymongo-3.11.0/doc/api/pymongo/cursor.rst000066400000000000000000000026451374256237000203610ustar00rootroot00000000000000:mod:`cursor` -- Tools for iterating over MongoDB query results =============================================================== .. automodule:: pymongo.cursor :synopsis: Tools for iterating over MongoDB query results .. autoclass:: pymongo.cursor.CursorType .. autoattribute:: NON_TAILABLE :annotation: .. autoattribute:: TAILABLE :annotation: .. autoattribute:: TAILABLE_AWAIT :annotation: .. autoattribute:: EXHAUST :annotation: .. autoclass:: pymongo.cursor.Cursor(collection, filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, modifiers=None, batch_size=0, manipulate=True, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None) :members: .. describe:: c[index] See :meth:`__getitem__`. .. automethod:: __getitem__ .. autoclass:: pymongo.cursor.RawBatchCursor(collection, filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, modifiers=None, batch_size=0, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None) pymongo-3.11.0/doc/api/pymongo/cursor_manager.rst000066400000000000000000000004571374256237000220520ustar00rootroot00000000000000:mod:`cursor_manager` -- Managers to handle when cursors are killed after being closed ====================================================================================== .. automodule:: pymongo.cursor_manager :synopsis: Managers to handle when cursors are killed after being closed :members: pymongo-3.11.0/doc/api/pymongo/database.rst000066400000000000000000000017651374256237000206120ustar00rootroot00000000000000:mod:`database` -- Database level operations ============================================ .. automodule:: pymongo.database :synopsis: Database level operations .. autodata:: pymongo.auth.MECHANISMS .. autodata:: pymongo.OFF .. autodata:: pymongo.SLOW_ONLY .. autodata:: pymongo.ALL .. autoclass:: pymongo.database.Database :members: .. describe:: db[collection_name] || db.collection_name Get the `collection_name` :class:`~pymongo.collection.Collection` of :class:`Database` `db`. Raises :class:`~pymongo.errors.InvalidName` if an invalid collection name is used. .. note:: Use dictionary style access if `collection_name` is an attribute of the :class:`Database` class eg: db[`collection_name`]. .. autoattribute:: codec_options .. autoattribute:: read_preference .. autoattribute:: write_concern .. autoattribute:: read_concern .. autoclass:: pymongo.database.SystemJS :members: pymongo-3.11.0/doc/api/pymongo/driver_info.rst000066400000000000000000000002451374256237000213440ustar00rootroot00000000000000:mod:`driver_info` ================== .. automodule:: pymongo.driver_info .. autoclass:: pymongo.driver_info.DriverInfo(name=None, version=None, platform=None) pymongo-3.11.0/doc/api/pymongo/encryption.rst000066400000000000000000000002411374256237000212240ustar00rootroot00000000000000:mod:`encryption` -- Client-Side Field Level Encryption ======================================================= .. automodule:: pymongo.encryption :members: pymongo-3.11.0/doc/api/pymongo/encryption_options.rst000066400000000000000000000005301374256237000230000ustar00rootroot00000000000000:mod:`encryption_options` -- Automatic Client-Side Field Level Encryption ========================================================================= .. automodule:: pymongo.encryption_options :synopsis: Support for automatic client-side field level encryption .. autoclass:: pymongo.encryption_options.AutoEncryptionOpts :members: pymongo-3.11.0/doc/api/pymongo/errors.rst000066400000000000000000000003461374256237000203540ustar00rootroot00000000000000:mod:`errors` -- Exceptions raised by the :mod:`pymongo` package ================================================================ .. automodule:: pymongo.errors :synopsis: Exceptions raised by the pymongo package :members: pymongo-3.11.0/doc/api/pymongo/event_loggers.rst000066400000000000000000000003251374256237000217000ustar00rootroot00000000000000:mod:`event_loggers` -- Example loggers =========================================== .. automodule:: pymongo.event_loggers :synopsis: A collection of simple listeners for monitoring driver events. :members:pymongo-3.11.0/doc/api/pymongo/index.rst000066400000000000000000000021701374256237000201440ustar00rootroot00000000000000:mod:`pymongo` -- Python driver for MongoDB =========================================== .. automodule:: pymongo :synopsis: Python driver for MongoDB .. autodata:: version .. data:: MongoClient Alias for :class:`pymongo.mongo_client.MongoClient`. .. data:: MongoReplicaSetClient Alias for :class:`pymongo.mongo_replica_set_client.MongoReplicaSetClient`. .. data:: ReadPreference Alias for :class:`pymongo.read_preferences.ReadPreference`. .. autofunction:: has_c .. data:: MIN_SUPPORTED_WIRE_VERSION The minimum wire protocol version PyMongo supports. .. data:: MAX_SUPPORTED_WIRE_VERSION The maximum wire protocol version PyMongo supports. Sub-modules: .. toctree:: :maxdepth: 2 bulk change_stream client_session collation collection command_cursor cursor cursor_manager database driver_info encryption encryption_options errors message mongo_client mongo_replica_set_client monitoring operations pool read_concern read_preferences results son_manipulator uri_parser write_concern event_loggers pymongo-3.11.0/doc/api/pymongo/ismaster.rst000066400000000000000000000003731374256237000206670ustar00rootroot00000000000000:orphan: :mod:`ismaster` -- A wrapper for ismaster command responses. ============================================================ .. automodule:: pymongo.ismaster .. autoclass:: pymongo.ismaster.IsMaster(doc) .. autoattribute:: document pymongo-3.11.0/doc/api/pymongo/message.rst000066400000000000000000000003661374256237000204660ustar00rootroot00000000000000:mod:`message` -- Tools for creating messages to be sent to MongoDB =================================================================== .. automodule:: pymongo.message :synopsis: Tools for creating messages to be sent to MongoDB :members: pymongo-3.11.0/doc/api/pymongo/mongo_client.rst000066400000000000000000000035611374256237000215170ustar00rootroot00000000000000:mod:`mongo_client` -- Tools for connecting to MongoDB ====================================================== .. automodule:: pymongo.mongo_client :synopsis: Tools for connecting to MongoDB .. autoclass:: pymongo.mongo_client.MongoClient(host='localhost', port=27017, document_class=dict, tz_aware=False, connect=True, **kwargs) .. automethod:: close .. describe:: c[db_name] || c.db_name Get the `db_name` :class:`~pymongo.database.Database` on :class:`MongoClient` `c`. Raises :class:`~pymongo.errors.InvalidName` if an invalid database name is used. .. autoattribute:: event_listeners .. autoattribute:: address .. autoattribute:: primary .. autoattribute:: secondaries .. autoattribute:: arbiters .. autoattribute:: is_primary .. autoattribute:: is_mongos .. autoattribute:: max_pool_size .. autoattribute:: min_pool_size .. autoattribute:: max_idle_time_ms .. autoattribute:: nodes .. autoattribute:: max_bson_size .. autoattribute:: max_message_size .. autoattribute:: max_write_batch_size .. autoattribute:: local_threshold_ms .. autoattribute:: server_selection_timeout .. autoattribute:: codec_options .. autoattribute:: read_preference .. autoattribute:: write_concern .. autoattribute:: read_concern .. automethod:: start_session .. automethod:: list_databases .. automethod:: list_database_names .. automethod:: database_names .. automethod:: drop_database .. automethod:: get_default_database .. automethod:: get_database .. automethod:: server_info .. automethod:: watch .. automethod:: close_cursor .. automethod:: kill_cursors .. automethod:: set_cursor_manager .. autoattribute:: is_locked .. automethod:: fsync .. automethod:: unlock pymongo-3.11.0/doc/api/pymongo/mongo_replica_set_client.rst000066400000000000000000000024731374256237000240720ustar00rootroot00000000000000:mod:`mongo_replica_set_client` -- Tools for connecting to a MongoDB replica set ================================================================================ .. automodule:: pymongo.mongo_replica_set_client :synopsis: Tools for connecting to a MongoDB replica set .. autoclass:: pymongo.mongo_replica_set_client.MongoReplicaSetClient(hosts_or_uri, document_class=dict, tz_aware=False, connect=True, **kwargs) .. automethod:: close .. describe:: c[db_name] || c.db_name Get the `db_name` :class:`~pymongo.database.Database` on :class:`MongoReplicaSetClient` `c`. Raises :class:`~pymongo.errors.InvalidName` if an invalid database name is used. .. autoattribute:: primary .. autoattribute:: secondaries .. autoattribute:: arbiters .. autoattribute:: max_pool_size .. autoattribute:: max_bson_size .. autoattribute:: max_message_size .. autoattribute:: local_threshold_ms .. autoattribute:: codec_options .. autoattribute:: read_preference .. autoattribute:: write_concern .. automethod:: database_names .. automethod:: drop_database .. automethod:: get_database .. automethod:: close_cursor .. automethod:: kill_cursors .. automethod:: set_cursor_manager .. automethod:: get_default_database pymongo-3.11.0/doc/api/pymongo/monitoring.rst000066400000000000000000000050461374256237000212270ustar00rootroot00000000000000:mod:`monitoring` -- Tools for monitoring driver events. ======================================================== .. automodule:: pymongo.monitoring :synopsis: Tools for monitoring driver events. .. autofunction:: register(listener) .. autoclass:: CommandListener :members: :inherited-members: .. autoclass:: ServerListener :members: :inherited-members: .. autoclass:: ServerHeartbeatListener :members: :inherited-members: .. autoclass:: TopologyListener :members: :inherited-members: .. autoclass:: ConnectionPoolListener :members: :inherited-members: .. autoclass:: CommandStartedEvent :members: :inherited-members: .. autoclass:: CommandSucceededEvent :members: :inherited-members: .. autoclass:: CommandFailedEvent :members: :inherited-members: .. autoclass:: ServerDescriptionChangedEvent :members: :inherited-members: .. autoclass:: ServerOpeningEvent :members: :inherited-members: .. autoclass:: ServerClosedEvent :members: :inherited-members: .. autoclass:: TopologyDescriptionChangedEvent :members: :inherited-members: .. autoclass:: TopologyOpenedEvent :members: :inherited-members: .. autoclass:: TopologyClosedEvent :members: :inherited-members: .. autoclass:: ServerHeartbeatStartedEvent :members: :inherited-members: .. autoclass:: ServerHeartbeatSucceededEvent :members: :inherited-members: .. autoclass:: ServerHeartbeatFailedEvent :members: :inherited-members: .. autoclass:: PoolCreatedEvent :members: :inherited-members: .. autoclass:: PoolClearedEvent :members: :inherited-members: .. autoclass:: PoolClosedEvent :members: :inherited-members: .. autoclass:: ConnectionCreatedEvent :members: :inherited-members: .. autoclass:: ConnectionReadyEvent :members: :inherited-members: .. autoclass:: ConnectionClosedReason :members: .. autoclass:: ConnectionClosedEvent :members: :inherited-members: .. autoclass:: ConnectionCheckOutStartedEvent :members: :inherited-members: .. autoclass:: ConnectionCheckOutFailedReason :members: .. autoclass:: ConnectionCheckOutFailedEvent :members: :inherited-members: .. autoclass:: ConnectionCheckedOutEvent :members: :inherited-members: .. autoclass:: ConnectionCheckedInEvent :members: :inherited-members: pymongo-3.11.0/doc/api/pymongo/operations.rst000066400000000000000000000002751374256237000212240ustar00rootroot00000000000000:mod:`operations` -- Operation class definitions ================================================ .. automodule:: pymongo.operations :synopsis: Operation class definitions :members: pymongo-3.11.0/doc/api/pymongo/pool.rst000066400000000000000000000003351374256237000200070ustar00rootroot00000000000000:mod:`pool` -- Pool module for use with a MongoDB client. ============================================================== .. automodule:: pymongo.pool :synopsis: Pool module for use with a MongoDB client. :members: pymongo-3.11.0/doc/api/pymongo/read_concern.rst000066400000000000000000000003651374256237000214630ustar00rootroot00000000000000:mod:`read_concern` -- Tools for working with read concern. =========================================================== .. automodule:: pymongo.read_concern :synopsis: Tools for working with read concern. :members: :inherited-members: pymongo-3.11.0/doc/api/pymongo/read_preferences.rst000066400000000000000000000021411374256237000223270ustar00rootroot00000000000000:mod:`read_preferences` -- Utilities for choosing which member of a replica set to read from. ============================================================================================= .. automodule:: pymongo.read_preferences :synopsis: Utilities for choosing which member of a replica set to read from. .. autoclass:: pymongo.read_preferences.Primary .. max_staleness, min_wire_version, mongos_mode, and tag_sets don't make sense for Primary. .. autoattribute:: document .. autoattribute:: mode .. autoattribute:: name .. autoclass:: pymongo.read_preferences.PrimaryPreferred :inherited-members: .. autoclass:: pymongo.read_preferences.Secondary :inherited-members: .. autoclass:: pymongo.read_preferences.SecondaryPreferred :inherited-members: .. autoclass:: pymongo.read_preferences.Nearest :inherited-members: .. autoclass:: ReadPreference .. autoattribute:: PRIMARY .. autoattribute:: PRIMARY_PREFERRED .. autoattribute:: SECONDARY .. autoattribute:: SECONDARY_PREFERRED .. autoattribute:: NEAREST pymongo-3.11.0/doc/api/pymongo/results.rst000066400000000000000000000003021374256237000205310ustar00rootroot00000000000000:mod:`results` -- Result class definitions ========================================== .. automodule:: pymongo.results :synopsis: Result class definitions :members: :inherited-members: pymongo-3.11.0/doc/api/pymongo/server_description.rst000066400000000000000000000007101374256237000227440ustar00rootroot00000000000000:orphan: :mod:`server_description` -- An object representation of a server the driver is connected to. ============================================================================================= .. automodule:: pymongo.server_description .. autoclass:: pymongo.server_description.ServerDescription() .. autoattribute:: address .. autoattribute:: all_hosts .. autoattribute:: server_type .. autoattribute:: server_type_name pymongo-3.11.0/doc/api/pymongo/son_manipulator.rst000066400000000000000000000005201374256237000222440ustar00rootroot00000000000000:mod:`son_manipulator` -- Manipulators that can edit SON documents as they are saved or retrieved ================================================================================================= .. automodule:: pymongo.son_manipulator :synopsis: Manipulators that can edit SON documents as they are saved or retrieved :members: pymongo-3.11.0/doc/api/pymongo/topology_description.rst000066400000000000000000000010601374256237000233110ustar00rootroot00000000000000:orphan: :mod:`topology_description` -- An object representation of a deployment of MongoDB servers. =========================================================================================== .. automodule:: pymongo.topology_description .. autoclass:: pymongo.topology_description.TopologyDescription() .. automethod:: has_readable_server(read_preference=ReadPreference.PRIMARY) .. automethod:: has_writable_server .. automethod:: server_descriptions .. autoattribute:: topology_type .. autoattribute:: topology_type_name pymongo-3.11.0/doc/api/pymongo/uri_parser.rst000066400000000000000000000003501374256237000212060ustar00rootroot00000000000000:mod:`uri_parser` -- Tools to parse and validate a MongoDB URI ============================================================== .. automodule:: pymongo.uri_parser :synopsis: Tools to parse and validate a MongoDB URI. :members: pymongo-3.11.0/doc/api/pymongo/write_concern.rst000066400000000000000000000003341374256237000216760ustar00rootroot00000000000000:mod:`write_concern` -- Tools for specifying write concern ========================================================== .. automodule:: pymongo.write_concern :synopsis: Tools for specifying write concern. :members: pymongo-3.11.0/doc/atlas.rst000066400000000000000000000056021374256237000157030ustar00rootroot00000000000000Using PyMongo with MongoDB Atlas ================================ `Atlas `_ is MongoDB, Inc.'s hosted MongoDB as a service offering. To connect to Atlas, pass the connection string provided by Atlas to :class:`~pymongo.mongo_client.MongoClient`:: client = pymongo.MongoClient() Connections to Atlas require TLS/SSL. For connections using TLS/SSL, PyMongo may require third party dependencies as determined by your version of Python. With PyMongo 3.3+, you can install PyMongo 3.3+ and any TLS/SSL-related dependencies using the following pip command:: $ python -m pip install pymongo[tls] Starting with PyMongo 3.11 this installs `PyOpenSSL `_, `requests`_ and `service_identity `_ for users of Python versions older than 2.7.9. PyOpenSSL supports SNI for these old Python versions, allowing applictions to connect to Altas free and shared tier instances. Earlier versions of PyMongo require you to manually install the dependencies. For a list of TLS/SSL-related dependencies, see :doc:`examples/tls`. .. note:: Connecting to Atlas "Free Tier" or "Shared Cluster" instances requires Server Name Indication (SNI) support. SNI support requires CPython 2.7.9 / PyPy 2.5.1 or newer or PyMongo 3.11+ with PyOpenSSL. To check if your version of Python supports SNI run the following command:: $ python -c "import ssl; print(getattr(ssl, 'HAS_SNI', False))" You should see "True". .. warning:: Industry best practices recommend, and some regulations require, the use of TLS 1.1 or newer. Though no application changes are required for PyMongo to make use of the newest protocols, some operating systems or versions may not provide an OpenSSL version new enough to support them. Users of macOS older than 10.13 (High Sierra) will need to install Python from `python.org`_, `homebrew`_, `macports`_, or another similar source. Users of Linux or other non-macOS Unix can check their OpenSSL version like this:: $ openssl version If the version number is less than 1.0.1 support for TLS 1.1 or newer is not available. Contact your operating system vendor for a solution or upgrade to a newer distribution. You can check your Python interpreter by installing the `requests`_ module and executing the following command:: python -c "import requests; print(requests.get('https://www.howsmyssl.com/a/check', verify=False).json()['tls_version'])" You should see "TLS 1.X" where X is >= 1. You can read more about TLS versions and their security implications here: ``_ .. _python.org: https://www.python.org/downloads/ .. _homebrew: https://brew.sh/ .. _macports: https://www.macports.org/ .. _requests: https://pypi.python.org/pypi/requests pymongo-3.11.0/doc/changelog.rst000066400000000000000000003617571374256237000165460ustar00rootroot00000000000000Changelog ========= Changes in Version 3.11.0 ------------------------- Version 3.11 adds support for MongoDB 4.4 and includes a number of bug fixes. Highlights include: - Support for :ref:`OCSP` (Online Certificate Status Protocol). - Support for `PyOpenSSL `_ as an alternative TLS implementation. PyOpenSSL is required for :ref:`OCSP` support. It will also be installed when using the "tls" extra if the version of Python in use is older than 2.7.9. - Support for the :ref:`MONGODB-AWS` authentication mechanism. - Support for the ``directConnection`` URI option and kwarg to :class:`~pymongo.mongo_client.MongoClient`. - Support for speculative authentication attempts in connection handshakes which reduces the number of network roundtrips needed to authenticate new connections on MongoDB 4.4+. - Support for creating collections in multi-document transactions with :meth:`~pymongo.database.Database.create_collection` on MongoDB 4.4+. - Added index hinting support to the :meth:`~pymongo.collection.Collection.replace_one`, :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`, :meth:`~pymongo.collection.Collection.find_one_and_replace`, :meth:`~pymongo.collection.Collection.find_one_and_update`, :meth:`~pymongo.collection.Collection.delete_one`, :meth:`~pymongo.collection.Collection.delete_many`, and :meth:`~pymongo.collection.Collection.find_one_and_delete` commands. - Added index hinting support to the :class:`~pymongo.operations.ReplaceOne`, :class:`~pymongo.operations.UpdateOne`, :class:`~pymongo.operations.UpdateMany`, :class:`~pymongo.operations.DeleteOne`, and :class:`~pymongo.operations.DeleteMany` bulk operations. - Added support for :data:`bson.binary.UuidRepresentation.UNSPECIFIED` and ``MongoClient(uuidRepresentation='unspecified')`` which will become the default UUID representation starting in PyMongo 4.0. See :ref:`handling-uuid-data-example` for details. - Added the ``background`` parameter to :meth:`pymongo.database.Database.validate_collection`. For a description of this parameter see the MongoDB documentation for the `validate command`_. - Added the ``allow_disk_use`` parameters to :meth:`pymongo.collection.Collection.find`. - Added the ``hedge`` parameter to :class:`~pymongo.read_preferences.PrimaryPreferred`, :class:`~pymongo.read_preferences.Secondary`, :class:`~pymongo.read_preferences.SecondaryPreferred`, :class:`~pymongo.read_preferences.Nearest` to support disabling (or explicitly enabling) hedged reads in MongoDB 4.4+. - Fixed a bug in change streams that could cause PyMongo to miss some change documents when resuming a stream that was started without a resume token and whose first batch did not contain any change documents. - Fixed an bug where using gevent.Timeout to timeout an operation could lead to a deadlock. Deprecations: - Deprecated the ``oplog_replay`` parameter to :meth:`pymongo.collection.Collection.find`. Starting in MongoDB 4.4, the server optimizes queries against the oplog collection without requiring the user to set this flag. - Deprecated :meth:`pymongo.collection.Collection.reindex`. Use :meth:`~pymongo.database.Database.command` to run the ``reIndex`` command instead. - Deprecated :meth:`pymongo.mongo_client.MongoClient.fsync`. Use :meth:`~pymongo.database.Database.command` to run the ``fsync`` command instead. - Deprecated :meth:`pymongo.mongo_client.MongoClient.unlock`. Use :meth:`~pymongo.database.Database.command` to run the ``fsyncUnlock`` command instead. See the documentation for more information. - Deprecated :attr:`pymongo.mongo_client.MongoClient.is_locked`. Use :meth:`~pymongo.database.Database.command` to run the ``currentOp`` command instead. See the documentation for more information. Unavoidable breaking changes: - :class:`~gridfs.GridFSBucket` and :class:`~gridfs.GridFS` do not support multi-document transactions. Running a GridFS operation in a transaction now always raises the following error: ``InvalidOperation: GridFS does not support multi-document transactions`` .. _validate command: https://docs.mongodb.com/manual/reference/command/validate/ Issues Resolved ............... See the `PyMongo 3.11.0 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.11.0 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=24799 Changes in Version 3.10.1 ------------------------- Version 3.10.1 fixes the following issues discovered since the release of 3.10.0: - Fix a TypeError logged to stderr that could be triggered during server maintenance or during :meth:`pymongo.mongo_client.MongoClient.close`. - Avoid creating new connections during :meth:`pymongo.mongo_client.MongoClient.close`. Issues Resolved ............... See the `PyMongo 3.10.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.10.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=25039 Changes in Version 3.10.0 ------------------------- Version 3.10 includes a number of improvements and bug fixes. Highlights include: - Support for Client-Side Field Level Encryption with MongoDB 4.2. See :doc:`examples/encryption` for examples. - Support for Python 3.8. - Added :attr:`pymongo.client_session.ClientSession.in_transaction`. - Do not hold the Topology lock while creating connections in a MongoClient's background thread. This change fixes a bug where application operations would block while the background thread ensures that all server pools have minPoolSize connections. - Fix a UnicodeDecodeError bug when coercing a PyMongoError with a non-ascii error message to unicode on Python 2. - Fix an edge case bug where PyMongo could exceed the server's maxMessageSizeBytes when generating a compressed bulk write command. Issues Resolved ............... See the `PyMongo 3.10 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.10 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=23944 Changes in Version 3.9.0 ------------------------ Version 3.9 adds support for MongoDB 4.2. Highlights include: - Support for MongoDB 4.2 sharded transactions. Sharded transactions have the same API as replica set transactions. See :ref:`transactions-ref`. - New method :meth:`pymongo.client_session.ClientSession.with_transaction` to support conveniently running a transaction in a session with automatic retries and at-most-once semantics. - Initial support for client side field level encryption. See the docstring for :class:`~pymongo.mongo_client.MongoClient`, :class:`~pymongo.encryption_options.AutoEncryptionOpts`, and :mod:`~pymongo.encryption` for details. **Note: Support for client side encryption is in beta. Backwards-breaking changes may be made before the final release.** - Added the ``max_commit_time_ms`` parameter to :meth:`~pymongo.client_session.ClientSession.start_transaction`. - Implement the `URI options specification`_ in the :meth:`~pymongo.mongo_client.MongoClient` constructor. Consequently, there are a number of changes in connection options: - The ``tlsInsecure`` option has been added. - The ``tls`` option has been added. The older ``ssl`` option has been retained as an alias to the new ``tls`` option. - ``wTimeout`` has been deprecated in favor of ``wTimeoutMS``. - ``wTimeoutMS`` now overrides ``wTimeout`` if the user provides both. - ``j`` has been deprecated in favor of ``journal``. - ``journal`` now overrides ``j`` if the user provides both. - ``ssl_cert_reqs`` has been deprecated in favor of ``tlsAllowInvalidCertificates``. Instead of ``ssl.CERT_NONE``, ``ssl.CERT_OPTIONAL`` and ``ssl.CERT_REQUIRED``, the new option expects a boolean value - ``True`` is equivalent to ``ssl.CERT_NONE``, while ``False`` is equivalent to ``ssl.CERT_REQUIRED``. - ``ssl_match_hostname`` has been deprecated in favor of ``tlsAllowInvalidHostnames``. - ``ssl_ca_certs`` has been deprecated in favor of ``tlsCAFile``. - ``ssl_certfile`` has been deprecated in favor of ``tlsCertificateKeyFile``. - ``ssl_pem_passphrase`` has been deprecated in favor of ``tlsCertificateKeyFilePassword``. - ``waitQueueMultiple`` has been deprecated without replacement. This option was a poor solution for putting an upper bound on queuing since it didn't affect queuing in other parts of the driver. - The ``retryWrites`` URI option now defaults to ``True``. Supported write operations that fail with a retryable error will automatically be retried one time, with at-most-once semantics. - Support for retryable reads and the ``retryReads`` URI option which is enabled by default. See the :class:`~pymongo.mongo_client.MongoClient` documentation for details. Now that supported operations are retried automatically and transparently, users should consider adjusting any custom retry logic to prevent an application from inadvertently retrying for too long. - Support zstandard for wire protocol compression. - Support for periodically polling DNS SRV records to update the mongos proxy list without having to change client configuration. - New method :meth:`pymongo.database.Database.aggregate` to support running database level aggregations. - Support for publishing Connection Monitoring and Pooling events via the new :class:`~pymongo.monitoring.ConnectionPoolListener` class. See :mod:`~pymongo.monitoring` for an example. - :meth:`pymongo.collection.Collection.aggregate` and :meth:`pymongo.database.Database.aggregate` now support the ``$merge`` pipeline stage and use read preference :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY` if the ``$out`` or ``$merge`` pipeline stages are used. - Support for specifying a pipeline or document in :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`, :meth:`~pymongo.collection.Collection.find_one_and_update`, :meth:`~pymongo.operations.UpdateOne`, and :meth:`~pymongo.operations.UpdateMany`. - New BSON utility functions :func:`~bson.encode` and :func:`~bson.decode` - :class:`~bson.binary.Binary` now supports any bytes-like type that implements the buffer protocol. - Resume tokens can now be accessed from a ``ChangeStream`` cursor using the :attr:`~pymongo.change_stream.ChangeStream.resume_token` attribute. - Connections now survive primary step-down when using MongoDB 4.2+. Applications should expect less socket connection turnover during replica set elections. Unavoidable breaking changes: - Applications that use MongoDB with the MMAPv1 storage engine must now explicitly disable retryable writes via the connection string (e.g. ``MongoClient("mongodb://my.mongodb.cluster/db?retryWrites=false")``) or the :class:`~pymongo.mongo_client.MongoClient` constructor's keyword argument (e.g. ``MongoClient("mongodb://my.mongodb.cluster/db", retryWrites=False)``) to avoid running into :class:`~pymongo.errors.OperationFailure` exceptions during write operations. The MMAPv1 storage engine is deprecated and does not support retryable writes which are now turned on by default. - In order to ensure that the ``connectTimeoutMS`` URI option is honored when connecting to clusters with a ``mongodb+srv://`` connection string, the minimum required version of the optional ``dnspython`` dependency has been bumped to 1.16.0. This is a breaking change for applications that use PyMongo's SRV support with a version of ``dnspython`` older than 1.16.0. .. _URI options specification: https://github.com/mongodb/specifications/blob/master/source/uri-options/uri-options.rst Issues Resolved ............... See the `PyMongo 3.9 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.9 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=21787 Changes in Version 3.8.0 ------------------------ .. warning:: PyMongo no longer supports Python 2.6. RHEL 6 users should install Python 2.7 or newer from `Red Hat Software Collections `_. CentOS 6 users should install Python 2.7 or newer from `SCL `_ .. warning:: PyMongo no longer supports PyPy3 versions older than 3.5. Users must upgrade to PyPy3.5+. - :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification version 0.2 `_. - For better performance and to better follow the GridFS spec, :class:`~gridfs.grid_file.GridOut` now uses a single cursor to read all the chunks in the file. Previously, each chunk in the file was queried individually using :meth:`~pymongo.collection.Collection.find_one`. - :meth:`gridfs.grid_file.GridOut.read` now only checks for extra chunks after reading the entire file. Previously, this method would check for extra chunks on every call. - :meth:`~pymongo.database.Database.current_op` now always uses the ``Database``'s :attr:`~pymongo.database.Database.codec_options` when decoding the command response. Previously the codec_options was only used when the MongoDB server version was <= 3.0. - Undeprecated :meth:`~pymongo.mongo_client.MongoClient.get_default_database` and added the ``default`` parameter. - TLS Renegotiation is now disabled when possible. - Custom types can now be directly encoded to, and decoded from MongoDB using the :class:`~bson.codec_options.TypeCodec` and :class:`~bson.codec_options.TypeRegistry` APIs. For more information, see the :doc:`custom type example `. - Attempting a multi-document transaction on a sharded cluster now raises a :exc:`~pymongo.errors.ConfigurationError`. - :meth:`pymongo.cursor.Cursor.distinct` and :meth:`pymongo.cursor.Cursor.count` now send the Cursor's :meth:`~pymongo.cursor.Cursor.comment` as the "comment" top-level command option instead of "$comment". Also, note that "comment" must be a string. - Add the ``filter`` parameter to :meth:`~pymongo.database.Database.list_collection_names`. - Changes can now be requested from a ``ChangeStream`` cursor without blocking indefinitely using the new :meth:`pymongo.change_stream.ChangeStream.try_next` method. - Fixed a reference leak bug when splitting a batched write command based on maxWriteBatchSize or the max message size. - Deprecated running find queries that set :meth:`~pymongo.cursor.Cursor.min` and/or :meth:`~pymongo.cursor.Cursor.max` but do not also set a :meth:`~pymongo.cursor.Cursor.hint` of which index to use. The find command is expected to require a :meth:`~pymongo.cursor.Cursor.hint` when using min/max starting in MongoDB 4.2. - Documented support for the uuidRepresentation URI option, which has been supported since PyMongo 2.7. Valid values are `pythonLegacy` (the default), `javaLegacy`, `csharpLegacy` and `standard`. New applications should consider setting this to `standard` for cross language compatibility. - :class:`~bson.raw_bson.RawBSONDocument` now validates that the ``bson_bytes`` passed in represent a single bson document. Earlier versions would mistakenly accept multiple bson documents. - Iterating over a :class:`~bson.raw_bson.RawBSONDocument` now maintains the same field order of the underlying raw BSON document. - Applications can now register a custom server selector. For more information see the :doc:`server selector example `. - The connection pool now implements a LIFO policy. Unavoidable breaking changes: - In order to follow the ObjectID Spec version 0.2, an ObjectId's 3-byte machine identifier and 2-byte process id have been replaced with a single 5-byte random value generated per process. This is a breaking change for any application that attempts to interpret those bytes. Issues Resolved ............... See the `PyMongo 3.8 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.8 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=19904 Changes in Version 3.7.2 ------------------------ Version 3.7.2 fixes a few issues discovered since the release of 3.7.1. - Fixed a bug in retryable writes where a previous command's "txnNumber" field could be sent leading to incorrect results. - Fixed a memory leak of a few bytes on some insert, update, or delete commands when running against MongoDB 3.6+. - Fixed a bug that caused :meth:`pymongo.collection.Collection.ensure_index` to only cache a single index per database. - Updated the documentation examples to use :meth:`pymongo.collection.Collection.count_documents` instead of :meth:`pymongo.collection.Collection.count` and :meth:`pymongo.cursor.Cursor.count`. Issues Resolved ............... See the `PyMongo 3.7.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.7.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=21519 Changes in Version 3.7.1 ------------------------ Version 3.7.1 fixes a few issues discovered since the release of 3.7.0. - Calling :meth:`~pymongo.database.Database.authenticate` more than once with the same credentials results in OperationFailure. - Authentication fails when SCRAM-SHA-1 is used to authenticate users with only MONGODB-CR credentials. - A millisecond rounding problem when decoding datetimes in the pure Python BSON decoder on 32 bit systems and AWS lambda. Issues Resolved ............... See the `PyMongo 3.7.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.7.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=21096 Changes in Version 3.7.0 ------------------------ Version 3.7 adds support for MongoDB 4.0. Highlights include: - Support for single replica set multi-document ACID transactions. See :ref:`transactions-ref`. - Support for wire protocol compression. See the :meth:`~pymongo.mongo_client.MongoClient` documentation for details. - Support for Python 3.7. - New count methods, :meth:`~pymongo.collection.Collection.count_documents` and :meth:`~pymongo.collection.Collection.estimated_document_count`. :meth:`~pymongo.collection.Collection.count_documents` is always accurate when used with MongoDB 3.6+, or when used with older standalone or replica set deployments. With older sharded clusters is it always accurate when used with Primary read preference. It can also be used in a transaction, unlike the now deprecated :meth:`pymongo.collection.Collection.count` and :meth:`pymongo.cursor.Cursor.count` methods. - Support for watching changes on all collections in a database using the new :meth:`pymongo.database.Database.watch` method. - Support for watching changes on all collections in all databases using the new :meth:`pymongo.mongo_client.MongoClient.watch` method. - Support for watching changes starting at a user provided timestamp using the new ``start_at_operation_time`` parameter for the ``watch()`` helpers. - Better support for using PyMongo in a FIPS 140-2 environment. Specifically, the following features and changes allow PyMongo to function when MD5 support is disabled in OpenSSL by the FIPS Object Module: - Support for the :ref:`SCRAM-SHA-256 ` authentication mechanism. The :ref:`GSSAPI `, :ref:`PLAIN `, and :ref:`MONGODB-X509 ` mechanisms can also be used to avoid issues with OpenSSL in FIPS environments. - MD5 checksums are now optional in GridFS. See the `disable_md5` option of :class:`~gridfs.GridFS` and :class:`~gridfs.GridFSBucket`. - :class:`~bson.objectid.ObjectId` machine bytes are now hashed using `FNV-1a `_ instead of MD5. - The :meth:`~pymongo.database.Database.list_collection_names` and :meth:`~pymongo.database.Database.collection_names` methods use the nameOnly option when supported by MongoDB. - The :meth:`pymongo.collection.Collection.watch` method now returns an instance of the :class:`~pymongo.change_stream.CollectionChangeStream` class which is a subclass of :class:`~pymongo.change_stream.ChangeStream`. - SCRAM client and server keys are cached for improved performance, following `RFC 5802 `_. - If not specified, the authSource for the :ref:`PLAIN ` authentication mechanism defaults to $external. - wtimeoutMS is once again supported as a URI option. - When using unacknowledged write concern and connected to MongoDB server version 3.6 or greater, the `bypass_document_validation` option is now supported in the following write helpers: :meth:`~pymongo.collection.Collection.insert_one`, :meth:`~pymongo.collection.Collection.replace_one`, :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`. Deprecations: - Deprecated :meth:`pymongo.collection.Collection.count` and :meth:`pymongo.cursor.Cursor.count`. These two methods use the `count` command and `may or may not be accurate `_, depending on the options used and connected MongoDB topology. Use :meth:`~pymongo.collection.Collection.count_documents` instead. - Deprecated the snapshot option of :meth:`~pymongo.collection.Collection.find` and :meth:`~pymongo.collection.Collection.find_one`. The option was deprecated in MongoDB 3.6 and removed in MongoDB 4.0. - Deprecated the max_scan option of :meth:`~pymongo.collection.Collection.find` and :meth:`~pymongo.collection.Collection.find_one`. The option was deprecated in MongoDB 4.0. Use `maxTimeMS` instead. - Deprecated :meth:`~pymongo.mongo_client.MongoClient.close_cursor`. Use :meth:`~pymongo.cursor.Cursor.close` instead. - Deprecated :meth:`~pymongo.mongo_client.MongoClient.database_names`. Use :meth:`~pymongo.mongo_client.MongoClient.list_database_names` instead. - Deprecated :meth:`~pymongo.database.Database.collection_names`. Use :meth:`~pymongo.database.Database.list_collection_names` instead. - Deprecated :meth:`~pymongo.collection.Collection.parallel_scan`. MongoDB 4.2 will remove the parallelCollectionScan command. Unavoidable breaking changes: - Commands that fail with server error codes 10107, 13435, 13436, 11600, 11602, 189, 91 (NotMaster, NotMasterNoSlaveOk, NotMasterOrSecondary, InterruptedAtShutdown, InterruptedDueToReplStateChange, PrimarySteppedDown, ShutdownInProgress respectively) now always raise :class:`~pymongo.errors.NotMasterError` instead of :class:`~pymongo.errors.OperationFailure`. - :meth:`~pymongo.collection.Collection.parallel_scan` no longer uses an implicit session. Explicit sessions are still supported. - Unacknowledged writes (``w=0``) with an explicit ``session`` parameter now raise a client side error. Since PyMongo does not wait for a response for an unacknowledged write, two unacknowledged writes run serially by the client may be executed simultaneously on the server. However, the server requires a single session must not be used simultaneously by more than one operation. Therefore explicit sessions cannot support unacknowledged writes. Unacknowledged writes without a ``session`` parameter are still supported. Issues Resolved ............... See the `PyMongo 3.7 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.7 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=19287 Changes in Version 3.6.1 ------------------------ Version 3.6.1 fixes bugs reported since the release of 3.6.0: - Fix regression in PyMongo 3.5.0 that causes idle sockets to be closed almost instantly when ``maxIdleTimeMS`` is set. Idle sockets are now closed after ``maxIdleTimeMS`` milliseconds. - :attr:`pymongo.mongo_client.MongoClient.max_idle_time_ms` now returns milliseconds instead of seconds. - Properly import and use the `monotonic `_ library for monotonic time when it is installed. - :meth:`~pymongo.collection.Collection.aggregate` now ignores the ``batchSize`` argument when running a pipeline with a ``$out`` stage. - Always send handshake metadata for new connections. Issues Resolved ............... See the `PyMongo 3.6.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.6.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=19438 Changes in Version 3.6.0 ------------------------ Version 3.6 adds support for MongoDB 3.6, drops support for CPython 3.3 (PyPy3 is still supported), and drops support for MongoDB versions older than 2.6. If connecting to a MongoDB 2.4 server or older, PyMongo now throws a :exc:`~pymongo.errors.ConfigurationError`. Highlights include: - Support for change streams. See the :meth:`~pymongo.collection.Collection.watch` method for details. - Support for array_filters in :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`, :meth:`~pymongo.collection.Collection.find_one_and_update`, :meth:`~pymongo.operations.UpdateOne`, and :meth:`~pymongo.operations.UpdateMany`. - New Session API, see :meth:`~pymongo.mongo_client.MongoClient.start_session`. - New methods :meth:`~pymongo.collection.Collection.find_raw_batches` and :meth:`~pymongo.collection.Collection.aggregate_raw_batches` for use with external libraries that can parse raw batches of BSON data. - New methods :meth:`~pymongo.mongo_client.MongoClient.list_databases` and :meth:`~pymongo.mongo_client.MongoClient.list_database_names`. - New methods :meth:`~pymongo.database.Database.list_collections` and :meth:`~pymongo.database.Database.list_collection_names`. - Support for mongodb+srv:// URIs. See :class:`~pymongo.mongo_client.MongoClient` for details. - Index management helpers (:meth:`~pymongo.collection.Collection.create_index`, :meth:`~pymongo.collection.Collection.create_indexes`, :meth:`~pymongo.collection.Collection.drop_index`, :meth:`~pymongo.collection.Collection.drop_indexes`, :meth:`~pymongo.collection.Collection.reindex`) now support maxTimeMS. - Support for retryable writes and the ``retryWrites`` URI option. See :class:`~pymongo.mongo_client.MongoClient` for details. Deprecations: - The `useCursor` option for :meth:`~pymongo.collection.Collection.aggregate` is deprecated. The option was only necessary when upgrading from MongoDB 2.4 to MongoDB 2.6. MongoDB 2.4 is no longer supported. - The :meth:`~pymongo.database.Database.add_user` and :meth:`~pymongo.database.Database.remove_user` methods are deprecated. See the method docstrings for alternatives. Unavoidable breaking changes: - Starting in MongoDB 3.6, the deprecated methods :meth:`~pymongo.database.Database.authenticate` and :meth:`~pymongo.database.Database.logout` now invalidate all cursors created prior. Instead of using these methods to change credentials, pass credentials for one user to the :class:`~pymongo.mongo_client.MongoClient` at construction time, and either grant access to several databases to one user account, or use a distinct client object for each user. - BSON binary subtype 4 is decoded using RFC-4122 byte order regardless of the UUID representation. This is a change in behavior for applications that use UUID representation :data:`bson.binary.JAVA_LEGACY` or :data:`bson.binary.CSHARP_LEGACY` to decode BSON binary subtype 4. Other UUID representations, :data:`bson.binary.PYTHON_LEGACY` (the default) and :data:`bson.binary.STANDARD`, and the decoding of BSON binary subtype 3 are unchanged. Issues Resolved ............... See the `PyMongo 3.6 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.6 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=18043 Changes in Version 3.5.1 ------------------------ Version 3.5.1 fixes bugs reported since the release of 3.5.0: - Work around socket.getsockopt issue with NetBSD. - :meth:`pymongo.command_cursor.CommandCursor.close` now closes the cursor synchronously instead of deferring to a background thread. - Fix documentation build warnings with Sphinx 1.6.x. Issues Resolved ............... See the `PyMongo 3.5.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.5.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=18721 Changes in Version 3.5 ---------------------- Version 3.5 implements a number of improvements and bug fixes: Highlights include: - Username and password can be passed to :class:`~pymongo.mongo_client.MongoClient` as keyword arguments. Before, the only way to pass them was in the URI. - Increased the performance of using :class:`~bson.raw_bson.RawBSONDocument`. - Increased the performance of :meth:`~pymongo.mongo_client.MongoClient.database_names` by using the `nameOnly` option for listDatabases when available. - Increased the performance of :meth:`~pymongo.collection.Collection.bulk_write` by reducing the memory overhead of :class:`~pymongo.operations.InsertOne`, :class:`~pymongo.operations.DeleteOne`, and :class:`~pymongo.operations.DeleteMany`. - Added the `collation` option to :class:`~pymongo.operations.DeleteOne`, :class:`~pymongo.operations.DeleteMany`, :class:`~pymongo.operations.ReplaceOne`, :class:`~pymongo.operations.UpdateOne`, and :class:`~pymongo.operations.UpdateMany`. - Implemented the `MongoDB Extended JSON `_ specification. - :class:`~bson.decimal128.Decimal128` now works when cdecimal is installed. - PyMongo is now tested against a wider array of operating systems and CPU architectures (including s390x, ARM64, and POWER8). Changes and Deprecations: - :meth:`~pymongo.collection.Collection.find` has new options `return_key`, `show_record_id`, `snapshot`, `hint`, `max_time_ms`, `max_scan`, `min`, `max`, and `comment`. Deprecated the option `modifiers`. - Deprecated :meth:`~pymongo.collection.Collection.group`. The group command was deprecated in MongoDB 3.4 and is expected to be removed in MongoDB 3.6. Applications should use :meth:`~pymongo.collection.Collection.aggregate` with the `$group` pipeline stage instead. - Deprecated :meth:`~pymongo.database.Database.authenticate`. Authenticating multiple users conflicts with support for logical sessions in MongoDB 3.6. To authenticate as multiple users, create multiple instances of :class:`~pymongo.mongo_client.MongoClient`. - Deprecated :meth:`~pymongo.database.Database.eval`. The eval command was deprecated in MongoDB 3.0 and will be removed in a future server version. - Deprecated :class:`~pymongo.database.SystemJS`. - Deprecated :meth:`~pymongo.mongo_client.MongoClient.get_default_database`. Applications should use :meth:`~pymongo.mongo_client.MongoClient.get_database` without the `name` parameter instead. - Deprecated the MongoClient option `socketKeepAlive`. It now defaults to true and disabling it is not recommended, see `does TCP keepalive time affect MongoDB Deployments? `_ - Deprecated :meth:`~pymongo.collection.Collection.initialize_ordered_bulk_op`, :meth:`~pymongo.collection.Collection.initialize_unordered_bulk_op`, and :class:`~pymongo.bulk.BulkOperationBuilder`. Use :meth:`~pymongo.collection.Collection.bulk_write` instead. - Deprecated :const:`~bson.json_util.STRICT_JSON_OPTIONS`. Use :const:`~bson.json_util.RELAXED_JSON_OPTIONS` or :const:`~bson.json_util.CANONICAL_JSON_OPTIONS` instead. - If a custom :class:`~bson.codec_options.CodecOptions` is passed to :class:`RawBSONDocument`, its `document_class` must be :class:`RawBSONDocument`. - :meth:`~pymongo.collection.Collection.list_indexes` no longer raises OperationFailure when the collection (or database) does not exist on MongoDB >= 3.0. Instead, it returns an empty :class:`~pymongo.command_cursor.CommandCursor` to make the behavior consistent across all MongoDB versions. - In Python 3, :meth:`~bson.json_util.loads` now automatically decodes JSON $binary with a subtype of 0 into :class:`bytes` instead of :class:`~bson.binary.Binary`. See the :doc:`/python3` for more details. - :meth:`~bson.json_util.loads` now raises ``TypeError`` or ``ValueError`` when parsing JSON type wrappers with values of the wrong type or any extra keys. - :meth:`pymongo.cursor.Cursor.close` and :meth:`pymongo.mongo_client.MongoClient.close` now kill cursors synchronously instead of deferring to a background thread. - :meth:`~pymongo.uri_parser.parse_uri` now returns the original value of the ``readPreference`` MongoDB URI option instead of the validated read preference mode. Issues Resolved ............... See the `PyMongo 3.5 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.5 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=17590 Changes in Version 3.4 ---------------------- Version 3.4 implements the new server features introduced in MongoDB 3.4 and a whole lot more: Highlights include: - Complete support for MongoDB 3.4: - Unicode aware string comparison using :doc:`examples/collations`. - Support for the new :class:`~bson.decimal128.Decimal128` BSON type. - A new maxStalenessSeconds read preference option. - A username is no longer required for the MONGODB-X509 authentication mechanism when connected to MongoDB >= 3.4. - :meth:`~pymongo.collection.Collection.parallel_scan` supports maxTimeMS. - :attr:`~pymongo.write_concern.WriteConcern` is automatically applied by all helpers for commands that write to the database when connected to MongoDB 3.4+. This change affects the following helpers: - :meth:`~pymongo.mongo_client.MongoClient.drop_database` - :meth:`~pymongo.database.Database.create_collection` - :meth:`~pymongo.database.Database.drop_collection` - :meth:`~pymongo.collection.Collection.aggregate` (when using $out) - :meth:`~pymongo.collection.Collection.create_indexes` - :meth:`~pymongo.collection.Collection.create_index` - :meth:`~pymongo.collection.Collection.drop_indexes` - :meth:`~pymongo.collection.Collection.drop_indexes` - :meth:`~pymongo.collection.Collection.drop_index` - :meth:`~pymongo.collection.Collection.map_reduce` (when output is not "inline") - :meth:`~pymongo.collection.Collection.reindex` - :meth:`~pymongo.collection.Collection.rename` - Improved support for logging server discovery and monitoring events. See :mod:`~pymongo.monitoring` for examples. - Support for matching iPAddress subjectAltName values for TLS certificate verification. - TLS compression is now explicitly disabled when possible. - The Server Name Indication (SNI) TLS extension is used when possible. - Finer control over JSON encoding/decoding with :class:`~bson.json_util.JSONOptions`. - Allow :class:`~bson.code.Code` objects to have a scope of ``None``, signifying no scope. Also allow encoding Code objects with an empty scope (i.e. ``{}``). .. warning:: Starting in PyMongo 3.4, :attr:`bson.code.Code.scope` may return ``None``, as the default scope is ``None`` instead of ``{}``. .. note:: PyMongo 3.4+ attempts to create sockets non-inheritable when possible (i.e. it sets the close-on-exec flag on socket file descriptors). Support is limited to a subset of POSIX operating systems (not including Windows) and the flag usually cannot be set in a single atomic operation. CPython 3.4+ implements `PEP 446`_, creating all file descriptors non-inheritable by default. Users that require this behavior are encouraged to upgrade to CPython 3.4+. Since 3.4rc0, the max staleness option has been renamed from ``maxStalenessMS`` to ``maxStalenessSeconds``, its smallest value has changed from twice ``heartbeatFrequencyMS`` to 90 seconds, and its default value has changed from ``None`` or 0 to -1. .. _PEP 446: https://www.python.org/dev/peps/pep-0446/ Issues Resolved ............... See the `PyMongo 3.4 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.4 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16594 Changes in Version 3.3.1 ------------------------ Version 3.3.1 fixes a memory leak when decoding elements inside of a :class:`~bson.raw_bson.RawBSONDocument`. Issues Resolved ............... See the `PyMongo 3.3.1 release notes in Jira`_ for the list of resolved issues in this release. .. _PyMongo 3.3.1 release notes in Jira: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=17636 Changes in Version 3.3 ---------------------- Version 3.3 adds the following major new features: - C extensions support on big endian systems. - Kerberos authentication support on Windows using `WinKerberos `_. - A new ``ssl_clrfile`` option to support certificate revocation lists. - A new ``ssl_pem_passphrase`` option to support encrypted key files. - Support for publishing server discovery and monitoring events. See :mod:`~pymongo.monitoring` for details. - New connection pool options ``minPoolSize`` and ``maxIdleTimeMS``. - New ``heartbeatFrequencyMS`` option controls the rate at which background monitoring threads re-check servers. Default is once every 10 seconds. .. warning:: PyMongo 3.3 drops support for MongoDB versions older than 2.4. It also drops support for python 3.2 (pypy3 continues to be supported). Issues Resolved ............... See the `PyMongo 3.3 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16005 Changes in Version 3.2.2 ------------------------ Version 3.2.2 fixes a few issues reported since the release of 3.2.1, including a fix for using the `connect` option in the MongoDB URI and support for setting the batch size for a query to 1 when using MongoDB 3.2+. Issues Resolved ............... See the `PyMongo 3.2.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.2.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16538 Changes in Version 3.2.1 ------------------------ Version 3.2.1 fixes a few issues reported since the release of 3.2, including running the mapreduce command twice when calling the :meth:`~pymongo.collection.Collection.inline_map_reduce` method and a :exc:`TypeError` being raised when calling :meth:`~gridfs.GridFSBucket.download_to_stream`. This release also improves error messaging around BSON decoding. Issues Resolved ............... See the `PyMongo 3.2.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.2.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16312 Changes in Version 3.2 ---------------------- Version 3.2 implements the new server features introduced in MongoDB 3.2. Highlights include: - Full support for MongoDB 3.2 including: - Support for :class:`~pymongo.read_concern.ReadConcern` - :class:`~pymongo.write_concern.WriteConcern` is now applied to :meth:`~pymongo.collection.Collection.find_one_and_replace`, :meth:`~pymongo.collection.Collection.find_one_and_update`, and :meth:`~pymongo.collection.Collection.find_one_and_delete`. - Support for the new `bypassDocumentValidation` option in write helpers. - Support for reading and writing raw BSON with :class:`~bson.raw_bson.RawBSONDocument` .. note:: Certain :class:`~pymongo.mongo_client.MongoClient` properties now block until a connection is established or raise :exc:`~pymongo.errors.ServerSelectionTimeoutError` if no server is available. See :class:`~pymongo.mongo_client.MongoClient` for details. Issues Resolved ............... See the `PyMongo 3.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=15612 Changes in Version 3.1.1 ------------------------ Version 3.1.1 fixes a few issues reported since the release of 3.1, including a regression in error handling for oversize command documents and interrupt handling issues in the C extensions. Issues Resolved ............... See the `PyMongo 3.1.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.1.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16211 Changes in Version 3.1 ---------------------- Version 3.1 implements a few new features and fixes bugs reported since the release of 3.0.3. Highlights include: - Command monitoring support. See :mod:`~pymongo.monitoring` for details. - Configurable error handling for :exc:`UnicodeDecodeError`. See the `unicode_decode_error_handler` option of :class:`~bson.codec_options.CodecOptions`. - Optional automatic timezone conversion when decoding BSON datetime. See the `tzinfo` option of :class:`~bson.codec_options.CodecOptions`. - An implementation of :class:`~gridfs.GridFSBucket` from the new GridFS spec. - Compliance with the new Connection String spec. - Reduced idle CPU usage in Python 2. Changes in internal classes ........................... The private ``PeriodicExecutor`` class no longer takes a ``condition_class`` option, and the private ``thread_util.Event`` class is removed. Issues Resolved ............... See the `PyMongo 3.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=14796 Changes in Version 3.0.3 ------------------------ Version 3.0.3 fixes issues reported since the release of 3.0.2, including a feature breaking bug in the GSSAPI implementation. Issues Resolved ............... See the `PyMongo 3.0.3 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.0.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=15528 Changes in Version 3.0.2 ------------------------ Version 3.0.2 fixes issues reported since the release of 3.0.1, most importantly a bug that could route operations to replica set members that are not in primary or secondary state when using :class:`~pymongo.read_preferences.PrimaryPreferred` or :class:`~pymongo.read_preferences.Nearest`. It is a recommended upgrade for all users of PyMongo 3.0.x. Issues Resolved ............... See the `PyMongo 3.0.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.0.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=15430 Changes in Version 3.0.1 ------------------------ Version 3.0.1 fixes issues reported since the release of 3.0, most importantly a bug in GridFS.delete that could prevent file chunks from actually being deleted. Issues Resolved ............... See the `PyMongo 3.0.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.0.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=15322 Changes in Version 3.0 ---------------------- PyMongo 3.0 is a partial rewrite of PyMongo bringing a large number of improvements: - A unified client class. MongoClient is the one and only client class for connecting to a standalone mongod, replica set, or sharded cluster. Migrating from a standalone, to a replica set, to a sharded cluster can be accomplished with only a simple URI change. - MongoClient is much more responsive to configuration changes in your MongoDB deployment. All connected servers are monitored in a non-blocking manner. Slow to respond or down servers no longer block server discovery, reducing application startup time and time to respond to new or reconfigured servers and replica set failovers. - A unified CRUD API. All official MongoDB drivers now implement a standard CRUD API allowing polyglot developers to move from language to language with ease. - Single source support for Python 2.x and 3.x. PyMongo no longer relies on 2to3 to support Python 3. - A rewritten pure Python BSON implementation, improving performance with pypy and cpython deployments without support for C extensions. - Better support for greenlet based async frameworks including eventlet. - Immutable client, database, and collection classes, avoiding a host of thread safety issues in client applications. PyMongo 3.0 brings a large number of API changes. Be sure to read the changes listed below before upgrading from PyMongo 2.x. .. warning:: PyMongo no longer supports Python 2.4, 2.5, or 3.1. If you must use PyMongo with these versions of Python the 2.x branch of PyMongo will be minimally supported for some time. SONManipulator changes ...................... The :class:`~pymongo.son_manipulator.SONManipulator` API has limitations as a technique for transforming your data. Instead, it is more flexible and straightforward to transform outgoing documents in your own code before passing them to PyMongo, and transform incoming documents after receiving them from PyMongo. Thus the :meth:`~pymongo.database.Database.add_son_manipulator` method is deprecated. PyMongo 3's new CRUD API does **not** apply SON manipulators to documents passed to :meth:`~pymongo.collection.Collection.bulk_write`, :meth:`~pymongo.collection.Collection.insert_one`, :meth:`~pymongo.collection.Collection.insert_many`, :meth:`~pymongo.collection.Collection.update_one`, or :meth:`~pymongo.collection.Collection.update_many`. SON manipulators are **not** applied to documents returned by the new methods :meth:`~pymongo.collection.Collection.find_one_and_delete`, :meth:`~pymongo.collection.Collection.find_one_and_replace`, and :meth:`~pymongo.collection.Collection.find_one_and_update`. SSL/TLS changes ............... When `ssl` is ``True`` the `ssl_cert_reqs` option now defaults to :attr:`ssl.CERT_REQUIRED` if not provided. PyMongo will attempt to load OS provided CA certificates to verify the server, raising :exc:`~pymongo.errors.ConfigurationError` if it cannot. Gevent Support .............. In previous versions, PyMongo supported Gevent in two modes: you could call ``gevent.monkey.patch_socket()`` and pass ``use_greenlets=True`` to :class:`~pymongo.mongo_client.MongoClient`, or you could simply call ``gevent.monkey.patch_all()`` and omit the ``use_greenlets`` argument. In PyMongo 3.0, the ``use_greenlets`` option is gone. To use PyMongo with Gevent simply call ``gevent.monkey.patch_all()``. For more information, see :doc:`PyMongo's Gevent documentation `. :class:`~pymongo.mongo_client.MongoClient` changes .................................................. :class:`~pymongo.mongo_client.MongoClient` is now the one and only client class for a standalone server, mongos, or replica set. It includes the functionality that had been split into ``MongoReplicaSetClient``: it can connect to a replica set, discover all its members, and monitor the set for stepdowns, elections, and reconfigs. :class:`~pymongo.mongo_client.MongoClient` now also supports the full :class:`~pymongo.read_preferences.ReadPreference` API. The obsolete classes ``MasterSlaveConnection``, ``Connection``, and ``ReplicaSetConnection`` are removed. The :class:`~pymongo.mongo_client.MongoClient` constructor no longer blocks while connecting to the server or servers, and it no longer raises :class:`~pymongo.errors.ConnectionFailure` if they are unavailable, nor :class:`~pymongo.errors.ConfigurationError` if the user's credentials are wrong. Instead, the constructor returns immediately and launches the connection process on background threads. The ``connect`` option is added to control whether these threads are started immediately, or when the client is first used. Therefore the ``alive`` method is removed since it no longer provides meaningful information; even if the client is disconnected, it may discover a server in time to fulfill the next operation. In PyMongo 2.x, :class:`~pymongo.mongo_client.MongoClient` accepted a list of standalone MongoDB servers and used the first it could connect to:: MongoClient(['host1.com:27017', 'host2.com:27017']) A list of multiple standalones is no longer supported; if multiple servers are listed they must be members of the same replica set, or mongoses in the same sharded cluster. The behavior for a list of mongoses is changed from "high availability" to "load balancing". Before, the client connected to the lowest-latency mongos in the list, and used it until a network error prompted it to re-evaluate all mongoses' latencies and reconnect to one of them. In PyMongo 3, the client monitors its network latency to all the mongoses continuously, and distributes operations evenly among those with the lowest latency. See :ref:`mongos-load-balancing` for more information. The client methods ``start_request``, ``in_request``, and ``end_request`` are removed, and so is the ``auto_start_request`` option. Requests were designed to make read-your-writes consistency more likely with the ``w=0`` write concern. Additionally, a thread in a request used the same member for all secondary reads in a replica set. To ensure read-your-writes consistency in PyMongo 3.0, do not override the default write concern with ``w=0``, and do not override the default :ref:`read preference ` of PRIMARY. Support for the ``slaveOk`` (or ``slave_okay``), ``safe``, and ``network_timeout`` options has been removed. Use :attr:`~pymongo.read_preferences.ReadPreference.SECONDARY_PREFERRED` instead of slave_okay. Accept the default write concern, acknowledged writes, instead of setting safe=True. Use socketTimeoutMS in place of network_timeout (note that network_timeout was in seconds, where as socketTimeoutMS is milliseconds). The ``max_pool_size`` option has been removed. It is replaced by the ``maxPoolSize`` MongoDB URI option. ``maxPoolSize`` is now a supported URI option in PyMongo and can be passed as a keyword argument. The ``copy_database`` method is removed, see the :doc:`copy_database examples ` for alternatives. The ``disconnect`` method is removed. Use :meth:`~pymongo.mongo_client.MongoClient.close` instead. The ``get_document_class`` method is removed. Use :attr:`~pymongo.mongo_client.MongoClient.codec_options` instead. The ``get_lasterror_options``, ``set_lasterror_options``, and ``unset_lasterror_options`` methods are removed. Write concern options can be passed to :class:`~pymongo.mongo_client.MongoClient` as keyword arguments or MongoDB URI options. The :meth:`~pymongo.mongo_client.MongoClient.get_database` method is added for getting a Database instance with its options configured differently than the MongoClient's. The following read-only attributes have been added: - :attr:`~pymongo.mongo_client.MongoClient.codec_options` The following attributes are now read-only: - :attr:`~pymongo.mongo_client.MongoClient.read_preference` - :attr:`~pymongo.mongo_client.MongoClient.write_concern` The following attributes have been removed: - :attr:`~pymongo.mongo_client.MongoClient.document_class` (use :attr:`~pymongo.mongo_client.MongoClient.codec_options` instead) - :attr:`~pymongo.mongo_client.MongoClient.host` (use :attr:`~pymongo.mongo_client.MongoClient.address` instead) - :attr:`~pymongo.mongo_client.MongoClient.min_wire_version` - :attr:`~pymongo.mongo_client.MongoClient.max_wire_version` - :attr:`~pymongo.mongo_client.MongoClient.port` (use :attr:`~pymongo.mongo_client.MongoClient.address` instead) - :attr:`~pymongo.mongo_client.MongoClient.safe` (use :attr:`~pymongo.mongo_client.MongoClient.write_concern` instead) - :attr:`~pymongo.mongo_client.MongoClient.slave_okay` (use :attr:`~pymongo.mongo_client.MongoClient.read_preference` instead) - :attr:`~pymongo.mongo_client.MongoClient.tag_sets` (use :attr:`~pymongo.mongo_client.MongoClient.read_preference` instead) - :attr:`~pymongo.mongo_client.MongoClient.tz_aware` (use :attr:`~pymongo.mongo_client.MongoClient.codec_options` instead) The following attributes have been renamed: - :attr:`~pymongo.mongo_client.MongoClient.secondary_acceptable_latency_ms` is now :attr:`~pymongo.mongo_client.MongoClient.local_threshold_ms` and is now read-only. :class:`~pymongo.cursor.Cursor` changes ....................................... The ``conn_id`` property is renamed to :attr:`~pymongo.cursor.Cursor.address`. Cursor management changes ......................... :class:`~pymongo.cursor_manager.CursorManager` and :meth:`~pymongo.mongo_client.MongoClient.set_cursor_manager` are no longer deprecated. If you subclass :class:`~pymongo.cursor_manager.CursorManager` your implementation of :meth:`~pymongo.cursor_manager.CursorManager.close` must now take a second parameter, `address`. The ``BatchCursorManager`` class is removed. The second parameter to :meth:`~pymongo.mongo_client.MongoClient.close_cursor` is renamed from ``_conn_id`` to ``address``. :meth:`~pymongo.mongo_client.MongoClient.kill_cursors` now accepts an `address` parameter. :class:`~pymongo.database.Database` changes ........................................... The ``connection`` property is renamed to :attr:`~pymongo.database.Database.client`. The following read-only attributes have been added: - :attr:`~pymongo.database.Database.codec_options` The following attributes are now read-only: - :attr:`~pymongo.database.Database.read_preference` - :attr:`~pymongo.database.Database.write_concern` Use :meth:`~pymongo.mongo_client.MongoClient.get_database` for getting a Database instance with its options configured differently than the MongoClient's. The following attributes have been removed: - :attr:`~pymongo.database.Database.safe` - :attr:`~pymongo.database.Database.secondary_acceptable_latency_ms` - :attr:`~pymongo.database.Database.slave_okay` - :attr:`~pymongo.database.Database.tag_sets` The following methods have been added: - :meth:`~pymongo.database.Database.get_collection` The following methods have been changed: - :meth:`~pymongo.database.Database.command`. Support for `as_class`, `uuid_subtype`, `tag_sets`, and `secondary_acceptable_latency_ms` have been removed. You can instead pass an instance of :class:`~bson.codec_options.CodecOptions` as `codec_options` and an instance of a read preference class from :mod:`~pymongo.read_preferences` as `read_preference`. The `fields` and `compile_re` options are also removed. The `fields` options was undocumented and never really worked. Regular expressions are always decoded to :class:`~bson.regex.Regex`. The following methods have been deprecated: - :meth:`~pymongo.database.Database.add_son_manipulator` The following methods have been removed: The ``get_lasterror_options``, ``set_lasterror_options``, and ``unset_lasterror_options`` methods have been removed. Use :class:`~pymongo.write_concern.WriteConcern` with :meth:`~pymongo.mongo_client.MongoClient.get_database` instead. :class:`~pymongo.collection.Collection` changes ............................................... The following read-only attributes have been added: - :attr:`~pymongo.collection.Collection.codec_options` The following attributes are now read-only: - :attr:`~pymongo.collection.Collection.read_preference` - :attr:`~pymongo.collection.Collection.write_concern` Use :meth:`~pymongo.database.Database.get_collection` or :meth:`~pymongo.collection.Collection.with_options` for getting a Collection instance with its options configured differently than the Database's. The following attributes have been removed: - :attr:`~pymongo.collection.Collection.safe` - :attr:`~pymongo.collection.Collection.secondary_acceptable_latency_ms` - :attr:`~pymongo.collection.Collection.slave_okay` - :attr:`~pymongo.collection.Collection.tag_sets` The following methods have been added: - :meth:`~pymongo.collection.Collection.bulk_write` - :meth:`~pymongo.collection.Collection.insert_one` - :meth:`~pymongo.collection.Collection.insert_many` - :meth:`~pymongo.collection.Collection.update_one` - :meth:`~pymongo.collection.Collection.update_many` - :meth:`~pymongo.collection.Collection.replace_one` - :meth:`~pymongo.collection.Collection.delete_one` - :meth:`~pymongo.collection.Collection.delete_many` - :meth:`~pymongo.collection.Collection.find_one_and_delete` - :meth:`~pymongo.collection.Collection.find_one_and_replace` - :meth:`~pymongo.collection.Collection.find_one_and_update` - :meth:`~pymongo.collection.Collection.with_options` - :meth:`~pymongo.collection.Collection.create_indexes` - :meth:`~pymongo.collection.Collection.list_indexes` The following methods have changed: - :meth:`~pymongo.collection.Collection.aggregate` now **always** returns an instance of :class:`~pymongo.command_cursor.CommandCursor`. See the documentation for all options. - :meth:`~pymongo.collection.Collection.count` now optionally takes a filter argument, as well as other options supported by the count command. - :meth:`~pymongo.collection.Collection.distinct` now optionally takes a filter argument. - :meth:`~pymongo.collection.Collection.create_index` no longer caches indexes, therefore the `cache_for` parameter has been removed. It also no longer supports the `bucket_size` and `drop_dups` aliases for `bucketSize` and `dropDups`. The following methods are deprecated: - :meth:`~pymongo.collection.Collection.save` - :meth:`~pymongo.collection.Collection.insert` - :meth:`~pymongo.collection.Collection.update` - :meth:`~pymongo.collection.Collection.remove` - :meth:`~pymongo.collection.Collection.find_and_modify` - :meth:`~pymongo.collection.Collection.ensure_index` The following methods have been removed: The ``get_lasterror_options``, ``set_lasterror_options``, and ``unset_lasterror_options`` methods have been removed. Use :class:`~pymongo.write_concern.WriteConcern` with :meth:`~pymongo.collection.Collection.with_options` instead. Changes to :meth:`~pymongo.collection.Collection.find` and :meth:`~pymongo.collection.Collection.find_one` `````````````````````````````````````````````````````````````````````````````````````````````````````````` The following find/find_one options have been renamed: These renames only affect your code if you passed these as keyword arguments, like find(fields=['fieldname']). If you passed only positional parameters these changes are not significant for your application. - spec -> filter - fields -> projection - partial -> allow_partial_results The following find/find_one options have been added: - cursor_type (see :class:`~pymongo.cursor.CursorType` for values) - oplog_replay - modifiers The following find/find_one options have been removed: - network_timeout (use :meth:`~pymongo.cursor.Cursor.max_time_ms` instead) - slave_okay (use one of the read preference classes from :mod:`~pymongo.read_preferences` and :meth:`~pymongo.collection.Collection.with_options` instead) - read_preference (use :meth:`~pymongo.collection.Collection.with_options` instead) - tag_sets (use one of the read preference classes from :mod:`~pymongo.read_preferences` and :meth:`~pymongo.collection.Collection.with_options` instead) - secondary_acceptable_latency_ms (use the `localThresholdMS` URI option instead) - max_scan (use the new `modifiers` option instead) - snapshot (use the new `modifiers` option instead) - tailable (use the new `cursor_type` option instead) - await_data (use the new `cursor_type` option instead) - exhaust (use the new `cursor_type` option instead) - as_class (use :meth:`~pymongo.collection.Collection.with_options` with :class:`~bson.codec_options.CodecOptions` instead) - compile_re (BSON regular expressions are always decoded to :class:`~bson.regex.Regex`) The following find/find_one options are deprecated: - manipulate The following renames need special handling. - timeout -> no_cursor_timeout - The default for `timeout` was True. The default for `no_cursor_timeout` is False. If you were previously passing False for `timeout` you must pass **True** for `no_cursor_timeout` to keep the previous behavior. :mod:`~pymongo.errors` changes .............................. The exception classes ``UnsupportedOption`` and ``TimeoutError`` are deleted. :mod:`~gridfs` changes ...................... Since PyMongo 1.6, methods ``open`` and ``close`` of :class:`~gridfs.GridFS` raised an ``UnsupportedAPI`` exception, as did the entire ``GridFile`` class. The unsupported methods, the class, and the exception are all deleted. :mod:`~bson` changes .................... The `compile_re` option is removed from all methods that accepted it in :mod:`~bson` and :mod:`~bson.json_util`. Additionally, it is removed from :meth:`~pymongo.collection.Collection.find`, :meth:`~pymongo.collection.Collection.find_one`, :meth:`~pymongo.collection.Collection.aggregate`, :meth:`~pymongo.database.Database.command`, and so on. PyMongo now always represents BSON regular expressions as :class:`~bson.regex.Regex` objects. This prevents errors for incompatible patterns, see `PYTHON-500`_. Use :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a BSON regular expression to a Python regular expression object. PyMongo now decodes the int64 BSON type to :class:`~bson.int64.Int64`, a trivial wrapper around long (in python 2.x) or int (in python 3.x). This allows BSON int64 to be round tripped without losing type information in python 3. Note that if you store a python long (or a python int larger than 4 bytes) it will be returned from PyMongo as :class:`~bson.int64.Int64`. The `as_class`, `tz_aware`, and `uuid_subtype` options are removed from all BSON encoding and decoding methods. Use :class:`~bson.codec_options.CodecOptions` to configure these options. The APIs affected are: - :func:`~bson.decode_all` - :func:`~bson.decode_iter` - :func:`~bson.decode_file_iter` - :meth:`~bson.BSON.encode` - :meth:`~bson.BSON.decode` This is a breaking change for any application that uses the BSON API directly and changes any of the named parameter defaults. No changes are required for applications that use the default values for these options. The behavior remains the same. .. _PYTHON-500: https://jira.mongodb.org/browse/PYTHON-500 Issues Resolved ............... See the `PyMongo 3.0 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 3.0 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12501 Changes in Version 2.9.5 ------------------------ Version 2.9.5 works around ssl module deprecations in Python 3.6, and expected future ssl module deprecations. It also fixes bugs found since the release of 2.9.4. - Use ssl.SSLContext and ssl.PROTOCOL_TLS_CLIENT when available. - Fixed a C extensions build issue when the interpreter was built with -std=c99 - Fixed various build issues with MinGW32. - Fixed a write concern bug in :meth:`~pymongo.database.Database.add_user` and :meth:`~pymongo.database.Database.remove_user` when connected to MongoDB 3.2+ - Fixed various test failures related to changes in gevent, MongoDB, and our CI test environment. Issues Resolved ............... See the `PyMongo 2.9.5 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9.5 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=17605 Changes in Version 2.9.4 ------------------------ Version 2.9.4 fixes issues reported since the release of 2.9.3. - Fixed __repr__ for closed instances of :class:`~pymongo.mongo_client.MongoClient`. - Fixed :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` handling of uuidRepresentation. - Fixed building and testing the documentation with python 3.x. - New documentation for :doc:`examples/tls` and :doc:`atlas`. Issues Resolved ............... See the `PyMongo 2.9.4 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9.4 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16885 Changes in Version 2.9.3 ------------------------ Version 2.9.3 fixes a few issues reported since the release of 2.9.2 including thread safety issues in :meth:`~pymongo.collection.Collection.ensure_index`, :meth:`~pymongo.collection.Collection.drop_index`, and :meth:`~pymongo.collection.Collection.drop_indexes`. Issues Resolved ............... See the `PyMongo 2.9.3 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16539 Changes in Version 2.9.2 ------------------------ Version 2.9.2 restores Python 3.1 support, which was broken in PyMongo 2.8. It improves an error message when decoding BSON as well as fixes a couple other issues including :meth:`~pymongo.collection.Collection.aggregate` ignoring :attr:`~pymongo.collection.Collection.codec_options` and :meth:`~pymongo.database.Database.command` raising a superfluous `DeprecationWarning`. Issues Resolved ............... See the `PyMongo 2.9.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16303 Changes in Version 2.9.1 ------------------------ Version 2.9.1 fixes two interrupt handling issues in the C extensions and adapts a test case for a behavior change in MongoDB 3.2. Issues Resolved ............... See the `PyMongo 2.9.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=16208 Changes in Version 2.9 ---------------------- Version 2.9 provides an upgrade path to PyMongo 3.x. Most of the API changes from PyMongo 3.0 have been backported in a backward compatible way, allowing applications to be written against PyMongo >= 2.9, rather then PyMongo 2.x or PyMongo 3.x. See the :doc:`/migrate-to-pymongo3` for detailed examples. .. note:: There are a number of new deprecations in this release for features that were removed in PyMongo 3.0. :class:`~pymongo.mongo_client.MongoClient`: - :attr:`~pymongo.mongo_client.MongoClient.host` - :attr:`~pymongo.mongo_client.MongoClient.port` - :attr:`~pymongo.mongo_client.MongoClient.use_greenlets` - :attr:`~pymongo.mongo_client.MongoClient.document_class` - :attr:`~pymongo.mongo_client.MongoClient.tz_aware` - :attr:`~pymongo.mongo_client.MongoClient.secondary_acceptable_latency_ms` - :attr:`~pymongo.mongo_client.MongoClient.tag_sets` - :attr:`~pymongo.mongo_client.MongoClient.uuid_subtype` - :meth:`~pymongo.mongo_client.MongoClient.disconnect` - :meth:`~pymongo.mongo_client.MongoClient.alive` :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`: - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.use_greenlets` - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.document_class` - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.tz_aware` - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.secondary_acceptable_latency_ms` - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.tag_sets` - :attr:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.uuid_subtype` - :meth:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.alive` :class:`~pymongo.database.Database`: - :attr:`~pymongo.database.Database.secondary_acceptable_latency_ms` - :attr:`~pymongo.database.Database.tag_sets` - :attr:`~pymongo.database.Database.uuid_subtype` :class:`~pymongo.collection.Collection`: - :attr:`~pymongo.collection.Collection.secondary_acceptable_latency_ms` - :attr:`~pymongo.collection.Collection.tag_sets` - :attr:`~pymongo.collection.Collection.uuid_subtype` .. warning:: In previous versions of PyMongo, changing the value of :attr:`~pymongo.mongo_client.MongoClient.document_class` changed the behavior of all existing instances of :class:`~pymongo.collection.Collection`:: >>> coll = client.test.test >>> coll.find_one() {u'_id': ObjectId('5579dc7cfba5220cc14d9a18')} >>> from bson.son import SON >>> client.document_class = SON >>> coll.find_one() SON([(u'_id', ObjectId('5579dc7cfba5220cc14d9a18'))]) The document_class setting is now configurable at the client, database, collection, and per-operation level. This required breaking the existing behavior. To change the document class per operation in a forward compatible way use :meth:`~pymongo.collection.Collection.with_options`:: >>> coll.find_one() {u'_id': ObjectId('5579dc7cfba5220cc14d9a18')} >>> from bson.codec_options import CodecOptions >>> coll.with_options(CodecOptions(SON)).find_one() SON([(u'_id', ObjectId('5579dc7cfba5220cc14d9a18'))]) Issues Resolved ............... See the `PyMongo 2.9 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.9 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=14795 Changes in Version 2.8.1 ------------------------ Version 2.8.1 fixes a number of issues reported since the release of PyMongo 2.8. It is a recommended upgrade for all users of PyMongo 2.x. Issues Resolved ............... See the `PyMongo 2.8.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.8.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=15324 Changes in Version 2.8 ---------------------- Version 2.8 is a major release that provides full support for MongoDB 3.0 and fixes a number of bugs. Special thanks to Don Mitchell, Ximing, Can Zhang, Sergey Azovskov, and Heewa Barfchin for their contributions to this release. Highlights include: - Support for the SCRAM-SHA-1 authentication mechanism (new in MongoDB 3.0). - JSON decoder support for the new $numberLong and $undefined types. - JSON decoder support for the $date type as an ISO-8601 string. - Support passing an index name to :meth:`~pymongo.cursor.Cursor.hint`. - The :meth:`~pymongo.cursor.Cursor.count` method will use a hint if one has been provided through :meth:`~pymongo.cursor.Cursor.hint`. - A new socketKeepAlive option for the connection pool. - New generator based BSON decode functions, :func:`~bson.decode_iter` and :func:`~bson.decode_file_iter`. - Internal changes to support alternative storage engines like wiredtiger. .. note:: There are a number of deprecations in this release for features that will be removed in PyMongo 3.0. These include: - :meth:`~pymongo.mongo_client.MongoClient.start_request` - :meth:`~pymongo.mongo_client.MongoClient.in_request` - :meth:`~pymongo.mongo_client.MongoClient.end_request` - :meth:`~pymongo.mongo_client.MongoClient.copy_database` - :meth:`~pymongo.database.Database.error` - :meth:`~pymongo.database.Database.last_status` - :meth:`~pymongo.database.Database.previous_error` - :meth:`~pymongo.database.Database.reset_error_history` - :class:`~pymongo.master_slave_connection.MasterSlaveConnection` The JSON format for :class:`~bson.timestamp.Timestamp` has changed from '{"t": , "i": }' to '{"$timestamp": {"t": , "i": }}'. This new format will be decoded to an instance of :class:`~bson.timestamp.Timestamp`. The old format will continue to be decoded to a python dict as before. Encoding to the old format is no longer supported as it was never correct and loses type information. Issues Resolved ............... See the `PyMongo 2.8 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.8 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=14223 Changes in Version 2.7.2 ------------------------ Version 2.7.2 includes fixes for upsert reporting in the bulk API for MongoDB versions previous to 2.6, a regression in how son manipulators are applied in :meth:`~pymongo.collection.Collection.insert`, a few obscure connection pool semaphore leaks, and a few other minor issues. See the list of issues resolved for full details. Issues Resolved ............... See the `PyMongo 2.7.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.7.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=14005 Changes in Version 2.7.1 ------------------------ Version 2.7.1 fixes a number of issues reported since the release of 2.7, most importantly a fix for creating indexes and manipulating users through mongos versions older than 2.4.0. Issues Resolved ............... See the `PyMongo 2.7.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.7.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=13823 Changes in Version 2.7 ---------------------- PyMongo 2.7 is a major release with a large number of new features and bug fixes. Highlights include: - Full support for MongoDB 2.6. - A new :doc:`bulk write operations API `. - Support for server side query timeouts using :meth:`~pymongo.cursor.Cursor.max_time_ms`. - Support for writing :meth:`~pymongo.collection.Collection.aggregate` output to a collection. - A new :meth:`~pymongo.collection.Collection.parallel_scan` helper. - :class:`~pymongo.errors.OperationFailure` and its subclasses now include a :attr:`~pymongo.errors.OperationFailure.details` attribute with complete error details from the server. - A new GridFS :meth:`~gridfs.GridFS.find` method that returns a :class:`~gridfs.grid_file.GridOutCursor`. - Greatly improved :doc:`support for mod_wsgi ` when using PyMongo's C extensions. Read `Jesse's blog post `_ for details. - Improved C extension support for ARM little endian. Breaking changes ................ Version 2.7 drops support for replica sets running MongoDB versions older than 1.6.2. Issues Resolved ............... See the `PyMongo 2.7 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.7 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12892 Changes in Version 2.6.3 ------------------------ Version 2.6.3 fixes issues reported since the release of 2.6.2, most importantly a semaphore leak when a connection to the server fails. Issues Resolved ............... See the `PyMongo 2.6.3 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.6.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=13098 Changes in Version 2.6.2 ------------------------ Version 2.6.2 fixes a :exc:`TypeError` problem when max_pool_size=None is used in Python 3. Issues Resolved ............... See the `PyMongo 2.6.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.6.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12910 Changes in Version 2.6.1 ------------------------ Version 2.6.1 fixes a reference leak in the :meth:`~pymongo.collection.Collection.insert` method. Issues Resolved ............... See the `PyMongo 2.6.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.6.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12905 Changes in Version 2.6 ---------------------- Version 2.6 includes some frequently requested improvements and adds support for some early MongoDB 2.6 features. Special thanks go to Justin Patrin for his work on the connection pool in this release. Important new features: - The ``max_pool_size`` option for :class:`~pymongo.mongo_client.MongoClient` and :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` now actually caps the number of sockets the pool will open concurrently. Once the pool has reached :attr:`~pymongo.mongo_client.MongoClient.max_pool_size` operations will block waiting for a socket to become available. If ``waitQueueTimeoutMS`` is set, an operation that blocks waiting for a socket will raise :exc:`~pymongo.errors.ConnectionFailure` after the timeout. By default ``waitQueueTimeoutMS`` is not set. See :ref:`connection-pooling` for more information. - The :meth:`~pymongo.collection.Collection.insert` method automatically splits large batches of documents into multiple insert messages based on :attr:`~pymongo.mongo_client.MongoClient.max_message_size` - Support for the exhaust cursor flag. See :meth:`~pymongo.collection.Collection.find` for details and caveats. - Support for the PLAIN and MONGODB-X509 authentication mechanisms. See :doc:`the authentication docs ` for more information. - Support aggregation output as a :class:`~pymongo.cursor.Cursor`. See :meth:`~pymongo.collection.Collection.aggregate` for details. .. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.6. Previously, `max_pool_size` would limit only the idle sockets the pool would hold onto, not the number of open sockets. The default has also changed, from 10 to 100. If you pass a value for ``max_pool_size`` make sure it is large enough for the expected load. (Sockets are only opened when needed, so there is no cost to having a ``max_pool_size`` larger than necessary. Err towards a larger value.) If your application accepts the default, continue to do so. See :ref:`connection-pooling` for more information. Issues Resolved ............... See the `PyMongo 2.6 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.6 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12380 Changes in Version 2.5.2 ------------------------ Version 2.5.2 fixes a NULL pointer dereference issue when decoding an invalid :class:`~bson.dbref.DBRef`. Issues Resolved ............... See the `PyMongo 2.5.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.5.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12581 Changes in Version 2.5.1 ------------------------ Version 2.5.1 is a minor release that fixes issues discovered after the release of 2.5. Most importantly, this release addresses some race conditions in replica set monitoring. Issues Resolved ............... See the `PyMongo 2.5.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.5.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12484 Changes in Version 2.5 ---------------------- Version 2.5 includes changes to support new features in MongoDB 2.4. Important new features: - Support for :ref:`GSSAPI (Kerberos) authentication `. - Support for SSL certificate validation with hostname matching. - Support for delegated and role based authentication. - New GEOSPHERE (2dsphere) and HASHED index constants. .. note:: :meth:`~pymongo.database.Database.authenticate` now raises a subclass of :class:`~pymongo.errors.PyMongoError` if authentication fails due to invalid credentials or configuration issues. Issues Resolved ............... See the `PyMongo 2.5 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.5 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=11981 Changes in Version 2.4.2 ------------------------ Version 2.4.2 is a minor release that fixes issues discovered after the release of 2.4.1. Most importantly, PyMongo will no longer select a replica set member for read operations that is not in primary or secondary state. Issues Resolved ............... See the `PyMongo 2.4.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.4.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12299 Changes in Version 2.4.1 ------------------------ Version 2.4.1 is a minor release that fixes issues discovered after the release of 2.4. Most importantly, this release fixes a regression using :meth:`~pymongo.collection.Collection.aggregate`, and possibly other commands, with mongos. Issues Resolved ............... See the `PyMongo 2.4.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.4.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=12286 Changes in Version 2.4 ---------------------- Version 2.4 includes a few important new features and a large number of bug fixes. Important new features: - New :class:`~pymongo.mongo_client.MongoClient` and :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` classes - these connection classes do acknowledged write operations (previously referred to as 'safe' writes) by default. :class:`~pymongo.connection.Connection` and :class:`~pymongo.replica_set_connection.ReplicaSetConnection` are deprecated but still support the old default fire-and-forget behavior. - A new write concern API implemented as a :attr:`~pymongo.collection.Collection.write_concern` attribute on the connection, :class:`~pymongo.database.Database`, or :class:`~pymongo.collection.Collection` classes. - :class:`~pymongo.mongo_client.MongoClient` (and :class:`~pymongo.connection.Connection`) now support Unix Domain Sockets. - :class:`~pymongo.cursor.Cursor` can be copied with functions from the :mod:`copy` module. - The :meth:`~pymongo.database.Database.set_profiling_level` method now supports a `slow_ms` option. - The replica set monitor task (used by :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` and :class:`~pymongo.replica_set_connection.ReplicaSetConnection`) is a daemon thread once again, meaning you won't have to call :meth:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient.close` before exiting the python interactive shell. .. warning:: The constructors for :class:`~pymongo.mongo_client.MongoClient`, :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`, :class:`~pymongo.connection.Connection`, and :class:`~pymongo.replica_set_connection.ReplicaSetConnection` now raise :exc:`~pymongo.errors.ConnectionFailure` instead of its subclass :exc:`~pymongo.errors.AutoReconnect` if the server is unavailable. Applications that expect to catch :exc:`~pymongo.errors.AutoReconnect` should now catch :exc:`~pymongo.errors.ConnectionFailure` while creating a new connection. Issues Resolved ............... See the `PyMongo 2.4 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.4 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=11485 Changes in Version 2.3 ---------------------- Version 2.3 adds support for new features and behavior changes in MongoDB 2.2. Important New Features: - Support for expanded read preferences including directing reads to tagged servers - See :ref:`secondary-reads` for more information. - Support for mongos failover. - A new :meth:`~pymongo.collection.Collection.aggregate` method to support MongoDB's new `aggregation framework `_. - Support for legacy Java and C# byte order when encoding and decoding UUIDs. - Support for connecting directly to an arbiter. .. warning:: Starting with MongoDB 2.2 the getLastError command requires authentication when the server's `authentication features `_ are enabled. Changes to PyMongo were required to support this behavior change. Users of authentication must upgrade to PyMongo 2.3 (or newer) for "safe" write operations to function correctly. Issues Resolved ............... See the `PyMongo 2.3 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=11146 Changes in Version 2.2.1 ------------------------ Version 2.2.1 is a minor release that fixes issues discovered after the release of 2.2. Most importantly, this release fixes an incompatibility with mod_wsgi 2.x that could cause connections to leak. Users of mod_wsgi 2.x are strongly encouraged to upgrade from PyMongo 2.2. Issues Resolved ............... See the `PyMongo 2.2.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.2.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=11185 Changes in Version 2.2 ---------------------- Version 2.2 adds a few more frequently requested features and fixes a number of bugs. Special thanks go to Alex Grönholm for his contributions to Python 3 support and maintaining the original pymongo3 port. Christoph Simon, Wouter Bolsterlee, Mike O'Brien, and Chris Tompkinson also contributed to this release. Important New Features: - Support for Python 3 - See the :doc:`python3` for more information. - Support for Gevent - See :doc:`examples/gevent` for more information. - Improved connection pooling. See `PYTHON-287 `_. .. warning:: A number of methods and method parameters that were deprecated in PyMongo 1.9 or older versions have been removed in this release. The full list of changes can be found in the following JIRA ticket: https://jira.mongodb.org/browse/PYTHON-305 BSON module aliases from the pymongo package that were deprecated in PyMongo 1.9 have also been removed in this release. See the following JIRA ticket for details: https://jira.mongodb.org/browse/PYTHON-304 As a result of this cleanup some minor code changes may be required to use this release. Issues Resolved ............... See the `PyMongo 2.2 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.2 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=10584 Changes in Version 2.1.1 ------------------------ Version 2.1.1 is a minor release that fixes a few issues discovered after the release of 2.1. You can now use :class:`~pymongo.replica_set_connection.ReplicaSetConnection` to run inline map reduce commands on secondaries. See :meth:`~pymongo.collection.Collection.inline_map_reduce` for details. Special thanks go to Samuel Clay and Ross Lawley for their contributions to this release. Issues Resolved ............... See the `PyMongo 2.1.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.1.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?version=11081&styleName=Html&projectId=10004 Changes in Version 2.1 ---------------------- Version 2.1 adds a few frequently requested features and includes the usual round of bug fixes and improvements. Special thanks go to Alexey Borzenkov, Dan Crosta, Kostya Rybnikov, Flavio Percoco Premoli, Jonas Haag, and Jesse Davis for their contributions to this release. Important New Features: - ReplicaSetConnection - :class:`~pymongo.replica_set_connection.ReplicaSetConnection` can be used to distribute reads to secondaries in a replica set. It supports automatic failover handling and periodically checks the state of the replica set to handle issues like primary stepdown or secondaries being removed for backup operations. Read preferences are defined through :class:`~pymongo.read_preferences.ReadPreference`. - PyMongo supports the new BSON binary subtype 4 for UUIDs. The default subtype to use can be set through :attr:`~pymongo.collection.Collection.uuid_subtype` The current default remains :attr:`~bson.binary.OLD_UUID_SUBTYPE` but will be changed to :attr:`~bson.binary.UUID_SUBTYPE` in a future release. - The getLastError option 'w' can be set to a string, allowing for options like "majority" available in newer version of MongoDB. - Added support for the MongoDB URI options socketTimeoutMS and connectTimeoutMS. - Added support for the ContinueOnError insert flag. - Added basic SSL support. - Added basic support for Jython. - Secondaries can be used for :meth:`~pymongo.cursor.Cursor.count`, :meth:`~pymongo.cursor.Cursor.distinct`, :meth:`~pymongo.collection.Collection.group`, and querying :class:`~gridfs.GridFS`. - Added document_class and tz_aware options to :class:`~pymongo.master_slave_connection.MasterSlaveConnection` Issues Resolved ............... See the `PyMongo 2.1 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=10583 Changes in Version 2.0.1 ------------------------ Version 2.0.1 fixes a regression in :class:`~gridfs.grid_file.GridIn` when writing pre-chunked strings. Thanks go to Alexey Borzenkov for reporting the issue and submitting a patch. Issues Resolved ............... - `PYTHON-271 `_: Regression in GridFS leads to serious loss of data. Changes in Version 2.0 ---------------------- Version 2.0 adds a large number of features and fixes a number of issues. Special thanks go to James Murty, Abhay Vardhan, David Pisoni, Ryan Smith-Roberts, Andrew Pendleton, Mher Movsisyan, Reed O'Brien, Michael Schurter, Josip Delic and Jonas Haag for their contributions to this release. Important New Features: - PyMongo now performs automatic per-socket database authentication. You no longer have to re-authenticate for each new thread or after a replica set failover. Authentication credentials are cached by the driver until the application calls :meth:`~pymongo.database.Database.logout`. - slave_okay can be set independently at the connection, database, collection or query level. Each level will inherit the slave_okay setting from the previous level and each level can override the previous level's setting. - safe and getLastError options (e.g. w, wtimeout, etc.) can be set independently at the connection, database, collection or query level. Each level will inherit settings from the previous level and each level can override the previous level's setting. - PyMongo now supports the `await_data` and `partial` cursor flags. If the `await_data` flag is set on a `tailable` cursor the server will block for some extra time waiting for more data to return. The `partial` flag tells a mongos to return partial data for a query if not all shards are available. - :meth:`~pymongo.collection.Collection.map_reduce` will accept a `dict` or instance of :class:`~bson.son.SON` as the `out` parameter. - The URI parser has been moved into its own module and can be used directly by application code. - AutoReconnect exception now provides information about the error that actually occured instead of a generic failure message. - A number of new helper methods have been added with options for setting and unsetting cursor flags, re-indexing a collection, fsync and locking a server, and getting the server's current operations. API changes: - If only one host:port pair is specified :class:`~pymongo.connection.Connection` will make a direct connection to only that host. Please note that `slave_okay` must be `True` in order to query from a secondary. - If more than one host:port pair is specified or the `replicaset` option is used PyMongo will treat the specified host:port pair(s) as a seed list and connect using replica set behavior. .. warning:: The default subtype for :class:`~bson.binary.Binary` has changed from :const:`~bson.binary.OLD_BINARY_SUBTYPE` (2) to :const:`~bson.binary.BINARY_SUBTYPE` (0). Issues Resolved ............... See the `PyMongo 2.0 release notes in JIRA`_ for the list of resolved issues in this release. .. _PyMongo 2.0 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=10274 Changes in Version 1.11 ----------------------- Version 1.11 adds a few new features and fixes a few more bugs. New Features: - Basic IPv6 support: pymongo prefers IPv4 but will try IPv6. You can also specify an IPv6 address literal in the `host` parameter or a MongoDB URI provided it is enclosed in '[' and ']'. - max_pool_size option: previously pymongo had a hard coded pool size of 10 connections. With this change you can specify a different pool size as a parameter to :class:`~pymongo.connection.Connection` (max_pool_size=) or in the MongoDB URI (maxPoolSize=). - Find by metadata in GridFS: You can know specify query fields as keyword parameters for :meth:`~gridfs.GridFS.get_version` and :meth:`~gridfs.GridFS.get_last_version`. - Per-query slave_okay option: slave_okay=True is now a valid keyword argument for :meth:`~pymongo.collection.Collection.find` and :meth:`~pymongo.collection.Collection.find_one`. API changes: - :meth:`~pymongo.database.Database.validate_collection` now returns a dict instead of a string. This change was required to deal with an API change on the server. This method also now takes the optional `scandata` and `full` parameters. See the documentation for more details. .. warning:: The `pool_size`, `auto_start_request`, and `timeout` parameters for :class:`~pymongo.connection.Connection` have been completely removed in this release. They were deprecated in pymongo-1.4 and have had no effect since then. Please make sure that your code doesn't currently pass these parameters when creating a Connection instance. Issues resolved ............... - `PYTHON-241 `_: Support setting slaveok at the cursor level. - `PYTHON-240 `_: Queries can sometimes permanently fail after a replica set fail over. - `PYTHON-238 `_: error after few million requests - `PYTHON-237 `_: Basic IPv6 support. - `PYTHON-236 `_: Restore option to specify pool size in Connection. - `PYTHON-212 `_: pymongo does not recover after stale config - `PYTHON-138 `_: Find method for GridFS Changes in Version 1.10.1 ------------------------- Version 1.10.1 is primarily a bugfix release. It fixes a regression in version 1.10 that broke pickling of ObjectIds. A number of other bugs have been fixed as well. There are two behavior changes to be aware of: - If a read slave raises :class:`~pymongo.errors.AutoReconnect` :class:`~pymongo.master_slave_connection.MasterSlaveConnection` will now retry the query on each slave until it is successful or all slaves have raised :class:`~pymongo.errors.AutoReconnect`. Any other exception will immediately be raised. The order that the slaves are tried is random. Previously the read would be sent to one randomly chosen slave and :class:`~pymongo.errors.AutoReconnect` was immediately raised in case of a connection failure. - A Python `long` is now always BSON encoded as an int64. Previously the encoding was based only on the value of the field and a `long` with a value less than `2147483648` or greater than `-2147483649` would always be BSON encoded as an int32. Issues resolved ............... - `PYTHON-234 `_: Fix setup.py to raise exception if any when building extensions - `PYTHON-233 `_: Add information to build and test with extensions on windows - `PYTHON-232 `_: Traceback when hashing a DBRef instance - `PYTHON-231 `_: Traceback when pickling a DBRef instance - `PYTHON-230 `_: Pickled ObjectIds are not compatible between pymongo 1.9 and 1.10 - `PYTHON-228 `_: Cannot pickle bson.ObjectId - `PYTHON-227 `_: Traceback when calling find() on system.js - `PYTHON-216 `_: MasterSlaveConnection is missing disconnect() method - `PYTHON-186 `_: When storing integers, type is selected according to value instead of type - `PYTHON-173 `_: as_class option is not propogated by Cursor.clone - `PYTHON-113 `_: Redunducy in MasterSlaveConnection Changes in Version 1.10 ----------------------- Version 1.10 includes changes to support new features in MongoDB 1.8.x. Highlights include a modified map/reduce API including an inline map/reduce helper method, a new find_and_modify helper, and the ability to query the server for the maximum BSON document size it supports. - added :meth:`~pymongo.collection.Collection.find_and_modify`. - added :meth:`~pymongo.collection.Collection.inline_map_reduce`. - changed :meth:`~pymongo.collection.Collection.map_reduce`. .. warning:: MongoDB versions greater than 1.7.4 no longer generate temporary collections for map/reduce results. An output collection name must be provided and the output will replace any existing output collection with the same name. :meth:`~pymongo.collection.Collection.map_reduce` now requires the `out` parameter. Issues resolved ............... - PYTHON-225: :class:`~pymongo.objectid.ObjectId` class definition should use __slots__. - PYTHON-223: Documentation fix. - PYTHON-220: Documentation fix. - PYTHON-219: KeyError in :meth:`~pymongo.collection.Collection.find_and_modify` - PYTHON-213: Query server for maximum BSON document size. - PYTHON-208: Fix :class:`~pymongo.connection.Connection` __repr__. - PYTHON-207: Changes to Map/Reduce API. - PYTHON-205: Accept slaveOk in the URI to match the URI docs. - PYTHON-203: When slave_okay=True and we only specify one host don't autodetect other set members. - PYTHON-194: Show size when whining about a document being too large. - PYTHON-184: Raise :class:`~pymongo.errors.DuplicateKeyError` for duplicate keys in capped collections. - PYTHON-178: Don't segfault when trying to encode a recursive data structure. - PYTHON-177: Don't segfault when decoding dicts with broken iterators. - PYTHON-172: Fix a typo. - PYTHON-170: Add :meth:`~pymongo.collection.Collection.find_and_modify`. - PYTHON-169: Support deepcopy of DBRef. - PYTHON-167: Duplicate of PYTHON-166. - PYTHON-166: Fixes a concurrency issue. - PYTHON-158: Add code and err string to `db assertion` messages. Changes in Version 1.9 ---------------------- Version 1.9 adds a new package to the PyMongo distribution, :mod:`bson`. :mod:`bson` contains all of the `BSON `_ encoding and decoding logic, and the BSON types that were formerly in the :mod:`pymongo` package. The following modules have been renamed: - :mod:`pymongo.bson` -> :mod:`bson` - :mod:`pymongo._cbson` -> :mod:`bson._cbson` and :mod:`pymongo._cmessage` - :mod:`pymongo.binary` -> :mod:`bson.binary` - :mod:`pymongo.code` -> :mod:`bson.code` - :mod:`pymongo.dbref` -> :mod:`bson.dbref` - :mod:`pymongo.json_util` -> :mod:`bson.json_util` - :mod:`pymongo.max_key` -> :mod:`bson.max_key` - :mod:`pymongo.min_key` -> :mod:`bson.min_key` - :mod:`pymongo.objectid` -> :mod:`bson.objectid` - :mod:`pymongo.son` -> :mod:`bson.son` - :mod:`pymongo.timestamp` -> :mod:`bson.timestamp` - :mod:`pymongo.tz_util` -> :mod:`bson.tz_util` In addition, the following exception classes have been renamed: - :class:`pymongo.errors.InvalidBSON` -> :class:`bson.errors.InvalidBSON` - :class:`pymongo.errors.InvalidStringData` -> :class:`bson.errors.InvalidStringData` - :class:`pymongo.errors.InvalidDocument` -> :class:`bson.errors.InvalidDocument` - :class:`pymongo.errors.InvalidId` -> :class:`bson.errors.InvalidId` The above exceptions now inherit from :class:`bson.errors.BSONError` rather than :class:`pymongo.errors.PyMongoError`. .. note:: All of the renamed modules and exceptions above have aliases created with the old names, so these changes should not break existing code. The old names will eventually be deprecated and then removed, so users should begin migrating towards the new names now. .. warning:: The change to the exception hierarchy mentioned above is possibly breaking. If your code is catching :class:`~pymongo.errors.PyMongoError`, then the exceptions raised by :mod:`bson` will not be caught, even though they would have been caught previously. Before upgrading, it is recommended that users check for any cases like this. - the C extension now shares buffer.c/h with the Ruby driver - :mod:`bson` no longer raises :class:`~pymongo.errors.InvalidName`, all occurrences have been replaced with :class:`~bson.errors.InvalidDocument`. - renamed :meth:`bson._to_dicts` to :meth:`~bson.decode_all`. - renamed :meth:`~bson.BSON.from_dict` to :meth:`~bson.BSON.encode` and :meth:`~bson.BSON.to_dict` to :meth:`~bson.BSON.decode`. - added :meth:`~pymongo.cursor.Cursor.batch_size`. - allow updating (some) file metadata after a :class:`~gridfs.grid_file.GridIn` instance has been closed. - performance improvements for reading from GridFS. - special cased slice with the same start and stop to return an empty cursor. - allow writing :class:`unicode` to GridFS if an :attr:`encoding` attribute has been specified for the file. - added :meth:`gridfs.GridFS.get_version`. - scope variables for :class:`~bson.code.Code` can now be specified as keyword arguments. - added :meth:`~gridfs.grid_file.GridOut.readline` to :class:`~gridfs.grid_file.GridOut`. - make a best effort to transparently auto-reconnect if a :class:`~pymongo.connection.Connection` has been idle for a while. - added :meth:`~pymongo.database.SystemJS.list` to :class:`~pymongo.database.SystemJS`. - added `file_document` argument to :meth:`~gridfs.grid_file.GridOut` to allow initializing from an existing file document. - raise :class:`~pymongo.errors.TimeoutError` even if the ``getLastError`` command was run manually and not through "safe" mode. - added :class:`uuid` support to :mod:`~bson.json_util`. Changes in Version 1.8.1 ------------------------ - fixed a typo in the C extension that could cause safe-mode operations to report a failure (:class:`SystemError`) even when none occurred. - added a :meth:`__ne__` implementation to any class where we define :meth:`__eq__`. Changes in Version 1.8 ---------------------- Version 1.8 adds support for connecting to replica sets, specifying per-operation values for `w` and `wtimeout`, and decoding to timezone-aware datetimes. - fixed a reference leak in the C extension when decoding a :class:`~bson.dbref.DBRef`. - added support for `w`, `wtimeout`, and `fsync` (and any other options for `getLastError`) to "safe mode" operations. - added :attr:`~pymongo.connection.Connection.nodes` property. - added a maximum pool size of 10 sockets. - added support for replica sets. - DEPRECATED :meth:`~pymongo.connection.Connection.from_uri` and :meth:`~pymongo.connection.Connection.paired`, both are supplanted by extended functionality in :meth:`~pymongo.connection.Connection`. - added tz aware support for datetimes in :class:`~bson.objectid.ObjectId`, :class:`~bson.timestamp.Timestamp` and :mod:`~bson.json_util` methods. - added :meth:`~pymongo.collection.Collection.drop` helper. - reuse the socket used for finding the master when a :class:`~pymongo.connection.Connection` is first created. - added support for :class:`~bson.min_key.MinKey`, :class:`~bson.max_key.MaxKey` and :class:`~bson.timestamp.Timestamp` to :mod:`~bson.json_util`. - added support for decoding datetimes as aware (UTC) - it is highly recommended to enable this by setting the `tz_aware` parameter to :meth:`~pymongo.connection.Connection` to ``True``. - added `network_timeout` option for individual calls to :meth:`~pymongo.collection.Collection.find` and :meth:`~pymongo.collection.Collection.find_one`. - added :meth:`~gridfs.GridFS.exists` to check if a file exists in GridFS. - added support for additional keys in :class:`~bson.dbref.DBRef` instances. - added :attr:`~pymongo.errors.OperationFailure.code` attribute to :class:`~pymongo.errors.OperationFailure` exceptions. - fixed serialization of int and float subclasses in the C extension. Changes in Version 1.7 ---------------------- Version 1.7 is a recommended upgrade for all PyMongo users. The full release notes are below, and some more in depth discussion of the highlights is `here `_. - no longer attempt to build the C extension on big-endian systems. - added :class:`~bson.min_key.MinKey` and :class:`~bson.max_key.MaxKey`. - use unsigned for :class:`~bson.timestamp.Timestamp` in BSON encoder/decoder. - support ``True`` as ``"ok"`` in command responses, in addition to ``1.0`` - necessary for server versions **>= 1.5.X** - BREAKING change to :meth:`~pymongo.collection.Collection.index_information` to add support for querying unique status and other index information. - added :attr:`~pymongo.connection.Connection.document_class`, to specify class for returned documents. - added `as_class` argument for :meth:`~pymongo.collection.Collection.find`, and in the BSON decoder. - added support for creating :class:`~bson.timestamp.Timestamp` instances using a :class:`~datetime.datetime`. - allow `dropTarget` argument for :class:`~pymongo.collection.Collection.rename`. - handle aware :class:`~datetime.datetime` instances, by converting to UTC. - added support for :class:`~pymongo.cursor.Cursor.max_scan`. - raise :class:`~gridfs.errors.FileExists` exception when creating a duplicate GridFS file. - use `y2038 `_ for time handling in the C extension - eliminates 2038 problems when extension is installed. - added `sort` parameter to :meth:`~pymongo.collection.Collection.find` - finalized deprecation of changes from versions **<= 1.4** - take any non-:class:`dict` as an ``"_id"`` query for :meth:`~pymongo.collection.Collection.find_one` or :meth:`~pymongo.collection.Collection.remove` - added ability to pass a :class:`dict` for `fields` argument to :meth:`~pymongo.collection.Collection.find` (supports ``"$slice"`` and field negation) - simplified code to find master, since paired setups don't always have a remote - fixed bug in C encoder for certain invalid types (like :class:`~pymongo.collection.Collection` instances). - don't transparently map ``"filename"`` key to :attr:`name` attribute for GridFS. Changes in Version 1.6 ---------------------- The biggest change in version 1.6 is a complete re-implementation of :mod:`gridfs` with a lot of improvements over the old implementation. There are many details and examples of using the new API in `this blog post `_. The old API has been removed in this version, so existing code will need to be modified before upgrading to 1.6. - fixed issue where connection pool was being shared across :class:`~pymongo.connection.Connection` instances. - more improvements to Python code caching in C extension - should improve behavior on mod_wsgi. - added :meth:`~bson.objectid.ObjectId.from_datetime`. - complete rewrite of :mod:`gridfs` support. - improvements to the :meth:`~pymongo.database.Database.command` API. - fixed :meth:`~pymongo.collection.Collection.drop_indexes` behavior on non-existent collections. - disallow empty bulk inserts. Changes in Version 1.5.2 ------------------------ - fixed response handling to ignore unknown response flags in queries. - handle server versions containing '-pre-'. Changes in Version 1.5.1 ------------------------ - added :data:`~gridfs.grid_file.GridFile._id` property for :class:`~gridfs.grid_file.GridFile` instances. - fix for making a :class:`~pymongo.connection.Connection` (with `slave_okay` set) directly to a slave in a replica pair. - accept kwargs for :meth:`~pymongo.collection.Collection.create_index` and :meth:`~pymongo.collection.Collection.ensure_index` to support all indexing options. - add :data:`pymongo.GEO2D` and support for geo indexing. - improvements to Python code caching in C extension - should improve behavior on mod_wsgi. Changes in Version 1.5 ---------------------- - added subtype constants to :mod:`~bson.binary` module. - DEPRECATED `options` argument to :meth:`~pymongo.collection.Collection` and :meth:`~pymongo.database.Database.create_collection` in favor of kwargs. - added :meth:`~pymongo.has_c` to check for C extension. - added :meth:`~pymongo.connection.Connection.copy_database`. - added :data:`~pymongo.cursor.Cursor.alive` to tell when a cursor might have more data to return (useful for tailable cursors). - added :class:`~bson.timestamp.Timestamp` to better support dealing with internal MongoDB timestamps. - added `name` argument for :meth:`~pymongo.collection.Collection.create_index` and :meth:`~pymongo.collection.Collection.ensure_index`. - fixed connection pooling w/ fork - :meth:`~pymongo.connection.Connection.paired` takes all kwargs that are allowed for :meth:`~pymongo.connection.Connection`. - :meth:`~pymongo.collection.Collection.insert` returns list for bulk inserts of size one. - fixed handling of :class:`datetime.datetime` instances in :mod:`~bson.json_util`. - added :meth:`~pymongo.connection.Connection.from_uri` to support MongoDB connection uri scheme. - fixed chunk number calculation when unaligned in :mod:`gridfs`. - :meth:`~pymongo.database.Database.command` takes a string for simple commands. - added :data:`~pymongo.database.Database.system_js` helper for dealing with server-side JS. - don't wrap queries containing ``"$query"`` (support manual use of ``"$min"``, etc.). - added :class:`~gridfs.errors.GridFSError` as base class for :mod:`gridfs` exceptions. Changes in Version 1.4 ---------------------- Perhaps the most important change in version 1.4 is that we have decided to **no longer support Python 2.3**. The most immediate reason for this is to allow some improvements to connection pooling. This will also allow us to use some new (as in Python 2.4 ;) idioms and will help begin the path towards supporting Python 3.0. If you need to use Python 2.3 you should consider using version 1.3 of this driver, although that will no longer be actively supported. Other changes: - move ``"_id"`` to front only for top-level documents (fixes some corner cases). - :meth:`~pymongo.collection.Collection.update` and :meth:`~pymongo.collection.Collection.remove` return the entire response to the *lastError* command when safe is ``True``. - completed removal of things that were deprecated in version 1.2 or earlier. - enforce that collection names do not contain the NULL byte. - fix to allow using UTF-8 collection names with the C extension. - added :class:`~pymongo.errors.PyMongoError` as base exception class for all :mod:`~pymongo.errors`. this changes the exception hierarchy somewhat, and is a BREAKING change if you depend on :class:`~pymongo.errors.ConnectionFailure` being a :class:`IOError` or :class:`~bson.errors.InvalidBSON` being a :class:`ValueError`, for example. - added :class:`~pymongo.errors.DuplicateKeyError` for calls to :meth:`~pymongo.collection.Collection.insert` or :meth:`~pymongo.collection.Collection.update` with `safe` set to ``True``. - removed :mod:`~pymongo.thread_util`. - added :meth:`~pymongo.database.Database.add_user` and :meth:`~pymongo.database.Database.remove_user` helpers. - fix for :meth:`~pymongo.database.Database.authenticate` when using non-UTF-8 names or passwords. - minor fixes for :class:`~pymongo.master_slave_connection.MasterSlaveConnection`. - clean up all cases where :class:`~pymongo.errors.ConnectionFailure` is raised. - simplification of connection pooling - makes driver ~2x faster for simple benchmarks. see :ref:`connection-pooling` for more information. - DEPRECATED `pool_size`, `auto_start_request` and `timeout` parameters to :class:`~pymongo.connection.Connection`. DEPRECATED :meth:`~pymongo.connection.Connection.start_request`. - use :meth:`socket.sendall`. - removed :meth:`~bson.son.SON.from_xml` as it was only being used for some internal testing - also eliminates dependency on :mod:`elementtree`. - implementation of :meth:`~pymongo.message.update` in C. - deprecate :meth:`~pymongo.database.Database._command` in favor of :meth:`~pymongo.database.Database.command`. - send all commands without wrapping as ``{"query": ...}``. - support string as `key` argument to :meth:`~pymongo.collection.Collection.group` (keyf) and run all groups as commands. - support for equality testing for :class:`~bson.code.Code` instances. - allow the NULL byte in strings and disallow it in key names or regex patterns Changes in Version 1.3 ---------------------- - DEPRECATED running :meth:`~pymongo.collection.Collection.group` as :meth:`~pymongo.database.Database.eval`, also changed default for :meth:`~pymongo.collection.Collection.group` to running as a command - remove :meth:`pymongo.cursor.Cursor.__len__`, which was deprecated in 1.1.1 - needed to do this aggressively due to it's presence breaking **Django** template *for* loops - DEPRECATED :meth:`~pymongo.connection.Connection.host`, :meth:`~pymongo.connection.Connection.port`, :meth:`~pymongo.database.Database.connection`, :meth:`~pymongo.database.Database.name`, :meth:`~pymongo.collection.Collection.database`, :meth:`~pymongo.collection.Collection.name` and :meth:`~pymongo.collection.Collection.full_name` in favor of :attr:`~pymongo.connection.Connection.host`, :attr:`~pymongo.connection.Connection.port`, :attr:`~pymongo.database.Database.connection`, :attr:`~pymongo.database.Database.name`, :attr:`~pymongo.collection.Collection.database`, :attr:`~pymongo.collection.Collection.name` and :attr:`~pymongo.collection.Collection.full_name`, respectively. The deprecation schedule for this change will probably be faster than usual, as it carries some performance implications. - added :meth:`~pymongo.connection.Connection.disconnect` Changes in Version 1.2.1 ------------------------ - added :doc:`changelog` to docs - added ``setup.py doc --test`` to run doctests for tutorial, examples - moved most examples to Sphinx docs (and remove from *examples/* directory) - raise :class:`~bson.errors.InvalidId` instead of :class:`TypeError` when passing a 24 character string to :class:`~bson.objectid.ObjectId` that contains non-hexadecimal characters - allow :class:`unicode` instances for :class:`~bson.objectid.ObjectId` init Changes in Version 1.2 ---------------------- - `spec` parameter for :meth:`~pymongo.collection.Collection.remove` is now optional to allow for deleting all documents in a :class:`~pymongo.collection.Collection` - always wrap queries with ``{query: ...}`` even when no special options - get around some issues with queries on fields named ``query`` - enforce 4MB document limit on the client side - added :meth:`~pymongo.collection.Collection.map_reduce` helper - see :doc:`example ` - added :meth:`~pymongo.cursor.Cursor.distinct` method on :class:`~pymongo.cursor.Cursor` instances to allow distinct with queries - fix for :meth:`~pymongo.cursor.Cursor.__getitem__` after :meth:`~pymongo.cursor.Cursor.skip` - allow any UTF-8 string in :class:`~bson.BSON` encoder, not just ASCII subset - added :attr:`~bson.objectid.ObjectId.generation_time` - removed support for legacy :class:`~bson.objectid.ObjectId` format - pretty sure this was never used, and is just confusing - DEPRECATED :meth:`~bson.objectid.ObjectId.url_encode` and :meth:`~bson.objectid.ObjectId.url_decode` in favor of :meth:`str` and :meth:`~bson.objectid.ObjectId`, respectively - allow *oplog.$main* as a valid collection name - some minor fixes for installation process - added support for datetime and regex in :mod:`~bson.json_util` Changes in Version 1.1.2 ------------------------ - improvements to :meth:`~pymongo.collection.Collection.insert` speed (using C for insert message creation) - use random number for request_id - fix some race conditions with :class:`~pymongo.errors.AutoReconnect` Changes in Version 1.1.1 ------------------------ - added `multi` parameter for :meth:`~pymongo.collection.Collection.update` - fix unicode regex patterns with C extension - added :meth:`~pymongo.collection.Collection.distinct` - added `database` support for :class:`~bson.dbref.DBRef` - added :mod:`~bson.json_util` with helpers for encoding / decoding special types to JSON - DEPRECATED :meth:`pymongo.cursor.Cursor.__len__` in favor of :meth:`~pymongo.cursor.Cursor.count` with `with_limit_and_skip` set to ``True`` due to performance regression - switch documentation to Sphinx Changes in Version 1.1 ---------------------- - added :meth:`__hash__` for :class:`~bson.dbref.DBRef` and :class:`~bson.objectid.ObjectId` - bulk :meth:`~pymongo.collection.Collection.insert` works with any iterable - fix :class:`~bson.objectid.ObjectId` generation when using :mod:`multiprocessing` - added :attr:`~pymongo.cursor.Cursor.collection` - added `network_timeout` parameter for :meth:`~pymongo.connection.Connection` - DEPRECATED `slave_okay` parameter for individual queries - fix for `safe` mode when multi-threaded - added `safe` parameter for :meth:`~pymongo.collection.Collection.remove` - added `tailable` parameter for :meth:`~pymongo.collection.Collection.find` Changes in Version 1.0 ---------------------- - fixes for :class:`~pymongo.master_slave_connection.MasterSlaveConnection` - added `finalize` parameter for :meth:`~pymongo.collection.Collection.group` - improvements to :meth:`~pymongo.collection.Collection.insert` speed - improvements to :mod:`gridfs` speed - added :meth:`~pymongo.cursor.Cursor.__getitem__` and :meth:`~pymongo.cursor.Cursor.__len__` for :class:`~pymongo.cursor.Cursor` instances Changes in Version 0.16 ----------------------- - support for encoding/decoding :class:`uuid.UUID` instances - fix for :meth:`~pymongo.cursor.Cursor.explain` with limits Changes in Version 0.15.2 ------------------------- - documentation changes only Changes in Version 0.15.1 ------------------------- - various performance improvements - API CHANGE no longer need to specify direction for :meth:`~pymongo.collection.Collection.create_index` and :meth:`~pymongo.collection.Collection.ensure_index` when indexing a single key - support for encoding :class:`tuple` instances as :class:`list` instances Changes in Version 0.15 ----------------------- - fix string representation of :class:`~bson.objectid.ObjectId` instances - added `timeout` parameter for :meth:`~pymongo.collection.Collection.find` - allow scope for `reduce` function in :meth:`~pymongo.collection.Collection.group` Changes in Version 0.14.2 ------------------------- - minor bugfixes Changes in Version 0.14.1 ------------------------- - :meth:`~gridfs.grid_file.GridFile.seek` and :meth:`~gridfs.grid_file.GridFile.tell` for (read mode) :class:`~gridfs.grid_file.GridFile` instances Changes in Version 0.14 ----------------------- - support for long in :class:`~bson.BSON` - added :meth:`~pymongo.collection.Collection.rename` - added `snapshot` parameter for :meth:`~pymongo.collection.Collection.find` Changes in Version 0.13 ----------------------- - better :class:`~pymongo.master_slave_connection.MasterSlaveConnection` support - API CHANGE :meth:`~pymongo.collection.Collection.insert` and :meth:`~pymongo.collection.Collection.save` both return inserted ``_id`` - DEPRECATED passing an index name to :meth:`~pymongo.cursor.Cursor.hint` Changes in Version 0.12 ----------------------- - improved :class:`~bson.objectid.ObjectId` generation - added :class:`~pymongo.errors.AutoReconnect` exception for when reconnection is possible - make :mod:`gridfs` thread-safe - fix for :mod:`gridfs` with non :class:`~bson.objectid.ObjectId` ``_id`` Changes in Version 0.11.3 ------------------------- - don't allow NULL bytes in string encoder - fixes for Python 2.3 Changes in Version 0.11.2 ------------------------- - PEP 8 - updates for :meth:`~pymongo.collection.Collection.group` - VS build Changes in Version 0.11.1 ------------------------- - fix for connection pooling under Python 2.5 Changes in Version 0.11 ----------------------- - better build failure detection - driver support for selecting fields in sub-documents - disallow insertion of invalid key names - added `timeout` parameter for :meth:`~pymongo.connection.Connection` Changes in Version 0.10.3 ------------------------- - fix bug with large :meth:`~pymongo.cursor.Cursor.limit` - better exception when modules get reloaded out from underneath the C extension - better exception messages when calling a :class:`~pymongo.collection.Collection` or :class:`~pymongo.database.Database` instance Changes in Version 0.10.2 ------------------------- - support subclasses of :class:`dict` in C encoder Changes in Version 0.10.1 ------------------------- - alias :class:`~pymongo.connection.Connection` as :attr:`pymongo.Connection` - raise an exception rather than silently overflowing in encoder Changes in Version 0.10 ----------------------- - added :meth:`~pymongo.collection.Collection.ensure_index` Changes in Version 0.9.7 ------------------------ - allow sub-collections of *$cmd* as valid :class:`~pymongo.collection.Collection` names - add version as :attr:`pymongo.version` - add ``--no_ext`` command line option to *setup.py* .. toctree:: :hidden: python3 examples/gevent pymongo-3.11.0/doc/compatibility-policy.rst000066400000000000000000000043521374256237000207460ustar00rootroot00000000000000Compatibility Policy ==================== Semantic Versioning ------------------- PyMongo's version numbers follow `semantic versioning`_: each version number is structured "major.minor.patch". Patch releases fix bugs, minor releases add features (and may fix bugs), and major releases include API changes that break backwards compatibility (and may add features and fix bugs). Deprecation ----------- Before we remove a feature in a major release, PyMongo's maintainers make an effort to release at least one minor version that *deprecates* it. We add "**DEPRECATED**" to the feature's documentation, and update the code to raise a `DeprecationWarning`_. You can ensure your code is future-proof by running your code with the latest PyMongo release and looking for DeprecationWarnings. Starting with Python 2.7, the interpreter silences DeprecationWarnings by default. For example, the following code uses the deprecated ``insert`` method but does not raise any warning: .. code-block:: python # "insert.py" from pymongo import MongoClient client = MongoClient() client.test.test.insert({}) To print deprecation warnings to stderr, run python with "-Wd":: $ python -Wd insert.py insert.py:4: DeprecationWarning: insert is deprecated. Use insert_one or insert_many instead. client.test.test.insert({}) You can turn warnings into exceptions with "python -We":: $ python -We insert.py Traceback (most recent call last): File "insert.py", line 4, in client.test.test.insert({}) File "/home/durin/work/mongo-python-driver/pymongo/collection.py", line 2906, in insert "instead.", DeprecationWarning, stacklevel=2) DeprecationWarning: insert is deprecated. Use insert_one or insert_many instead. If your own code's test suite passes with "python -We" then it uses no deprecated PyMongo features. .. seealso:: The Python documentation on `the warnings module`_, and `the -W command line option`_. .. _semantic versioning: http://semver.org/ .. _DeprecationWarning: https://docs.python.org/2/library/exceptions.html#exceptions.DeprecationWarning .. _the warnings module: https://docs.python.org/2/library/warnings.html .. _the -W command line option: https://docs.python.org/2/using/cmdline.html#cmdoption-W pymongo-3.11.0/doc/conf.py000066400000000000000000000126751374256237000153540ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # PyMongo documentation build configuration file # # This file is execfile()d with the current directory set to its containing dir. import sys, os sys.path[0:0] = [os.path.abspath('..')] import pymongo # -- General configuration ----------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.coverage', 'sphinx.ext.todo', 'doc.mongo_extensions', 'sphinx.ext.intersphinx'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The master toctree document. master_doc = 'index' # General information about the project. project = u'PyMongo' copyright = u'MongoDB, Inc. 2008-present. MongoDB, Mongo, and the leaf logo are registered trademarks of MongoDB, Inc' html_show_sphinx = False # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = pymongo.version # The full version, including alpha/beta/rc tags. release = pymongo.version # List of documents that shouldn't be included in the build. unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. exclude_trees = ['_build'] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # If true, the current module name will be prepended to all description # unit titles (such as .. function::). add_module_names = True # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] # -- Options for extensions ---------------------------------------------------- autoclass_content = 'init' doctest_path = [os.path.abspath('..')] doctest_test_doctest_blocks = '' doctest_global_setup = """ from pymongo.mongo_client import MongoClient client = MongoClient() client.drop_database("doctest_test") db = client.doctest_test """ # -- Options for HTML output --------------------------------------------------- # Theme gratefully vendored from CPython source. html_theme = "pydoctheme" html_theme_path = ["."] html_theme_options = { 'collapsiblesidebar': True, 'googletag': False } # Additional static files. html_static_path = ['static'] # These paths are either relative to html_static_path # or fully qualified paths (eg. https://...) # Note: html_js_files was added in Sphinx 1.8. html_js_files = [ 'delighted.js', ] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. #html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". #html_static_path = ['_static'] # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If true, links to the reST sources are added to the pages. #html_show_sourcelink = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = '' # Output file base name for HTML help builder. htmlhelp_basename = 'PyMongo' + release.replace('.', '_') # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). #latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). #latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ('index', 'PyMongo.tex', u'PyMongo Documentation', u'Michael Dirolf', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # Additional stuff for the LaTeX preamble. #latex_preamble = '' # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_use_modindex = True intersphinx_mapping = { 'gevent': ('http://www.gevent.org/', None), 'py': ('https://docs.python.org/3/', None), } pymongo-3.11.0/doc/contributors.rst000066400000000000000000000046271374256237000173420ustar00rootroot00000000000000Contributors ============ The following is a list of people who have contributed to **PyMongo**. If you belong here and are missing please let us know (or send a pull request after adding yourself to the list): - Mike Dirolf (mdirolf) - Jeff Jenkins (jeffjenkins) - Jim Jones - Eliot Horowitz (erh) - Michael Stephens (mikejs) - Joakim Sernbrant (serbaut) - Alexander Artemenko (svetlyak40wt) - Mathias Stearn (RedBeard0531) - Fajran Iman Rusadi (fajran) - Brad Clements (bkc) - Andrey Fedorov (andreyf) - Joshua Roesslein (joshthecoder) - Gregg Lind (gregglind) - Michael Schurter (schmichael) - Daniel Lundin - Michael Richardson (mtrichardson) - Dan McKinley (mcfunley) - David Wolever (wolever) - Carlos Valiente (carletes) - Jehiah Czebotar (jehiah) - Drew Perttula (drewp) - Carl Baatz (c-w-b) - Johan Bergstrom (jbergstroem) - Jonas Haag (jonashaag) - Kristina Chodorow (kchodorow) - Andrew Sibley (sibsibsib) - Flavio Percoco Premoli (FlaPer87) - Ken Kurzweil (kurzweil) - Christian Wyglendowski (dowski) - James Murty (jmurty) - Brendan W. McAdams (bwmcadams) - Bernie Hackett (behackett) - Reed O'Brien (reedobrien) - Francisco Souza (fsouza) - Alexey I. Froloff (raorn) - Steve Lacy (slacy) - Richard Shea (shearic) - Vladimir Sidorenko (gearheart) - Aaron Westendorf (awestendorf) - Dan Crosta (dcrosta) - Ryan Smith-Roberts (rmsr) - David Pisoni (gefilte) - Abhay Vardhan (abhayv) - Alexey Borzenkov (snaury) - Kostya Rybnikov (k-bx) - A Jesse Jiryu Davis (ajdavis) - Samuel Clay (samuelclay) - Ross Lawley (rozza) - Wouter Bolsterlee (wbolster) - Alex Grönholm (agronholm) - Christoph Simon (kalanzun) - Chris Tompkinson (tompko) - Mike O'Brien (mpobrien) - T Dampier (dampier) - Michael Henson (hensom) - Craig Hobbs (craigahobbs) - Emily Stolfo (estolfo) - Sam Helman (shelman) - Justin Patrin (reversefold) - Xiuming Chen (cxmcc) - Tyler Jones (thomascirca) - Amalia Hawkins (hawka) - Yuchen Ying (yegle) - Kyle Erf (3rf) - Luke Lovett (lovett89) - Jaroslav Semančík (girogiro) - Don Mitchell (dmitchell) - Ximing (armnotstrong) - Can Zhang (cannium) - Sergey Azovskov (last-g) - Heewa Barfchin (heewa) - Anna Herlihy (aherlihy) - Len Buckens (buckensl) - ultrabug - Shane Harvey (ShaneHarvey) - Cao Siyang (caosiyang) - Zhecong Kwok (gzcf) - TaoBeier(tao12345666333) - Jagrut Trivedi(Jagrut) - Shrey Batra(shreybatra) - Felipe Rodrigues(fbidu) - Terence Honles (terencehonles) - Paul Fisher (thetorpedodog) - Julius Park (juliusgeo) pymongo-3.11.0/doc/developer/000077500000000000000000000000001374256237000160275ustar00rootroot00000000000000pymongo-3.11.0/doc/developer/index.rst000066400000000000000000000002021374256237000176620ustar00rootroot00000000000000Developer Guide =============== Technical guide for contributors to PyMongo. .. toctree:: :maxdepth: 1 periodic_executor pymongo-3.11.0/doc/developer/periodic_executor.rst000066400000000000000000000122371374256237000223020ustar00rootroot00000000000000Periodic Executors ================== .. currentmodule:: pymongo PyMongo implements a :class:`~periodic_executor.PeriodicExecutor` for two purposes: as the background thread for :class:`~monitor.Monitor`, and to regularly check if there are `OP_KILL_CURSORS` messages that must be sent to the server. Killing Cursors --------------- An incompletely iterated :class:`~cursor.Cursor` on the client represents an open cursor object on the server. In code like this, we lose a reference to the cursor before finishing iteration:: for doc in collection.find(): raise Exception() We try to send an `OP_KILL_CURSORS` to the server to tell it to clean up the server-side cursor. But we must not take any locks directly from the cursor's destructor (see `PYTHON-799`_), so we cannot safely use the PyMongo data structures required to send a message. The solution is to add the cursor's id to an array on the :class:`~mongo_client.MongoClient` without taking any locks. Each client has a :class:`~periodic_executor.PeriodicExecutor` devoted to checking the array for cursor ids. Any it sees are the result of cursors that were freed while the server-side cursor was still open. The executor can safely take the locks it needs in order to send the `OP_KILL_CURSORS` message. .. _PYTHON-799: https://jira.mongodb.org/browse/PYTHON-799 Stopping Executors ------------------ Just as :class:`~cursor.Cursor` must not take any locks from its destructor, neither can :class:`~mongo_client.MongoClient` and :class:`~topology.Topology`. Thus, although the client calls :meth:`close` on its kill-cursors thread, and the topology calls :meth:`close` on all its monitor threads, the :meth:`close` method cannot actually call :meth:`wake` on the executor, since :meth:`wake` takes a lock. Instead, executors wake periodically to check if ``self.close`` is set, and if so they exit. A thread can log spurious errors if it wakes late in the Python interpreter's shutdown sequence, so we try to join threads before then. Each periodic executor (either a monitor or a kill-cursors thread) adds a weakref to itself to a set called ``_EXECUTORS``, in the ``periodic_executor`` module. An `exit handler`_ runs on shutdown and tells all executors to stop, then tries (with a short timeout) to join all executor threads. .. _exit handler: https://docs.python.org/2/library/atexit.html Monitoring ---------- For each server in the topology, :class:`~topology.Topology` uses a periodic executor to launch a monitor thread. This thread must not prevent the topology from being freed, so it weakrefs the topology. Furthermore, it uses a weakref callback to terminate itself soon after the topology is freed. Solid lines represent strong references, dashed lines weak ones: .. generated with graphviz: "dot -Tpng periodic-executor-refs.dot > periodic-executor-refs.png" .. image:: ../static/periodic-executor-refs.png See `Stopping Executors`_ above for an explanation of the ``_EXECUTORS`` set. It is a requirement of the `Server Discovery And Monitoring Spec`_ that a sleeping monitor can be awakened early. Aside from infrequent wakeups to do their appointed chores, and occasional interruptions, periodic executors also wake periodically to check if they should terminate. Our first implementation of this idea was the obvious one: use the Python standard library's threading.Condition.wait with a timeout. Another thread wakes the executor early by signaling the condition variable. A topology cannot signal the condition variable to tell the executor to terminate, because it would risk a deadlock in the garbage collector: no destructor or weakref callback can take a lock to signal the condition variable (see `PYTHON-863`_); thus the only way for a dying object to terminate a periodic executor is to set its "stopped" flag and let the executor see the flag next time it wakes. We erred on the side of prompt cleanup, and set the check interval at 100ms. We assumed that checking a flag and going back to sleep 10 times a second was cheap on modern machines. Starting in Python 3.2, the builtin C implementation of lock.acquire takes a timeout parameter, so Python 3.2+ Condition variables sleep simply by calling lock.acquire; they are implemented as efficiently as expected. But in Python 2, lock.acquire has no timeout. To wait with a timeout, a Python 2 condition variable sleeps a millisecond, tries to acquire the lock, sleeps twice as long, and tries again. This exponential backoff reaches a maximum sleep time of 50ms. If PyMongo calls the condition variable's "wait" method with a short timeout, the exponential backoff is restarted frequently. Overall, the condition variable is not waking a few times a second, but hundreds of times. (See `PYTHON-983`_.) Thus the current design of periodic executors is surprisingly simple: they do a simple `time.sleep` for a half-second, check if it is time to wake or terminate, and sleep again. .. _Server Discovery And Monitoring Spec: https://github.com/mongodb/specifications/blob/master/source/server-discovery-and-monitoring/server-discovery-and-monitoring.rst#requesting-an-immediate-check .. _PYTHON-863: https://jira.mongodb.org/browse/PYTHON-863 .. _PYTHON-983: https://jira.mongodb.org/browse/PYTHON-983 pymongo-3.11.0/doc/examples/000077500000000000000000000000001374256237000156605ustar00rootroot00000000000000pymongo-3.11.0/doc/examples/aggregation.rst000066400000000000000000000132471374256237000207100ustar00rootroot00000000000000Aggregation Examples ==================== There are several methods of performing aggregations in MongoDB. These examples cover the new aggregation framework, using map reduce and using the group method. .. testsetup:: from pymongo import MongoClient client = MongoClient() client.drop_database('aggregation_example') Setup ----- To start, we'll insert some example data which we can perform aggregations on: .. doctest:: >>> from pymongo import MongoClient >>> db = MongoClient().aggregation_example >>> result = db.things.insert_many([{"x": 1, "tags": ["dog", "cat"]}, ... {"x": 2, "tags": ["cat"]}, ... {"x": 2, "tags": ["mouse", "cat", "dog"]}, ... {"x": 3, "tags": []}]) >>> result.inserted_ids [ObjectId('...'), ObjectId('...'), ObjectId('...'), ObjectId('...')] .. _aggregate-examples: Aggregation Framework --------------------- This example shows how to use the :meth:`~pymongo.collection.Collection.aggregate` method to use the aggregation framework. We'll perform a simple aggregation to count the number of occurrences for each tag in the ``tags`` array, across the entire collection. To achieve this we need to pass in three operations to the pipeline. First, we need to unwind the ``tags`` array, then group by the tags and sum them up, finally we sort by count. As python dictionaries don't maintain order you should use :class:`~bson.son.SON` or :class:`collections.OrderedDict` where explicit ordering is required eg "$sort": .. note:: aggregate requires server version **>= 2.1.0**. .. doctest:: >>> from bson.son import SON >>> pipeline = [ ... {"$unwind": "$tags"}, ... {"$group": {"_id": "$tags", "count": {"$sum": 1}}}, ... {"$sort": SON([("count", -1), ("_id", -1)])} ... ] >>> import pprint >>> pprint.pprint(list(db.things.aggregate(pipeline))) [{u'_id': u'cat', u'count': 3}, {u'_id': u'dog', u'count': 2}, {u'_id': u'mouse', u'count': 1}] To run an explain plan for this aggregation use the :meth:`~pymongo.database.Database.command` method:: >>> db.command('aggregate', 'things', pipeline=pipeline, explain=True) {u'ok': 1.0, u'stages': [...]} As well as simple aggregations the aggregation framework provides projection capabilities to reshape the returned data. Using projections and aggregation, you can add computed fields, create new virtual sub-objects, and extract sub-fields into the top-level of results. .. seealso:: The full documentation for MongoDB's `aggregation framework `_ Map/Reduce ---------- Another option for aggregation is to use the map reduce framework. Here we will define **map** and **reduce** functions to also count the number of occurrences for each tag in the ``tags`` array, across the entire collection. Our **map** function just emits a single `(key, 1)` pair for each tag in the array: .. doctest:: >>> from bson.code import Code >>> mapper = Code(""" ... function () { ... this.tags.forEach(function(z) { ... emit(z, 1); ... }); ... } ... """) The **reduce** function sums over all of the emitted values for a given key: .. doctest:: >>> reducer = Code(""" ... function (key, values) { ... var total = 0; ... for (var i = 0; i < values.length; i++) { ... total += values[i]; ... } ... return total; ... } ... """) .. note:: We can't just return ``values.length`` as the **reduce** function might be called iteratively on the results of other reduce steps. Finally, we call :meth:`~pymongo.collection.Collection.map_reduce` and iterate over the result collection: .. doctest:: >>> result = db.things.map_reduce(mapper, reducer, "myresults") >>> for doc in result.find().sort("_id"): ... pprint.pprint(doc) ... {u'_id': u'cat', u'value': 3.0} {u'_id': u'dog', u'value': 2.0} {u'_id': u'mouse', u'value': 1.0} Advanced Map/Reduce ------------------- PyMongo's API supports all of the features of MongoDB's map/reduce engine. One interesting feature is the ability to get more detailed results when desired, by passing `full_response=True` to :meth:`~pymongo.collection.Collection.map_reduce`. This returns the full response to the map/reduce command, rather than just the result collection: .. doctest:: >>> pprint.pprint( ... db.things.map_reduce(mapper, reducer, "myresults", full_response=True)) {...u'ok': 1.0,... u'result': u'myresults'...} All of the optional map/reduce parameters are also supported, simply pass them as keyword arguments. In this example we use the `query` parameter to limit the documents that will be mapped over: .. doctest:: >>> results = db.things.map_reduce( ... mapper, reducer, "myresults", query={"x": {"$lt": 2}}) >>> for doc in results.find().sort("_id"): ... pprint.pprint(doc) ... {u'_id': u'cat', u'value': 1.0} {u'_id': u'dog', u'value': 1.0} You can use :class:`~bson.son.SON` or :class:`collections.OrderedDict` to specify a different database to store the result collection: .. doctest:: >>> from bson.son import SON >>> pprint.pprint( ... db.things.map_reduce( ... mapper, ... reducer, ... out=SON([("replace", "results"), ("db", "outdb")]), ... full_response=True)) {...u'ok': 1.0,... u'result': {u'collection': u'results', u'db': u'outdb'}...} .. seealso:: The full list of options for MongoDB's `map reduce engine `_ pymongo-3.11.0/doc/examples/authentication.rst000066400000000000000000000334521374256237000214400ustar00rootroot00000000000000Authentication Examples ======================= MongoDB supports several different authentication mechanisms. These examples cover all authentication methods currently supported by PyMongo, documenting Python module and MongoDB version dependencies. .. _percent escaped: Percent-Escaping Username and Password -------------------------------------- Username and password must be percent-escaped with :meth:`urllib.parse.quote_plus` in Python 3, or :meth:`urllib.quote_plus` in Python 2, to be used in a MongoDB URI. For example, in Python 3:: >>> from pymongo import MongoClient >>> import urllib.parse >>> username = urllib.parse.quote_plus('user') >>> username 'user' >>> password = urllib.parse.quote_plus('pass/word') >>> password 'pass%2Fword' >>> MongoClient('mongodb://%s:%s@127.0.0.1' % (username, password)) ... .. _scram_sha_256: SCRAM-SHA-256 (RFC 7677) ------------------------ .. versionadded:: 3.7 SCRAM-SHA-256 is the default authentication mechanism supported by a cluster configured for authentication with MongoDB 4.0 or later. Authentication requires a username, a password, and a database name. The default database name is "admin", this can be overridden with the ``authSource`` option. Credentials can be specified as arguments to :class:`~pymongo.mongo_client.MongoClient`:: >>> from pymongo import MongoClient >>> client = MongoClient('example.com', ... username='user', ... password='password', ... authSource='the_database', ... authMechanism='SCRAM-SHA-256') Or through the MongoDB URI:: >>> uri = "mongodb://user:password@example.com/?authSource=the_database&authMechanism=SCRAM-SHA-256" >>> client = MongoClient(uri) SCRAM-SHA-1 (RFC 5802) ---------------------- .. versionadded:: 2.8 SCRAM-SHA-1 is the default authentication mechanism supported by a cluster configured for authentication with MongoDB 3.0 or later. Authentication requires a username, a password, and a database name. The default database name is "admin", this can be overridden with the ``authSource`` option. Credentials can be specified as arguments to :class:`~pymongo.mongo_client.MongoClient`:: >>> from pymongo import MongoClient >>> client = MongoClient('example.com', ... username='user', ... password='password', ... authSource='the_database', ... authMechanism='SCRAM-SHA-1') Or through the MongoDB URI:: >>> uri = "mongodb://user:password@example.com/?authSource=the_database&authMechanism=SCRAM-SHA-1" >>> client = MongoClient(uri) For best performance on Python versions older than 2.7.8 install `backports.pbkdf2`_. .. _backports.pbkdf2: https://pypi.python.org/pypi/backports.pbkdf2/ MONGODB-CR ---------- .. warning:: MONGODB-CR was deprecated with the release of MongoDB 3.6 and is no longer supported by MongoDB 4.0. Before MongoDB 3.0 the default authentication mechanism was MONGODB-CR, the "MongoDB Challenge-Response" protocol:: >>> from pymongo import MongoClient >>> client = MongoClient('example.com', ... username='user', ... password='password', ... authMechanism='MONGODB-CR') >>> >>> uri = "mongodb://user:password@example.com/?authSource=the_database&authMechanism=MONGODB-CR" >>> client = MongoClient(uri) Default Authentication Mechanism -------------------------------- If no mechanism is specified, PyMongo automatically uses MONGODB-CR when connected to a pre-3.0 version of MongoDB, SCRAM-SHA-1 when connected to MongoDB 3.0 through 3.6, and negotiates the mechanism to use (SCRAM-SHA-1 or SCRAM-SHA-256) when connected to MongoDB 4.0+. Default Database and "authSource" --------------------------------- You can specify both a default database and the authentication database in the URI:: >>> uri = "mongodb://user:password@example.com/default_db?authSource=admin" >>> client = MongoClient(uri) PyMongo will authenticate on the "admin" database, but the default database will be "default_db":: >>> # get_database with no "name" argument chooses the DB from the URI >>> db = MongoClient(uri).get_database() >>> print(db.name) 'default_db' .. _mongodb_x509: MONGODB-X509 ------------ .. versionadded:: 2.6 The MONGODB-X509 mechanism authenticates a username derived from the distinguished subject name of the X.509 certificate presented by the driver during SSL negotiation. This authentication method requires the use of SSL connections with certificate validation and is available in MongoDB 2.6 and newer:: >>> import ssl >>> from pymongo import MongoClient >>> client = MongoClient('example.com', ... username="" ... authMechanism="MONGODB-X509", ... ssl=True, ... ssl_certfile='/path/to/client.pem', ... ssl_cert_reqs=ssl.CERT_REQUIRED, ... ssl_ca_certs='/path/to/ca.pem') MONGODB-X509 authenticates against the $external virtual database, so you do not have to specify a database in the URI:: >>> uri = "mongodb://@example.com/?authMechanism=MONGODB-X509" >>> client = MongoClient(uri, ... ssl=True, ... ssl_certfile='/path/to/client.pem', ... ssl_cert_reqs=ssl.CERT_REQUIRED, ... ssl_ca_certs='/path/to/ca.pem') >>> .. versionchanged:: 3.4 When connected to MongoDB >= 3.4 the username is no longer required. .. _gssapi: GSSAPI (Kerberos) ----------------- .. versionadded:: 2.5 GSSAPI (Kerberos) authentication is available in the Enterprise Edition of MongoDB. Unix ~~~~ To authenticate using GSSAPI you must first install the python `kerberos`_ or `pykerberos`_ module using easy_install or pip. Make sure you run kinit before using the following authentication methods:: $ kinit mongodbuser@EXAMPLE.COM mongodbuser@EXAMPLE.COM's Password: $ klist Credentials cache: FILE:/tmp/krb5cc_1000 Principal: mongodbuser@EXAMPLE.COM Issued Expires Principal Feb 9 13:48:51 2013 Feb 9 23:48:51 2013 krbtgt/EXAMPLE.COM@EXAMPLE.COM Now authenticate using the MongoDB URI. GSSAPI authenticates against the $external virtual database so you do not have to specify a database in the URI:: >>> # Note: the kerberos principal must be url encoded. >>> from pymongo import MongoClient >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM@mongo-server.example.com/?authMechanism=GSSAPI" >>> client = MongoClient(uri) >>> The default service name used by MongoDB and PyMongo is `mongodb`. You can specify a custom service name with the ``authMechanismProperties`` option:: >>> from pymongo import MongoClient >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM@mongo-server.example.com/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_NAME:myservicename" >>> client = MongoClient(uri) Windows (SSPI) ~~~~~~~~~~~~~~ .. versionadded:: 3.3 First install the `winkerberos`_ module. Unlike authentication on Unix kinit is not used. If the user to authenticate is different from the user that owns the application process provide a password to authenticate:: >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM:mongodbuserpassword@example.com/?authMechanism=GSSAPI" Two extra ``authMechanismProperties`` are supported on Windows platforms: - CANONICALIZE_HOST_NAME - Uses the fully qualified domain name (FQDN) of the MongoDB host for the server principal (GSSAPI libraries on Unix do this by default):: >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM@example.com/?authMechanism=GSSAPI&authMechanismProperties=CANONICALIZE_HOST_NAME:true" - SERVICE_REALM - This is used when the user's realm is different from the service's realm:: >>> uri = "mongodb://mongodbuser%40EXAMPLE.COM@example.com/?authMechanism=GSSAPI&authMechanismProperties=SERVICE_REALM:otherrealm" .. _kerberos: http://pypi.python.org/pypi/kerberos .. _pykerberos: https://pypi.python.org/pypi/pykerberos .. _winkerberos: https://pypi.python.org/pypi/winkerberos/ .. _sasl_plain: SASL PLAIN (RFC 4616) --------------------- .. versionadded:: 2.6 MongoDB Enterprise Edition version 2.6 and newer support the SASL PLAIN authentication mechanism, initially intended for delegating authentication to an LDAP server. Using the PLAIN mechanism is very similar to MONGODB-CR. These examples use the $external virtual database for LDAP support:: >>> from pymongo import MongoClient >>> uri = "mongodb://user:password@example.com/?authMechanism=PLAIN" >>> client = MongoClient(uri) >>> SASL PLAIN is a clear-text authentication mechanism. We **strongly** recommend that you connect to MongoDB using SSL with certificate validation when using the SASL PLAIN mechanism:: >>> import ssl >>> from pymongo import MongoClient >>> uri = "mongodb://user:password@example.com/?authMechanism=PLAIN" >>> client = MongoClient(uri, ... ssl=True, ... ssl_certfile='/path/to/client.pem', ... ssl_cert_reqs=ssl.CERT_REQUIRED, ... ssl_ca_certs='/path/to/ca.pem') >>> .. _MONGODB-AWS: MONGODB-AWS ----------- .. versionadded:: 3.11 The MONGODB-AWS authentication mechanism is available in MongoDB 4.4+ and requires extra pymongo dependencies. To use it, install pymongo with the ``aws`` extra:: $ python -m pip install 'pymongo[aws]' The MONGODB-AWS mechanism authenticates using AWS IAM credentials (an access key ID and a secret access key), `temporary AWS IAM credentials`_ obtained from an `AWS Security Token Service (STS)`_ `Assume Role`_ request, AWS Lambda `environment variables`_, or temporary AWS IAM credentials assigned to an `EC2 instance`_ or ECS task. The use of temporary credentials, in addition to an access key ID and a secret access key, also requires a security (or session) token. Credentials can be configured through the MongoDB URI, environment variables, or the local EC2 or ECS endpoint. The order in which the client searches for credentials is: #. Credentials passed through the URI #. Environment variables #. ECS endpoint if and only if ``AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`` is set. #. EC2 endpoint MONGODB-AWS authenticates against the "$external" virtual database, so none of the URIs in this section need to include the ``authSource`` URI option. AWS IAM credentials ~~~~~~~~~~~~~~~~~~~ Applications can authenticate using AWS IAM credentials by providing a valid access key id and secret access key pair as the username and password, respectively, in the MongoDB URI. A sample URI would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://:@localhost/?authMechanism=MONGODB-AWS" >>> client = MongoClient(uri) .. note:: The access_key_id and secret_access_key passed into the URI MUST be `percent escaped`_. AssumeRole ~~~~~~~~~~ Applications can authenticate using temporary credentials returned from an assume role request. These temporary credentials consist of an access key ID, a secret access key, and a security token passed into the URI. A sample URI would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://:@example.com/?authMechanism=MONGODB-AWS&authMechanismProperties=AWS_SESSION_TOKEN:" >>> client = MongoClient(uri) .. note:: The access_key_id, secret_access_key, and session_token passed into the URI MUST be `percent escaped`_. AWS Lambda (Environment Variables) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When the username and password are not provided and the MONGODB-AWS mechanism is set, the client will fallback to using the `environment variables`_ ``AWS_ACCESS_KEY_ID``, ``AWS_SECRET_ACCESS_KEY``, and ``AWS_SESSION_TOKEN`` for the access key ID, secret access key, and session token, respectively:: $ export AWS_ACCESS_KEY_ID= $ export AWS_SECRET_ACCESS_KEY= $ export AWS_SESSION_TOKEN= $ python >>> from pymongo import MongoClient >>> uri = "mongodb://example.com/?authMechanism=MONGODB-AWS" >>> client = MongoClient(uri) .. note:: No username, password, or session token is passed into the URI. PyMongo will use credentials set via the environment variables. These environment variables MUST NOT be `percent escaped`_. ECS Container ~~~~~~~~~~~~~ Applications can authenticate from an ECS container via temporary credentials assigned to the machine. A sample URI on an ECS container would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://localhost/?authMechanism=MONGODB-AWS" >>> client = MongoClient(uri) .. note:: No username, password, or session token is passed into the URI. PyMongo will query the ECS container endpoint to obtain these credentials. EC2 Instance ~~~~~~~~~~~~ Applications can authenticate from an EC2 instance via temporary credentials assigned to the machine. A sample URI on an EC2 machine would be:: >>> from pymongo import MongoClient >>> uri = "mongodb://localhost/?authMechanism=MONGODB-AWS" >>> client = MongoClient(uri) .. note:: No username, password, or session token is passed into the URI. PyMongo will query the EC2 instance endpoint to obtain these credentials. .. _temporary AWS IAM credentials: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp.html .. _AWS Security Token Service (STS): https://docs.aws.amazon.com/STS/latest/APIReference/Welcome.html .. _Assume Role: https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html .. _EC2 instance: https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_use_switch-role-ec2.html .. _environment variables: https://docs.aws.amazon.com/lambda/latest/dg/configuration-envvars.html#configuration-envvars-runtime pymongo-3.11.0/doc/examples/bulk.rst000066400000000000000000000132671374256237000173600ustar00rootroot00000000000000Bulk Write Operations ===================== .. testsetup:: from pymongo import MongoClient client = MongoClient() client.drop_database('bulk_example') This tutorial explains how to take advantage of PyMongo's bulk write operation features. Executing write operations in batches reduces the number of network round trips, increasing write throughput. Bulk Insert ----------- .. versionadded:: 2.6 A batch of documents can be inserted by passing a list to the :meth:`~pymongo.collection.Collection.insert_many` method. PyMongo will automatically split the batch into smaller sub-batches based on the maximum message size accepted by MongoDB, supporting very large bulk insert operations. .. doctest:: >>> import pymongo >>> db = pymongo.MongoClient().bulk_example >>> db.test.insert_many([{'i': i} for i in range(10000)]).inserted_ids [...] >>> db.test.count_documents({}) 10000 Mixed Bulk Write Operations --------------------------- .. versionadded:: 2.7 PyMongo also supports executing mixed bulk write operations. A batch of insert, update, and remove operations can be executed together using the bulk write operations API. .. _ordered_bulk: Ordered Bulk Write Operations ............................. Ordered bulk write operations are batched and sent to the server in the order provided for serial execution. The return value is an instance of :class:`~pymongo.results.BulkWriteResult` describing the type and count of operations performed. .. doctest:: :options: +NORMALIZE_WHITESPACE >>> from pprint import pprint >>> from pymongo import InsertOne, DeleteMany, ReplaceOne, UpdateOne >>> result = db.test.bulk_write([ ... DeleteMany({}), # Remove all documents from the previous example. ... InsertOne({'_id': 1}), ... InsertOne({'_id': 2}), ... InsertOne({'_id': 3}), ... UpdateOne({'_id': 1}, {'$set': {'foo': 'bar'}}), ... UpdateOne({'_id': 4}, {'$inc': {'j': 1}}, upsert=True), ... ReplaceOne({'j': 1}, {'j': 2})]) >>> pprint(result.bulk_api_result) {'nInserted': 3, 'nMatched': 2, 'nModified': 2, 'nRemoved': 10000, 'nUpserted': 1, 'upserted': [{u'_id': 4, u'index': 5}], 'writeConcernErrors': [], 'writeErrors': []} .. warning:: ``nModified`` is only reported by MongoDB 2.6 and later. When connected to an earlier server version, or in certain mixed version sharding configurations, PyMongo omits this field from the results of a bulk write operation. The first write failure that occurs (e.g. duplicate key error) aborts the remaining operations, and PyMongo raises :class:`~pymongo.errors.BulkWriteError`. The :attr:`details` attibute of the exception instance provides the execution results up until the failure occurred and details about the failure - including the operation that caused the failure. .. doctest:: :options: +NORMALIZE_WHITESPACE >>> from pymongo import InsertOne, DeleteOne, ReplaceOne >>> from pymongo.errors import BulkWriteError >>> requests = [ ... ReplaceOne({'j': 2}, {'i': 5}), ... InsertOne({'_id': 4}), # Violates the unique key constraint on _id. ... DeleteOne({'i': 5})] >>> try: ... db.test.bulk_write(requests) ... except BulkWriteError as bwe: ... pprint(bwe.details) ... {'nInserted': 0, 'nMatched': 1, 'nModified': 1, 'nRemoved': 0, 'nUpserted': 0, 'upserted': [], 'writeConcernErrors': [], 'writeErrors': [{u'code': 11000, u'errmsg': u'...E11000...duplicate key error...', u'index': 1,... u'op': {'_id': 4}}]} .. _unordered_bulk: Unordered Bulk Write Operations ............................... Unordered bulk write operations are batched and sent to the server in **arbitrary order** where they may be executed in parallel. Any errors that occur are reported after all operations are attempted. In the next example the first and third operations fail due to the unique constraint on _id. Since we are doing unordered execution the second and fourth operations succeed. .. doctest:: :options: +NORMALIZE_WHITESPACE >>> requests = [ ... InsertOne({'_id': 1}), ... DeleteOne({'_id': 2}), ... InsertOne({'_id': 3}), ... ReplaceOne({'_id': 4}, {'i': 1})] >>> try: ... db.test.bulk_write(requests, ordered=False) ... except BulkWriteError as bwe: ... pprint(bwe.details) ... {'nInserted': 0, 'nMatched': 1, 'nModified': 1, 'nRemoved': 1, 'nUpserted': 0, 'upserted': [], 'writeConcernErrors': [], 'writeErrors': [{u'code': 11000, u'errmsg': u'...E11000...duplicate key error...', u'index': 0,... u'op': {'_id': 1}}, {u'code': 11000, u'errmsg': u'...E11000...duplicate key error...', u'index': 2,... u'op': {'_id': 3}}]} Write Concern ............. Bulk operations are executed with the :attr:`~pymongo.collection.Collection.write_concern` of the collection they are executed against. Write concern errors (e.g. wtimeout) will be reported after all operations are attempted, regardless of execution order. :: >>> from pymongo import WriteConcern >>> coll = db.get_collection( ... 'test', write_concern=WriteConcern(w=3, wtimeout=1)) >>> try: ... coll.bulk_write([InsertOne({'a': i}) for i in range(4)]) ... except BulkWriteError as bwe: ... pprint(bwe.details) ... {'nInserted': 4, 'nMatched': 0, 'nModified': 0, 'nRemoved': 0, 'nUpserted': 0, 'upserted': [], 'writeConcernErrors': [{u'code': 64... u'errInfo': {u'wtimeout': True}, u'errmsg': u'waiting for replication timed out'}], 'writeErrors': []} pymongo-3.11.0/doc/examples/collations.rst000066400000000000000000000114741374256237000205700ustar00rootroot00000000000000Collations ========== .. seealso:: The API docs for :mod:`~pymongo.collation`. Collations are a new feature in MongoDB version 3.4. They provide a set of rules to use when comparing strings that comply with the conventions of a particular language, such as Spanish or German. If no collation is specified, the server sorts strings based on a binary comparison. Many languages have specific ordering rules, and collations allow users to build applications that adhere to language-specific comparison rules. In French, for example, the last accent in a given word determines the sorting order. The correct sorting order for the following four words in French is:: cote < côte < coté < côté Specifying a French collation allows users to sort string fields using the French sort order. Usage ----- Users can specify a collation for a :ref:`collection`, an :ref:`index`, or a :ref:`CRUD command `. Collation Parameters: ~~~~~~~~~~~~~~~~~~~~~ Collations can be specified with the :class:`~pymongo.collation.Collation` model or with plain Python dictionaries. The structure is the same:: Collation(locale=, caseLevel=, caseFirst=, strength=, numericOrdering=, alternate=, maxVariable=, backwards=) The only required parameter is ``locale``, which the server parses as an `ICU format locale ID `_. For example, set ``locale`` to ``en_US`` to represent US English or ``fr_CA`` to represent Canadian French. For a complete description of the available parameters, see the MongoDB `manual `_. .. COMMENT add link for manual entry. .. _collation-on-collection: Assign a Default Collation to a Collection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The following example demonstrates how to create a new collection called ``contacts`` and assign a default collation with the ``fr_CA`` locale. This operation ensures that all queries that are run against the ``contacts`` collection use the ``fr_CA`` collation unless another collation is explicitly specified:: from pymongo import MongoClient from pymongo.collation import Collation db = MongoClient().test collection = db.create_collection('contacts', collation=Collation(locale='fr_CA')) .. _collation-on-index: Assign a Default Collation to an Index ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ When creating a new index, you can specify a default collation. The following example shows how to create an index on the ``name`` field of the ``contacts`` collection, with the ``unique`` parameter enabled and a default collation with ``locale`` set to ``fr_CA``:: from pymongo import MongoClient from pymongo.collation import Collation contacts = MongoClient().test.contacts contacts.create_index('name', unique=True, collation=Collation(locale='fr_CA')) .. _collation-on-operation: Specify a Collation for a Query ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Individual queries can specify a collation to use when sorting results. The following example demonstrates a query that runs on the ``contacts`` collection in database ``test``. It matches on documents that contain ``New York`` in the ``city`` field, and sorts on the ``name`` field with the ``fr_CA`` collation:: from pymongo import MongoClient from pymongo.collation import Collation collection = MongoClient().test.contacts docs = collection.find({'city': 'New York'}).sort('name').collation( Collation(locale='fr_CA')) Other Query Types ~~~~~~~~~~~~~~~~~ You can use collations to control document matching rules for several different types of queries. All the various update and delete methods (:meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`, :meth:`~pymongo.collection.Collection.delete_one`, etc.) support collation, and you can create query filters which employ collations to comply with any of the languages and variants available to the ``locale`` parameter. The following example uses a collation with ``strength`` set to :const:`~pymongo.collation.CollationStrength.SECONDARY`, which considers only the base character and character accents in string comparisons, but not case sensitivity, for example. All documents in the ``contacts`` collection with ``jürgen`` (case-insensitive) in the ``first_name`` field are updated:: from pymongo import MongoClient from pymongo.collation import Collation, CollationStrength contacts = MongoClient().test.contacts result = contacts.update_many( {'first_name': 'jürgen'}, {'$set': {'verified': 1}}, collation=Collation(locale='de', strength=CollationStrength.SECONDARY)) pymongo-3.11.0/doc/examples/copydb.rst000066400000000000000000000027701374256237000177000ustar00rootroot00000000000000Copying a Database ================== To copy a database within a single mongod process, or between mongod servers, simply connect to the target mongod and use the :meth:`~pymongo.database.Database.command` method:: >>> from pymongo import MongoClient >>> client = MongoClient('target.example.com') >>> client.admin.command('copydb', fromdb='source_db_name', todb='target_db_name') To copy from a different mongod server that is not password-protected:: >>> client.admin.command('copydb', fromdb='source_db_name', todb='target_db_name', fromhost='source.example.com') If the target server is password-protected, authenticate to the "admin" database:: >>> client = MongoClient('target.example.com', ... username='administrator', ... password='pwd') >>> client.admin.command('copydb', fromdb='source_db_name', todb='target_db_name', fromhost='source.example.com') See the :doc:`authentication examples `. If the **source** server is password-protected, use the `copyDatabase function in the mongo shell`_. Versions of PyMongo before 3.0 included a ``copy_database`` helper method, but it has been removed. .. _copyDatabase function in the mongo shell: http://docs.mongodb.org/manual/reference/method/db.copyDatabase/ pymongo-3.11.0/doc/examples/custom_type.rst000066400000000000000000000355341374256237000207770ustar00rootroot00000000000000Custom Type Example =================== This is an example of using a custom type with PyMongo. The example here shows how to subclass :class:`~bson.codec_options.TypeCodec` to write a type codec, which is used to populate a :class:`~bson.codec_options.TypeRegistry`. The type registry can then be used to create a custom-type-aware :class:`~pymongo.collection.Collection`. Read and write operations issued against the resulting collection object transparently manipulate documents as they are saved to or retrieved from MongoDB. Setting Up ---------- We'll start by getting a clean database to use for the example: .. doctest:: >>> from pymongo import MongoClient >>> client = MongoClient() >>> client.drop_database('custom_type_example') >>> db = client.custom_type_example Since the purpose of the example is to demonstrate working with custom types, we'll need a custom data type to use. For this example, we will be working with the :py:class:`~decimal.Decimal` type from Python's standard library. Since the BSON library's :class:`~bson.decimal128.Decimal128` type (that implements the IEEE 754 decimal128 decimal-based floating-point numbering format) is distinct from Python's built-in :py:class:`~decimal.Decimal` type, attempting to save an instance of ``Decimal`` with PyMongo, results in an :exc:`~bson.errors.InvalidDocument` exception. .. doctest:: >>> from decimal import Decimal >>> num = Decimal("45.321") >>> db.test.insert_one({'num': num}) Traceback (most recent call last): ... bson.errors.InvalidDocument: cannot encode object: Decimal('45.321'), of type: .. _custom-type-type-codec: The :class:`~bson.codec_options.TypeCodec` Class ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. versionadded:: 3.8 In order to encode a custom type, we must first define a **type codec** for that type. A type codec describes how an instance of a custom type can be *transformed* to and/or from one of the types :mod:`~bson` already understands. Depending on the desired functionality, users must choose from the following base classes when defining type codecs: * :class:`~bson.codec_options.TypeEncoder`: subclass this to define a codec that encodes a custom Python type to a known BSON type. Users must implement the ``python_type`` property/attribute and the ``transform_python`` method. * :class:`~bson.codec_options.TypeDecoder`: subclass this to define a codec that decodes a specified BSON type into a custom Python type. Users must implement the ``bson_type`` property/attribute and the ``transform_bson`` method. * :class:`~bson.codec_options.TypeCodec`: subclass this to define a codec that can both encode and decode a custom type. Users must implement the ``python_type`` and ``bson_type`` properties/attributes, as well as the ``transform_python`` and ``transform_bson`` methods. The type codec for our custom type simply needs to define how a :py:class:`~decimal.Decimal` instance can be converted into a :class:`~bson.decimal128.Decimal128` instance and vice-versa. Since we are interested in both encoding and decoding our custom type, we use the ``TypeCodec`` base class to define our codec: .. doctest:: >>> from bson.decimal128 import Decimal128 >>> from bson.codec_options import TypeCodec >>> class DecimalCodec(TypeCodec): ... python_type = Decimal # the Python type acted upon by this type codec ... bson_type = Decimal128 # the BSON type acted upon by this type codec ... def transform_python(self, value): ... """Function that transforms a custom type value into a type ... that BSON can encode.""" ... return Decimal128(value) ... def transform_bson(self, value): ... """Function that transforms a vanilla BSON type value into our ... custom type.""" ... return value.to_decimal() >>> decimal_codec = DecimalCodec() .. _custom-type-type-registry: The :class:`~bson.codec_options.TypeRegistry` Class ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. versionadded:: 3.8 Before we can begin encoding and decoding our custom type objects, we must first inform PyMongo about the corresponding codec. This is done by creating a :class:`~bson.codec_options.TypeRegistry` instance: .. doctest:: >>> from bson.codec_options import TypeRegistry >>> type_registry = TypeRegistry([decimal_codec]) Note that type registries can be instantiated with any number of type codecs. Once instantiated, registries are immutable and the only way to add codecs to a registry is to create a new one. Putting It Together ------------------- Finally, we can define a :class:`~bson.codec_options.CodecOptions` instance with our ``type_registry`` and use it to get a :class:`~pymongo.collection.Collection` object that understands the :py:class:`~decimal.Decimal` data type: .. doctest:: >>> from bson.codec_options import CodecOptions >>> codec_options = CodecOptions(type_registry=type_registry) >>> collection = db.get_collection('test', codec_options=codec_options) Now, we can seamlessly encode and decode instances of :py:class:`~decimal.Decimal`: .. doctest:: >>> collection.insert_one({'num': Decimal("45.321")}) >>> mydoc = collection.find_one() >>> import pprint >>> pprint.pprint(mydoc) {u'_id': ObjectId('...'), u'num': Decimal('45.321')} We can see what's actually being saved to the database by creating a fresh collection object without the customized codec options and using that to query MongoDB: .. doctest:: >>> vanilla_collection = db.get_collection('test') >>> pprint.pprint(vanilla_collection.find_one()) {u'_id': ObjectId('...'), u'num': Decimal128('45.321')} Encoding Subtypes ^^^^^^^^^^^^^^^^^ Consider the situation where, in addition to encoding :py:class:`~decimal.Decimal`, we also need to encode a type that subclasses ``Decimal``. PyMongo does this automatically for types that inherit from Python types that are BSON-encodable by default, but the type codec system described above does not offer the same flexibility. Consider this subtype of ``Decimal`` that has a method to return its value as an integer: .. doctest:: >>> class DecimalInt(Decimal): ... def my_method(self): ... """Method implementing some custom logic.""" ... return int(self) If we try to save an instance of this type without first registering a type codec for it, we get an error: .. doctest:: >>> collection.insert_one({'num': DecimalInt("45.321")}) Traceback (most recent call last): ... bson.errors.InvalidDocument: cannot encode object: Decimal('45.321'), of type: In order to proceed further, we must define a type codec for ``DecimalInt``. This is trivial to do since the same transformation as the one used for ``Decimal`` is adequate for encoding ``DecimalInt`` as well: .. doctest:: >>> class DecimalIntCodec(DecimalCodec): ... @property ... def python_type(self): ... """The Python type acted upon by this type codec.""" ... return DecimalInt >>> decimalint_codec = DecimalIntCodec() .. note:: No attempt is made to modify decoding behavior because without additional information, it is impossible to discern which incoming :class:`~bson.decimal128.Decimal128` value needs to be decoded as ``Decimal`` and which needs to be decoded as ``DecimalInt``. This example only considers the situation where a user wants to *encode* documents containing either of these types. After creating a new codec options object and using it to get a collection object, we can seamlessly encode instances of ``DecimalInt``: .. doctest:: >>> type_registry = TypeRegistry([decimal_codec, decimalint_codec]) >>> codec_options = CodecOptions(type_registry=type_registry) >>> collection = db.get_collection('test', codec_options=codec_options) >>> collection.drop() >>> collection.insert_one({'num': DecimalInt("45.321")}) >>> mydoc = collection.find_one() >>> pprint.pprint(mydoc) {u'_id': ObjectId('...'), u'num': Decimal('45.321')} Note that the ``transform_bson`` method of the base codec class results in these values being decoded as ``Decimal`` (and not ``DecimalInt``). .. _decoding-binary-types: Decoding :class:`~bson.binary.Binary` Types ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The decoding treatment of :class:`~bson.binary.Binary` types having ``subtype = 0`` by the :mod:`bson` module varies slightly depending on the version of the Python runtime in use. This must be taken into account while writing a ``TypeDecoder`` that modifies how this datatype is decoded. On Python 3.x, :class:`~bson.binary.Binary` data (``subtype = 0``) is decoded as a ``bytes`` instance: .. code-block:: python >>> # On Python 3.x. >>> from bson.binary import Binary >>> newcoll = db.get_collection('new') >>> newcoll.insert_one({'_id': 1, 'data': Binary(b"123", subtype=0)}) >>> doc = newcoll.find_one() >>> type(doc['data']) bytes On Python 2.7.x, the same data is decoded as a :class:`~bson.binary.Binary` instance: .. code-block:: python >>> # On Python 2.7.x >>> newcoll = db.get_collection('new') >>> doc = newcoll.find_one() >>> type(doc['data']) bson.binary.Binary As a consequence of this disparity, users must set the ``bson_type`` attribute on their :class:`~bson.codec_options.TypeDecoder` classes differently, depending on the python version in use. .. note:: For codebases requiring compatibility with both Python 2 and 3, type decoders will have to be registered for both possible ``bson_type`` values. .. _fallback-encoder-callable: The ``fallback_encoder`` Callable --------------------------------- .. versionadded:: 3.8 In addition to type codecs, users can also register a callable to encode types that BSON doesn't recognize and for which no type codec has been registered. This callable is the **fallback encoder** and like the ``transform_python`` method, it accepts an unencodable value as a parameter and returns a BSON-encodable value. The following fallback encoder encodes python's :py:class:`~decimal.Decimal` type to a :class:`~bson.decimal128.Decimal128`: .. doctest:: >>> def fallback_encoder(value): ... if isinstance(value, Decimal): ... return Decimal128(value) ... return value After declaring the callback, we must create a type registry and codec options with this fallback encoder before it can be used for initializing a collection: .. doctest:: >>> type_registry = TypeRegistry(fallback_encoder=fallback_encoder) >>> codec_options = CodecOptions(type_registry=type_registry) >>> collection = db.get_collection('test', codec_options=codec_options) >>> collection.drop() We can now seamlessly encode instances of :py:class:`~decimal.Decimal`: .. doctest:: >>> collection.insert_one({'num': Decimal("45.321")}) >>> mydoc = collection.find_one() >>> pprint.pprint(mydoc) {u'_id': ObjectId('...'), u'num': Decimal128('45.321')} .. note:: Fallback encoders are invoked *after* attempts to encode the given value with standard BSON encoders and any configured type encoders have failed. Therefore, in a type registry configured with a type encoder and fallback encoder that both target the same custom type, the behavior specified in the type encoder will prevail. Because fallback encoders don't need to declare the types that they encode beforehand, they can be used to support interesting use-cases that cannot be serviced by ``TypeEncoder``. One such use-case is described in the next section. Encoding Unknown Types ^^^^^^^^^^^^^^^^^^^^^^ In this example, we demonstrate how a fallback encoder can be used to save arbitrary objects to the database. We will use the the standard library's :py:mod:`pickle` module to serialize the unknown types and so naturally, this approach only works for types that are picklable. We start by defining some arbitrary custom types: .. code-block:: python class MyStringType(object): def __init__(self, value): self.__value = value def __repr__(self): return "MyStringType('%s')" % (self.__value,) class MyNumberType(object): def __init__(self, value): self.__value = value def __repr__(self): return "MyNumberType(%s)" % (self.__value,) We also define a fallback encoder that pickles whatever objects it receives and returns them as :class:`~bson.binary.Binary` instances with a custom subtype. The custom subtype, in turn, allows us to write a TypeDecoder that identifies pickled artifacts upon retrieval and transparently decodes them back into Python objects: .. code-block:: python import pickle from bson.binary import Binary, USER_DEFINED_SUBTYPE def fallback_pickle_encoder(value): return Binary(pickle.dumps(value), USER_DEFINED_SUBTYPE) class PickledBinaryDecoder(TypeDecoder): bson_type = Binary def transform_bson(self, value): if value.subtype == USER_DEFINED_SUBTYPE: return pickle.loads(value) return value .. note:: The above example is written assuming the use of Python 3. If you are using Python 2, ``bson_type`` must be set to ``Binary``. See the :ref:`decoding-binary-types` section for a detailed explanation. Finally, we create a ``CodecOptions`` instance: .. code-block:: python codec_options = CodecOptions(type_registry=TypeRegistry( [PickledBinaryDecoder()], fallback_encoder=fallback_pickle_encoder)) We can now round trip our custom objects to MongoDB: .. code-block:: python collection = db.get_collection('test_fe', codec_options=codec_options) collection.insert_one({'_id': 1, 'str': MyStringType("hello world"), 'num': MyNumberType(2)}) mydoc = collection.find_one() assert isinstance(mydoc['str'], MyStringType) assert isinstance(mydoc['num'], MyNumberType) Limitations ----------- PyMongo's type codec and fallback encoder features have the following limitations: #. Users cannot customize the encoding behavior of Python types that PyMongo already understands like ``int`` and ``str`` (the 'built-in types'). Attempting to instantiate a type registry with one or more codecs that act upon a built-in type results in a ``TypeError``. This limitation extends to all subtypes of the standard types. #. Chaining type encoders is not supported. A custom type value, once transformed by a codec's ``transform_python`` method, *must* result in a type that is either BSON-encodable by default, or can be transformed by the fallback encoder into something BSON-encodable--it *cannot* be transformed a second time by a different type codec. #. The :meth:`~pymongo.database.Database.command` method does not apply the user's TypeDecoders while decoding the command response document. #. :mod:`gridfs` does not apply custom type encoding or decoding to any documents received from or to returned to the user. pymongo-3.11.0/doc/examples/datetimes.rst000066400000000000000000000074511374256237000204000ustar00rootroot00000000000000Datetimes and Timezones ======================= .. testsetup:: import datetime from pymongo import MongoClient from bson.codec_options import CodecOptions client = MongoClient() client.drop_database('dt_example') db = client.dt_example These examples show how to handle Python :class:`datetime.datetime` objects correctly in PyMongo. Basic Usage ----------- PyMongo uses :class:`datetime.datetime` objects for representing dates and times in MongoDB documents. Because MongoDB assumes that dates and times are in UTC, care should be taken to ensure that dates and times written to the database reflect UTC. For example, the following code stores the current UTC date and time into MongoDB: .. doctest:: >>> result = db.objects.insert_one( ... {"last_modified": datetime.datetime.utcnow()}) Always use :meth:`datetime.datetime.utcnow`, which returns the current time in UTC, instead of :meth:`datetime.datetime.now`, which returns the current local time. Avoid doing this: .. doctest:: >>> result = db.objects.insert_one( ... {"last_modified": datetime.datetime.now()}) The value for `last_modified` is very different between these two examples, even though both documents were stored at around the same local time. This will be confusing to the application that reads them: .. doctest:: >>> [doc['last_modified'] for doc in db.objects.find()] # doctest: +SKIP [datetime.datetime(2015, 7, 8, 18, 17, 28, 324000), datetime.datetime(2015, 7, 8, 11, 17, 42, 911000)] :class:`bson.codec_options.CodecOptions` has a `tz_aware` option that enables "aware" :class:`datetime.datetime` objects, i.e., datetimes that know what timezone they're in. By default, PyMongo retrieves naive datetimes: .. doctest:: >>> result = db.tzdemo.insert_one( ... {'date': datetime.datetime(2002, 10, 27, 6, 0, 0)}) >>> db.tzdemo.find_one()['date'] datetime.datetime(2002, 10, 27, 6, 0) >>> options = CodecOptions(tz_aware=True) >>> db.get_collection('tzdemo', codec_options=options).find_one()['date'] # doctest: +SKIP datetime.datetime(2002, 10, 27, 6, 0, tzinfo=) Saving Datetimes with Timezones ------------------------------- When storing :class:`datetime.datetime` objects that specify a timezone (i.e. they have a `tzinfo` property that isn't ``None``), PyMongo will convert those datetimes to UTC automatically: .. doctest:: >>> import pytz >>> pacific = pytz.timezone('US/Pacific') >>> aware_datetime = pacific.localize( ... datetime.datetime(2002, 10, 27, 6, 0, 0)) >>> result = db.times.insert_one({"date": aware_datetime}) >>> db.times.find_one()['date'] datetime.datetime(2002, 10, 27, 14, 0) Reading Time ------------ As previously mentioned, by default all :class:`datetime.datetime` objects returned by PyMongo will be naive but reflect UTC (i.e. the time as stored in MongoDB). By setting the `tz_aware` option on :class:`~bson.codec_options.CodecOptions`, :class:`datetime.datetime` objects will be timezone-aware and have a `tzinfo` property that reflects the UTC timezone. PyMongo 3.1 introduced a `tzinfo` property that can be set on :class:`~bson.codec_options.CodecOptions` to convert :class:`datetime.datetime` objects to local time automatically. For example, if we wanted to read all times out of MongoDB in US/Pacific time: >>> from bson.codec_options import CodecOptions >>> db.times.find_one()['date'] datetime.datetime(2002, 10, 27, 14, 0) >>> aware_times = db.times.with_options(codec_options=CodecOptions( ... tz_aware=True, ... tzinfo=pytz.timezone('US/Pacific'))) >>> result = aware_times.find_one() datetime.datetime(2002, 10, 27, 6, 0, # doctest: +NORMALIZE_WHITESPACE tzinfo=) pymongo-3.11.0/doc/examples/encryption.rst000066400000000000000000000465341374256237000206200ustar00rootroot00000000000000.. _Client-Side Field Level Encryption: Client-Side Field Level Encryption ================================== New in MongoDB 4.2, client-side field level encryption allows an application to encrypt specific data fields in addition to pre-existing MongoDB encryption features such as `Encryption at Rest `_ and `TLS/SSL (Transport Encryption) `_. With field level encryption, applications can encrypt fields in documents *prior* to transmitting data over the wire to the server. Client-side field level encryption supports workloads where applications must guarantee that unauthorized parties, including server administrators, cannot read the encrypted data. .. mongodoc:: client-side-field-level-encryption Dependencies ------------ To get started using client-side field level encryption in your project, you will need to install the `pymongocrypt `_ library as well as the driver itself. Install both the driver and a compatible version of pymongocrypt like this:: $ python -m pip install 'pymongo[encryption]' Note that installing on Linux requires pip 19 or later for manylinux2010 wheel support. For more information about installing pymongocrypt see `the installation instructions on the project's PyPI page `_. mongocryptd ----------- The ``mongocryptd`` binary is required for automatic client-side encryption and is included as a component in the `MongoDB Enterprise Server package `_. For detailed installation instructions see `the MongoDB documentation on mongocryptd `_. ``mongocryptd`` performs the following: - Parses the automatic encryption rules specified to the database connection. If the JSON schema contains invalid automatic encryption syntax or any document validation syntax, ``mongocryptd`` returns an error. - Uses the specified automatic encryption rules to mark fields in read and write operations for encryption. - Rejects read/write operations that may return unexpected or incorrect results when applied to an encrypted field. For supported and unsupported operations, see `Read/Write Support with Automatic Field Level Encryption `_. A MongoClient configured with auto encryption will automatically spawn the ``mongocryptd`` process from the application's ``PATH``. Applications can control the spawning behavior as part of the automatic encryption options. For example to set the path to the ``mongocryptd`` process:: auto_encryption_opts = AutoEncryptionOpts( ..., mongocryptd_spawn_path='/path/to/mongocryptd') To control the logging output of ``mongocryptd`` pass options using ``mongocryptd_spawn_args``:: auto_encryption_opts = AutoEncryptionOpts( ..., mongocryptd_spawn_args=['--logpath=/path/to/mongocryptd.log', '--logappend']) If your application wishes to manage the ``mongocryptd`` process manually, it is possible to disable spawning ``mongocryptd``:: auto_encryption_opts = AutoEncryptionOpts( ..., mongocryptd_bypass_spawn=True, # URI of the local ``mongocryptd`` process. mongocryptd_uri='mongodb://localhost:27020') ``mongocryptd`` is only responsible for supporting automatic client-side field level encryption and does not itself perform any encryption or decryption. .. _automatic-client-side-encryption: Automatic Client-Side Field Level Encryption ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Automatic client-side field level encryption is enabled by creating a :class:`~pymongo.mongo_client.MongoClient` with the ``auto_encryption_opts`` option set to an instance of :class:`~pymongo.encryption_options.AutoEncryptionOpts`. The following examples show how to setup automatic client-side field level encryption using :class:`~pymongo.encryption.ClientEncryption` to create a new encryption data key. .. note:: Automatic client-side field level encryption requires MongoDB 4.2 enterprise or a MongoDB 4.2 Atlas cluster. The community version of the server supports automatic decryption as well as :ref:`explicit-client-side-encryption`. Providing Local Automatic Encryption Rules `````````````````````````````````````````` The following example shows how to specify automatic encryption rules via the ``schema_map`` option. The automatic encryption rules are expressed using a `strict subset of the JSON Schema syntax `_. Supplying a ``schema_map`` provides more security than relying on JSON Schemas obtained from the server. It protects against a malicious server advertising a false JSON Schema, which could trick the client into sending unencrypted data that should be encrypted. JSON Schemas supplied in the ``schema_map`` only apply to configuring automatic client-side field level encryption. Other validation rules in the JSON schema will not be enforced by the driver and will result in an error.:: import os from bson.codec_options import CodecOptions from bson import json_util from pymongo import MongoClient from pymongo.encryption import (Algorithm, ClientEncryption) from pymongo.encryption_options import AutoEncryptionOpts def create_json_schema_file(kms_providers, key_vault_namespace, key_vault_client): client_encryption = ClientEncryption( kms_providers, key_vault_namespace, key_vault_client, # The CodecOptions class used for encrypting and decrypting. # This should be the same CodecOptions instance you have configured # on MongoClient, Database, or Collection. We will not be calling # encrypt() or decrypt() in this example so we can use any # CodecOptions. CodecOptions()) # Create a new data key and json schema for the encryptedField. # https://dochub.mongodb.org/core/client-side-field-level-encryption-automatic-encryption-rules data_key_id = client_encryption.create_data_key( 'local', key_alt_names=['pymongo_encryption_example_1']) schema = { "properties": { "encryptedField": { "encrypt": { "keyId": [data_key_id], "bsonType": "string", "algorithm": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic } } }, "bsonType": "object" } # Use CANONICAL_JSON_OPTIONS so that other drivers and tools will be # able to parse the MongoDB extended JSON file. json_schema_string = json_util.dumps( schema, json_options=json_util.CANONICAL_JSON_OPTIONS) with open('jsonSchema.json', 'w') as file: file.write(json_schema_string) def main(): # The MongoDB namespace (db.collection) used to store the # encrypted documents in this example. encrypted_namespace = "test.coll" # This must be the same master key that was used to create # the encryption key. local_master_key = os.urandom(96) kms_providers = {"local": {"key": local_master_key}} # The MongoDB namespace (db.collection) used to store # the encryption data keys. key_vault_namespace = "encryption.__pymongoTestKeyVault" key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1) # The MongoClient used to access the key vault (key_vault_namespace). key_vault_client = MongoClient() key_vault = key_vault_client[key_vault_db_name][key_vault_coll_name] # Ensure that two data keys cannot share the same keyAltName. key_vault.drop() key_vault.create_index( "keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}) create_json_schema_file( kms_providers, key_vault_namespace, key_vault_client) # Load the JSON Schema and construct the local schema_map option. with open('jsonSchema.json', 'r') as file: json_schema_string = file.read() json_schema = json_util.loads(json_schema_string) schema_map = {encrypted_namespace: json_schema} auto_encryption_opts = AutoEncryptionOpts( kms_providers, key_vault_namespace, schema_map=schema_map) client = MongoClient(auto_encryption_opts=auto_encryption_opts) db_name, coll_name = encrypted_namespace.split(".", 1) coll = client[db_name][coll_name] # Clear old data coll.drop() coll.insert_one({"encryptedField": "123456789"}) print('Decrypted document: %s' % (coll.find_one(),)) unencrypted_coll = MongoClient()[db_name][coll_name] print('Encrypted document: %s' % (unencrypted_coll.find_one(),)) if __name__ == "__main__": main() Server-Side Field Level Encryption Enforcement `````````````````````````````````````````````` The MongoDB 4.2 server supports using schema validation to enforce encryption of specific fields in a collection. This schema validation will prevent an application from inserting unencrypted values for any fields marked with the ``"encrypt"`` JSON schema keyword. The following example shows how to setup automatic client-side field level encryption using :class:`~pymongo.encryption.ClientEncryption` to create a new encryption data key and create a collection with the `Automatic Encryption JSON Schema Syntax `_:: import os from bson.codec_options import CodecOptions from bson.binary import STANDARD from pymongo import MongoClient from pymongo.encryption import (Algorithm, ClientEncryption) from pymongo.encryption_options import AutoEncryptionOpts from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern def main(): # The MongoDB namespace (db.collection) used to store the # encrypted documents in this example. encrypted_namespace = "test.coll" # This must be the same master key that was used to create # the encryption key. local_master_key = os.urandom(96) kms_providers = {"local": {"key": local_master_key}} # The MongoDB namespace (db.collection) used to store # the encryption data keys. key_vault_namespace = "encryption.__pymongoTestKeyVault" key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1) # The MongoClient used to access the key vault (key_vault_namespace). key_vault_client = MongoClient() key_vault = key_vault_client[key_vault_db_name][key_vault_coll_name] # Ensure that two data keys cannot share the same keyAltName. key_vault.drop() key_vault.create_index( "keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}) client_encryption = ClientEncryption( kms_providers, key_vault_namespace, key_vault_client, # The CodecOptions class used for encrypting and decrypting. # This should be the same CodecOptions instance you have configured # on MongoClient, Database, or Collection. We will not be calling # encrypt() or decrypt() in this example so we can use any # CodecOptions. CodecOptions()) # Create a new data key and json schema for the encryptedField. data_key_id = client_encryption.create_data_key( 'local', key_alt_names=['pymongo_encryption_example_2']) json_schema = { "properties": { "encryptedField": { "encrypt": { "keyId": [data_key_id], "bsonType": "string", "algorithm": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic } } }, "bsonType": "object" } auto_encryption_opts = AutoEncryptionOpts( kms_providers, key_vault_namespace) client = MongoClient(auto_encryption_opts=auto_encryption_opts) db_name, coll_name = encrypted_namespace.split(".", 1) db = client[db_name] # Clear old data db.drop_collection(coll_name) # Create the collection with the encryption JSON Schema. db.create_collection( coll_name, # uuid_representation=STANDARD is required to ensure that any # UUIDs in the $jsonSchema document are encoded to BSON Binary # with the standard UUID subtype 4. This is only needed when # running the "create" collection command with an encryption # JSON Schema. codec_options=CodecOptions(uuid_representation=STANDARD), write_concern=WriteConcern(w="majority"), validator={"$jsonSchema": json_schema}) coll = client[db_name][coll_name] coll.insert_one({"encryptedField": "123456789"}) print('Decrypted document: %s' % (coll.find_one(),)) unencrypted_coll = MongoClient()[db_name][coll_name] print('Encrypted document: %s' % (unencrypted_coll.find_one(),)) try: unencrypted_coll.insert_one({"encryptedField": "123456789"}) except OperationFailure as exc: print('Unencrypted insert failed: %s' % (exc.details,)) if __name__ == "__main__": main() .. _explicit-client-side-encryption: Explicit Encryption ~~~~~~~~~~~~~~~~~~~ Explicit encryption is a MongoDB community feature and does not use the ``mongocryptd`` process. Explicit encryption is provided by the :class:`~pymongo.encryption.ClientEncryption` class, for example:: import os from pymongo import MongoClient from pymongo.encryption import (Algorithm, ClientEncryption) def main(): # This must be the same master key that was used to create # the encryption key. local_master_key = os.urandom(96) kms_providers = {"local": {"key": local_master_key}} # The MongoDB namespace (db.collection) used to store # the encryption data keys. key_vault_namespace = "encryption.__pymongoTestKeyVault" key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1) # The MongoClient used to read/write application data. client = MongoClient() coll = client.test.coll # Clear old data coll.drop() # Set up the key vault (key_vault_namespace) for this example. key_vault = client[key_vault_db_name][key_vault_coll_name] # Ensure that two data keys cannot share the same keyAltName. key_vault.drop() key_vault.create_index( "keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}) client_encryption = ClientEncryption( kms_providers, key_vault_namespace, # The MongoClient to use for reading/writing to the key vault. # This can be the same MongoClient used by the main application. client, # The CodecOptions class used for encrypting and decrypting. # This should be the same CodecOptions instance you have configured # on MongoClient, Database, or Collection. coll.codec_options) # Create a new data key for the encryptedField. data_key_id = client_encryption.create_data_key( 'local', key_alt_names=['pymongo_encryption_example_3']) # Explicitly encrypt a field: encrypted_field = client_encryption.encrypt( "123456789", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id) coll.insert_one({"encryptedField": encrypted_field}) doc = coll.find_one() print('Encrypted document: %s' % (doc,)) # Explicitly decrypt the field: doc["encryptedField"] = client_encryption.decrypt(doc["encryptedField"]) print('Decrypted document: %s' % (doc,)) # Cleanup resources. client_encryption.close() client.close() if __name__ == "__main__": main() Explicit Encryption with Automatic Decryption ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Although automatic encryption requires MongoDB 4.2 enterprise or a MongoDB 4.2 Atlas cluster, automatic *decryption* is supported for all users. To configure automatic *decryption* without automatic *encryption* set ``bypass_auto_encryption=True`` in :class:`~pymongo.encryption_options.AutoEncryptionOpts`:: import os from pymongo import MongoClient from pymongo.encryption import (Algorithm, ClientEncryption) from pymongo.encryption_options import AutoEncryptionOpts def main(): # This must be the same master key that was used to create # the encryption key. local_master_key = os.urandom(96) kms_providers = {"local": {"key": local_master_key}} # The MongoDB namespace (db.collection) used to store # the encryption data keys. key_vault_namespace = "encryption.__pymongoTestKeyVault" key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1) # bypass_auto_encryption=True disable automatic encryption but keeps # the automatic _decryption_ behavior. bypass_auto_encryption will # also disable spawning mongocryptd. auto_encryption_opts = AutoEncryptionOpts( kms_providers, key_vault_namespace, bypass_auto_encryption=True) client = MongoClient(auto_encryption_opts=auto_encryption_opts) coll = client.test.coll # Clear old data coll.drop() # Set up the key vault (key_vault_namespace) for this example. key_vault = client[key_vault_db_name][key_vault_coll_name] # Ensure that two data keys cannot share the same keyAltName. key_vault.drop() key_vault.create_index( "keyAltNames", unique=True, partialFilterExpression={"keyAltNames": {"$exists": True}}) client_encryption = ClientEncryption( kms_providers, key_vault_namespace, # The MongoClient to use for reading/writing to the key vault. # This can be the same MongoClient used by the main application. client, # The CodecOptions class used for encrypting and decrypting. # This should be the same CodecOptions instance you have configured # on MongoClient, Database, or Collection. coll.codec_options) # Create a new data key for the encryptedField. data_key_id = client_encryption.create_data_key( 'local', key_alt_names=['pymongo_encryption_example_4']) # Explicitly encrypt a field: encrypted_field = client_encryption.encrypt( "123456789", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name='pymongo_encryption_example_4') coll.insert_one({"encryptedField": encrypted_field}) # Automatically decrypts any encrypted fields. doc = coll.find_one() print('Decrypted document: %s' % (doc,)) unencrypted_coll = MongoClient().test.coll print('Encrypted document: %s' % (unencrypted_coll.find_one(),)) # Cleanup resources. client_encryption.close() client.close() if __name__ == "__main__": main() pymongo-3.11.0/doc/examples/geo.rst000066400000000000000000000063701374256237000171720ustar00rootroot00000000000000Geospatial Indexing Example =========================== .. testsetup:: from pymongo import MongoClient client = MongoClient() client.drop_database('geo_example') This example shows how to create and use a :data:`~pymongo.GEO2D` index in PyMongo. To create a spherical (earth-like) geospatial index use :data:`~pymongo.GEOSPHERE` instead. .. mongodoc:: geo Creating a Geospatial Index --------------------------- Creating a geospatial index in pymongo is easy: .. doctest:: >>> from pymongo import MongoClient, GEO2D >>> db = MongoClient().geo_example >>> db.places.create_index([("loc", GEO2D)]) u'loc_2d' Inserting Places ---------------- Locations in MongoDB are represented using either embedded documents or lists where the first two elements are coordinates. Here, we'll insert a couple of example locations: .. doctest:: >>> result = db.places.insert_many([{"loc": [2, 5]}, ... {"loc": [30, 5]}, ... {"loc": [1, 2]}, ... {"loc": [4, 4]}]) # doctest: +ELLIPSIS >>> result.inserted_ids [ObjectId('...'), ObjectId('...'), ObjectId('...'), ObjectId('...')] .. note:: If specifying latitude and longitude coordinates in :data:`~pymongo.GEOSPHERE`, list the **longitude** first and then **latitude**. Querying -------- Using the geospatial index we can find documents near another point: .. doctest:: >>> import pprint >>> for doc in db.places.find({"loc": {"$near": [3, 6]}}).limit(3): ... pprint.pprint(doc) ... {u'_id': ObjectId('...'), u'loc': [2, 5]} {u'_id': ObjectId('...'), u'loc': [4, 4]} {u'_id': ObjectId('...'), u'loc': [1, 2]} .. note:: If using :data:`pymongo.GEOSPHERE`, using $nearSphere is recommended. The $maxDistance operator requires the use of :class:`~bson.son.SON`: .. doctest:: >>> from bson.son import SON >>> query = {"loc": SON([("$near", [3, 6]), ("$maxDistance", 100)])} >>> for doc in db.places.find(query).limit(3): ... pprint.pprint(doc) ... {u'_id': ObjectId('...'), u'loc': [2, 5]} {u'_id': ObjectId('...'), u'loc': [4, 4]} {u'_id': ObjectId('...'), u'loc': [1, 2]} It's also possible to query for all items within a given rectangle (specified by lower-left and upper-right coordinates): .. doctest:: >>> query = {"loc": {"$within": {"$box": [[2, 2], [5, 6]]}}} >>> for doc in db.places.find(query).sort('_id'): ... pprint.pprint(doc) {u'_id': ObjectId('...'), u'loc': [2, 5]} {u'_id': ObjectId('...'), u'loc': [4, 4]} Or circle (specified by center point and radius): .. doctest:: >>> query = {"loc": {"$within": {"$center": [[0, 0], 6]}}} >>> for doc in db.places.find(query).sort('_id'): ... pprint.pprint(doc) ... {u'_id': ObjectId('...'), u'loc': [2, 5]} {u'_id': ObjectId('...'), u'loc': [1, 2]} {u'_id': ObjectId('...'), u'loc': [4, 4]} geoNear queries are also supported using :class:`~bson.son.SON`:: >>> from bson.son import SON >>> db.command(SON([('geoNear', 'places'), ('near', [1, 2])])) {u'ok': 1.0, u'stats': ...} .. warning:: Starting in MongoDB version 4.0, MongoDB deprecates the **geoNear** command. Use one of the following operations instead. * $geoNear - aggregation stage. * $near - query operator. * $nearSphere - query operator. pymongo-3.11.0/doc/examples/gevent.rst000066400000000000000000000036321374256237000177060ustar00rootroot00000000000000Gevent ====== PyMongo supports `Gevent `_. Simply call Gevent's ``monkey.patch_all()`` before loading any other modules: .. doctest:: >>> # You must call patch_all() *before* importing any other modules >>> from gevent import monkey >>> _ = monkey.patch_all() >>> from pymongo import MongoClient >>> client = MongoClient() PyMongo uses thread and socket functions from the Python standard library. Gevent's monkey-patching replaces those standard functions so that PyMongo does asynchronous I/O with non-blocking sockets, and schedules operations on greenlets instead of threads. Avoid blocking in Hub.join -------------------------- By default, PyMongo uses threads to discover and monitor your servers' topology (see :ref:`health-monitoring`). If you execute ``monkey.patch_all()`` when your application first begins, PyMongo automatically uses greenlets instead of threads. When shutting down, if your application calls :meth:`~gevent.hub.Hub.join` on Gevent's :class:`~gevent.hub.Hub` without first terminating these background greenlets, the call to :meth:`~gevent.hub.Hub.join` blocks indefinitely. You therefore **must close or dereference** any active :class:`~pymongo.mongo_client.MongoClient` before exiting. An example solution to this issue in some application frameworks is a signal handler to end background greenlets when your application receives SIGHUP: .. code-block:: python import signal def graceful_reload(signum, traceback): """Explicitly close some global MongoClient object.""" client.close() signal.signal(signal.SIGHUP, graceful_reload) Applications using uWSGI prior to 1.9.16 are affected by this issue, or newer uWSGI versions with the ``-gevent-wait-for-hub`` option. See `the uWSGI changelog for details `_. pymongo-3.11.0/doc/examples/gridfs.rst000066400000000000000000000043731374256237000176770ustar00rootroot00000000000000GridFS Example ============== .. testsetup:: from pymongo import MongoClient client = MongoClient() client.drop_database('gridfs_example') This example shows how to use :mod:`gridfs` to store large binary objects (e.g. files) in MongoDB. .. seealso:: The API docs for :mod:`gridfs`. .. seealso:: `This blog post `_ for some motivation behind this API. Setup ----- We start by creating a :class:`~gridfs.GridFS` instance to use: .. doctest:: >>> from pymongo import MongoClient >>> import gridfs >>> >>> db = MongoClient().gridfs_example >>> fs = gridfs.GridFS(db) Every :class:`~gridfs.GridFS` instance is created with and will operate on a specific :class:`~pymongo.database.Database` instance. Saving and Retrieving Data -------------------------- The simplest way to work with :mod:`gridfs` is to use its key/value interface (the :meth:`~gridfs.GridFS.put` and :meth:`~gridfs.GridFS.get` methods). To write data to GridFS, use :meth:`~gridfs.GridFS.put`: .. doctest:: >>> a = fs.put(b"hello world") :meth:`~gridfs.GridFS.put` creates a new file in GridFS, and returns the value of the file document's ``"_id"`` key. Given that ``"_id"`` we can use :meth:`~gridfs.GridFS.get` to get back the contents of the file: .. doctest:: >>> fs.get(a).read() 'hello world' :meth:`~gridfs.GridFS.get` returns a file-like object, so we get the file's contents by calling :meth:`~gridfs.grid_file.GridOut.read`. In addition to putting a :class:`str` as a GridFS file, we can also put any file-like object (an object with a :meth:`read` method). GridFS will handle reading the file in chunk-sized segments automatically. We can also add additional attributes to the file as keyword arguments: .. doctest:: >>> b = fs.put(fs.get(a), filename="foo", bar="baz") >>> out = fs.get(b) >>> out.read() 'hello world' >>> out.filename u'foo' >>> out.bar u'baz' >>> out.upload_date datetime.datetime(...) The attributes we set in :meth:`~gridfs.GridFS.put` are stored in the file document, and retrievable after calling :meth:`~gridfs.GridFS.get`. Some attributes (like ``"filename"``) are special and are defined in the GridFS specification - see that document for more details. pymongo-3.11.0/doc/examples/high_availability.rst000066400000000000000000000340401374256237000220640ustar00rootroot00000000000000High Availability and PyMongo ============================= PyMongo makes it easy to write highly available applications whether you use a `single replica set `_ or a `large sharded cluster `_. Connecting to a Replica Set --------------------------- PyMongo makes working with `replica sets `_ easy. Here we'll launch a new replica set and show how to handle both initialization and normal connections with PyMongo. .. mongodoc:: rs Starting a Replica Set ~~~~~~~~~~~~~~~~~~~~~~ The main `replica set documentation `_ contains extensive information about setting up a new replica set or migrating an existing MongoDB setup, be sure to check that out. Here, we'll just do the bare minimum to get a three node replica set setup locally. .. warning:: Replica sets should always use multiple nodes in production - putting all set members on the same physical node is only recommended for testing and development. We start three ``mongod`` processes, each on a different port and with a different dbpath, but all using the same replica set name "foo". .. code-block:: bash $ mkdir -p /data/db0 /data/db1 /data/db2 $ mongod --port 27017 --dbpath /data/db0 --replSet foo .. code-block:: bash $ mongod --port 27018 --dbpath /data/db1 --replSet foo .. code-block:: bash $ mongod --port 27019 --dbpath /data/db2 --replSet foo Initializing the Set ~~~~~~~~~~~~~~~~~~~~ At this point all of our nodes are up and running, but the set has yet to be initialized. Until the set is initialized no node will become the primary, and things are essentially "offline". To initialize the set we need to connect to a single node and run the initiate command:: >>> from pymongo import MongoClient >>> c = MongoClient('localhost', 27017) .. note:: We could have connected to any of the other nodes instead, but only the node we initiate from is allowed to contain any initial data. After connecting, we run the initiate command to get things started:: >>> config = {'_id': 'foo', 'members': [ ... {'_id': 0, 'host': 'localhost:27017'}, ... {'_id': 1, 'host': 'localhost:27018'}, ... {'_id': 2, 'host': 'localhost:27019'}]} >>> c.admin.command("replSetInitiate", config) {'ok': 1.0, ...} The three ``mongod`` servers we started earlier will now coordinate and come online as a replica set. Connecting to a Replica Set ~~~~~~~~~~~~~~~~~~~~~~~~~~~ The initial connection as made above is a special case for an uninitialized replica set. Normally we'll want to connect differently. A connection to a replica set can be made using the :meth:`~pymongo.mongo_client.MongoClient` constructor, specifying one or more members of the set, along with the replica set name. Any of the following connects to the replica set we just created:: >>> MongoClient('localhost', replicaset='foo') MongoClient(host=['localhost:27017'], replicaset='foo', ...) >>> MongoClient('localhost:27018', replicaset='foo') MongoClient(['localhost:27018'], replicaset='foo', ...) >>> MongoClient('localhost', 27019, replicaset='foo') MongoClient(['localhost:27019'], replicaset='foo', ...) >>> MongoClient('mongodb://localhost:27017,localhost:27018/?replicaSet=foo') MongoClient(['localhost:27017', 'localhost:27018'], replicaset='foo', ...) The addresses passed to :meth:`~pymongo.mongo_client.MongoClient` are called the *seeds*. As long as at least one of the seeds is online, MongoClient discovers all the members in the replica set, and determines which is the current primary and which are secondaries or arbiters. Each seed must be the address of a single mongod. Multihomed and round robin DNS addresses are **not** supported. The :class:`~pymongo.mongo_client.MongoClient` constructor is non-blocking: the constructor returns immediately while the client connects to the replica set using background threads. Note how, if you create a client and immediately print the string representation of its :attr:`~pymongo.mongo_client.MongoClient.nodes` attribute, the list may be empty initially. If you wait a moment, MongoClient discovers the whole replica set:: >>> from time import sleep >>> c = MongoClient(replicaset='foo'); print(c.nodes); sleep(0.1); print(c.nodes) frozenset([]) frozenset([(u'localhost', 27019), (u'localhost', 27017), (u'localhost', 27018)]) You need not wait for replica set discovery in your application, however. If you need to do any operation with a MongoClient, such as a :meth:`~pymongo.collection.Collection.find` or an :meth:`~pymongo.collection.Collection.insert_one`, the client waits to discover a suitable member before it attempts the operation. Handling Failover ~~~~~~~~~~~~~~~~~ When a failover occurs, PyMongo will automatically attempt to find the new primary node and perform subsequent operations on that node. This can't happen completely transparently, however. Here we'll perform an example failover to illustrate how everything behaves. First, we'll connect to the replica set and perform a couple of basic operations:: >>> db = MongoClient("localhost", replicaSet='foo').test >>> db.test.insert_one({"x": 1}).inserted_id ObjectId('...') >>> db.test.find_one() {u'x': 1, u'_id': ObjectId('...')} By checking the host and port, we can see that we're connected to *localhost:27017*, which is the current primary:: >>> db.client.address ('localhost', 27017) Now let's bring down that node and see what happens when we run our query again:: >>> db.test.find_one() Traceback (most recent call last): pymongo.errors.AutoReconnect: ... We get an :class:`~pymongo.errors.AutoReconnect` exception. This means that the driver was not able to connect to the old primary (which makes sense, as we killed the server), but that it will attempt to automatically reconnect on subsequent operations. When this exception is raised our application code needs to decide whether to retry the operation or to simply continue, accepting the fact that the operation might have failed. On subsequent attempts to run the query we might continue to see this exception. Eventually, however, the replica set will failover and elect a new primary (this should take no more than a couple of seconds in general). At that point the driver will connect to the new primary and the operation will succeed:: >>> db.test.find_one() {u'x': 1, u'_id': ObjectId('...')} >>> db.client.address ('localhost', 27018) Bring the former primary back up. It will rejoin the set as a secondary. Now we can move to the next section: distributing reads to secondaries. .. _secondary-reads: Secondary Reads ~~~~~~~~~~~~~~~ By default an instance of MongoClient sends queries to the primary member of the replica set. To use secondaries for queries we have to change the read preference:: >>> client = MongoClient( ... 'localhost:27017', ... replicaSet='foo', ... readPreference='secondaryPreferred') >>> client.read_preference SecondaryPreferred(tag_sets=None) Now all queries will be sent to the secondary members of the set. If there are no secondary members the primary will be used as a fallback. If you have queries you would prefer to never send to the primary you can specify that using the ``secondary`` read preference. By default the read preference of a :class:`~pymongo.database.Database` is inherited from its MongoClient, and the read preference of a :class:`~pymongo.collection.Collection` is inherited from its Database. To use a different read preference use the :meth:`~pymongo.mongo_client.MongoClient.get_database` method, or the :meth:`~pymongo.database.Database.get_collection` method:: >>> from pymongo import ReadPreference >>> client.read_preference SecondaryPreferred(tag_sets=None) >>> db = client.get_database('test', read_preference=ReadPreference.SECONDARY) >>> db.read_preference Secondary(tag_sets=None) >>> coll = db.get_collection('test', read_preference=ReadPreference.PRIMARY) >>> coll.read_preference Primary() You can also change the read preference of an existing :class:`~pymongo.collection.Collection` with the :meth:`~pymongo.collection.Collection.with_options` method:: >>> coll2 = coll.with_options(read_preference=ReadPreference.NEAREST) >>> coll.read_preference Primary() >>> coll2.read_preference Nearest(tag_sets=None) Note that since most database commands can only be sent to the primary of a replica set, the :meth:`~pymongo.database.Database.command` method does not obey the Database's :attr:`~pymongo.database.Database.read_preference`, but you can pass an explicit read preference to the method:: >>> db.command('dbstats', read_preference=ReadPreference.NEAREST) {...} Reads are configured using three options: **read preference**, **tag sets**, and **local threshold**. **Read preference**: Read preference is configured using one of the classes from :mod:`~pymongo.read_preferences` (:class:`~pymongo.read_preferences.Primary`, :class:`~pymongo.read_preferences.PrimaryPreferred`, :class:`~pymongo.read_preferences.Secondary`, :class:`~pymongo.read_preferences.SecondaryPreferred`, or :class:`~pymongo.read_preferences.Nearest`). For convenience, we also provide :class:`~pymongo.read_preferences.ReadPreference` with the following attributes: - ``PRIMARY``: Read from the primary. This is the default read preference, and provides the strongest consistency. If no primary is available, raise :class:`~pymongo.errors.AutoReconnect`. - ``PRIMARY_PREFERRED``: Read from the primary if available, otherwise read from a secondary. - ``SECONDARY``: Read from a secondary. If no matching secondary is available, raise :class:`~pymongo.errors.AutoReconnect`. - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise from the primary. - ``NEAREST``: Read from any available member. **Tag sets**: Replica-set members can be `tagged `_ according to any criteria you choose. By default, PyMongo ignores tags when choosing a member to read from, but your read preference can be configured with a ``tag_sets`` parameter. ``tag_sets`` must be a list of dictionaries, each dict providing tag values that the replica set member must match. PyMongo tries each set of tags in turn until it finds a set of tags with at least one matching member. For example, to prefer reads from the New York data center, but fall back to the San Francisco data center, tag your replica set members according to their location and create a MongoClient like so:: >>> from pymongo.read_preferences import Secondary >>> db = client.get_database( ... 'test', read_preference=Secondary([{'dc': 'ny'}, {'dc': 'sf'}])) >>> db.read_preference Secondary(tag_sets=[{'dc': 'ny'}, {'dc': 'sf'}]) MongoClient tries to find secondaries in New York, then San Francisco, and raises :class:`~pymongo.errors.AutoReconnect` if none are available. As an additional fallback, specify a final, empty tag set, ``{}``, which means "read from any member that matches the mode, ignoring tags." See :mod:`~pymongo.read_preferences` for more information. .. _distributes reads to secondaries: **Local threshold**: If multiple members match the read preference and tag sets, PyMongo reads from among the nearest members, chosen according to ping time. By default, only members whose ping times are within 15 milliseconds of the nearest are used for queries. You can choose to distribute reads among members with higher latencies by setting ``localThresholdMS`` to a larger number:: >>> client = pymongo.MongoClient( ... replicaSet='repl0', ... readPreference='secondaryPreferred', ... localThresholdMS=35) In this case, PyMongo distributes reads among matching members within 35 milliseconds of the closest member's ping time. .. note:: ``localThresholdMS`` is ignored when talking to a replica set *through* a mongos. The equivalent is the localThreshold_ command line option. .. _localThreshold: http://docs.mongodb.org/manual/reference/mongos/#cmdoption--localThreshold .. _health-monitoring: Health Monitoring ''''''''''''''''' When MongoClient is initialized it launches background threads to monitor the replica set for changes in: * Health: detect when a member goes down or comes up, or if a different member becomes primary * Configuration: detect when members are added or removed, and detect changes in members' tags * Latency: track a moving average of each member's ping time Replica-set monitoring ensures queries are continually routed to the proper members as the state of the replica set changes. .. _mongos-load-balancing: mongos Load Balancing --------------------- An instance of :class:`~pymongo.mongo_client.MongoClient` can be configured with a list of addresses of mongos servers: >>> client = MongoClient('mongodb://host1,host2,host3') Each member of the list must be a single mongos server. Multihomed and round robin DNS addresses are **not** supported. The client continuously monitors all the mongoses' availability, and its network latency to each. PyMongo distributes operations evenly among the set of mongoses within its ``localThresholdMS`` (similar to how it `distributes reads to secondaries`_ in a replica set). By default the threshold is 15 ms. The lowest-latency server, and all servers with latencies no more than ``localThresholdMS`` beyond the lowest-latency server's, receive operations equally. For example, if we have three mongoses: - host1: 20 ms - host2: 35 ms - host3: 40 ms By default the ``localThresholdMS`` is 15 ms, so PyMongo uses host1 and host2 evenly. It uses host1 because its network latency to the driver is shortest. It uses host2 because its latency is within 15 ms of the lowest-latency server's. But it excuses host3: host3 is 20ms beyond the lowest-latency server. If we set ``localThresholdMS`` to 30 ms all servers are within the threshold: >>> client = MongoClient('mongodb://host1,host2,host3/?localThresholdMS=30') .. warning:: Do **not** connect PyMongo to a pool of mongos instances through a load balancer. A single socket connection must always be routed to the same mongos instance for proper cursor support. pymongo-3.11.0/doc/examples/index.rst000066400000000000000000000012421374256237000175200ustar00rootroot00000000000000Examples ======== The examples in this section are intended to give in depth overviews of how to accomplish specific tasks with MongoDB and PyMongo. Unless otherwise noted, all examples assume that a MongoDB instance is running on the default host and port. Assuming you have `downloaded and installed `_ MongoDB, you can start it like so: .. code-block:: bash $ mongod .. toctree:: :maxdepth: 1 aggregation authentication collations copydb custom_type bulk datetimes geo gevent gridfs high_availability mod_wsgi server_selection tailable tls encryption uuid pymongo-3.11.0/doc/examples/mod_wsgi.rst000066400000000000000000000050771374256237000202330ustar00rootroot00000000000000.. _pymongo-and-mod_wsgi: PyMongo and mod_wsgi ==================== To run your application under `mod_wsgi `_, follow these guidelines: * Run ``mod_wsgi`` in daemon mode with the ``WSGIDaemonProcess`` directive. * Assign each application to a separate daemon with ``WSGIProcessGroup``. * Use ``WSGIApplicationGroup %{GLOBAL}`` to ensure your application is running in the daemon's main Python interpreter, not a sub interpreter. For example, this ``mod_wsgi`` configuration ensures an application runs in the main interpreter:: WSGIDaemonProcess my_process WSGIScriptAlias /my_app /path/to/app.wsgi WSGIProcessGroup my_process WSGIApplicationGroup %{GLOBAL} If you have multiple applications that use PyMongo, put each in a separate daemon, still in the global application group:: WSGIDaemonProcess my_process WSGIScriptAlias /my_app /path/to/app.wsgi WSGIProcessGroup my_process WSGIDaemonProcess my_other_process WSGIScriptAlias /my_other_app /path/to/other_app.wsgi WSGIProcessGroup my_other_process WSGIApplicationGroup %{GLOBAL} Background: ``mod_wsgi`` can run in "embedded" mode when only WSGIScriptAlias is set, or "daemon" mode with WSGIDaemonProcess. In daemon mode, ``mod_wsgi`` can run your application in the Python main interpreter, or in sub interpreters. The correct way to run a PyMongo application is in daemon mode, using the main interpreter. Python C extensions in general have issues running in multiple Python sub interpreters. These difficulties are explained in the documentation for `Py_NewInterpreter `_ and in the `Multiple Python Sub Interpreters `_ section of the ``mod_wsgi`` documentation. Beginning with PyMongo 2.7, the C extension for BSON detects when it is running in a sub interpreter and activates a workaround, which adds a small cost to BSON decoding. To avoid this cost, use ``WSGIApplicationGroup %{GLOBAL}`` to ensure your application runs in the main interpreter. Since your program runs in the main interpreter it should not share its process with any other applications, lest they interfere with each other's state. Each application should have its own daemon process, as shown in the example above. pymongo-3.11.0/doc/examples/server_selection.rst000066400000000000000000000077311374256237000217750ustar00rootroot00000000000000Server Selector Example ======================= Users can exert fine-grained control over the `server selection algorithm`_ by setting the `server_selector` option on the :class:`~pymongo.MongoClient` to an appropriate callable. This example shows how to use this functionality to prefer servers running on ``localhost``. .. warning:: Use of custom server selector functions is a power user feature. Misusing custom server selectors can have unintended consequences such as degraded read/write performance. .. testsetup:: from pymongo import MongoClient .. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/ Example: Selecting Servers Running on ``localhost`` --------------------------------------------------- To start, we need to write the server selector function that will be used. The server selector function should accept a list of :class:`~pymongo.server_description.ServerDescription` objects and return a list of server descriptions that are suitable for the read or write operation. A server selector must not create or modify :class:`~pymongo.server_description.ServerDescription` objects, and must return the selected instances unchanged. In this example, we write a server selector that prioritizes servers running on ``localhost``. This can be desirable when using a sharded cluster with multiple ``mongos``, as locally run queries are likely to see lower latency and higher throughput. Please note, however, that it is highly dependent on the application if preferring ``localhost`` is beneficial or not. In addition to comparing the hostname with ``localhost``, our server selector function accounts for the edge case when no servers are running on ``localhost``. In this case, we allow the default server selection logic to prevail by passing through the received server description list unchanged. Failure to do this would render the client unable to communicate with MongoDB in the event that no servers were running on ``localhost``. The described server selection logic is implemented in the following server selector function: .. doctest:: >>> def server_selector(server_descriptions): ... servers = [ ... server for server in server_descriptions ... if server.address[0] == 'localhost' ... ] ... if not servers: ... return server_descriptions ... return servers Finally, we can create a :class:`~pymongo.MongoClient` instance with this server selector. .. doctest:: >>> client = MongoClient(server_selector=server_selector) Server Selection Process ------------------------ This section dives deeper into the server selection process for reads and writes. In the case of a write, the driver performs the following operations (in order) during the selection process: #. Select all writeable servers from the list of known hosts. For a replica set this is the primary, while for a sharded cluster this is all the known mongoses. #. Apply the user-defined server selector function. Note that the custom server selector is **not** called if there are no servers left from the previous filtering stage. #. Apply the ``localThresholdMS`` setting to the list of remaining hosts. This whittles the host list down to only contain servers whose latency is at most ``localThresholdMS`` milliseconds higher than the lowest observed latency. #. Select a server at random from the remaining host list. The desired operation is then performed against the selected server. In the case of **reads** the process is identical except for the first step. Here, instead of selecting all writeable servers, we select all servers matching the user's :class:`~pymongo.read_preferences.ReadPreference` from the list of known hosts. As an example, for a 3-member replica set with a :class:`~pymongo.read_preferences.Secondary` read preference, we would select all available secondaries. .. _server selection algorithm: https://docs.mongodb.com/manual/core/read-preference-mechanics/pymongo-3.11.0/doc/examples/tailable.rst000066400000000000000000000033341374256237000201720ustar00rootroot00000000000000Tailable Cursors ================ By default, MongoDB will automatically close a cursor when the client has exhausted all results in the cursor. However, for `capped collections `_ you may use a `tailable cursor `_ that remains open after the client exhausts the results in the initial cursor. The following is a basic example of using a tailable cursor to tail the oplog of a replica set member:: import time import pymongo client = pymongo.MongoClient() oplog = client.local.oplog.rs first = oplog.find().sort('$natural', pymongo.ASCENDING).limit(-1).next() print(first) ts = first['ts'] while True: # For a regular capped collection CursorType.TAILABLE_AWAIT is the # only option required to create a tailable cursor. When querying the # oplog, the oplog_replay option enables an optimization to quickly # find the 'ts' value we're looking for. The oplog_replay option # can only be used when querying the oplog. Starting in MongoDB 4.4 # this option is ignored by the server as queries against the oplog # are optimized automatically by the MongoDB query engine. cursor = oplog.find({'ts': {'$gt': ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True) while cursor.alive: for doc in cursor: ts = doc['ts'] print(doc) # We end up here if the find() returned no documents or if the # tailable cursor timed out (no new documents were added to the # collection for more than 1 second). time.sleep(1) pymongo-3.11.0/doc/examples/tls.rst000066400000000000000000000224151374256237000172200ustar00rootroot00000000000000TLS/SSL and PyMongo =================== PyMongo supports connecting to MongoDB over TLS/SSL. This guide covers the configuration options supported by PyMongo. See `the server documentation `_ to configure MongoDB. Dependencies ............ For connections using TLS/SSL, PyMongo may require third party dependencies as determined by your version of Python. With PyMongo 3.3+, you can install PyMongo 3.3+ and any TLS/SSL-related dependencies using the following pip command:: $ python -m pip install pymongo[tls] Starting with PyMongo 3.11 this installs `PyOpenSSL `_, `requests`_ and `service_identity `_ for users of Python versions older than 2.7.9. PyOpenSSL supports SNI for these old Python versions allowing applictions to connect to Altas free and shared tier instances. Earlier versions of PyMongo require you to manually install the dependencies listed below. Python 2.x `````````` The `ipaddress`_ module is required on all platforms. When using CPython < 2.7.9 or PyPy < 2.5.1: - On Windows, the `wincertstore`_ module is required. - On all other platforms, the `certifi`_ module is required. .. _ipaddress: https://pypi.python.org/pypi/ipaddress .. _wincertstore: https://pypi.python.org/pypi/wincertstore .. _certifi: https://pypi.python.org/pypi/certifi .. warning:: Industry best practices recommend, and some regulations require, the use of TLS 1.1 or newer. Though no application changes are required for PyMongo to make use of the newest protocols, some operating systems or versions may not provide an OpenSSL version new enough to support them. Users of macOS older than 10.13 (High Sierra) will need to install Python from `python.org`_, `homebrew`_, `macports`_, or another similar source. Users of Linux or other non-macOS Unix can check their OpenSSL version like this:: $ openssl version If the version number is less than 1.0.1 support for TLS 1.1 or newer is not available. Contact your operating system vendor for a solution or upgrade to a newer distribution. You can check your Python interpreter by installing the `requests`_ module and executing the following command:: python -c "import requests; print(requests.get('https://www.howsmyssl.com/a/check', verify=False).json()['tls_version'])" You should see "TLS 1.X" where X is >= 1. You can read more about TLS versions and their security implications here: ``_ .. _python.org: https://www.python.org/downloads/ .. _homebrew: https://brew.sh/ .. _macports: https://www.macports.org/ .. _requests: https://pypi.python.org/pypi/requests Basic configuration ................... In many cases connecting to MongoDB over TLS/SSL requires nothing more than passing ``ssl=True`` as a keyword argument to :class:`~pymongo.mongo_client.MongoClient`:: >>> client = pymongo.MongoClient('example.com', ssl=True) Or passing ``ssl=true`` in the URI:: >>> client = pymongo.MongoClient('mongodb://example.com/?ssl=true') This configures PyMongo to connect to the server using TLS, verify the server's certificate and verify that the host you are attempting to connect to is listed by that certificate. Certificate verification policy ............................... By default, PyMongo is configured to require a certificate from the server when TLS is enabled. This is configurable using the `ssl_cert_reqs` option. To disable this requirement pass ``ssl.CERT_NONE`` as a keyword parameter:: >>> import ssl >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_cert_reqs=ssl.CERT_NONE) Or, in the URI:: >>> uri = 'mongodb://example.com/?ssl=true&ssl_cert_reqs=CERT_NONE' >>> client = pymongo.MongoClient(uri) Specifying a CA file .................... In some cases you may want to configure PyMongo to use a specific set of CA certificates. This is most often the case when you are acting as your own certificate authority rather than using server certificates signed by a well known authority. The `ssl_ca_certs` option takes a path to a CA file. It can be passed as a keyword argument:: >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_ca_certs='/path/to/ca.pem') Or, in the URI:: >>> uri = 'mongodb://example.com/?ssl=true&ssl_ca_certs=/path/to/ca.pem' >>> client = pymongo.MongoClient(uri) Specifying a certificate revocation list ........................................ Python 2.7.9+ (pypy 2.5.1+) and 3.4+ provide support for certificate revocation lists. The `ssl_crlfile` option takes a path to a CRL file. It can be passed as a keyword argument:: >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_crlfile='/path/to/crl.pem') Or, in the URI:: >>> uri = 'mongodb://example.com/?ssl=true&ssl_crlfile=/path/to/crl.pem' >>> client = pymongo.MongoClient(uri) .. note:: Certificate revocation lists and :ref:`OCSP` cannot be used together. Client certificates ................... PyMongo can be configured to present a client certificate using the `ssl_certfile` option:: >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_certfile='/path/to/client.pem') If the private key for the client certificate is stored in a separate file use the `ssl_keyfile` option:: >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_certfile='/path/to/client.pem', ... ssl_keyfile='/path/to/key.pem') Python 2.7.9+ (pypy 2.5.1+) and 3.3+ support providing a password or passphrase to decrypt encrypted private keys. Use the `ssl_pem_passphrase` option:: >>> client = pymongo.MongoClient('example.com', ... ssl=True, ... ssl_certfile='/path/to/client.pem', ... ssl_keyfile='/path/to/key.pem', ... ssl_pem_passphrase=) These options can also be passed as part of the MongoDB URI. .. _OCSP: OCSP .... Starting with PyMongo 3.11, if PyMongo was installed with the "ocsp" extra:: python -m pip install pymongo[ocsp] certificate revocation checking is enabled by way of `OCSP (Online Certification Status Protocol) `_. MongoDB 4.4+ `staples OCSP responses `_ to the TLS handshake which PyMongo will verify, failing the TLS handshake if the stapled OCSP response is invalid or indicates that the peer certificate is revoked. When connecting to a server version older than 4.4, or when a 4.4+ version of MongoDB does not staple an OCSP response, PyMongo will attempt to connect directly to an OCSP endpoint if the peer certificate specified one. The TLS handshake will only fail in this case if the response indicates that the certificate is revoked. Invalid or malformed responses will be ignored, favoring availability over maximum security. Troubleshooting TLS Errors .......................... TLS errors often fall into three categories - certificate verification failure, protocol version mismatch or certificate revocation checking failure. An error message similar to the following means that OpenSSL was not able to verify the server's certificate:: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed This often occurs because OpenSSL does not have access to the system's root certificates or the certificates are out of date. Linux users should ensure that they have the latest root certificate updates installed from their Linux vendor. macOS users using Python 3.6.0 or newer downloaded from python.org `may have to run a script included with python `_ to install root certificates:: open "/Applications/Python /Install Certificates.command" Users of older PyPy portable versions may have to `set an environment variable `_ to tell OpenSSL where to find root certificates. This is easily done using the `certifi module `_ from pypi:: $ pypy -m pip install certifi $ export SSL_CERT_FILE=$(pypy -c "import certifi; print(certifi.where())") An error message similar to the following message means that the OpenSSL version used by Python does not support a new enough TLS protocol to connect to the server:: [SSL: TLSV1_ALERT_PROTOCOL_VERSION] tlsv1 alert protocol version Industry best practices recommend, and some regulations require, that older TLS protocols be disabled in some MongoDB deployments. Some deployments may disable TLS 1.0, others may disable TLS 1.0 and TLS 1.1. See the warning earlier in this document for troubleshooting steps and solutions. An error message similar to the following message means that certificate revocation checking failed:: [('SSL routines', 'tls_process_initial_server_flight', 'invalid status response')] See :ref:`OCSP` for more details. pymongo-3.11.0/doc/examples/uuid.rst000066400000000000000000000516771374256237000174000ustar00rootroot00000000000000.. _handling-uuid-data-example: Handling UUID Data ================== PyMongo ships with built-in support for dealing with UUID types. It is straightforward to store native :class:`uuid.UUID` objects to MongoDB and retrieve them as native :class:`uuid.UUID` objects:: from pymongo import MongoClient from bson.binary import UuidRepresentation from uuid import uuid4 # use the 'standard' representation for cross-language compatibility. client = MongoClient(uuid_representation=UuidRepresentation.STANDARD) collection = client.get_database('uuid_db').get_collection('uuid_coll') # remove all documents from collection collection.delete_many({}) # create a native uuid object uuid_obj = uuid4() # save the native uuid object to MongoDB collection.insert_one({'uuid': uuid_obj}) # retrieve the stored uuid object from MongoDB document = collection.find_one({}) # check that the retrieved UUID matches the inserted UUID assert document['uuid'] == uuid_obj Native :class:`uuid.UUID` objects can also be used as part of MongoDB queries:: document = collection.find({'uuid': uuid_obj}) assert document['uuid'] == uuid_obj The above examples illustrate the simplest of use-cases - one where the UUID is generated by, and used in the same application. However, the situation can be significantly more complex when dealing with a MongoDB deployment that contains UUIDs created by other drivers as the Java and CSharp drivers have historically encoded UUIDs using a byte-order that is different from the one used by PyMongo. Applications that require interoperability across these drivers must specify the appropriate :class:`~bson.binary.UuidRepresentation`. In the following sections, we describe how drivers have historically differed in their encoding of UUIDs, and how applications can use the :class:`~bson.binary.UuidRepresentation` configuration option to maintain cross-language compatibility. .. attention:: New applications that do not share a MongoDB deployment with any other application and that have never stored UUIDs in MongoDB should use the ``standard`` UUID representation for cross-language compatibility. See :ref:`configuring-uuid-representation` for details on how to configure the :class:`~bson.binary.UuidRepresentation`. .. _example-legacy-uuid: Legacy Handling of UUID Data ---------------------------- Historically, MongoDB Drivers have used different byte-ordering while serializing UUID types to :class:`~bson.binary.Binary`. Consider, for instance, a UUID with the following canonical textual representation:: 00112233-4455-6677-8899-aabbccddeeff This UUID would historically be serialized by the Python driver as:: 00112233-4455-6677-8899-aabbccddeeff The same UUID would historically be serialized by the C# driver as:: 33221100-5544-7766-8899-aabbccddeeff Finally, the same UUID would historically be serialized by the Java driver as:: 77665544-3322-1100-ffee-ddccbbaa9988 .. note:: For in-depth information about the the byte-order historically used by different drivers, see the `Handling of Native UUID Types Specification `_. This difference in the byte-order of UUIDs encoded by different drivers can result in highly unintuitive behavior in some scenarios. We detail two such scenarios in the next sections. Scenario 1: Applications Share a MongoDB Deployment ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Consider the following situation: * Application ``C`` written in C# generates a UUID and uses it as the ``_id`` of a document that it proceeds to insert into the ``uuid_test`` collection of the ``example_db`` database. Let's assume that the canonical textual representation of the generated UUID is:: 00112233-4455-6677-8899-aabbccddeeff * Application ``P`` written in Python attempts to ``find`` the document written by application ``C`` in the following manner:: from uuid import UUID collection = client.example_db.uuid_test result = collection.find_one({'_id': UUID('00112233-4455-6677-8899-aabbccddeeff')}) In this instance, ``result`` will never be the document that was inserted by application ``C`` in the previous step. This is because of the different byte-order used by the C# driver for representing UUIDs as BSON Binary. The following query, on the other hand, will successfully find this document:: result = collection.find_one({'_id': UUID('33221100-5544-7766-8899-aabbccddeeff')}) This example demonstrates how the differing byte-order used by different drivers can hamper interoperability. To workaround this problem, users should configure their ``MongoClient`` with the appropriate :class:`~bson.binary.UuidRepresentation` (in this case, ``client`` in application ``P`` can be configured to use the :data:`~bson.binary.UuidRepresentation.CSHARP_LEGACY` representation to avoid the unintuitive behavior) as described in :ref:`configuring-uuid-representation`. Scenario 2: Round-Tripping UUIDs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In the following examples, we see how using a misconfigured :class:`~bson.binary.UuidRepresentation` can cause an application to inadvertently change the :class:`~bson.binary.Binary` subtype, and in some cases, the bytes of the :class:`~bson.binary.Binary` field itself when round-tripping documents containing UUIDs. Consider the following situation:: from bson.codec_options import CodecOptions, DEFAULT_CODEC_OPTIONS from bson.binary import Binary, UuidRepresentation from uuid import uuid4 # Using UuidRepresentation.PYTHON_LEGACY stores a Binary subtype-3 UUID python_opts = CodecOptions(uuid_representation=UuidRepresentation.PYTHON_LEGACY) input_uuid = uuid4() collection = client.testdb.get_collection('test', codec_options=python_opts) collection.insert_one({'_id': 'foo', 'uuid': input_uuid}) assert collection.find_one({'uuid': Binary(input_uuid.bytes, 3)})['_id'] == 'foo' # Retrieving this document using UuidRepresentation.STANDARD returns a native UUID std_opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) std_collection = client.testdb.get_collection('test', codec_options=std_opts) doc = std_collection.find_one({'_id': 'foo'}) assert doc['uuid'] == input_uuid # Round-tripping the retrieved document silently changes the Binary subtype to 4 std_collection.replace_one({'_id': 'foo'}, doc) assert collection.find_one({'uuid': Binary(input_uuid.bytes, 3)}) is None round_tripped_doc = collection.find_one({'uuid': Binary(input_uuid.bytes, 4)}) assert doc == round_tripped_doc In this example, round-tripping the document using the incorrect :class:`~bson.binary.UuidRepresentation` (``STANDARD`` instead of ``PYTHON_LEGACY``) changes the :class:`~bson.binary.Binary` subtype as a side-effect. **Note that this can also happen when the situation is reversed - i.e. when the original document is written using ``STANDARD`` representation and then round-tripped using the ``PYTHON_LEGACY`` representation.** In the next example, we see the consequences of incorrectly using a representation that modifies byte-order (``CSHARP_LEGACY`` or ``JAVA_LEGACY``) when round-tripping documents:: from bson.codec_options import CodecOptions, DEFAULT_CODEC_OPTIONS from bson.binary import Binary, UuidRepresentation from uuid import uuid4 # Using UuidRepresentation.STANDARD stores a Binary subtype-4 UUID std_opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) input_uuid = uuid4() collection = client.testdb.get_collection('test', codec_options=std_opts) collection.insert_one({'_id': 'baz', 'uuid': input_uuid}) assert collection.find_one({'uuid': Binary(input_uuid.bytes, 4)})['_id'] == 'baz' # Retrieving this document using UuidRepresentation.JAVA_LEGACY returns a native UUID # without modifying the UUID byte-order java_opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) java_collection = client.testdb.get_collection('test', codec_options=java_opts) doc = java_collection.find_one({'_id': 'baz'}) assert doc['uuid'] == input_uuid # Round-tripping the retrieved document silently changes the Binary bytes and subtype java_collection.replace_one({'_id': 'baz'}, doc) assert collection.find_one({'uuid': Binary(input_uuid.bytes, 3)}) is None assert collection.find_one({'uuid': Binary(input_uuid.bytes, 4)}) is None round_tripped_doc = collection.find_one({'_id': 'baz'}) assert round_tripped_doc['uuid'] == Binary(input_uuid.bytes, 3).as_uuid(UuidRepresentation.JAVA_LEGACY) In this case, using the incorrect :class:`~bson.binary.UuidRepresentation` (``JAVA_LEGACY`` instead of ``STANDARD``) changes the :class:`~bson.binary.Binary` bytes and subtype as a side-effect. **Note that this happens when any representation that manipulates byte-order (``CSHARP_LEGACY`` or ``JAVA_LEGACY``) is incorrectly used to round-trip UUIDs written with ``STANDARD``. When the situation is reversed - i.e. when the original document is written using ``CSHARP_LEGACY`` or ``JAVA_LEGACY`` and then round-tripped using ``STANDARD`` - only the :class:`~bson.binary.Binary` subtype is changed.** .. note:: Starting in PyMongo 4.0, these issue will be resolved as the ``STANDARD`` representation will decode Binary subtype 3 fields as :class:`~bson.binary.Binary` objects of subtype 3 (instead of :class:`uuid.UUID`), and each of the ``LEGACY_*`` representations will decode Binary subtype 4 fields to :class:`~bson.binary.Binary` objects of subtype 4 (instead of :class:`uuid.UUID`). .. _configuring-uuid-representation: Configuring a UUID Representation --------------------------------- Users can workaround the problems described above by configuring their applications with the appropriate :class:`~bson.binary.UuidRepresentation`. Configuring the representation modifies PyMongo's behavior while encoding :class:`uuid.UUID` objects to BSON and decoding Binary subtype 3 and 4 fields from BSON. Applications can set the UUID representation in one of the following ways: #. At the ``MongoClient`` level using the ``uuidRepresentation`` URI option, e.g.:: client = MongoClient("mongodb://a:27107/?uuidRepresentation=javaLegacy") Valid values are: .. list-table:: :header-rows: 1 * - Value - UUID Representation * - ``pythonLegacy`` - :ref:`python-legacy-representation-details` * - ``javaLegacy`` - :ref:`java-legacy-representation-details` * - ``csharpLegacy`` - :ref:`csharp-legacy-representation-details` * - ``standard`` - :ref:`standard-representation-details` * - ``unspecified`` - :ref:`unspecified-representation-details` #. Using the ``uuid_representation`` kwarg option, e.g.:: from bson.binary import UuidRepresentation client = MongoClient(uuid_representation=UuidRepresentation.PYTHON_LEGACY) #. By supplying a suitable :class:`~bson.codec_options.CodecOptions` instance, e.g.:: from bson.codec_options import CodecOptions csharp_opts = CodecOptions(uuid_representation=UuidRepresentation.CSHARP_LEGACY) csharp_database = client.get_database('csharp_db', codec_options=csharp_opts) csharp_collection = client.testdb.get_collection('csharp_coll', codec_options=csharp_opts) Supported UUID Representations ------------------------------ .. list-table:: :header-rows: 1 * - UUID Representation - Default? - Encode :class:`uuid.UUID` to - Decode :class:`~bson.binary.Binary` subtype 4 to - Decode :class:`~bson.binary.Binary` subtype 3 to * - :ref:`python-legacy-representation-details` - Yes, in PyMongo>=2.9,<4 - :class:`~bson.binary.Binary` subtype 3 with standard byte-order - :class:`uuid.UUID` in PyMongo<4; :class:`~bson.binary.Binary` subtype 4 in PyMongo>=4 - :class:`uuid.UUID` * - :ref:`java-legacy-representation-details` - No - :class:`~bson.binary.Binary` subtype 3 with Java legacy byte-order - :class:`uuid.UUID` in PyMongo<4; :class:`~bson.binary.Binary` subtype 4 in PyMongo>=4 - :class:`uuid.UUID` * - :ref:`csharp-legacy-representation-details` - No - :class:`~bson.binary.Binary` subtype 3 with C# legacy byte-order - :class:`uuid.UUID` in PyMongo<4; :class:`~bson.binary.Binary` subtype 4 in PyMongo>=4 - :class:`uuid.UUID` * - :ref:`standard-representation-details` - No - :class:`~bson.binary.Binary` subtype 4 - :class:`uuid.UUID` - :class:`uuid.UUID` in PyMongo<4; :class:`~bson.binary.Binary` subtype 3 in PyMongo>=4 * - :ref:`unspecified-representation-details` - Yes, in PyMongo>=4 - Raise :exc:`ValueError` - :class:`~bson.binary.Binary` subtype 4 - :class:`uuid.UUID` in PyMongo<4; :class:`~bson.binary.Binary` subtype 3 in PyMongo>=4 We now detail the behavior and use-case for each supported UUID representation. .. _python-legacy-representation-details: ``PYTHON_LEGACY`` ^^^^^^^^^^^^^^^^^ .. attention:: This uuid representation should be used when reading UUIDs generated by existing applications that use the Python driver but **don't** explicitly set a UUID representation. .. attention:: :data:`~bson.binary.UuidRepresentation.PYTHON_LEGACY` has been the default uuid representation since PyMongo 2.9. The :data:`~bson.binary.UuidRepresentation.PYTHON_LEGACY` representation corresponds to the legacy representation of UUIDs used by PyMongo. This representation conforms with `RFC 4122 Section 4.1.2 `_. The following example illustrates the use of this representation:: from bson.codec_options import CodecOptions, DEFAULT_CODEC_OPTIONS from bson.binary import UuidRepresentation # No configured UUID representation collection = client.python_legacy.get_collection('test', codec_options=DEFAULT_CODEC_OPTIONS) # Using UuidRepresentation.PYTHON_LEGACY pylegacy_opts = CodecOptions(uuid_representation=UuidRepresentation.PYTHON_LEGACY) pylegacy_collection = client.python_legacy.get_collection('test', codec_options=pylegacy_opts) # UUIDs written by PyMongo with no UuidRepresentation configured can be queried using PYTHON_LEGACY uuid_1 = uuid4() collection.insert_one({'uuid': uuid_1}) document = pylegacy_collection.find_one({'uuid': uuid_1}) # UUIDs written using PYTHON_LEGACY can be read by PyMongo with no UuidRepresentation configured uuid_2 = uuid4() pylegacy_collection.insert_one({'uuid': uuid_2}) document = collection.find_one({'uuid': uuid_2}) ``PYTHON_LEGACY`` encodes native :class:`uuid.UUID` objects to :class:`~bson.binary.Binary` subtype 3 objects, preserving the same byte-order as :attr:`~uuid.UUID.bytes`:: from bson.binary import Binary document = collection.find_one({'uuid': Binary(uuid_2.bytes, subtype=3)}) assert document['uuid'] == uuid_2 .. _java-legacy-representation-details: ``JAVA_LEGACY`` ^^^^^^^^^^^^^^^ .. attention:: This UUID representation should be used when reading UUIDs written to MongoDB by the legacy applications (i.e. applications that don't use the ``STANDARD`` representation) using the Java driver. The :data:`~bson.binary.UuidRepresentation.JAVA_LEGACY` representation corresponds to the legacy representation of UUIDs used by the MongoDB Java Driver. .. note:: The ``JAVA_LEGACY`` representation reverses the order of bytes 0-7, and bytes 8-15. As an example, consider the same UUID described in :ref:`example-legacy-uuid`. Let us assume that an application used the Java driver without an explicitly specified UUID representation to insert the example UUID ``00112233-4455-6677-8899-aabbccddeeff`` into MongoDB. If we try to read this value using PyMongo with no UUID representation specified, we end up with an entirely different UUID:: UUID('77665544-3322-1100-ffee-ddccbbaa9988') However, if we explicitly set the representation to :data:`~bson.binary.UuidRepresentation.JAVA_LEGACY`, we get the correct result:: UUID('00112233-4455-6677-8899-aabbccddeeff') PyMongo uses the specified UUID representation to reorder the BSON bytes and load them correctly. ``JAVA_LEGACY`` encodes native :class:`uuid.UUID` objects to :class:`~bson.binary.Binary` subtype 3 objects, while performing the same byte-reordering as the legacy Java driver's UUID to BSON encoder. .. _csharp-legacy-representation-details: ``CSHARP_LEGACY`` ^^^^^^^^^^^^^^^^^ .. attention:: This UUID representation should be used when reading UUIDs written to MongoDB by the legacy applications (i.e. applications that don't use the ``STANDARD`` representation) using the C# driver. The :data:`~bson.binary.UuidRepresentation.CSHARP_LEGACY` representation corresponds to the legacy representation of UUIDs used by the MongoDB Java Driver. .. note:: The ``CSHARP_LEGACY`` representation reverses the order of bytes 0-3, bytes 4-5, and bytes 6-7. As an example, consider the same UUID described in :ref:`example-legacy-uuid`. Let us assume that an application used the C# driver without an explicitly specified UUID representation to insert the example UUID ``00112233-4455-6677-8899-aabbccddeeff`` into MongoDB. If we try to read this value using PyMongo with no UUID representation specified, we end up with an entirely different UUID:: UUID('33221100-5544-7766-8899-aabbccddeeff') However, if we explicitly set the representation to :data:`~bson.binary.UuidRepresentation.CSHARP_LEGACY`, we get the correct result:: UUID('00112233-4455-6677-8899-aabbccddeeff') PyMongo uses the specified UUID representation to reorder the BSON bytes and load them correctly. ``CSHARP_LEGACY`` encodes native :class:`uuid.UUID` objects to :class:`~bson.binary.Binary` subtype 3 objects, while performing the same byte-reordering as the legacy C# driver's UUID to BSON encoder. .. _standard-representation-details: ``STANDARD`` ^^^^^^^^^^^^ .. attention:: This UUID representation should be used by new applications that have never stored UUIDs in MongoDB. The :data:`~bson.binary.UuidRepresentation.STANDARD` representation enables cross-language compatibility by ensuring the same byte-ordering when encoding UUIDs from all drivers. UUIDs written by a driver with this representation configured will be handled correctly by every other provided it is also configured with the ``STANDARD`` representation. ``STANDARD`` encodes native :class:`uuid.UUID` objects to :class:`~bson.binary.Binary` subtype 4 objects. .. _unspecified-representation-details: ``UNSPECIFIED`` ^^^^^^^^^^^^^^^ .. attention:: Starting in PyMongo 4.0, :data:`~bson.binary.UuidRepresentation.UNSPECIFIED` will be the default UUID representation used by PyMongo. The :data:`~bson.binary.UuidRepresentation.UNSPECIFIED` representation prevents the incorrect interpretation of UUID bytes by stopping short of automatically converting UUID fields in BSON to native UUID types. Loading a UUID when using this representation returns a :class:`~bson.binary.Binary` object instead. If required, users can coerce the decoded :class:`~bson.binary.Binary` objects into native UUIDs using the :meth:`~bson.binary.Binary.as_uuid` method and specifying the appropriate representation format. The following example shows what this might look like for a UUID stored by the C# driver:: from bson.codec_options import CodecOptions, DEFAULT_CODEC_OPTIONS from bson.binary import Binary, UuidRepresentation from uuid import uuid4 # Using UuidRepresentation.CSHARP_LEGACY csharp_opts = CodecOptions(uuid_representation=UuidRepresentation.CSHARP_LEGACY) # Store a legacy C#-formatted UUID input_uuid = uuid4() collection = client.testdb.get_collection('test', codec_options=csharp_opts) collection.insert_one({'_id': 'foo', 'uuid': input_uuid}) # Using UuidRepresentation.UNSPECIFIED unspec_opts = CodecOptions(uuid_representation=UuidRepresentation.UNSPECIFIED) unspec_collection = client.testdb.get_collection('test', codec_options=unspec_opts) # UUID fields are decoded as Binary when UuidRepresentation.UNSPECIFIED is configured document = unspec_collection.find_one({'_id': 'foo'}) decoded_field = document['uuid'] assert isinstance(decoded_field, Binary) # Binary.as_uuid() can be used to coerce the decoded value to a native UUID decoded_uuid = decoded_field.as_uuid(UuidRepresentation.CSHARP_LEGACY) assert decoded_uuid == input_uuid Native :class:`uuid.UUID` objects cannot directly be encoded to :class:`~bson.binary.Binary` when the UUID representation is ``UNSPECIFIED`` and attempting to do so will result in an exception:: unspec_collection.insert_one({'_id': 'bar', 'uuid': uuid4()}) Traceback (most recent call last): ... ValueError: cannot encode native uuid.UUID with UuidRepresentation.UNSPECIFIED. UUIDs can be manually converted to bson.Binary instances using bson.Binary.from_uuid() or a different UuidRepresentation can be configured. See the documentation for UuidRepresentation for more information. Instead, applications using :data:`~bson.binary.UuidRepresentation.UNSPECIFIED` must explicitly coerce a native UUID using the :meth:`~bson.binary.Binary.from_uuid` method:: explicit_binary = Binary.from_uuid(uuid4(), UuidRepresentation.PYTHON_LEGACY) unspec_collection.insert_one({'_id': 'bar', 'uuid': explicit_binary}) pymongo-3.11.0/doc/faq.rst000066400000000000000000000533411374256237000153510ustar00rootroot00000000000000Frequently Asked Questions ========================== .. contents:: Is PyMongo thread-safe? ----------------------- PyMongo is thread-safe and provides built-in connection pooling for threaded applications. .. _pymongo-fork-safe: Is PyMongo fork-safe? --------------------- PyMongo is not fork-safe. Care must be taken when using instances of :class:`~pymongo.mongo_client.MongoClient` with ``fork()``. Specifically, instances of MongoClient must not be copied from a parent process to a child process. Instead, the parent process and each child process must create their own instances of MongoClient. Instances of MongoClient copied from the parent process have a high probability of deadlock in the child process due to the inherent incompatibilities between ``fork()``, threads, and locks described :ref:`below `. PyMongo will attempt to issue a warning if there is a chance of this deadlock occurring. .. _pymongo-fork-safe-details: MongoClient spawns multiple threads to run background tasks such as monitoring connected servers. These threads share state that is protected by instances of :class:`~threading.Lock`, which are themselves `not fork-safe`_. The driver is therefore subject to the same limitations as any other multithreaded code that uses :class:`~threading.Lock` (and mutexes in general). One of these limitations is that the locks become useless after ``fork()``. During the fork, all locks are copied over to the child process in the same state as they were in the parent: if they were locked, the copied locks are also locked. The child created by ``fork()`` only has one thread, so any locks that were taken out by other threads in the parent will never be released in the child. The next time the child process attempts to acquire one of these locks, deadlock occurs. For a long but interesting read about the problems of Python locks in multithreaded contexts with ``fork()``, see http://bugs.python.org/issue6721. .. _not fork-safe: http://bugs.python.org/issue6721 .. _connection-pooling: How does connection pooling work in PyMongo? -------------------------------------------- Every :class:`~pymongo.mongo_client.MongoClient` instance has a built-in connection pool per server in your MongoDB topology. These pools open sockets on demand to support the number of concurrent MongoDB operations that your multi-threaded application requires. There is no thread-affinity for sockets. The size of each connection pool is capped at ``maxPoolSize``, which defaults to 100. If there are ``maxPoolSize`` connections to a server and all are in use, the next request to that server will wait until one of the connections becomes available. The client instance opens one additional socket per server in your MongoDB topology for monitoring the server's state. For example, a client connected to a 3-node replica set opens 3 monitoring sockets. It also opens as many sockets as needed to support a multi-threaded application's concurrent operations on each server, up to ``maxPoolSize``. With a ``maxPoolSize`` of 100, if the application only uses the primary (the default), then only the primary connection pool grows and the total connections is at most 103. If the application uses a :class:`~pymongo.read_preferences.ReadPreference` to query the secondaries, their pools also grow and the total connections can reach 303. It is possible to set the minimum number of concurrent connections to each server with ``minPoolSize``, which defaults to 0. The connection pool will be initialized with this number of sockets. If sockets are closed due to any network errors, causing the total number of sockets (both in use and idle) to drop below the minimum, more sockets are opened until the minimum is reached. The maximum number of milliseconds that a connection can remain idle in the pool before being removed and replaced can be set with ``maxIdleTime``, which defaults to `None` (no limit). The default configuration for a :class:`~pymongo.mongo_client.MongoClient` works for most applications:: client = MongoClient(host, port) Create this client **once** for each process, and reuse it for all operations. It is a common mistake to create a new client for each request, which is very inefficient. To support extremely high numbers of concurrent MongoDB operations within one process, increase ``maxPoolSize``:: client = MongoClient(host, port, maxPoolSize=200) ... or make it unbounded:: client = MongoClient(host, port, maxPoolSize=None) Once the pool reaches its maximum size, additional threads have to wait for sockets to become available. PyMongo does not limit the number of threads that can wait for sockets to become available and it is the application's responsibility to limit the size of its thread pool to bound queuing during a load spike. Threads are allowed to wait for any length of time unless ``waitQueueTimeoutMS`` is defined:: client = MongoClient(host, port, waitQueueTimeoutMS=100) A thread that waits more than 100ms (in this example) for a socket raises :exc:`~pymongo.errors.ConnectionFailure`. Use this option if it is more important to bound the duration of operations during a load spike than it is to complete every operation. When :meth:`~pymongo.mongo_client.MongoClient.close` is called by any thread, all idle sockets are closed, and all sockets that are in use will be closed as they are returned to the pool. Does PyMongo support Python 3? ------------------------------ PyMongo supports CPython 3.4+ and PyPy3.5+. See the :doc:`python3` for details. Does PyMongo support asynchronous frameworks like Gevent, asyncio, Tornado, or Twisted? --------------------------------------------------------------------------------------- PyMongo fully supports :doc:`Gevent `. To use MongoDB with `asyncio `_ or `Tornado `_, see the `Motor `_ project. For `Twisted `_, see `TxMongo `_. Its stated mission is to keep feature parity with PyMongo. .. _writes-and-ids: Why does PyMongo add an _id field to all of my documents? --------------------------------------------------------- When a document is inserted to MongoDB using :meth:`~pymongo.collection.Collection.insert_one`, :meth:`~pymongo.collection.Collection.insert_many`, or :meth:`~pymongo.collection.Collection.bulk_write`, and that document does not include an ``_id`` field, PyMongo automatically adds one for you, set to an instance of :class:`~bson.objectid.ObjectId`. For example:: >>> my_doc = {'x': 1} >>> collection.insert_one(my_doc) >>> my_doc {'x': 1, '_id': ObjectId('560db337fba522189f171720')} Users often discover this behavior when calling :meth:`~pymongo.collection.Collection.insert_many` with a list of references to a single document raises :exc:`~pymongo.errors.BulkWriteError`. Several Python idioms lead to this pitfall:: >>> doc = {} >>> collection.insert_many(doc for _ in range(10)) Traceback (most recent call last): ... pymongo.errors.BulkWriteError: batch op errors occurred >>> doc {'_id': ObjectId('560f171cfba52279f0b0da0c')} >>> docs = [{}] >>> collection.insert_many(docs * 10) Traceback (most recent call last): ... pymongo.errors.BulkWriteError: batch op errors occurred >>> docs [{'_id': ObjectId('560f1933fba52279f0b0da0e')}] PyMongo adds an ``_id`` field in this manner for a few reasons: - All MongoDB documents are required to have an ``_id`` field. - If PyMongo were to insert a document without an ``_id`` MongoDB would add one itself, but it would not report the value back to PyMongo. - Copying the document to insert before adding the ``_id`` field would be prohibitively expensive for most high write volume applications. If you don't want PyMongo to add an ``_id`` to your documents, insert only documents that already have an ``_id`` field, added by your application. Key order in subdocuments -- why does my query work in the shell but not PyMongo? --------------------------------------------------------------------------------- .. testsetup:: key-order from bson.son import SON from pymongo.mongo_client import MongoClient collection = MongoClient().test.collection collection.drop() collection.insert_one({'_id': 1.0, 'subdocument': SON([('b', 1.0), ('a', 1.0)])}) The key-value pairs in a BSON document can have any order (except that ``_id`` is always first). The mongo shell preserves key order when reading and writing data. Observe that "b" comes before "a" when we create the document and when it is displayed: .. code-block:: javascript > // mongo shell. > db.collection.insert( { "_id" : 1, "subdocument" : { "b" : 1, "a" : 1 } } ) WriteResult({ "nInserted" : 1 }) > db.collection.find() { "_id" : 1, "subdocument" : { "b" : 1, "a" : 1 } } PyMongo represents BSON documents as Python dicts by default, and the order of keys in dicts is not defined. That is, a dict declared with the "a" key first is the same, to Python, as one with "b" first: >>> print({'a': 1.0, 'b': 1.0}) {'a': 1.0, 'b': 1.0} >>> print({'b': 1.0, 'a': 1.0}) {'a': 1.0, 'b': 1.0} Therefore, Python dicts are not guaranteed to show keys in the order they are stored in BSON. Here, "a" is shown before "b": >>> print(collection.find_one()) {u'_id': 1.0, u'subdocument': {u'a': 1.0, u'b': 1.0}} To preserve order when reading BSON, use the :class:`~bson.son.SON` class, which is a dict that remembers its key order. First, get a handle to the collection, configured to use :class:`~bson.son.SON` instead of dict: .. doctest:: key-order :options: +NORMALIZE_WHITESPACE >>> from bson import CodecOptions, SON >>> opts = CodecOptions(document_class=SON) >>> opts CodecOptions(document_class=, tz_aware=False, uuid_representation=UuidRepresentation.PYTHON_LEGACY, unicode_decode_error_handler='strict', tzinfo=None, type_registry=TypeRegistry(type_codecs=[], fallback_encoder=None)) >>> collection_son = collection.with_options(codec_options=opts) Now, documents and subdocuments in query results are represented with :class:`~bson.son.SON` objects: .. doctest:: key-order >>> print(collection_son.find_one()) SON([(u'_id', 1.0), (u'subdocument', SON([(u'b', 1.0), (u'a', 1.0)]))]) The subdocument's actual storage layout is now visible: "b" is before "a". Because a dict's key order is not defined, you cannot predict how it will be serialized **to** BSON. But MongoDB considers subdocuments equal only if their keys have the same order. So if you use a dict to query on a subdocument it may not match: >>> collection.find_one({'subdocument': {'a': 1.0, 'b': 1.0}}) is None True Swapping the key order in your query makes no difference: >>> collection.find_one({'subdocument': {'b': 1.0, 'a': 1.0}}) is None True ... because, as we saw above, Python considers the two dicts the same. There are two solutions. First, you can match the subdocument field-by-field: >>> collection.find_one({'subdocument.a': 1.0, ... 'subdocument.b': 1.0}) {u'_id': 1.0, u'subdocument': {u'a': 1.0, u'b': 1.0}} The query matches any subdocument with an "a" of 1.0 and a "b" of 1.0, regardless of the order you specify them in Python or the order they are stored in BSON. Additionally, this query now matches subdocuments with additional keys besides "a" and "b", whereas the previous query required an exact match. The second solution is to use a :class:`~bson.son.SON` to specify the key order: >>> query = {'subdocument': SON([('b', 1.0), ('a', 1.0)])} >>> collection.find_one(query) {u'_id': 1.0, u'subdocument': {u'a': 1.0, u'b': 1.0}} The key order you use when you create a :class:`~bson.son.SON` is preserved when it is serialized to BSON and used as a query. Thus you can create a subdocument that exactly matches the subdocument in the collection. .. seealso:: `MongoDB Manual entry on subdocument matching `_. What does *CursorNotFound* cursor id not valid at server mean? -------------------------------------------------------------- Cursors in MongoDB can timeout on the server if they've been open for a long time without any operations being performed on them. This can lead to an :class:`~pymongo.errors.CursorNotFound` exception being raised when attempting to iterate the cursor. How do I change the timeout value for cursors? ---------------------------------------------- MongoDB doesn't support custom timeouts for cursors, but cursor timeouts can be turned off entirely. Pass ``no_cursor_timeout=True`` to :meth:`~pymongo.collection.Collection.find`. How can I store :mod:`decimal.Decimal` instances? ------------------------------------------------- PyMongo >= 3.4 supports the Decimal128 BSON type introduced in MongoDB 3.4. See :mod:`~bson.decimal128` for more information. MongoDB <= 3.2 only supports IEEE 754 floating points - the same as the Python float type. The only way PyMongo could store Decimal instances to these versions of MongoDB would be to convert them to this standard, so you'd really only be storing floats anyway - we force users to do this conversion explicitly so that they are aware that it is happening. I'm saving ``9.99`` but when I query my document contains ``9.9900000000000002`` - what's going on here? -------------------------------------------------------------------------------------------------------- The database representation is ``9.99`` as an IEEE floating point (which is common to MongoDB and Python as well as most other modern languages). The problem is that ``9.99`` cannot be represented exactly with a double precision floating point - this is true in some versions of Python as well: >>> 9.99 9.9900000000000002 The result that you get when you save ``9.99`` with PyMongo is exactly the same as the result you'd get saving it with the JavaScript shell or any of the other languages (and as the data you're working with when you type ``9.99`` into a Python program). Can you add attribute style access for documents? ------------------------------------------------- This request has come up a number of times but we've decided not to implement anything like this. The relevant `jira case `_ has some information about the decision, but here is a brief summary: 1. This will pollute the attribute namespace for documents, so could lead to subtle bugs / confusing errors when using a key with the same name as a dictionary method. 2. The only reason we even use SON objects instead of regular dictionaries is to maintain key ordering, since the server requires this for certain operations. So we're hesitant to needlessly complicate SON (at some point it's hypothetically possible we might want to revert back to using dictionaries alone, without breaking backwards compatibility for everyone). 3. It's easy (and Pythonic) for new users to deal with documents, since they behave just like dictionaries. If we start changing their behavior it adds a barrier to entry for new users - another class to learn. What is the correct way to handle time zones with PyMongo? ---------------------------------------------------------- See :doc:`examples/datetimes` for examples on how to handle :class:`~datetime.datetime` objects correctly. How can I save a :mod:`datetime.date` instance? ----------------------------------------------- PyMongo doesn't support saving :mod:`datetime.date` instances, since there is no BSON type for dates without times. Rather than having the driver enforce a convention for converting :mod:`datetime.date` instances to :mod:`datetime.datetime` instances for you, any conversion should be performed in your client code. .. _web-application-querying-by-objectid: When I query for a document by ObjectId in my web application I get no result ----------------------------------------------------------------------------- It's common in web applications to encode documents' ObjectIds in URLs, like:: "/posts/50b3bda58a02fb9a84d8991e" Your web framework will pass the ObjectId portion of the URL to your request handler as a string, so it must be converted to :class:`~bson.objectid.ObjectId` before it is passed to :meth:`~pymongo.collection.Collection.find_one`. It is a common mistake to forget to do this conversion. Here's how to do it correctly in Flask_ (other web frameworks are similar):: from pymongo import MongoClient from bson.objectid import ObjectId from flask import Flask, render_template client = MongoClient() app = Flask(__name__) @app.route("/posts/<_id>") def show_post(_id): # NOTE!: converting _id from string to ObjectId before passing to find_one post = client.db.posts.find_one({'_id': ObjectId(_id)}) return render_template('post.html', post=post) if __name__ == "__main__": app.run() .. _Flask: http://flask.pocoo.org/ .. seealso:: :ref:`querying-by-objectid` How can I use PyMongo from Django? ---------------------------------- `Django `_ is a popular Python web framework. Django includes an ORM, :mod:`django.db`. Currently, there's no official MongoDB backend for Django. `django-mongodb-engine `_ is an unofficial MongoDB backend that supports Django aggregations, (atomic) updates, embedded objects, Map/Reduce and GridFS. It allows you to use most of Django's built-in features, including the ORM, admin, authentication, site and session frameworks and caching. However, it's easy to use MongoDB (and PyMongo) from Django without using a Django backend. Certain features of Django that require :mod:`django.db` (admin, authentication and sessions) will not work using just MongoDB, but most of what Django provides can still be used. One project which should make working with MongoDB and Django easier is `mango `_. Mango is a set of MongoDB backends for Django sessions and authentication (bypassing :mod:`django.db` entirely). .. _using-with-mod-wsgi: Does PyMongo work with **mod_wsgi**? ------------------------------------ Yes. See the configuration guide for :ref:`pymongo-and-mod_wsgi`. Does PyMongo work with PythonAnywhere? -------------------------------------- No. PyMongo creates Python threads which `PythonAnywhere `_ does not support. For more information see `PYTHON-1495 `_. How can I use something like Python's :mod:`json` module to encode my documents to JSON? ---------------------------------------------------------------------------------------- :mod:`~bson.json_util` is PyMongo's built in, flexible tool for using Python's :mod:`json` module with BSON documents and `MongoDB Extended JSON `_. The :mod:`json` module won't work out of the box with all documents from PyMongo as PyMongo supports some special types (like :class:`~bson.objectid.ObjectId` and :class:`~bson.dbref.DBRef`) that are not supported in JSON. `python-bsonjs `_ is a fast BSON to MongoDB Extended JSON converter built on top of `libbson `_. `python-bsonjs` does not depend on PyMongo and can offer a nice performance improvement over :mod:`~bson.json_util`. `python-bsonjs` works best with PyMongo when using :class:`~bson.raw_bson.RawBSONDocument`. Why do I get OverflowError decoding dates stored by another language's driver? ------------------------------------------------------------------------------ PyMongo decodes BSON datetime values to instances of Python's :class:`datetime.datetime`. Instances of :class:`datetime.datetime` are limited to years between :data:`datetime.MINYEAR` (usually 1) and :data:`datetime.MAXYEAR` (usually 9999). Some MongoDB drivers (e.g. the PHP driver) can store BSON datetimes with year values far outside those supported by :class:`datetime.datetime`. There are a few ways to work around this issue. One option is to filter out documents with values outside of the range supported by :class:`datetime.datetime`:: >>> from datetime import datetime >>> coll = client.test.dates >>> cur = coll.find({'dt': {'$gte': datetime.min, '$lte': datetime.max}}) Another option, assuming you don't need the datetime field, is to filter out just that field:: >>> cur = coll.find({}, projection={'dt': False}) .. _multiprocessing: Using PyMongo with Multiprocessing ---------------------------------- On Unix systems the multiprocessing module spawns processes using ``fork()``. Care must be taken when using instances of :class:`~pymongo.mongo_client.MongoClient` with ``fork()``. Specifically, instances of MongoClient must not be copied from a parent process to a child process. Instead, the parent process and each child process must create their own instances of MongoClient. For example:: # Each process creates its own instance of MongoClient. def func(): db = pymongo.MongoClient().mydb # Do something with db. proc = multiprocessing.Process(target=func) proc.start() **Never do this**:: client = pymongo.MongoClient() # Each child process attempts to copy a global MongoClient # created in the parent process. Never do this. def func(): db = client.mydb # Do something with db. proc = multiprocessing.Process(target=func) proc.start() Instances of MongoClient copied from the parent process have a high probability of deadlock in the child process due to :ref:`inherent incompatibilities between fork(), threads, and locks `. PyMongo will attempt to issue a warning if there is a chance of this deadlock occurring. .. seealso:: :ref:`pymongo-fork-safe` pymongo-3.11.0/doc/index.rst000066400000000000000000000064341374256237000157120ustar00rootroot00000000000000PyMongo |release| Documentation =============================== Overview -------- **PyMongo** is a Python distribution containing tools for working with `MongoDB `_, and is the recommended way to work with MongoDB from Python. This documentation attempts to explain everything you need to know to use **PyMongo**. .. todo:: a list of PyMongo's features :doc:`installation` Instructions on how to get the distribution. :doc:`tutorial` Start here for a quick overview. :doc:`examples/index` Examples of how to perform specific tasks. :doc:`atlas` Using PyMongo with MongoDB Atlas. :doc:`examples/tls` Using PyMongo with TLS / SSL. :doc:`examples/encryption` Using PyMongo with client side encryption. :doc:`faq` Some questions that come up often. :doc:`migrate-to-pymongo3` A PyMongo 2.x to 3.x migration guide. :doc:`python3` Frequently asked questions about python 3 support. :doc:`compatibility-policy` Explanation of deprecations, and how to keep pace with changes in PyMongo's API. :doc:`api/index` The complete API documentation, organized by module. :doc:`tools` A listing of Python tools and libraries that have been written for MongoDB. :doc:`developer/index` Developer guide for contributors to PyMongo. Getting Help ------------ If you're having trouble or have questions about PyMongo, ask your question on our `MongoDB Community Forum `_. You may also want to consider a `commercial support subscription `_. Once you get an answer, it'd be great if you could work it back into this documentation and contribute! Issues ------ All issues should be reported (and can be tracked / voted for / commented on) at the main `MongoDB JIRA bug tracker `_, in the "Python Driver" project. Feature Requests / Feedback --------------------------- Use our `feedback engine `_ to send us feature requests and general feedback about PyMongo. Contributing ------------ **PyMongo** has a large :doc:`community ` and contributions are always encouraged. Contributions can be as simple as minor tweaks to this documentation. To contribute, fork the project on `GitHub `_ and send a pull request. Changes ------- See the :doc:`changelog` for a full list of changes to PyMongo. For older versions of the documentation please see the `archive list `_. About This Documentation ------------------------ This documentation is generated using the `Sphinx `_ documentation generator. The source files for the documentation are located in the *doc/* directory of the **PyMongo** distribution. To generate the docs locally run the following command from the root directory of the **PyMongo** source: .. code-block:: bash $ python setup.py doc Indices and tables ------------------ * :ref:`genindex` * :ref:`modindex` * :ref:`search` .. toctree:: :hidden: atlas installation tutorial examples/index faq compatibility-policy api/index tools contributors changelog python3 migrate-to-pymongo3 developer/index pymongo-3.11.0/doc/installation.rst000066400000000000000000000230271374256237000173010ustar00rootroot00000000000000Installing / Upgrading ====================== .. highlight:: bash **PyMongo** is in the `Python Package Index `_. .. warning:: **Do not install the "bson" package from pypi.** PyMongo comes with its own bson package; doing "pip install bson" or "easy_install bson" installs a third-party package that is incompatible with PyMongo. Installing with pip ------------------- We recommend using `pip `_ to install pymongo on all platforms:: $ python -m pip install pymongo To get a specific version of pymongo:: $ python -m pip install pymongo==3.5.1 To upgrade using pip:: $ python -m pip install --upgrade pymongo .. note:: pip does not support installing python packages in .egg format. If you would like to install PyMongo from a .egg provided on pypi use easy_install instead. Installing with easy_install ---------------------------- To use ``easy_install`` from `setuptools `_ do:: $ python -m easy_install pymongo To upgrade do:: $ python -m easy_install -U pymongo Dependencies ------------ PyMongo supports CPython 2.7, 3.4+, PyPy, and PyPy3.5+. Optional dependencies: GSSAPI authentication requires `pykerberos `_ on Unix or `WinKerberos `_ on Windows. The correct dependency can be installed automatically along with PyMongo:: $ python -m pip install pymongo[gssapi] :ref:`MONGODB-AWS` authentication requires `pymongo-auth-aws `_:: $ python -m pip install pymongo[aws] Support for mongodb+srv:// URIs requires `dnspython `_:: $ python -m pip install pymongo[srv] TLS / SSL support may require `ipaddress `_ and `certifi `_ or `wincertstore `_ depending on the Python version in use. The necessary dependencies can be installed along with PyMongo:: $ python -m pip install pymongo[tls] .. note:: Users of Python versions older than 2.7.9 will also receive the dependencies for OCSP when using the tls extra. :ref:`OCSP` requires `PyOpenSSL `_, `requests `_ and `service_identity `_:: $ python -m pip install pymongo[ocsp] Wire protocol compression with snappy requires `python-snappy `_:: $ python -m pip install pymongo[snappy] Wire protocol compression with zstandard requires `zstandard `_:: $ python -m pip install pymongo[zstd] :ref:`Client-Side Field Level Encryption` requires `pymongocrypt `_:: $ python -m pip install pymongo[encryption] You can install all dependencies automatically with the following command:: $ python -m pip install pymongo[gssapi,aws,ocsp,snappy,srv,tls,zstd,encryption] Other optional packages: - `backports.pbkdf2 `_, improves authentication performance with SCRAM-SHA-1 and SCRAM-SHA-256. It especially improves performance on Python versions older than 2.7.8. - `monotonic `_ adds support for a monotonic clock, which improves reliability in environments where clock adjustments are frequent. Not needed in Python 3. Installing from source ---------------------- If you'd rather install directly from the source (i.e. to stay on the bleeding edge), install the C extension dependencies then check out the latest source from GitHub and install the driver from the resulting tree:: $ git clone git://github.com/mongodb/mongo-python-driver.git pymongo $ cd pymongo/ $ python setup.py install Installing from source on Unix .............................. To build the optional C extensions on Linux or another non-macOS Unix you must have the GNU C compiler (gcc) installed. Depending on your flavor of Unix (or Linux distribution) you may also need a python development package that provides the necessary header files for your version of Python. The package name may vary from distro to distro. Debian and Ubuntu users should issue the following command:: $ sudo apt-get install build-essential python-dev Users of Red Hat based distributions (RHEL, CentOS, Amazon Linux, Oracle Linux, Fedora, etc.) should issue the following command:: $ sudo yum install gcc python-devel Installing from source on macOS / OSX ..................................... If you want to install PyMongo with C extensions from source you will need the command line developer tools. On modern versions of macOS they can be installed by running the following in Terminal (found in /Applications/Utilities/):: xcode-select --install For older versions of OSX you may need Xcode. See the notes below for various OSX and Xcode versions. **Snow Leopard (10.6)** - Xcode 3 with 'UNIX Development Support'. **Snow Leopard Xcode 4**: The Python versions shipped with OSX 10.6.x are universal binaries. They support i386, PPC, and x86_64. Xcode 4 removed support for PPC, causing the distutils version shipped with Apple's builds of Python to fail to build the C extensions if you have Xcode 4 installed. There is a workaround:: # For some Python builds from python.org $ env ARCHFLAGS='-arch i386 -arch x86_64' python -m easy_install pymongo See `http://bugs.python.org/issue11623 `_ for a more detailed explanation. **Lion (10.7) and newer** - PyMongo's C extensions can be built against versions of Python 2.7 >= 2.7.4 or Python 3.4+ downloaded from python.org. In all cases Xcode must be installed with 'UNIX Development Support'. **Xcode 5.1**: Starting with version 5.1 the version of clang that ships with Xcode throws an error when it encounters compiler flags it doesn't recognize. This may cause C extension builds to fail with an error similar to:: clang: error: unknown argument: '-mno-fused-madd' [-Wunused-command-line-argument-hard-error-in-future] There are workarounds:: # Apple specified workaround for Xcode 5.1 # easy_install $ ARCHFLAGS=-Wno-error=unused-command-line-argument-hard-error-in-future easy_install pymongo # or pip $ ARCHFLAGS=-Wno-error=unused-command-line-argument-hard-error-in-future pip install pymongo # Alternative workaround using CFLAGS # easy_install $ CFLAGS=-Qunused-arguments easy_install pymongo # or pip $ CFLAGS=-Qunused-arguments pip install pymongo Installing from source on Windows ................................. If you want to install PyMongo with C extensions from source the following requirements apply to both CPython and ActiveState's ActivePython: 64-bit Windows ~~~~~~~~~~~~~~ For Python 3.5 and newer install Visual Studio 2015. For Python 3.4 install Visual Studio 2010. You must use the full version of Visual Studio 2010 as Visual C++ Express does not provide 64-bit compilers. Make sure that you check the "x64 Compilers and Tools" option under Visual C++. For Python 2.7 install the `Microsoft Visual C++ Compiler for Python 2.7`_. 32-bit Windows ~~~~~~~~~~~~~~ For Python 3.5 and newer install Visual Studio 2015. For Python 3.4 install Visual C++ 2010 Express. For Python 2.7 install the `Microsoft Visual C++ Compiler for Python 2.7`_ .. _`Microsoft Visual C++ Compiler for Python 2.7`: https://www.microsoft.com/en-us/download/details.aspx?id=44266 .. _install-no-c: Installing Without C Extensions ------------------------------- By default, the driver attempts to build and install optional C extensions (used for increasing performance) when it is installed. If any extension fails to build the driver will be installed anyway but a warning will be printed. If you wish to install PyMongo without the C extensions, even if the extensions build properly, it can be done using a command line option to *setup.py*:: $ python setup.py --no_ext install Building PyMongo egg Packages ----------------------------- Some organizations do not allow compilers and other build tools on production systems. To install PyMongo on these systems with C extensions you may need to build custom egg packages. Make sure that you have installed the dependencies listed above for your operating system then run the following command in the PyMongo source directory:: $ python setup.py bdist_egg The egg package can be found in the dist/ subdirectory. The file name will resemble “pymongo-3.6-py2.7-linux-x86_64.egg” but may have a different name depending on your platform and the version of python you use to compile. .. warning:: These “binary distributions,” will only work on systems that resemble the environment on which you built the package. In other words, ensure that operating systems and versions of Python and architecture (i.e. “32” or “64” bit) match. Copy this file to the target system and issue the following command to install the package:: $ sudo python -m easy_install pymongo-3.6-py2.7-linux-x86_64.egg Installing a beta or release candidate -------------------------------------- MongoDB, Inc. may occasionally tag a beta or release candidate for testing by the community before final release. These releases will not be uploaded to pypi but can be found on the `GitHub tags page `_. They can be installed by passing the full URL for the tag to pip:: $ python -m pip install https://github.com/mongodb/mongo-python-driver/archive/3.11.0rc0.tar.gz pymongo-3.11.0/doc/migrate-to-pymongo3.rst000066400000000000000000000374121374256237000204240ustar00rootroot00000000000000PyMongo 3 Migration Guide ========================= .. contents:: .. testsetup:: from pymongo import MongoClient, ReadPreference client = MongoClient() collection = client.my_database.my_collection PyMongo 3 is a partial rewrite bringing a large number of improvements. It also brings a number of backward breaking changes. This guide provides a roadmap for migrating an existing application from PyMongo 2.x to 3.x or writing libraries that will work with both PyMongo 2.x and 3.x. PyMongo 2.9 ----------- The first step in any successful migration involves upgrading to, or requiring, at least PyMongo 2.9. If your project has a requirements.txt file, add the line "pymongo >= 2.9, < 3.0" until you have completely migrated to PyMongo 3. Most of the key new methods and options from PyMongo 3.0 are backported in PyMongo 2.9 making migration much easier. Enable Deprecation Warnings --------------------------- Starting with PyMongo 2.9, :exc:`DeprecationWarning` is raised by most methods removed in PyMongo 3.0. Make sure you enable runtime warnings to see where deprecated functions and methods are being used in your application:: python -Wd Warnings can also be changed to errors:: python -Wd -Werror .. note:: Not all deprecated features raise :exc:`DeprecationWarning` when used. For example, the :meth:`~pymongo.collection.Collection.find` options renamed in PyMongo 3.0 do not raise :exc:`DeprecationWarning` when used in PyMongo 2.x. See also `Removed features with no migration path`_. CRUD API -------- Changes to find() and find_one() ................................ "spec" renamed "filter" ~~~~~~~~~~~~~~~~~~~~~~~ The `spec` option has been renamed to `filter`. Code like this:: >>> cursor = collection.find(spec={"a": 1}) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> cursor = collection.find(filter={"a": 1}) or this with any version of PyMongo: .. doctest:: >>> cursor = collection.find({"a": 1}) "fields" renamed "projection" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `fields` option has been renamed to `projection`. Code like this:: >>> cursor = collection.find({"a": 1}, fields={"_id": False}) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> cursor = collection.find({"a": 1}, projection={"_id": False}) or this with any version of PyMongo: .. doctest:: >>> cursor = collection.find({"a": 1}, {"_id": False}) "partial" renamed "allow_partial_results" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `partial` option has been renamed to `allow_partial_results`. Code like this:: >>> cursor = collection.find({"a": 1}, partial=True) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> cursor = collection.find({"a": 1}, allow_partial_results=True) "timeout" replaced by "no_cursor_timeout" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `timeout` option has been replaced by `no_cursor_timeout`. Code like this:: >>> cursor = collection.find({"a": 1}, timeout=False) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> cursor = collection.find({"a": 1}, no_cursor_timeout=True) "network_timeout" is removed ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The `network_timeout` option has been removed. This option was always the wrong solution for timing out long running queries and should never be used in production. Starting with **MongoDB 2.6** you can use the $maxTimeMS query modifier. Code like this:: # Set a 5 second select() timeout. >>> cursor = collection.find({"a": 1}, network_timeout=5) can be changed to this with PyMongo 2.9 or later: .. doctest:: # Set a 5 second (5000 millisecond) server side query timeout. >>> cursor = collection.find({"a": 1}, modifiers={"$maxTimeMS": 5000}) or with PyMongo 3.5 or later: >>> cursor = collection.find({"a": 1}, max_time_ms=5000) or with any version of PyMongo: .. doctest:: >>> cursor = collection.find({"$query": {"a": 1}, "$maxTimeMS": 5000}) .. seealso:: `$maxTimeMS `_ Tailable cursors ~~~~~~~~~~~~~~~~ The `tailable` and `await_data` options have been replaced by `cursor_type`. Code like this:: >>> cursor = collection.find({"a": 1}, tailable=True) >>> cursor = collection.find({"a": 1}, tailable=True, await_data=True) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo import CursorType >>> cursor = collection.find({"a": 1}, cursor_type=CursorType.TAILABLE) >>> cursor = collection.find({"a": 1}, cursor_type=CursorType.TAILABLE_AWAIT) Other removed options ~~~~~~~~~~~~~~~~~~~~~ The `slave_okay`, `read_preference`, `tag_sets`, and `secondary_acceptable_latency_ms` options have been removed. See the `Read Preferences`_ section for solutions. The aggregate method always returns a cursor ............................................ PyMongo 2.6 added an option to return an iterable cursor from :meth:`~pymongo.collection.Collection.aggregate`. In PyMongo 3 :meth:`~pymongo.collection.Collection.aggregate` always returns a cursor. Use the `cursor` option for consistent behavior with PyMongo 2.9 and later: .. doctest:: >>> for result in collection.aggregate([], cursor={}): ... pass Read Preferences ---------------- The "slave_okay" option is removed .................................. The `slave_okay` option is removed from PyMongo's API. The secondaryPreferred read preference provides the same behavior. Code like this:: >>> client = MongoClient(slave_okay=True) can be changed to this with PyMongo 2.9 or newer: .. doctest:: >>> client = MongoClient(readPreference="secondaryPreferred") The "read_preference" attribute is immutable ............................................ Code like this:: >>> from pymongo import ReadPreference >>> db = client.my_database >>> db.read_preference = ReadPreference.SECONDARY can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> db = client.get_database("my_database", ... read_preference=ReadPreference.SECONDARY) Code like this:: >>> cursor = collection.find({"a": 1}, ... read_preference=ReadPreference.SECONDARY) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> coll2 = collection.with_options(read_preference=ReadPreference.SECONDARY) >>> cursor = coll2.find({"a": 1}) .. seealso:: :meth:`~pymongo.database.Database.get_collection` The "tag_sets" option and attribute are removed ............................................... The `tag_sets` MongoClient option is removed. The `read_preference` option can be used instead. Code like this:: >>> client = MongoClient( ... read_preference=ReadPreference.SECONDARY, ... tag_sets=[{"dc": "ny"}, {"dc": "sf"}]) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo.read_preferences import Secondary >>> client = MongoClient(read_preference=Secondary([{"dc": "ny"}])) To change the tags sets for a Database or Collection, code like this:: >>> db = client.my_database >>> db.read_preference = ReadPreference.SECONDARY >>> db.tag_sets = [{"dc": "ny"}] can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> db = client.get_database("my_database", ... read_preference=Secondary([{"dc": "ny"}])) Code like this:: >>> cursor = collection.find( ... {"a": 1}, ... read_preference=ReadPreference.SECONDARY, ... tag_sets=[{"dc": "ny"}]) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo.read_preferences import Secondary >>> coll2 = collection.with_options( ... read_preference=Secondary([{"dc": "ny"}])) >>> cursor = coll2.find({"a": 1}) .. seealso:: :meth:`~pymongo.database.Database.get_collection` The "secondary_acceptable_latency_ms" option and attribute are removed ...................................................................... PyMongo 2.x supports `secondary_acceptable_latency_ms` as an option to methods throughout the driver, but mongos only supports a global latency option. PyMongo 3.x has changed to match the behavior of mongos, allowing migration from a single server, to a replica set, to a sharded cluster without a surprising change in server selection behavior. A new option, `localThresholdMS`, is available through MongoClient and should be used in place of `secondaryAcceptableLatencyMS`. Code like this:: >>> client = MongoClient(readPreference="nearest", ... secondaryAcceptableLatencyMS=100) can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> client = MongoClient(readPreference="nearest", ... localThresholdMS=100) Write Concern ------------- The "safe" option is removed ............................ In PyMongo 3 the `safe` option is removed from the entire API. :class:`~pymongo.mongo_client.MongoClient` has always defaulted to acknowledged write operations and continues to do so in PyMongo 3. The "write_concern" attribute is immutable .......................................... The `write_concern` attribute is immutable in PyMongo 3. Code like this:: >>> client = MongoClient() >>> client.write_concern = {"w": "majority"} can be changed to this with any version of PyMongo: .. doctest:: >>> client = MongoClient(w="majority") Code like this:: >>> db = client.my_database >>> db.write_concern = {"w": "majority"} can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo import WriteConcern >>> db = client.get_database("my_database", ... write_concern=WriteConcern(w="majority")) The new CRUD API write methods do not accept write concern options. Code like this:: >>> oid = collection.insert({"a": 2}, w="majority") can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo import WriteConcern >>> coll2 = collection.with_options( ... write_concern=WriteConcern(w="majority")) >>> oid = coll2.insert({"a": 2}) .. seealso:: :meth:`~pymongo.database.Database.get_collection` Codec Options ------------- The "document_class" attribute is removed ......................................... Code like this:: >>> from bson.son import SON >>> client = MongoClient() >>> client.document_class = SON can be replaced by this in any version of PyMongo: .. doctest:: >>> from bson.son import SON >>> client = MongoClient(document_class=SON) or to change the `document_class` for a :class:`~pymongo.database.Database` with PyMongo 2.9 or later: .. doctest:: >>> from bson.codec_options import CodecOptions >>> from bson.son import SON >>> db = client.get_database("my_database", CodecOptions(SON)) .. seealso:: :meth:`~pymongo.database.Database.get_collection` and :meth:`~pymongo.collection.Collection.with_options` The "uuid_subtype" option and attribute are removed ................................................... Code like this:: >>> from bson.binary import JAVA_LEGACY >>> db = client.my_database >>> db.uuid_subtype = JAVA_LEGACY can be replaced by this with PyMongo 2.9 or later: .. doctest:: >>> from bson.binary import JAVA_LEGACY >>> from bson.codec_options import CodecOptions >>> db = client.get_database("my_database", ... CodecOptions(uuid_representation=JAVA_LEGACY)) .. seealso:: :meth:`~pymongo.database.Database.get_collection` and :meth:`~pymongo.collection.Collection.with_options` MongoClient ----------- MongoClient connects asynchronously ................................... In PyMongo 3, the :class:`~pymongo.mongo_client.MongoClient` constructor no longer blocks while connecting to the server or servers, and it no longer raises :exc:`~pymongo.errors.ConnectionFailure` if they are unavailable, nor :exc:`~pymongo.errors.ConfigurationError` if the user’s credentials are wrong. Instead, the constructor returns immediately and launches the connection process on background threads. The `connect` option is added to control whether these threads are started immediately, or when the client is first used. For consistent behavior in PyMongo 2.x and PyMongo 3.x, code like this:: >>> from pymongo.errors import ConnectionFailure >>> try: ... client = MongoClient() ... except ConnectionFailure: ... print("Server not available") >>> can be changed to this with PyMongo 2.9 or later: .. doctest:: >>> from pymongo.errors import ConnectionFailure >>> client = MongoClient(connect=False) >>> try: ... result = client.admin.command("ismaster") ... except ConnectionFailure: ... print("Server not available") >>> Any operation can be used to determine if the server is available. We choose the "ismaster" command here because it is cheap and does not require auth, so it is a simple way to check whether the server is available. The max_pool_size parameter is removed ...................................... PyMongo 3 replaced the max_pool_size parameter with support for the MongoDB URI `maxPoolSize` option. Code like this:: >>> client = MongoClient(max_pool_size=10) can be replaced by this with PyMongo 2.9 or later: .. doctest:: >>> client = MongoClient(maxPoolSize=10) >>> client = MongoClient("mongodb://localhost:27017/?maxPoolSize=10") The "disconnect" method is removed .................................. Code like this:: >>> client.disconnect() can be replaced by this with PyMongo 2.9 or later: .. doctest:: >>> client.close() The host and port attributes are removed ........................................ Code like this:: >>> host = client.host >>> port = client.port can be replaced by this with PyMongo 2.9 or later: .. doctest:: >>> address = client.address >>> host, port = address or (None, None) BSON ---- "as_class", "tz_aware", and "uuid_subtype" are removed ...................................................... The `as_class`, `tz_aware`, and `uuid_subtype` parameters have been removed from the functions provided in :mod:`bson`. Furthermore, the :func:`~bson.encode` and :func:`~bson.decode` functions have been added as more performant alternatives to the :meth:`bson.BSON.encode` and :meth:`bson.BSON.decode` methods. Code like this:: >>> from bson import BSON >>> from bson.son import SON >>> encoded = BSON.encode({"a": 1}, as_class=SON) can be replaced by this in PyMongo 2.9 or later: .. doctest:: >>> from bson import encode >>> from bson.codec_options import CodecOptions >>> from bson.son import SON >>> encoded = encode({"a": 1}, codec_options=CodecOptions(SON)) Removed features with no migration path --------------------------------------- MasterSlaveConnection is removed ................................ Master slave deployments are deprecated in MongoDB. Starting with MongoDB 3.0 a replica set can have up to 50 members and that limit is likely to be removed in later releases. We recommend migrating to replica sets instead. Requests are removed .................... The client methods `start_request`, `in_request`, and `end_request` are removed. Requests were designed to make read-your-writes consistency more likely with the w=0 write concern. Additionally, a thread in a request used the same member for all secondary reads in a replica set. To ensure read-your-writes consistency in PyMongo 3.0, do not override the default write concern with w=0, and do not override the default read preference of PRIMARY. The "compile_re" option is removed .................................. In PyMongo 3 regular expressions are never compiled to Python match objects. The "use_greenlets" option is removed ..................................... The `use_greenlets` option was meant to allow use of PyMongo with Gevent without the use of gevent.monkey.patch_threads(). This option caused a lot of confusion and made it difficult to support alternative asyncio libraries like Eventlet. Users of Gevent should use gevent.monkey.patch_all() instead. .. seealso:: :doc:`examples/gevent` pymongo-3.11.0/doc/mongo_extensions.py000066400000000000000000000056661374256237000200270ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MongoDB specific extensions to Sphinx.""" from docutils import nodes from docutils.parsers import rst from sphinx import addnodes class mongodoc(nodes.Admonition, nodes.Element): pass class mongoref(nodes.reference): pass def visit_mongodoc_node(self, node): self.visit_admonition(node, "seealso") def depart_mongodoc_node(self, node): self.depart_admonition(node) def visit_mongoref_node(self, node): atts = {"class": "reference external", "href": node["refuri"], "name": node["name"]} self.body.append(self.starttag(node, 'a', '', **atts)) def depart_mongoref_node(self, node): self.body.append('') if not isinstance(node.parent, nodes.TextElement): self.body.append('\n') class MongodocDirective(rst.Directive): has_content = True required_arguments = 0 optional_arguments = 0 final_argument_whitespace = False option_spec = {} def run(self): node = mongodoc() title = 'The MongoDB documentation on' node += nodes.title(title, title) self.state.nested_parse(self.content, self.content_offset, node) return [node] def process_mongodoc_nodes(app, doctree, fromdocname): for node in doctree.traverse(mongodoc): anchor = None for name in node.parent.parent.traverse(addnodes.desc_signature): anchor = name["ids"][0] break if not anchor: for name in node.parent.traverse(nodes.section): anchor = name["ids"][0] break for para in node.traverse(nodes.paragraph): tag = str(para.traverse()[1]) link = mongoref("", "") link["refuri"] = "http://dochub.mongodb.org/core/%s" % tag link["name"] = anchor link.append(nodes.emphasis(tag, tag)) new_para = nodes.paragraph() new_para += link node.replace(para, new_para) def setup(app): app.add_node(mongodoc, html=(visit_mongodoc_node, depart_mongodoc_node), latex=(visit_mongodoc_node, depart_mongodoc_node), text=(visit_mongodoc_node, depart_mongodoc_node)) app.add_node(mongoref, html=(visit_mongoref_node, depart_mongoref_node)) app.add_directive("mongodoc", MongodocDirective) app.connect("doctree-resolved", process_mongodoc_nodes) pymongo-3.11.0/doc/pydoctheme/000077500000000000000000000000001374256237000162035ustar00rootroot00000000000000pymongo-3.11.0/doc/pydoctheme/static/000077500000000000000000000000001374256237000174725ustar00rootroot00000000000000pymongo-3.11.0/doc/pydoctheme/static/pydoctheme.css000066400000000000000000000052671374256237000223570ustar00rootroot00000000000000@import url("default.css"); body { background-color: white; margin-left: 1em; margin-right: 1em; } div.related { margin-bottom: 1.2em; padding: 0.5em 0; border-top: 1px solid #ccc; margin-top: 0.5em; } div.related a:hover { color: #0095C4; } div.related:first-child { border-top: 0; border-bottom: 1px solid #ccc; } div.sphinxsidebar { background-color: #eeeeee; border-radius: 5px; line-height: 130%; font-size: smaller; } div.sphinxsidebar h3, div.sphinxsidebar h4 { margin-top: 1.5em; } div.sphinxsidebarwrapper > h3:first-child { margin-top: 0.2em; } div.sphinxsidebarwrapper > ul > li > ul > li { margin-bottom: 0.4em; } div.sphinxsidebar a:hover { color: #0095C4; } div.sphinxsidebar input { font-family: 'Lucida Grande',Arial,sans-serif; border: 1px solid #999999; font-size: smaller; border-radius: 3px; } div.sphinxsidebar input[type=text] { max-width: 150px; } div.body { padding: 0 0 0 1.2em; } div.body p { line-height: 140%; } div.body h1, div.body h2, div.body h3, div.body h4, div.body h5, div.body h6 { margin: 0; border: 0; padding: 0.3em 0; } div.body hr { border: 0; background-color: #ccc; height: 1px; } div.body pre { border-radius: 3px; border: 1px solid #ac9; } div.body div.admonition, div.body div.impl-detail { border-radius: 3px; } div.body div.impl-detail > p { margin: 0; } div.body div.seealso { border: 1px solid #dddd66; } div.body a { color: #0072aa; } div.body a:visited { color: #6363bb; } div.body a:hover { color: #00B0E4; } tt, code, pre { font-family: monospace, sans-serif; font-size: 96.5%; } div.body tt, div.body code { border-radius: 3px; } div.body tt.descname, div.body code.descname { font-size: 120%; } div.body tt.xref, div.body a tt, div.body code.xref, div.body a code { font-weight: normal; } .deprecated { border-radius: 3px; } table.docutils { border: 1px solid #ddd; min-width: 20%; border-radius: 3px; margin-top: 10px; margin-bottom: 10px; } table.docutils td, table.docutils th { border: 1px solid #ddd !important; border-radius: 3px; } table p, table li { text-align: left !important; } table.docutils th { background-color: #eee; padding: 0.3em 0.5em; } table.docutils td { background-color: white; padding: 0.3em 0.5em; } table.footnote, table.footnote td { border: 0 !important; } div.footer { line-height: 150%; margin-top: -2em; text-align: right; width: auto; margin-right: 10px; } div.footer a:hover { color: #0095C4; } .refcount { color: #060; } .stableabi { color: #229; } pymongo-3.11.0/doc/pydoctheme/theme.conf000066400000000000000000000010451374256237000201540ustar00rootroot00000000000000[theme] inherit = default stylesheet = pydoctheme.css pygments_style = sphinx [options] bodyfont = 'Lucida Grande', Arial, sans-serif headfont = 'Lucida Grande', Arial, sans-serif footerbgcolor = white footertextcolor = #555555 relbarbgcolor = white relbartextcolor = #666666 relbarlinkcolor = #444444 sidebarbgcolor = white sidebartextcolor = #444444 sidebarlinkcolor = #444444 bgcolor = white textcolor = #222222 linkcolor = #0090c0 visitedlinkcolor = #00608f headtextcolor = #1a1a1a headbgcolor = white headlinkcolor = #aaaaaa googletag = False pymongo-3.11.0/doc/python3.rst000066400000000000000000000107101374256237000161770ustar00rootroot00000000000000Python 3 FAQ ============ .. contents:: What Python 3 versions are supported? ------------------------------------- PyMongo supports CPython 3.4+ and PyPy3.5+. Are there any PyMongo behavior changes with Python 3? ----------------------------------------------------- Only one intentional change. Instances of :class:`bytes` are encoded as BSON type 5 (Binary data) with subtype 0. In Python 3 they are decoded back to :class:`bytes`. In Python 2 they are decoded to :class:`~bson.binary.Binary` with subtype 0. For example, let's insert a :class:`bytes` instance using Python 3 then read it back. Notice the byte string is decoded back to :class:`bytes`:: Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04) [GCC 4.9.3] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import pymongo >>> c = pymongo.MongoClient() >>> c.test.bintest.insert_one({'binary': b'this is a byte string'}).inserted_id ObjectId('4f9086b1fba5222021000000') >>> c.test.bintest.find_one() {'binary': b'this is a byte string', '_id': ObjectId('4f9086b1fba5222021000000')} Now retrieve the same document in Python 2. Notice the byte string is decoded to :class:`~bson.binary.Binary`:: Python 2.7.6 (default, Feb 26 2014, 10:36:22) [GCC 4.7.3] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> import pymongo >>> c = pymongo.MongoClient() >>> c.test.bintest.find_one() {u'binary': Binary('this is a byte string', 0), u'_id': ObjectId('4f9086b1fba5222021000000')} There is a similar change in behavior in parsing JSON binary with subtype 0. In Python 3 they are decoded into :class:`bytes`. In Python 2 they are decoded to :class:`~bson.binary.Binary` with subtype 0. For example, let's decode a JSON binary subtype 0 using Python 3. Notice the byte string is decoded to :class:`bytes`:: Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04) [GCC 4.2.1 (Apple Inc. build 5666) (dot 3)] on darwin Type "help", "copyright", "credits" or "license" for more information. >>> from bson.json_util import loads >>> loads('{"b": {"$binary": "dGhpcyBpcyBhIGJ5dGUgc3RyaW5n", "$type": "00"}}') {'b': b'this is a byte string'} Now decode the same JSON in Python 2 . Notice the byte string is decoded to :class:`~bson.binary.Binary`:: Python 2.7.10 (default, Feb 7 2017, 00:08:15) [GCC 4.2.1 Compatible Apple LLVM 8.0.0 (clang-800.0.34)] on darwin Type "help", "copyright", "credits" or "license" for more information. >>> from bson.json_util import loads >>> loads('{"b": {"$binary": "dGhpcyBpcyBhIGJ5dGUgc3RyaW5n", "$type": "00"}}') {u'b': Binary('this is a byte string', 0)} Why can't I share pickled ObjectIds between some versions of Python 2 and 3? ---------------------------------------------------------------------------- Instances of :class:`~bson.objectid.ObjectId` pickled using Python 2 can always be unpickled using Python 3. If you pickled an ObjectId using Python 2 and want to unpickle it using Python 3 you must pass ``encoding='latin-1'`` to pickle.loads:: Python 2.7.6 (default, Feb 26 2014, 10:36:22) [GCC 4.7.3] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> import pickle >>> from bson.objectid import ObjectId >>> oid = ObjectId() >>> oid ObjectId('4f919ba2fba5225b84000000') >>> pickle.dumps(oid) 'ccopy_reg\n_reconstructor\np0\n(cbson.objectid\...' Python 3.6.1 (v3.6.1:69c0db5050, Mar 21 2017, 01:21:04) [GCC 4.9.3] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import pickle >>> pickle.loads(b'ccopy_reg\n_reconstructor\np0\n(cbson.objectid\...', encoding='latin-1') ObjectId('4f919ba2fba5225b84000000') If you need to pickle ObjectIds using Python 3 and unpickle them using Python 2 you must use ``protocol <= 2``:: Python 3.6.5 (default, Jun 21 2018, 15:09:09) [GCC 7.3.0] on linux Type "help", "copyright", "credits" or "license" for more information. >>> import pickle >>> from bson.objectid import ObjectId >>> oid = ObjectId() >>> oid ObjectId('4f96f20c430ee6bd06000000') >>> pickle.dumps(oid, protocol=2) b'\x80\x02cbson.objectid\nObjectId\nq\x00)\x81q\x01c_codecs\nencode\...' Python 2.7.15 (default, Jun 21 2018, 15:00:48) [GCC 7.3.0] on linux2 Type "help", "copyright", "credits" or "license" for more information. >>> import pickle >>> pickle.loads('\x80\x02cbson.objectid\nObjectId\nq\x00)\x81q\x01c_codecs\nencode\...') ObjectId('4f96f20c430ee6bd06000000') pymongo-3.11.0/doc/static/000077500000000000000000000000001374256237000153315ustar00rootroot00000000000000pymongo-3.11.0/doc/static/delighted.js000066400000000000000000000035471374256237000176310ustar00rootroot00000000000000/* eslint-disable */ // Delighted !function(e,t,r,n,a){if(!e[a]){for(var i=e[a]=[],s=0;s`@+"`ʯ{lΔ=!ВkV[ !믿/R/B?'L~ᇜ~r>dM9934eYrg?nmsL3MuUU'|}'?cg`i߾C!9ngQ\38*lEw=Q^'(O~+=ߔ{-n饗m:S9ܔ_ܺc,^z)֣X^^,nXc6[,QX?(d7|uuYjV_}\sy*) @S 0~xcGyD>M>jA@V]uU2˸駟!u&?{w+^~eF&x]Ν݆nF  GaÆ#F^{MJЩS'k馛.s~{9裏7knos j)m123]c[mͲ2kou]!=Gn]vqM5> s+ݨQt1gϞjzn)hiXla%o`}q]tqF,BԊ3r`/w_|&lk/׭[7u+9Q~wwwk=SnEqGy}SOJB1dK/uoA+0'z4Gqw}ݴN[GM0*G%\`w]ؙ|.r֕38GLϭS~eex' Z]R6*f'0up駻AW^]qnWR jv!sTS"8SMewVƱx yeW}2eݻw(vaA)D\ve裏V[o"+<0 XZ,S|n__|E7n8 {+QZU hLBM7T7 .8}k(/*[L DLUcd!㏫s27tS"露nuu]vunvٜ_YC`PJ*n-p7xcnfߪF.,rH!UK`l!Vcg 3ڠͽlG@ MǎqE<c 6bpnaVmӌcǎU9B.D\wuK}wm {Bc I/ZEF{꘍U Aidwv #嗆N6D7ȓqGr3"|H.W(Q(;˝:âSHl*CJL]Q%Li׮]p`f+$*BA>2d+ # 3<̣bO+qdطo_OzH|]reeZ.H#cu3[::1w⋫1~+[iԒfAÍ6HdwoEWiLBw[nI;CAqd\##a9[9GuL^P$(|z-z 7JKyPͯ:P; TXT7pCT2>'p9rxrzc\$AaCłҎI:T0WL C~aj?$qZwf @l2?ݻwfŚ ?OIUZ.SXd N:>=t|c=6q?eN/T}|U/kQXZlwaf 21J rcE)YIMIp/]xչsu.?s1dҲ2 0\s5sF o!L 6C^XقJKh,S2aF%ϼ5}S~hwCeƏSfc4.@,*dwYge2,~p|J9(>##\!1Wg! O[^e+:\4XeI7[XeU XH!IRkY\#Hɐ a7C !,t!YSNѕ_Cg~֥)rS#_}UQzZPZ SY=5/rDzS] A9%M_*-kĉ~Bvm7U~xaxJZIWP w9]H{;8Q[}wܱPUd0*B@A W [_IoJY z ق:rHpwk0pb%.x %!RY9TY(2+I7|sL\\/\}Ձ02Kl zhgΝI<9!y ?w=Ju֍Ylow5 <^ó(_la?'T %&M,.VWY"2˕9BHc1%[*KX k:-8t?cٻCB ٹHfe !PPF|˝ڐX~А{,^;oA+NV}ɮ2 -6'ԹC)&Sa K;ud5Rf Vs~廤4^OSWNFK/s`g]zWj3}kvbKF;f~rlt >ljJx[ + ,H8] Crh"ɝ4EBݡ~2o)+PUR0Vc!^Vvgh!dߢXM2Xtq)yփ +-&az0#M̵`KJ/xBvU:˪jB`6Se&LPvi bW,Xxv_㰶;!{bs/%^HY~H09ҍ=nk EY$,?ABJ@-,D}Pt uqw;1mҥK(k [|u` :wR6V{0[| 7du~wV% (;SfK⣭6{> "P/{KMF2ϏK)[MA'ZZz[*6fEl/-=ih9{]wyۉ!P *W؜_@i@6InS~kR/{KLӔjk vSC>vo6ɼڔ_2jDҊpڂG+9xF,t$27?2aC n5>@V@xzS Z{z͵6h{K~XoRAeo/wC e!nLQZ*4@[]tQ׾}{AŔ_{Ž(eY&\]ZcL:qroak_ҟ(!R iQ~fUv!D{K\Q~ƍs;7nms~ug#lvm,IKRoY]S";u.b7qDmk[AgyVB -(Q[x[*f^xZk)aTSMv}w7,/A(j24( CN~-q箾jMbg}ӛ6B:!r]l B\$ݺus:t!1t&#`ʯ~ C H/"gSht %!P?_~ yɠv'k@,@K (s~f83uAϮ5BDlaOt(?p*ڛԞz1&k/wie oرqV1˯bDC Z|<SL u>;Βs)nw5r@u]z뭱e[` 'Ql؛^:7>c^~mf">?ЍZiM;n9ps9kE/: 6`!(_q_=Z˛iJ+V[m5V_}QH>Ӷ"*dl (%")~aD1a V^yeK]XĊ¢ܫ^yGH{^xa׵kW3nʯ. atR+ n.]Ν;֬o' Q(DVQ;.0 iFÔ_ 4wZwPGM@#GnMS`~{dxlR*NJݐ!C[o~T\^@c[orhPX&En \IKC9DޛoK颒fh{3<%{#FBKԨkv"c9R.m馎U+BDS+Ygdí暪CMj2嗚BvmƺB;V蚤[jT b qIzª)rbO>?]Uל~]ys{Cm>0Kc f!f/d@x pgǰdlY{(D@zׇڐ/|駡٨L5 iOC N8w:&عa|O>3?P`ےcVߢ.w-k:K8_׾}{S~e'!{=n 6./F\dpuYqWW'b]l_DC\|+5N03 {/*r KIS~I1ou=ꨣܹkC6=>5ᄏ#,n,~r#XIT~FlPSP`oUWa̬nTZ+[tE5+'7ɺKv@dAF *C\^{(M3Eo9Rs0޶Og}&5a ~$ Vzj/mM5~wڻ2h2r2TX^馛z쉿KX\UvA +9dAIjF ѓ]s-;<o*{s4˯C@bqdJMz멿a*v[lM(4z)[י[k4vZ.L<#v8!8{uX0)3lذVT| 99~>loСGR@r{"R"Xȼ֪bʯU{>f {ӬnE]`laԨQ-?5`~CǼ{zz֪aV[mф܇$>GbeJcq /B_Ĕ_:̪d$яUeŧ3;2Q"Xo(̱cǺ]P$ɔWA\'ɘZ߼z5jFb85{+4A,C|PpIq,=s ^/4x펆@a|:KrNm)Fi)V!Pn!wh==vC"COV&bN2V!ᄏ^s5aiA_`l! X=ӧ8plFZ垦ZA1cȑ#S$VےkV#4v&Ӭ-A ^,[kX[Eٰ7U;ȏd:|S~+L<3 4IS~I9w 1q!0Wzv!!qv{!K.a/ڔ_  KKrYSV5_r%i+$b/ɽguo Շki4@),IHrY[w/#%[GJ:Қrː!CR5reK(^h/-qu{-`e*fs~eÆ@0zB 2@^naȠ K-l\flnРAu#C/X *n/NrDg}2,U6oX xF'9z_V^yJNS~I1! 4W_u38cyd#a@h oVr;AL -@^^{ yy`LבhUVItS~JkH!HѣGkL)V{ʬ@Ln/^ pωۑ Y~hwC axaÆ^u!L34Y#6 #@]2r|*),?ӵk׮*$!P7|]|Ūn&!?c׷r[߯CW\ѡ/Q}= @RO>2-&X~&0˯0.5wuccKMΎ@Xn܃>s}ưsG[ͭMEjD,wygQVy>c:_=47]K {g}iʔ_6H {v]tQkA^꫻`Ii/ɽgu7!=؜!0/QLWM-@Ǎ?f,>0W3v!`4 +uM1zzF1/jmjYPz~[k &ԌC^\Ғ(GD駟t Z_U~'ďBQyPqG($tM禝vZ4Lp:.~ v"= 矻^z9k;<ǿeOGڷob 7lEB -o4\[Յ7nn<|I~$[P.3nEu-XfC4!(splZldgXds9gdf!%f1N<kZkY/8￯+RJcGۈ /y xC6Dല`=y)g}/l+br٘oxJ bD/:[x =sy,@XPvcL<2V#H=t>'&H ^~eB-Қ3 +JI*d1X&c g[wuIsr=[oU+;jR/÷V[nE_L({̟gsO={GO<b7 \"/8ZKczE,6m0}x݀ZpnJL+ JpСjn"VB XXt @~0 ,:dx㍁yJ,t%X@<%ˋ]qS,e_ k*gCp 7b(6@<#QVbXʏ}恬(jd'|Q6ʎ\{;j> K,Dp12_x yQ}~6&))QI]ѩS@OhYtO?6B'E 2/EK <U(%qȢJZS2?lھm&x뭷lGrx駃 6@_{g @9'\':a,g]ӎL$}v\uU%& q/@@FvġҲFSc!I8ʏ7>g ZY c@eq?@|F'H8Yb׊J$S~鋑gCTj}OVjmVRY*Uo %$:O,!n{B +sW&"p%8Sxmʏa,2!i&}=}`זCaÆ$߫a2DÔ>'o&%5F;0'n#OLw DC5.vObG._N[:r;Վ GaGkH0xNUVVi8%n(O"%N?4$I';yi҂V"V^GI`W˴+%/wOrD,Gw9h}p dE"?E0/WaKsjudNb#*B2S;{֥ s[|mS,L=VhO!(%~^xyk00e G0G( '[yj-*Y9._t<3%f9dSIXm:[KY #s3GJ;³.N?:ԉBc|f+cp˞_=I؀|f?^hޑ,w cn11wYHHP)Vl!:@O^('*"܍YY4iQ ⋫DM*%ڙ!˦̗*-69{zW)ߑ{뎐Lw-1/ŤKe+Zm~88ꨣtQ/',pyf(4̅|uʞ+ts?~ /r2te&eEu0ioc1*r BH..֓qJ_~ńeʨDsY=eBܧsHJW"mJ/TЪui=Չ(wFM%FqhV5؝Cl (2+hm,V ٣$D6{õzaIa "A ֋?RoiXjD^QHw}kj^e#y4Cѡze o^?LUgWwVR(cqqvucJYkꪫVv<E\]tSǛH i]f4.Lvu.S&~5\aDC R=roΗЪ6`/ "F$[ $.tl~FɼW':A;@3r)V t?+Q2KNSxHɪb!nS \SB=e ϾUxc%^^6,IR˦yg%{|#g_C֯: 5$dPڮ ّ@cC)sX2_@Z-9<>e"^!F:L: 7c=V)Vs($>PDdLrV_@De %B[pU($\ԛBK :+^tNP{lQrW96ulxP򴇗g&^|@]x0|G- 󀂒}>J3>/OIhj}뮻2wi>( C"$"F'Nr+~dh27a )wϫZyF{JSw`җʏs ĄT˩ЏwR9-pA2+n$X.ME!Xkm&ȃ vD.h(X1g\!Ѩ}E@k؉?a½ixr &IC<`0Q"ftpS11@@,ePYU2o 0,HD^V) RV)/> /\U9)&FTj.s WP`b޸LH츲c驻V@׳!m.ߤ&}c|H"C&;"%܈3T\4vбKJ4| H"]~ΤؠOJEi/1eG\ |%d N %f,~nE3(7pa$ݷo_,DA &?& pj-k1ǙRGQr.C6 @e;7:atXr@IDAT51(yc6@|&fUrh:TCT C#^QߣbҚn^*B`Ű!a&6Nt#A}<(8H%qiP#u qlg$r B)yTxT/oTd;9Ox^$mV3s/Jf3Yp eyѷ l؉8C9VeOz!Qu˃ V5Jψ%}(a9zɽe _B~'/z)(k~k~4'> JnuuY:_f8 ֍O `!Vyw>Ɍ`OB}r?'P^ԍٟ&+euق敼 (#4}<.Js%5=6:CIѯyy~$yyneo\Oڃ ?[|Oɡ JQmXnW#,>Dj?VNg^sA)-|2H)=jxO&@KJ(_DO>9a+xXX$V /Y/PLLŵCwVok^u @B Ņj!, [!adxӖZZ䚚*kJ \DhbwxQf(>%_mŦĔ_L{wVW%#\LkhժpY0`.jHohī 3׈2FXH!F+)6zҷsLhժ$b }IYJ$U.~qo3٘Yd"Bs{B8l!.۟`ITj5mht3bcIh,z oLeA$ ` +- N&U\jZ`T-iT2D}L:hS~ $~P] :TGݘ*OBvATld5i~6mILzj^UrLY8=ܖ گ\*AΉ5HǓ0HV[m뺇Q9vin 6p/&ܹsEVW~꒒5X_؇~=2vXiK/./ m,|̔_)tplܸqHM@cUK`yu(VX!M]LUR OaMM_ZlǎcؔD4o?wl 2@/.pAQ|Z6C U;6lSٶ(99('\vIR~Ě[s5<̣,˦ !e*E_H<|WL2}ݎ;~mNs曻:ʝ|Gu7_9!P-5Ֆd7hm"= Isx8&3*u޽Bj<;J*B>C }(=[H&AA<6z,<#ivZ|]tќ%pa!0a}ņU>C*8.]pޔ:UrS,6۬ͩ 塟RU)dW{3'{6BDh\ΔsRX;'qT~cƌq}y)? G1L^Z-JαuYg6U)RQn^f{Ǐbb+=W[m5U`<(lEwCrDiBd@|Z!eB2˸^Z?b;%,aRE :tP:ZX)V&n4(`W^領e72;qN[o$z_.sVkxԨQp/,!+ Nq~TX(n}aQ2T&Qn6xccdiRgroRu} yP5 +b6{++oQѣGGzg=Z'S~9tMkln5,sei n'ƒyƍ6HBƷF3ר`ꫯvW]u:bCgr9lT5bwɼ {4lt0,rU(5+:3thcg] I+RyA5x'gb=U;ja$ڸN{ /3,[OL(ExCg^??VEGn&!$guk׮`ge0Wj3.sX\s{Wե Ƥ6Zk-7rHWJ+d3l'KIB?O$=L%W1gJ/pWĆndj8p#ɛo Z;jkvNsMQ?a0矯 ՗dWCx{5|dF]w]x鼕Ŕ_뭷j?BI" P{hw}wWp&Vca\day ! M*ChVft%S~=+9gA )N:餜cOc` V $nDiꖟ {|xz-;ܓ)*J:,w7Gtː1/=}yKA%6s{:,_J޺B ŧr1 C^,uLJ.E<<$0+LTU^z92UvuWMI]M bK1dIe]7ܯaMB5~vZ=p^X9l^3s]w)s $:E믿VrQuőτ;HwY !/yGg"fjh'd99 6@_Ea 8@C@C 27)$ /sfm{@^{-J`La2ȏX u $$s_ʕ#J)| D;S O>D޴S2zꩺ ~@#<2Hf7=WhQXȁ$Z dqH! N|:$aQ U+/\ 5"!02&e>K"2gVX`E "\,zle+VL0SB{afKr$YM b 2IH^ 4[Z;zK]Y LXkٴ |H^+%)z,?.K<l+r24)xƇ-rܣ:*S4M.#qpq>g}\9,R@[d.%3=ap m q: .p; A^@y-{XF_7_ԧ~3VeOa^9JRQp2/\`vr@Ob|SQʏ a9b.3is{ҒV{+۫eU؂"k9y'XY2szrX!YaHyeS4~$''wI#gu&m ](IZ]LU0T%l\~(D_s֡o߾S-BrXEc_'LN=-R,^ex*zJGH![0CO[6b0o)&w]F !qm EbWQ./DWW%+W B L Ɂ\b$izJ2ל}ف(<]':/=cdv Xd`e܋ _\D7~x-KH!)D J<NK,|{iiNjPϗ]Eer8.saGVző[Tpy' 9䐀">zoI-Qvhp+b-cǎkV:_G(\Ia^xဗY)j!Y4;j.\϶*0r'ݼ&"V' /[Q_@ H*Wq>I/NRIz30#6(61Rd?&[Hmw-hVsL/c۰Et a-,FlPmC]m ܋/Xv~O!iC/"eɔ4UV{kmРAQbW~JKB@lX\Ybq"Ŕ_ O.!S`m@ݕIx饗Uj_B ?Yяe#)fXP'G)ocAK/uD@X!?jCLUV޹8xG /w&7U:3 ~VVz0Bߒ6np׭ڮO>rOE jxw7+0R/ :ԍ5ʉӲКŰ2[+0à` 5mV_O.$AKׯ siwKgE㮽ZVꫯԚHt6Zhrș袋?fJԲBɇ {8T* S~5VEp Ӊ 8w{SVOIO2gr^fU9??<̙?lʯ]D=UHJAs=wӴB@ 6LY"i[neQk*w}w [WULUW'hJ3dFTda#r 6<"+D9O'0}ض[ngZ\L@J5ra +ٛpšՁ!/?c<DŰ!c "8C;ȎUVY 0 2Z)8v~l>* 'J"Ffr޲faFY|Z-DL0=ܨSN8J+mm Ji/ ~uFL B0?I5 ) ߄iYאo6|s# ߹+lPޣ7K.*d2:(=sQxbwzhlԤ mMBےul7n&?+$KsUTbS@$:''q-[l1'lϮ,ua *ٔ_:+̟\ ]ZGydG7o zEm1RwI$LMo䜬 BbbhbW\QOQvm`B#FE4:;197?C` !\W%vI zmFlTD aԨQa4 S~M17FabYL"kU Uqk?\&#!`/v]^ U víFWa+h @}+HvqG +W02jo~!08S|)\kRذn[oVof ELE_J6Y5\KL_G}Hd/\Cpt6S~)} +֬j 0'7Tz)`#a6)%M,Icoرc?4W0ȫ^&fo[o5tȻZG 5oR'Ej}71aȋ7LM"5"T~b~A!B,r [¨IcoƐߍ~#_~H`d/R$,oݳLՂZ̯!1o :T35vOCLUR5oVvJ0W J :믿V +7KHgibS~qUPeh {t7es5Wt7  `/݄fež\`1)IX'|]tT~Lg^8{W,?GVRfU_ݦ ’ܝ(EY$V x5Ɨu˄ Lo/$K=v/&l2ۤ F'ߔ_ IO?Q^{5cȋذ7|S~x$?(?vu C^Ĕ_n#;Vǯ[kޣ啟qA~%@`ܸqnyƞZ?_V"Y~>$ rwIE`wڵsv[lpWvGqvNS_M;U,U oxc+Zkv{nuu Dzƌb\jaoj9PQ蓑#G.GQg!(k߇ 8w]wۯ0\ght?~?5wGķW^y%3WbV+Fj q\#\< bHL F!`ʯQHG|8+?ݡZ>q;h)0lbYqV~ .wq68iy Nxށ(?|䦙f+gy;S ֍n杸^h/%=|_!xO:$w9:d+Q}' kq48#I<ꨣ_[mI];@cIŦRҫ~՚> 4(Sg/G(_p&Gk )$rqM7]!C ?ES~v2%nKfNΥI~-w}w7S޽{^{ :T}u 'l%SM5.p==2s93܆$PMWo%A$S~x$?,f)1uȏ_~Y?x G35jn_|o-2z\rIX`UTz?A۪l;< "Zs5].]tC!T"X}Ϥ-D_guq\ÇW~'|Rk:w6xc*7$~ꩧWnA2ݻww={t믿~Ji B`)B$pC9J1F5+=C:Ƕ[2tMc"[o=ᄆ̋Ҩ>M}joJzfXp袋taa7w7k*skK-;`(Bb*סCKs=ŚRߌEX8V{1;$GB ĿkVz}饗RԴv3La+ffɉ0`ÂJ ;+AX$S~c5j;zh&8q` r>g}w}ל^Z^^<&C_`6b{ 78p@F\s5WR)ɜ8wǺ6ȑ<$YKV-3âO>sr3Eh}'_|Qux!'EL )¸$n/a[PG!'Nt]vu$""./;vT6L{(2~i]믿j>YGJz8*=1-)A,f@(t . :ûA%73Zc̔_Jaa !so;L#pg'T])(?HL #`ʯ0.q^_As}Uhn%J#oʯ4>9wkT:q?k )}f IQvXC`?9S6p5!ml+)$(sLa)Ĵ?neD\ު1ͨ)Ҩ+Ob~Ǻa\ݵkͬ{ou7Na*|Vɉn,nkƭqO 2IBeH-*@x.T#FZh@(Kcwq'*W\0ńZV7K;l-ðp 8nVQDBT꫇nvI25׆=?S]LPsE.$ 1ci CMCltJBo/ԉW[6$7Zsy$ C$wygv&s[m2$exR<~pru^\Cl+MP^>CvErW$F=mhqĺk$ ](wZ/9,08f2Yh^lQeh@U;mYϓU'ˁ _IY:_ ^ (59QEH\%u ԉr1ۜ_qlw֕?zCa8KW",!Nj DVZi%u]9ʪ8/a3\b-i쳵밆J.dS|fS~觪j U]XdrX|2#,;:ĥ=D|1!W\qsDmtf/E雂ՃP0E2t"J)|aEaHjտDn(mZ}Ck/=X(?[':Db-tU8{ŭ-mݦ){JOX8kt\sMD€뮻*#GLyMEqoe/Cθ ɼY\+*c$ UiP$I-L`vA}!ɔ_`Ʋ(oר x( UǏXy裏,mĢK+|%II+y&]vE 801 fKY`nIk I5H7KN_TSV,QsOM"Sg6dܝ_AL%=;X%mZħdkbԊ)ZKu }%. $E-?S~E! `ʯ~ 34R/%-+?7S~þawB"9ܒ.6Mzƣעo߾.2lذ gިNwݿ-UK/Kl؛GMţR~)i$ inbhCfwI^S~˲-㏜'[P7UxEkF^Fp8M{zғOT$~lț?7Ncs0W lIP~ؑB|͗{+)Ф 'zw&Q3s_}Ƥ2LUSz/KDao^/uY~m)ǔ_1dRcgu;5PҚj޶=Mi{D_AXҿs7WJ#<2qao.3-._)tR~|XnmbZk_ۮZguLS~%IArX{n-T߹$&!`Jwc$e]ٰ71]늚uD_EYqJu=|ĆIє_(RLذ7 :E^;rtM~"h׃)?{m1i~8EYlINL%"%\Igao{'9u3嗜-;t{#_7ao.r#FpiE_YZ;̭ n=pO,on뮻ܽޛ+) SL12dE8 {s{eرniE_YZ[ѣc {s/0 IE;%X‘608Cj@*0W^-sSN~o;cnNr`pa$L*fʯRZeY}27adȋILUTw]޽{9(؜ߤn`,v7 JUHy^{ޓk(w}{5gy;j*@5j8qbfv&;vt rLSTd2aTw݊9R9H뮻6 ,ρ7~dEdAfrIaϦ"0܆M?97֭R_׽m*#zws0? Ňճ;{+0b^OsIzJ+=z_UBr!K.Ϯnn/xDbbT)jjs۵knVG}Os饗*2\sM>ܓT|2ܗ!P Au$ž-ܢNЏ?x&x񎅈{EiXo¾kNM0!G)aa-l*˯W^>C$JcSN_|Sa$n ]Ȗo\8)LT $h& j R4|%(41)``"!CE %J@Nsνwfcc;wfc?{fݽZPDH֭*bK$82dDwy~ K:QwUŦ<܏ S) U~ᖺ[m}W-%]&n|[RyIxhѢEfIJYP7,MJ@ 믿.3-jVXafWjzQaqUQ.6olQ) U~ʧ7n,t˗/`>A1=@if:4hPSsNV T>ݥK/ecC,ńao@}=zYTy}U~Kl|Msˌ,[ !ҢE ӭ[7):,;v0.{=| x\Re$n)TTT|qSλ-[ 6dM6 ?;wJL/~F#6mT&M͛mfdiF,?]pfժU¸Nѕ;c]v(.fNG/qF 8P\(1^9PrK4Q'~g7PQMm>|п:rBbvx 77T@h4C[/!筟hĴlReB 9I P:uz<H,H[MjƍR~>}ٚDq?^z+J~ Uʌ@yi!`3a„+V8V9ַͱJ3b *XBPǦtlә8q\ ; 3tP*Ofdsrwuӿdž9LԱKwquʖ{bt>/;hԱa]ʵdN߾}|26|$.{:gmU<@c| A}'ls:Or۵k'%{夥ԝ^gABey]=o~םw"%9g#*_*xYʺA!Vى+½E 98_޸mbyI'y&id1 /\0bߎٗ?*mao*( gX\!.X_.̊H-I2HngϞfn1^q&C?:A!V_ʞJAg1h)qB`РAA8V l\\lr⦂ؼr9b ?e˖~-͛;6αJO]T,\wkpx?ֈ"p[:yq{qo&'e{[WXvU֒ ϵV[ܹs* .B3&vY^2s?Wpщo-9%(֭[g[>X5_PĤuvsu%OW%O(#I|2f5ʱ{q].F姞zjrӴF۽I.?)Fw8(jϢZwj׊9AlK1轵#y{)B؛ 7$Jb€BX 1@#ۂ[Jqy=#6(c̘1r]>ԇD- k"y~<[rԫIFVZڎ;"UkpĵNh w ~ KkTC 2𣞔QQщz*%2䞵QF9_?6f|\r 3ŋ,o֬YH89#8 K9R-#_PSW",( rBJKK%yϻQU-5QޡC1;={aSU|E@pja‚4ɋ /er@+7 1LIĒ:^B@.`HN3^ؒmNqaQWT V L" Q:$Bpi۶lVf$an t 1eʱi&Wء ca5w}eN- gb= =GJf(OԆʁrM ߴ⛠*5{[F3.]4Z&avYeGK/@wFH\(=r"(^xfٱMOL 8AT V9DŽ СCͤIӛ.k7D@LoŝRQE U~ 536-*NLǴ# "`I2uGڮPGX1E8\9TP~Mnlq%K)<^vҩS847mԙ_GM\<?υzx8/OƶjiËAjyǍWce6hvZR*ӰGB&f,@DҥKj EO}oGr˅զep*UW&?@ZUѼጌ*ppZ"^M81"- !RU~©E b~nʊZ{nj֯_/QZEeiӢ0 I~N`P'`#!C$SɗիVjD@W0KG9t"3t"@a=>t旾1F%K5E/E]-[3fTg<ڷ`Gju}ޑ(%Qé)JNnG#/%3|p3uT{y+۶m3 }T_C`ȏǗo^x C%pPZKm~3ݱcGYZh"ӱcGӬYԟJUqԵ oʔ)ծq_Þ={Q9X6mjXG0ɥGaVT+`干= / 3: L%2P蛴#~N:Ɍ;6T(#S[۶mC7합K7@_?~A~ ѭl oRz%/!۷o7j*괎6&(0={_Kff=TES{VE O|nVӧOs7 ?#e9!U~ H ֭0!U(QF63k5lʯkFC5s1첋T›7ojx۷rao>pӟ$J^zuc?sG'hwO_G[͜9SXiXnҤIz&(Uz^{^ |y뭷͘1c̷-sG?0s|r䀚87# ardwc7nl7onpi>CK Cfo л Ӭ^ZUVkך> GPl6l l%cU+IlْQ泌~\*bVΝ;"OW*!3?ԒbWXW^ feE꺤0;ZB0s0QD74\'Z$:ڵ( ړ2+U`!!Y XWX!3fV(Keq#*ŒY($7њYkoN8Y>G1h* 6#(dQQFwU~ mYy淿c=̏cQzALϘ1CyL_zs1{ ( EJ6j#6wС'##GK.4h ڃN4{[88k-A!]f58 CC???3'O7fԨQ}fʔ)aԘĥ\Hi; @xwa:u$$~aK/W#=ڼ 830KǼϯ\y商,yO; @̻X[haƏo^~eK.榛nJ-*jZps/3?XMR88n38vXÌ4&QWwF,#8(_uUbrh"aC4($."\$41+s=vm78q|K+wޢ }*:Ũi/"̚5 4He|IY=4*4rx뭷_BӾŸ'n:D"GtȻ:%6D-"`[ի,wRP x !wc$N NZC)Ucǎc)E3%"rJ17N"BJ,&)iG(Qؗ.\(ssL5gjfL@=^{I({iAKtAfn0g}'~|BD Kގu`LEԱ!b:-ZN>dRK9Y~uU'AջuV*qWU͎B0Ī(Bw^c#.̫-[ԯ_߱.9׿xۃ:6_I(Sj-|[+ tw}k-kVIxbpC7O>%z 4# 81ҥK=n@r f*z)qFӧOZ!ؓ&M[N*}'$Dp@]xBm@6'LU|3^`=R(-ԩSEAAğK+ƤŬ-^ۙ<'Q`CDJZ"O?T,q^N}ݎaԎ۷owݱuK/mYdLlT\veGYd<*YۜBXg_ %_Z|;E[f`{'dFɍ.Oeܹ\qzɧx˜[?<{m$H9ڒ_C*F-:/x]weM&Gt.H. 29yK{1(P_ IXZ`]f {a&"\|*F$/"ZH MZg{ U~hX"^Qnך{ǴnZ(IDC>`;,qrNV䜦NI_a&810я~$ ,(AAyeV ;>{I8@}<ZJ ڴiԩׯ$2#tMDzUT[{ .&R#=B Q֯*¯+@P>r!*@y⯵+@PW&ZE@(/Q #*6IENDB`pymongo-3.11.0/doc/static/sidebar.js000066400000000000000000000142141374256237000173020ustar00rootroot00000000000000/* * sidebar.js * ~~~~~~~~~~ * * This script makes the Sphinx sidebar collapsible and implements intelligent * scrolling. * * .sphinxsidebar contains .sphinxsidebarwrapper. This script adds in * .sphixsidebar, after .sphinxsidebarwrapper, the #sidebarbutton used to * collapse and expand the sidebar. * * When the sidebar is collapsed the .sphinxsidebarwrapper is hidden and the * width of the sidebar and the margin-left of the document are decreased. * When the sidebar is expanded the opposite happens. This script saves a * per-browser/per-session cookie used to remember the position of the sidebar * among the pages. Once the browser is closed the cookie is deleted and the * position reset to the default (expanded). * * :copyright: Copyright 2007-2011 by the Sphinx team, see AUTHORS. * :license: BSD, see LICENSE for details. * */ $(function() { // global elements used by the functions. // the 'sidebarbutton' element is defined as global after its // creation, in the add_sidebar_button function var jwindow = $(window); var jdocument = $(document); var bodywrapper = $('.bodywrapper'); var sidebar = $('.sphinxsidebar'); var sidebarwrapper = $('.sphinxsidebarwrapper'); // original margin-left of the bodywrapper and width of the sidebar // with the sidebar expanded var bw_margin_expanded = bodywrapper.css('margin-left'); var ssb_width_expanded = sidebar.width(); // margin-left of the bodywrapper and width of the sidebar // with the sidebar collapsed var bw_margin_collapsed = '.8em'; var ssb_width_collapsed = '.8em'; // colors used by the current theme var dark_color = '#AAAAAA'; var light_color = '#CCCCCC'; function get_viewport_height() { if (window.innerHeight) return window.innerHeight; else return jwindow.height(); } function sidebar_is_collapsed() { return sidebarwrapper.is(':not(:visible)'); } function toggle_sidebar() { if (sidebar_is_collapsed()) expand_sidebar(); else collapse_sidebar(); // adjust the scrolling of the sidebar scroll_sidebar(); } function collapse_sidebar() { sidebarwrapper.hide(); sidebar.css('width', ssb_width_collapsed); bodywrapper.css('margin-left', bw_margin_collapsed); sidebarbutton.css({ 'margin-left': '0', 'height': bodywrapper.height(), 'border-radius': '5px' }); sidebarbutton.find('span').text('»'); sidebarbutton.attr('title', _('Expand sidebar')); document.cookie = 'sidebar=collapsed'; } function expand_sidebar() { bodywrapper.css('margin-left', bw_margin_expanded); sidebar.css('width', ssb_width_expanded); sidebarwrapper.show(); sidebarbutton.css({ 'margin-left': ssb_width_expanded-12, 'height': bodywrapper.height(), 'border-radius': '0 5px 5px 0' }); sidebarbutton.find('span').text('«'); sidebarbutton.attr('title', _('Collapse sidebar')); //sidebarwrapper.css({'padding-top': // Math.max(window.pageYOffset - sidebarwrapper.offset().top, 10)}); document.cookie = 'sidebar=expanded'; } function add_sidebar_button() { sidebarwrapper.css({ 'float': 'left', 'margin-right': '0', 'width': ssb_width_expanded - 28 }); // create the button sidebar.append( '
«
' ); var sidebarbutton = $('#sidebarbutton'); // find the height of the viewport to center the '<<' in the page var viewport_height = get_viewport_height(); var sidebar_offset = sidebar.offset().top; var sidebar_height = Math.max(bodywrapper.height(), sidebar.height()); sidebarbutton.find('span').css({ 'display': 'block', 'position': 'fixed', 'top': Math.min(viewport_height/2, sidebar_height/2 + sidebar_offset) - 10 }); sidebarbutton.click(toggle_sidebar); sidebarbutton.attr('title', _('Collapse sidebar')); sidebarbutton.css({ 'border-radius': '0 5px 5px 0', 'color': '#444444', 'background-color': '#CCCCCC', 'font-size': '1.2em', 'cursor': 'pointer', 'height': sidebar_height, 'padding-top': '1px', 'padding-left': '1px', 'margin-left': ssb_width_expanded - 12 }); sidebarbutton.hover( function () { $(this).css('background-color', dark_color); }, function () { $(this).css('background-color', light_color); } ); } function set_position_from_cookie() { if (!document.cookie) return; var items = document.cookie.split(';'); for(var k=0; k wintop && curbot > winbot) { sidebarwrapper.css('top', $u.max([wintop - offset - 10, 0])); } else if (curtop < wintop && curbot < winbot) { sidebarwrapper.css('top', $u.min([winbot - sidebar_height - offset - 20, jdocument.height() - sidebar_height - 200])); } } } jwindow.scroll(scroll_sidebar); }); pymongo-3.11.0/doc/tools.rst000066400000000000000000000205051374256237000157360ustar00rootroot00000000000000Tools ===== Many tools have been written for working with **PyMongo**. If you know of or have created a tool for working with MongoDB from Python please list it here. .. note:: We try to keep this list current. As such, projects that have not been updated recently or appear to be unmaintained will occasionally be removed from the list or moved to the back (to keep the list from becoming too intimidating). If a project gets removed that is still being developed or is in active use please let us know or add it back. ORM-like Layers --------------- Some people have found that they prefer to work with a layer that has more features than PyMongo provides. Often, things like models and validation are desired. To that end, several different ORM-like layers have been written by various authors. It is our recommendation that new users begin by working directly with PyMongo, as described in the rest of this documentation. Many people have found that the features of PyMongo are enough for their needs. Even if you eventually come to the decision to use one of these layers, the time spent working directly with the driver will have increased your understanding of how MongoDB actually works. PyMODM `PyMODM `_ is an ORM-like framework on top of PyMongo. PyMODM is maintained by engineers at MongoDB, Inc. and is quick to adopt new MongoDB features. PyMODM is a "core" ODM, meaning that it provides simple, extensible functionality that can be leveraged by other libraries to target platforms like Django. At the same time, PyMODM is powerful enough to be used for developing applications on its own. Complete documentation is available on `readthedocs `_ in addition to a `Gitter channel `_ for discussing the project. Humongolus `Humongolus `_ is a lightweight ORM framework for Python and MongoDB. The name comes from the combination of MongoDB and `Homunculus `_ (the concept of a miniature though fully formed human body). Humongolus allows you to create models/schemas with robust validation. It attempts to be as pythonic as possible and exposes the pymongo cursor objects whenever possible. The code is available for download `at GitHub `_. Tutorials and usage examples are also available at GitHub. Ming `Ming `_ (the Merciless) is a library that allows you to enforce schemas on a MongoDB database in your Python application. It was developed by `SourceForge `_ in the course of their migration to MongoDB. See the `introductory blog post `_ for more details. MongoEngine `MongoEngine `_ is another ORM-like layer on top of PyMongo. It allows you to define schemas for documents and query collections using syntax inspired by the Django ORM. The code is available on `GitHub `_; for more information, see the `tutorial `_. MotorEngine `MotorEngine `_ is a port of MongoEngine to Motor, for asynchronous access with Tornado. It implements the same modeling APIs to be data-portable, meaning that a model defined in MongoEngine can be read in MotorEngine. The source is `available on GitHub `_. uMongo `uMongo `_ is a Python MongoDB ODM. Its inception comes from two needs: the lack of async ODM and the difficulty to do document (un)serialization with existing ODMs. Works with multiple drivers: PyMongo, TxMongo, motor_asyncio, and mongomock. The source `is available on GitHub `_ No longer maintained """""""""""""""""""" MongoKit The `MongoKit `_ framework is an ORM-like layer on top of PyMongo. There is also a MongoKit `google group `_. MongoAlchemy `MongoAlchemy `_ is another ORM-like layer on top of PyMongo. Its API is inspired by `SQLAlchemy `_. The code is available `on GitHub `_; for more information, see `the tutorial `_. Minimongo `minimongo `_ is a lightweight, pythonic interface to MongoDB. It retains pymongo's query and update API, and provides a number of additional features, including a simple document-oriented interface, connection pooling, index management, and collection & database naming helpers. The `source is on GitHub `_. Manga `Manga `_ aims to be a simpler ORM-like layer on top of PyMongo. The syntax for defining schema is inspired by the Django ORM, but Pymongo's query language is maintained. The source `is on GitHub `_. Framework Tools --------------- This section lists tools and adapters that have been designed to work with various Python frameworks and libraries. * `Djongo `_ is a connector for using Django with MongoDB as the database backend. Use the Django Admin GUI to add and modify documents in MongoDB. The `Djongo Source Code `_ is hosted on GitHub and the `Djongo package `_ is on pypi. * `Django MongoDB Engine `_ is a MongoDB database backend for Django that completely integrates with its ORM. For more information `see the tutorial `_. * `mango `_ provides MongoDB backends for Django sessions and authentication (bypassing :mod:`django.db` entirely). * `Django MongoEngine `_ is a MongoDB backend for Django, an `example: `_. For more information ``_ * `mongodb_beaker `_ is a project to enable using MongoDB as a backend for `beaker's `_ caching / session system. `The source is on GitHub `_. * `Log4Mongo `_ is a flexible Python logging handler that can store logs in MongoDB using normal and capped collections. * `MongoLog `_ is a Python logging handler that stores logs in MongoDB using a capped collection. * `c5t `_ is a content-management system using TurboGears and MongoDB. * `rod.recipe.mongodb `_ is a ZC Buildout recipe for downloading and installing MongoDB. * `repoze-what-plugins-mongodb `_ is a project working to support a plugin for using MongoDB as a backend for :mod:`repoze.what`. * `mongobox `_ is a tool to run a sandboxed MongoDB instance from within a python app. * `Flask-MongoAlchemy `_ Add Flask support for MongoDB using MongoAlchemy. * `Flask-MongoKit `_ Flask extension to better integrate MongoKit into Flask. * `Flask-PyMongo `_ Flask-PyMongo bridges Flask and PyMongo. Alternative Drivers ------------------- These are alternatives to PyMongo. * `Motor `_ is a full-featured, non-blocking MongoDB driver for Python Tornado applications. * `TxMongo `_ is an asynchronous Twisted Python driver for MongoDB. * `MongoMock `_ is a small library to help testing Python code that interacts with MongoDB via Pymongo. pymongo-3.11.0/doc/tutorial.rst000066400000000000000000000326641374256237000164520ustar00rootroot00000000000000Tutorial ======== .. testsetup:: from pymongo import MongoClient client = MongoClient() client.drop_database('test-database') This tutorial is intended as an introduction to working with **MongoDB** and **PyMongo**. Prerequisites ------------- Before we start, make sure that you have the **PyMongo** distribution :doc:`installed `. In the Python shell, the following should run without raising an exception: .. doctest:: >>> import pymongo This tutorial also assumes that a MongoDB instance is running on the default host and port. Assuming you have `downloaded and installed `_ MongoDB, you can start it like so: .. code-block:: bash $ mongod Making a Connection with MongoClient ------------------------------------ The first step when working with **PyMongo** is to create a :class:`~pymongo.mongo_client.MongoClient` to the running **mongod** instance. Doing so is easy: .. doctest:: >>> from pymongo import MongoClient >>> client = MongoClient() The above code will connect on the default host and port. We can also specify the host and port explicitly, as follows: .. doctest:: >>> client = MongoClient('localhost', 27017) Or use the MongoDB URI format: .. doctest:: >>> client = MongoClient('mongodb://localhost:27017/') Getting a Database ------------------ A single instance of MongoDB can support multiple independent `databases `_. When working with PyMongo you access databases using attribute style access on :class:`~pymongo.mongo_client.MongoClient` instances: .. doctest:: >>> db = client.test_database If your database name is such that using attribute style access won't work (like ``test-database``), you can use dictionary style access instead: .. doctest:: >>> db = client['test-database'] Getting a Collection -------------------- A `collection `_ is a group of documents stored in MongoDB, and can be thought of as roughly the equivalent of a table in a relational database. Getting a collection in PyMongo works the same as getting a database: .. doctest:: >>> collection = db.test_collection or (using dictionary style access): .. doctest:: >>> collection = db['test-collection'] An important note about collections (and databases) in MongoDB is that they are created lazily - none of the above commands have actually performed any operations on the MongoDB server. Collections and databases are created when the first document is inserted into them. Documents --------- Data in MongoDB is represented (and stored) using JSON-style documents. In PyMongo we use dictionaries to represent documents. As an example, the following dictionary might be used to represent a blog post: .. doctest:: >>> import datetime >>> post = {"author": "Mike", ... "text": "My first blog post!", ... "tags": ["mongodb", "python", "pymongo"], ... "date": datetime.datetime.utcnow()} Note that documents can contain native Python types (like :class:`datetime.datetime` instances) which will be automatically converted to and from the appropriate `BSON `_ types. .. todo:: link to table of Python <-> BSON types Inserting a Document -------------------- To insert a document into a collection we can use the :meth:`~pymongo.collection.Collection.insert_one` method: .. doctest:: >>> posts = db.posts >>> post_id = posts.insert_one(post).inserted_id >>> post_id ObjectId('...') When a document is inserted a special key, ``"_id"``, is automatically added if the document doesn't already contain an ``"_id"`` key. The value of ``"_id"`` must be unique across the collection. :meth:`~pymongo.collection.Collection.insert_one` returns an instance of :class:`~pymongo.results.InsertOneResult`. For more information on ``"_id"``, see the `documentation on _id `_. After inserting the first document, the *posts* collection has actually been created on the server. We can verify this by listing all of the collections in our database: .. doctest:: >>> db.list_collection_names() [u'posts'] Getting a Single Document With :meth:`~pymongo.collection.Collection.find_one` ------------------------------------------------------------------------------ The most basic type of query that can be performed in MongoDB is :meth:`~pymongo.collection.Collection.find_one`. This method returns a single document matching a query (or ``None`` if there are no matches). It is useful when you know there is only one matching document, or are only interested in the first match. Here we use :meth:`~pymongo.collection.Collection.find_one` to get the first document from the posts collection: .. doctest:: >>> import pprint >>> pprint.pprint(posts.find_one()) {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'mongodb', u'python', u'pymongo'], u'text': u'My first blog post!'} The result is a dictionary matching the one that we inserted previously. .. note:: The returned document contains an ``"_id"``, which was automatically added on insert. :meth:`~pymongo.collection.Collection.find_one` also supports querying on specific elements that the resulting document must match. To limit our results to a document with author "Mike" we do: .. doctest:: >>> pprint.pprint(posts.find_one({"author": "Mike"})) {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'mongodb', u'python', u'pymongo'], u'text': u'My first blog post!'} If we try with a different author, like "Eliot", we'll get no result: .. doctest:: >>> posts.find_one({"author": "Eliot"}) >>> .. _querying-by-objectid: Querying By ObjectId -------------------- We can also find a post by its ``_id``, which in our example is an ObjectId: .. doctest:: >>> post_id ObjectId(...) >>> pprint.pprint(posts.find_one({"_id": post_id})) {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'mongodb', u'python', u'pymongo'], u'text': u'My first blog post!'} Note that an ObjectId is not the same as its string representation: .. doctest:: >>> post_id_as_str = str(post_id) >>> posts.find_one({"_id": post_id_as_str}) # No result >>> A common task in web applications is to get an ObjectId from the request URL and find the matching document. It's necessary in this case to **convert the ObjectId from a string** before passing it to ``find_one``:: from bson.objectid import ObjectId # The web framework gets post_id from the URL and passes it as a string def get(post_id): # Convert from string to ObjectId: document = client.db.collection.find_one({'_id': ObjectId(post_id)}) .. seealso:: :ref:`web-application-querying-by-objectid` A Note On Unicode Strings ------------------------- You probably noticed that the regular Python strings we stored earlier look different when retrieved from the server (e.g. u'Mike' instead of 'Mike'). A short explanation is in order. MongoDB stores data in `BSON format `_. BSON strings are UTF-8 encoded so PyMongo must ensure that any strings it stores contain only valid UTF-8 data. Regular strings () are validated and stored unaltered. Unicode strings () are encoded UTF-8 first. The reason our example string is represented in the Python shell as u'Mike' instead of 'Mike' is that PyMongo decodes each BSON string to a Python unicode string, not a regular str. `You can read more about Python unicode strings here `_. Bulk Inserts ------------ In order to make querying a little more interesting, let's insert a few more documents. In addition to inserting a single document, we can also perform *bulk insert* operations, by passing a list as the first argument to :meth:`~pymongo.collection.Collection.insert_many`. This will insert each document in the list, sending only a single command to the server: .. doctest:: >>> new_posts = [{"author": "Mike", ... "text": "Another post!", ... "tags": ["bulk", "insert"], ... "date": datetime.datetime(2009, 11, 12, 11, 14)}, ... {"author": "Eliot", ... "title": "MongoDB is fun", ... "text": "and pretty easy too!", ... "date": datetime.datetime(2009, 11, 10, 10, 45)}] >>> result = posts.insert_many(new_posts) >>> result.inserted_ids [ObjectId('...'), ObjectId('...')] There are a couple of interesting things to note about this example: - The result from :meth:`~pymongo.collection.Collection.insert_many` now returns two :class:`~bson.objectid.ObjectId` instances, one for each inserted document. - ``new_posts[1]`` has a different "shape" than the other posts - there is no ``"tags"`` field and we've added a new field, ``"title"``. This is what we mean when we say that MongoDB is *schema-free*. Querying for More Than One Document ----------------------------------- To get more than a single document as the result of a query we use the :meth:`~pymongo.collection.Collection.find` method. :meth:`~pymongo.collection.Collection.find` returns a :class:`~pymongo.cursor.Cursor` instance, which allows us to iterate over all matching documents. For example, we can iterate over every document in the ``posts`` collection: .. doctest:: >>> for post in posts.find(): ... pprint.pprint(post) ... {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'mongodb', u'python', u'pymongo'], u'text': u'My first blog post!'} {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'bulk', u'insert'], u'text': u'Another post!'} {u'_id': ObjectId('...'), u'author': u'Eliot', u'date': datetime.datetime(...), u'text': u'and pretty easy too!', u'title': u'MongoDB is fun'} Just like we did with :meth:`~pymongo.collection.Collection.find_one`, we can pass a document to :meth:`~pymongo.collection.Collection.find` to limit the returned results. Here, we get only those documents whose author is "Mike": .. doctest:: >>> for post in posts.find({"author": "Mike"}): ... pprint.pprint(post) ... {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'mongodb', u'python', u'pymongo'], u'text': u'My first blog post!'} {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'bulk', u'insert'], u'text': u'Another post!'} Counting -------- If we just want to know how many documents match a query we can perform a :meth:`~pymongo.collection.Collection.count_documents` operation instead of a full query. We can get a count of all of the documents in a collection: .. doctest:: >>> posts.count_documents({}) 3 or just of those documents that match a specific query: .. doctest:: >>> posts.count_documents({"author": "Mike"}) 2 Range Queries ------------- MongoDB supports many different types of `advanced queries `_. As an example, lets perform a query where we limit results to posts older than a certain date, but also sort the results by author: .. doctest:: >>> d = datetime.datetime(2009, 11, 12, 12) >>> for post in posts.find({"date": {"$lt": d}}).sort("author"): ... pprint.pprint(post) ... {u'_id': ObjectId('...'), u'author': u'Eliot', u'date': datetime.datetime(...), u'text': u'and pretty easy too!', u'title': u'MongoDB is fun'} {u'_id': ObjectId('...'), u'author': u'Mike', u'date': datetime.datetime(...), u'tags': [u'bulk', u'insert'], u'text': u'Another post!'} Here we use the special ``"$lt"`` operator to do a range query, and also call :meth:`~pymongo.cursor.Cursor.sort` to sort the results by author. Indexing -------- Adding indexes can help accelerate certain queries and can also add additional functionality to querying and storing documents. In this example, we'll demonstrate how to create a `unique index `_ on a key that rejects documents whose value for that key already exists in the index. First, we'll need to create the index: .. doctest:: >>> result = db.profiles.create_index([('user_id', pymongo.ASCENDING)], ... unique=True) >>> sorted(list(db.profiles.index_information())) [u'_id_', u'user_id_1'] Notice that we have two indexes now: one is the index on ``_id`` that MongoDB creates automatically, and the other is the index on ``user_id`` we just created. Now let's set up some user profiles: .. doctest:: >>> user_profiles = [ ... {'user_id': 211, 'name': 'Luke'}, ... {'user_id': 212, 'name': 'Ziltoid'}] >>> result = db.profiles.insert_many(user_profiles) The index prevents us from inserting a document whose ``user_id`` is already in the collection: .. doctest:: :options: +IGNORE_EXCEPTION_DETAIL >>> new_profile = {'user_id': 213, 'name': 'Drew'} >>> duplicate_profile = {'user_id': 212, 'name': 'Tommy'} >>> result = db.profiles.insert_one(new_profile) # This is fine. >>> result = db.profiles.insert_one(duplicate_profile) Traceback (most recent call last): DuplicateKeyError: E11000 duplicate key error index: test_database.profiles.$user_id_1 dup key: { : 212 } .. seealso:: The MongoDB documentation on `indexes `_ pymongo-3.11.0/ez_setup.py000066400000000000000000000303711374256237000155110ustar00rootroot00000000000000#!/usr/bin/env python """ Setuptools bootstrapping installer. Maintained at https://github.com/pypa/setuptools/tree/bootstrap. Run this script to install or upgrade setuptools. This method is DEPRECATED. Check https://github.com/pypa/setuptools/issues/581 for more details. """ import os import shutil import sys import tempfile import zipfile import optparse import subprocess import platform import textwrap import contextlib from distutils import log try: from urllib.request import urlopen except ImportError: from urllib2 import urlopen try: from site import USER_SITE except ImportError: USER_SITE = None # 33.1.1 is the last version that supports setuptools self upgrade/installation. DEFAULT_VERSION = "33.1.1" DEFAULT_URL = "https://pypi.io/packages/source/s/setuptools/" DEFAULT_SAVE_DIR = os.curdir DEFAULT_DEPRECATION_MESSAGE = "ez_setup.py is deprecated and when using it setuptools will be pinned to {0} since it's the last version that supports setuptools self upgrade/installation, check https://github.com/pypa/setuptools/issues/581 for more info; use pip to install setuptools" MEANINGFUL_INVALID_ZIP_ERR_MSG = 'Maybe {0} is corrupted, delete it and try again.' log.warn(DEFAULT_DEPRECATION_MESSAGE.format(DEFAULT_VERSION)) def _python_cmd(*args): """ Execute a command. Return True if the command succeeded. """ args = (sys.executable,) + args return subprocess.call(args) == 0 def _install(archive_filename, install_args=()): """Install Setuptools.""" with archive_context(archive_filename): # installing log.warn('Installing Setuptools') if not _python_cmd('setup.py', 'install', *install_args): log.warn('Something went wrong during the installation.') log.warn('See the error message above.') # exitcode will be 2 return 2 def _build_egg(egg, archive_filename, to_dir): """Build Setuptools egg.""" with archive_context(archive_filename): # building an egg log.warn('Building a Setuptools egg in %s', to_dir) _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir) # returning the result log.warn(egg) if not os.path.exists(egg): raise IOError('Could not build the egg.') class ContextualZipFile(zipfile.ZipFile): """Supplement ZipFile class to support context manager for Python 2.6.""" def __enter__(self): return self def __exit__(self, type, value, traceback): self.close() def __new__(cls, *args, **kwargs): """Construct a ZipFile or ContextualZipFile as appropriate.""" if hasattr(zipfile.ZipFile, '__exit__'): return zipfile.ZipFile(*args, **kwargs) return super(ContextualZipFile, cls).__new__(cls) @contextlib.contextmanager def archive_context(filename): """ Unzip filename to a temporary directory, set to the cwd. The unzipped target is cleaned up after. """ tmpdir = tempfile.mkdtemp() log.warn('Extracting in %s', tmpdir) old_wd = os.getcwd() try: os.chdir(tmpdir) try: with ContextualZipFile(filename) as archive: archive.extractall() except zipfile.BadZipfile as err: if not err.args: err.args = ('', ) err.args = err.args + ( MEANINGFUL_INVALID_ZIP_ERR_MSG.format(filename), ) raise # going in the directory subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) os.chdir(subdir) log.warn('Now working in %s', subdir) yield finally: os.chdir(old_wd) shutil.rmtree(tmpdir) def _do_download(version, download_base, to_dir, download_delay): """Download Setuptools.""" py_desig = 'py{sys.version_info[0]}.{sys.version_info[1]}'.format(sys=sys) tp = 'setuptools-{version}-{py_desig}.egg' egg = os.path.join(to_dir, tp.format(**locals())) if not os.path.exists(egg): archive = download_setuptools(version, download_base, to_dir, download_delay) _build_egg(egg, archive, to_dir) sys.path.insert(0, egg) # Remove previously-imported pkg_resources if present (see # https://bitbucket.org/pypa/setuptools/pull-request/7/ for details). if 'pkg_resources' in sys.modules: _unload_pkg_resources() import setuptools setuptools.bootstrap_install_from = egg def use_setuptools( version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=DEFAULT_SAVE_DIR, download_delay=15): """ Ensure that a setuptools version is installed. Return None. Raise SystemExit if the requested version or later cannot be installed. """ to_dir = os.path.abspath(to_dir) # prior to importing, capture the module state for # representative modules. rep_modules = 'pkg_resources', 'setuptools' imported = set(sys.modules).intersection(rep_modules) try: import pkg_resources pkg_resources.require("setuptools>=" + version) # a suitable version is already installed return except ImportError: # pkg_resources not available; setuptools is not installed; download pass except pkg_resources.DistributionNotFound: # no version of setuptools was found; allow download pass except pkg_resources.VersionConflict as VC_err: if imported: _conflict_bail(VC_err, version) # otherwise, unload pkg_resources to allow the downloaded version to # take precedence. del pkg_resources _unload_pkg_resources() return _do_download(version, download_base, to_dir, download_delay) def _conflict_bail(VC_err, version): """ Setuptools was imported prior to invocation, so it is unsafe to unload it. Bail out. """ conflict_tmpl = textwrap.dedent(""" The required version of setuptools (>={version}) is not available, and can't be installed while this script is running. Please install a more recent version first, using 'easy_install -U setuptools'. (Currently using {VC_err.args[0]!r}) """) msg = conflict_tmpl.format(**locals()) sys.stderr.write(msg) sys.exit(2) def _unload_pkg_resources(): sys.meta_path = [ importer for importer in sys.meta_path if importer.__class__.__module__ != 'pkg_resources.extern' ] del_modules = [ name for name in sys.modules if name.startswith('pkg_resources') ] for mod_name in del_modules: del sys.modules[mod_name] def _clean_check(cmd, target): """ Run the command to download target. If the command fails, clean up before re-raising the error. """ try: subprocess.check_call(cmd) except subprocess.CalledProcessError: if os.access(target, os.F_OK): os.unlink(target) raise def download_file_powershell(url, target): """ Download the file at url to target using Powershell. Powershell will validate trust. Raise an exception if the command cannot complete. """ target = os.path.abspath(target) ps_cmd = ( "[System.Net.WebRequest]::DefaultWebProxy.Credentials = " "[System.Net.CredentialCache]::DefaultCredentials; " '(new-object System.Net.WebClient).DownloadFile("%(url)s", "%(target)s")' % locals() ) cmd = [ 'powershell', '-Command', ps_cmd, ] _clean_check(cmd, target) def has_powershell(): """Determine if Powershell is available.""" if platform.system() != 'Windows': return False cmd = ['powershell', '-Command', 'echo test'] with open(os.path.devnull, 'wb') as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True download_file_powershell.viable = has_powershell def download_file_curl(url, target): cmd = ['curl', url, '--location', '--silent', '--output', target] _clean_check(cmd, target) def has_curl(): cmd = ['curl', '--version'] with open(os.path.devnull, 'wb') as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True download_file_curl.viable = has_curl def download_file_wget(url, target): cmd = ['wget', url, '--quiet', '--output-document', target] _clean_check(cmd, target) def has_wget(): cmd = ['wget', '--version'] with open(os.path.devnull, 'wb') as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True download_file_wget.viable = has_wget def download_file_insecure(url, target): """Use Python to download the file, without connection authentication.""" src = urlopen(url) try: # Read all the data in one block. data = src.read() finally: src.close() # Write all the data in one block to avoid creating a partial file. with open(target, "wb") as dst: dst.write(data) download_file_insecure.viable = lambda: True def get_best_downloader(): downloaders = ( download_file_powershell, download_file_curl, download_file_wget, download_file_insecure, ) viable_downloaders = (dl for dl in downloaders if dl.viable()) return next(viable_downloaders, None) def download_setuptools( version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=DEFAULT_SAVE_DIR, delay=15, downloader_factory=get_best_downloader): """ Download setuptools from a specified location and return its filename. `version` should be a valid setuptools version number that is available as an sdist for download under the `download_base` URL (which should end with a '/'). `to_dir` is the directory where the egg will be downloaded. `delay` is the number of seconds to pause before an actual download attempt. ``downloader_factory`` should be a function taking no arguments and returning a function for downloading a URL to a target. """ # making sure we use the absolute path to_dir = os.path.abspath(to_dir) zip_name = "setuptools-%s.zip" % version url = download_base + zip_name saveto = os.path.join(to_dir, zip_name) if not os.path.exists(saveto): # Avoid repeated downloads log.warn("Downloading %s", url) downloader = downloader_factory() downloader(url, saveto) return os.path.realpath(saveto) def _build_install_args(options): """ Build the arguments to 'python setup.py install' on the setuptools package. Returns list of command line arguments. """ return ['--user'] if options.user_install else [] def _parse_args(): """Parse the command line for options.""" parser = optparse.OptionParser() parser.add_option( '--user', dest='user_install', action='store_true', default=False, help='install in user site package') parser.add_option( '--download-base', dest='download_base', metavar="URL", default=DEFAULT_URL, help='alternative URL from where to download the setuptools package') parser.add_option( '--insecure', dest='downloader_factory', action='store_const', const=lambda: download_file_insecure, default=get_best_downloader, help='Use internal, non-validating downloader' ) parser.add_option( '--version', help="Specify which version to download", default=DEFAULT_VERSION, ) parser.add_option( '--to-dir', help="Directory to save (and re-use) package", default=DEFAULT_SAVE_DIR, ) options, args = parser.parse_args() # positional arguments are ignored return options def _download_args(options): """Return args for download_setuptools function from cmdline args.""" return dict( version=options.version, download_base=options.download_base, downloader_factory=options.downloader_factory, to_dir=options.to_dir, ) def main(): """Install or upgrade setuptools and EasyInstall.""" options = _parse_args() archive = download_setuptools(**_download_args(options)) return _install(archive, _build_install_args(options)) if __name__ == '__main__': sys.exit(main()) pymongo-3.11.0/gridfs/000077500000000000000000000000001374256237000145535ustar00rootroot00000000000000pymongo-3.11.0/gridfs/__init__.py000066400000000000000000001100271374256237000166650ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GridFS is a specification for storing large objects in Mongo. The :mod:`gridfs` package is an implementation of GridFS on top of :mod:`pymongo`, exposing a file-like interface. .. mongodoc:: gridfs """ from bson.py3compat import abc from gridfs.errors import NoFile from gridfs.grid_file import (GridIn, GridOut, GridOutCursor, DEFAULT_CHUNK_SIZE, _clear_entity_type_registry, _disallow_transactions) from pymongo import (ASCENDING, DESCENDING) from pymongo.common import UNAUTHORIZED_CODES, validate_string from pymongo.database import Database from pymongo.errors import ConfigurationError, OperationFailure class GridFS(object): """An instance of GridFS on top of a single Database. """ def __init__(self, database, collection="fs", disable_md5=False): """Create a new instance of :class:`GridFS`. Raises :class:`TypeError` if `database` is not an instance of :class:`~pymongo.database.Database`. :Parameters: - `database`: database to use - `collection` (optional): root collection to use - `disable_md5` (optional): When True, MD5 checksums will not be computed for uploaded files. Useful in environments where MD5 cannot be used for regulatory or other reasons. Defaults to False. .. versionchanged:: 3.11 Running a GridFS operation in a transaction now always raises an error. GridFS does not support multi-document transactions. .. versionchanged:: 3.1 Indexes are only ensured on the first write to the DB. .. versionchanged:: 3.0 `database` must use an acknowledged :attr:`~pymongo.database.Database.write_concern` .. mongodoc:: gridfs """ if not isinstance(database, Database): raise TypeError("database must be an instance of Database") database = _clear_entity_type_registry(database) if not database.write_concern.acknowledged: raise ConfigurationError('database must use ' 'acknowledged write_concern') self.__collection = database[collection] self.__files = self.__collection.files self.__chunks = self.__collection.chunks self.__disable_md5 = disable_md5 def new_file(self, **kwargs): """Create a new file in GridFS. Returns a new :class:`~gridfs.grid_file.GridIn` instance to which data can be written. Any keyword arguments will be passed through to :meth:`~gridfs.grid_file.GridIn`. If the ``"_id"`` of the file is manually specified, it must not already exist in GridFS. Otherwise :class:`~gridfs.errors.FileExists` is raised. :Parameters: - `**kwargs` (optional): keyword arguments for file creation """ return GridIn( self.__collection, disable_md5=self.__disable_md5, **kwargs) def put(self, data, **kwargs): """Put data in GridFS as a new file. Equivalent to doing:: try: f = new_file(**kwargs) f.write(data) finally: f.close() `data` can be either an instance of :class:`str` (:class:`bytes` in python 3) or a file-like object providing a :meth:`read` method. If an `encoding` keyword argument is passed, `data` can also be a :class:`unicode` (:class:`str` in python 3) instance, which will be encoded as `encoding` before being written. Any keyword arguments will be passed through to the created file - see :meth:`~gridfs.grid_file.GridIn` for possible arguments. Returns the ``"_id"`` of the created file. If the ``"_id"`` of the file is manually specified, it must not already exist in GridFS. Otherwise :class:`~gridfs.errors.FileExists` is raised. :Parameters: - `data`: data to be written as a file. - `**kwargs` (optional): keyword arguments for file creation .. versionchanged:: 3.0 w=0 writes to GridFS are now prohibited. """ grid_file = GridIn( self.__collection, disable_md5=self.__disable_md5, **kwargs) try: grid_file.write(data) finally: grid_file.close() return grid_file._id def get(self, file_id, session=None): """Get a file from GridFS by ``"_id"``. Returns an instance of :class:`~gridfs.grid_file.GridOut`, which provides a file-like interface for reading. :Parameters: - `file_id`: ``"_id"`` of the file to get - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ gout = GridOut(self.__collection, file_id, session=session) # Raise NoFile now, instead of on first attribute access. gout._ensure_file() return gout def get_version(self, filename=None, version=-1, session=None, **kwargs): """Get a file from GridFS by ``"filename"`` or metadata fields. Returns a version of the file in GridFS whose filename matches `filename` and whose metadata fields match the supplied keyword arguments, as an instance of :class:`~gridfs.grid_file.GridOut`. Version numbering is a convenience atop the GridFS API provided by MongoDB. If more than one file matches the query (either by `filename` alone, by metadata fields, or by a combination of both), then version ``-1`` will be the most recently uploaded matching file, ``-2`` the second most recently uploaded, etc. Version ``0`` will be the first version uploaded, ``1`` the second version, etc. So if three versions have been uploaded, then version ``0`` is the same as version ``-3``, version ``1`` is the same as version ``-2``, and version ``2`` is the same as version ``-1``. Raises :class:`~gridfs.errors.NoFile` if no such version of that file exists. :Parameters: - `filename`: ``"filename"`` of the file to get, or `None` - `version` (optional): version of the file to get (defaults to -1, the most recent version uploaded) - `session` (optional): a :class:`~pymongo.client_session.ClientSession` - `**kwargs` (optional): find files by custom metadata. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.1 ``get_version`` no longer ensures indexes. """ query = kwargs if filename is not None: query["filename"] = filename _disallow_transactions(session) cursor = self.__files.find(query, session=session) if version < 0: skip = abs(version) - 1 cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING) else: cursor.limit(-1).skip(version).sort("uploadDate", ASCENDING) try: doc = next(cursor) return GridOut( self.__collection, file_document=doc, session=session) except StopIteration: raise NoFile("no version %d for filename %r" % (version, filename)) def get_last_version(self, filename=None, session=None, **kwargs): """Get the most recent version of a file in GridFS by ``"filename"`` or metadata fields. Equivalent to calling :meth:`get_version` with the default `version` (``-1``). :Parameters: - `filename`: ``"filename"`` of the file to get, or `None` - `session` (optional): a :class:`~pymongo.client_session.ClientSession` - `**kwargs` (optional): find files by custom metadata. .. versionchanged:: 3.6 Added ``session`` parameter. """ return self.get_version(filename=filename, session=session, **kwargs) # TODO add optional safe mode for chunk removal? def delete(self, file_id, session=None): """Delete a file from GridFS by ``"_id"``. Deletes all data belonging to the file with ``"_id"``: `file_id`. .. warning:: Any processes/threads reading from the file while this method is executing will likely see an invalid/corrupt file. Care should be taken to avoid concurrent reads to a file while it is being deleted. .. note:: Deletes of non-existent files are considered successful since the end result is the same: no file with that _id remains. :Parameters: - `file_id`: ``"_id"`` of the file to delete - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.1 ``delete`` no longer ensures indexes. """ _disallow_transactions(session) self.__files.delete_one({"_id": file_id}, session=session) self.__chunks.delete_many({"files_id": file_id}, session=session) def list(self, session=None): """List the names of all files stored in this instance of :class:`GridFS`. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.1 ``list`` no longer ensures indexes. """ _disallow_transactions(session) # With an index, distinct includes documents with no filename # as None. return [ name for name in self.__files.distinct("filename", session=session) if name is not None] def find_one(self, filter=None, session=None, *args, **kwargs): """Get a single file from gridfs. All arguments to :meth:`find` are also valid arguments for :meth:`find_one`, although any `limit` argument will be ignored. Returns a single :class:`~gridfs.grid_file.GridOut`, or ``None`` if no matching file is found. For example:: file = fs.find_one({"filename": "lisa.txt"}) :Parameters: - `filter` (optional): a dictionary specifying the query to be performing OR any other type to be used as the value for a query for ``"_id"`` in the file collection. - `*args` (optional): any additional positional arguments are the same as the arguments to :meth:`find`. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` - `**kwargs` (optional): any additional keyword arguments are the same as the arguments to :meth:`find`. .. versionchanged:: 3.6 Added ``session`` parameter. """ if filter is not None and not isinstance(filter, abc.Mapping): filter = {"_id": filter} _disallow_transactions(session) for f in self.find(filter, *args, session=session, **kwargs): return f return None def find(self, *args, **kwargs): """Query GridFS for files. Returns a cursor that iterates across files matching arbitrary queries on the files collection. Can be combined with other modifiers for additional control. For example:: for grid_out in fs.find({"filename": "lisa.txt"}, no_cursor_timeout=True): data = grid_out.read() would iterate through all versions of "lisa.txt" stored in GridFS. Note that setting no_cursor_timeout to True may be important to prevent the cursor from timing out during long multi-file processing work. As another example, the call:: most_recent_three = fs.find().sort("uploadDate", -1).limit(3) would return a cursor to the three most recently uploaded files in GridFS. Follows a similar interface to :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. If a :class:`~pymongo.client_session.ClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. :Parameters: - `filter` (optional): a SON object specifying elements which must be present for a document to be included in the result set - `skip` (optional): the number of files to omit (from the start of the result set) when returning the results - `limit` (optional): the maximum number of results to return - `no_cursor_timeout` (optional): if False (the default), any returned cursor is closed by the server after 10 minutes of inactivity. If set to True, the returned cursor will never time out on the server. Care should be taken to ensure that cursors with no_cursor_timeout turned on are properly closed. - `sort` (optional): a list of (key, direction) pairs specifying the sort order for this query. See :meth:`~pymongo.cursor.Cursor.sort` for details. Raises :class:`TypeError` if any of the arguments are of improper type. Returns an instance of :class:`~gridfs.grid_file.GridOutCursor` corresponding to this query. .. versionchanged:: 3.0 Removed the read_preference, tag_sets, and secondary_acceptable_latency_ms options. .. versionadded:: 2.7 .. mongodoc:: find """ return GridOutCursor(self.__collection, *args, **kwargs) def exists(self, document_or_id=None, session=None, **kwargs): """Check if a file exists in this instance of :class:`GridFS`. The file to check for can be specified by the value of its ``_id`` key, or by passing in a query document. A query document can be passed in as dictionary, or by using keyword arguments. Thus, the following three calls are equivalent: >>> fs.exists(file_id) >>> fs.exists({"_id": file_id}) >>> fs.exists(_id=file_id) As are the following two calls: >>> fs.exists({"filename": "mike.txt"}) >>> fs.exists(filename="mike.txt") And the following two: >>> fs.exists({"foo": {"$gt": 12}}) >>> fs.exists(foo={"$gt": 12}) Returns ``True`` if a matching file exists, ``False`` otherwise. Calls to :meth:`exists` will not automatically create appropriate indexes; application developers should be sure to create indexes if needed and as appropriate. :Parameters: - `document_or_id` (optional): query document, or _id of the document to check for - `session` (optional): a :class:`~pymongo.client_session.ClientSession` - `**kwargs` (optional): keyword arguments are used as a query document, if they're present. .. versionchanged:: 3.6 Added ``session`` parameter. """ _disallow_transactions(session) if kwargs: f = self.__files.find_one(kwargs, ["_id"], session=session) else: f = self.__files.find_one(document_or_id, ["_id"], session=session) return f is not None class GridFSBucket(object): """An instance of GridFS on top of a single Database.""" def __init__(self, db, bucket_name="fs", chunk_size_bytes=DEFAULT_CHUNK_SIZE, write_concern=None, read_preference=None, disable_md5=False): """Create a new instance of :class:`GridFSBucket`. Raises :exc:`TypeError` if `database` is not an instance of :class:`~pymongo.database.Database`. Raises :exc:`~pymongo.errors.ConfigurationError` if `write_concern` is not acknowledged. :Parameters: - `database`: database to use. - `bucket_name` (optional): The name of the bucket. Defaults to 'fs'. - `chunk_size_bytes` (optional): The chunk size in bytes. Defaults to 255KB. - `write_concern` (optional): The :class:`~pymongo.write_concern.WriteConcern` to use. If ``None`` (the default) db.write_concern is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) db.read_preference is used. - `disable_md5` (optional): When True, MD5 checksums will not be computed for uploaded files. Useful in environments where MD5 cannot be used for regulatory or other reasons. Defaults to False. .. versionchanged:: 3.11 Running a GridFS operation in a transaction now always raises an error. GridFSBucket does not support multi-document transactions. .. versionadded:: 3.1 .. mongodoc:: gridfs """ if not isinstance(db, Database): raise TypeError("database must be an instance of Database") db = _clear_entity_type_registry(db) wtc = write_concern if write_concern is not None else db.write_concern if not wtc.acknowledged: raise ConfigurationError('write concern must be acknowledged') self._bucket_name = bucket_name self._collection = db[bucket_name] self._disable_md5 = disable_md5 self._chunks = self._collection.chunks.with_options( write_concern=write_concern, read_preference=read_preference) self._files = self._collection.files.with_options( write_concern=write_concern, read_preference=read_preference) self._chunk_size_bytes = chunk_size_bytes def open_upload_stream(self, filename, chunk_size_bytes=None, metadata=None, session=None): """Opens a Stream that the application can write the contents of the file to. The user must specify the filename, and can choose to add any additional information in the metadata field of the file document or modify the chunk size. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) grid_in = fs.open_upload_stream( "test_file", chunk_size_bytes=4, metadata={"contentType": "text/plain"}) grid_in.write("data I want to store!") grid_in.close() # uploaded on close Returns an instance of :class:`~gridfs.grid_file.GridIn`. Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` if `filename` is not a string. :Parameters: - `filename`: The name of the file to upload. - `chunk_size_bytes` (options): The number of bytes per chunk of this file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`. - `metadata` (optional): User data for the 'metadata' field of the files collection document. If not provided the metadata field will be omitted from the files collection document. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ validate_string("filename", filename) opts = {"filename": filename, "chunk_size": (chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes)} if metadata is not None: opts["metadata"] = metadata return GridIn( self._collection, session=session, disable_md5=self._disable_md5, **opts) def open_upload_stream_with_id( self, file_id, filename, chunk_size_bytes=None, metadata=None, session=None): """Opens a Stream that the application can write the contents of the file to. The user must specify the file id and filename, and can choose to add any additional information in the metadata field of the file document or modify the chunk size. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) grid_in = fs.open_upload_stream_with_id( ObjectId(), "test_file", chunk_size_bytes=4, metadata={"contentType": "text/plain"}) grid_in.write("data I want to store!") grid_in.close() # uploaded on close Returns an instance of :class:`~gridfs.grid_file.GridIn`. Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` if `filename` is not a string. :Parameters: - `file_id`: The id to use for this file. The id must not have already been used for another file. - `filename`: The name of the file to upload. - `chunk_size_bytes` (options): The number of bytes per chunk of this file. Defaults to the chunk_size_bytes in :class:`GridFSBucket`. - `metadata` (optional): User data for the 'metadata' field of the files collection document. If not provided the metadata field will be omitted from the files collection document. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ validate_string("filename", filename) opts = {"_id": file_id, "filename": filename, "chunk_size": (chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes)} if metadata is not None: opts["metadata"] = metadata return GridIn( self._collection, session=session, disable_md5=self._disable_md5, **opts) def upload_from_stream(self, filename, source, chunk_size_bytes=None, metadata=None, session=None): """Uploads a user file to a GridFS bucket. Reads the contents of the user file from `source` and uploads it to the file `filename`. Source can be a string or file-like object. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) file_id = fs.upload_from_stream( "test_file", "data I want to store!", chunk_size_bytes=4, metadata={"contentType": "text/plain"}) Returns the _id of the uploaded file. Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` if `filename` is not a string. :Parameters: - `filename`: The name of the file to upload. - `source`: The source stream of the content to be uploaded. Must be a file-like object that implements :meth:`read` or a string. - `chunk_size_bytes` (options): The number of bytes per chunk of this file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`. - `metadata` (optional): User data for the 'metadata' field of the files collection document. If not provided the metadata field will be omitted from the files collection document. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ with self.open_upload_stream( filename, chunk_size_bytes, metadata, session=session) as gin: gin.write(source) return gin._id def upload_from_stream_with_id(self, file_id, filename, source, chunk_size_bytes=None, metadata=None, session=None): """Uploads a user file to a GridFS bucket with a custom file id. Reads the contents of the user file from `source` and uploads it to the file `filename`. Source can be a string or file-like object. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) file_id = fs.upload_from_stream( ObjectId(), "test_file", "data I want to store!", chunk_size_bytes=4, metadata={"contentType": "text/plain"}) Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` if `filename` is not a string. :Parameters: - `file_id`: The id to use for this file. The id must not have already been used for another file. - `filename`: The name of the file to upload. - `source`: The source stream of the content to be uploaded. Must be a file-like object that implements :meth:`read` or a string. - `chunk_size_bytes` (options): The number of bytes per chunk of this file. Defaults to the chunk_size_bytes of :class:`GridFSBucket`. - `metadata` (optional): User data for the 'metadata' field of the files collection document. If not provided the metadata field will be omitted from the files collection document. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ with self.open_upload_stream_with_id( file_id, filename, chunk_size_bytes, metadata, session=session) as gin: gin.write(source) def open_download_stream(self, file_id, session=None): """Opens a Stream from which the application can read the contents of the stored file specified by file_id. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) # get _id of file to read. file_id = fs.upload_from_stream("test_file", "data I want to store!") grid_out = fs.open_download_stream(file_id) contents = grid_out.read() Returns an instance of :class:`~gridfs.grid_file.GridOut`. Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. :Parameters: - `file_id`: The _id of the file to be downloaded. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ gout = GridOut(self._collection, file_id, session=session) # Raise NoFile now, instead of on first attribute access. gout._ensure_file() return gout def download_to_stream(self, file_id, destination, session=None): """Downloads the contents of the stored file specified by file_id and writes the contents to `destination`. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) # Get _id of file to read file_id = fs.upload_from_stream("test_file", "data I want to store!") # Get file to write to file = open('myfile','wb+') fs.download_to_stream(file_id, file) file.seek(0) contents = file.read() Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. :Parameters: - `file_id`: The _id of the file to be downloaded. - `destination`: a file-like object implementing :meth:`write`. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ with self.open_download_stream(file_id, session=session) as gout: for chunk in gout: destination.write(chunk) def delete(self, file_id, session=None): """Given an file_id, delete this stored file's files collection document and associated chunks from a GridFS bucket. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) # Get _id of file to delete file_id = fs.upload_from_stream("test_file", "data I want to store!") fs.delete(file_id) Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. :Parameters: - `file_id`: The _id of the file to be deleted. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ _disallow_transactions(session) res = self._files.delete_one({"_id": file_id}, session=session) self._chunks.delete_many({"files_id": file_id}, session=session) if not res.deleted_count: raise NoFile( "no file could be deleted because none matched %s" % file_id) def find(self, *args, **kwargs): """Find and return the files collection documents that match ``filter`` Returns a cursor that iterates across files matching arbitrary queries on the files collection. Can be combined with other modifiers for additional control. For example:: for grid_data in fs.find({"filename": "lisa.txt"}, no_cursor_timeout=True): data = grid_data.read() would iterate through all versions of "lisa.txt" stored in GridFS. Note that setting no_cursor_timeout to True may be important to prevent the cursor from timing out during long multi-file processing work. As another example, the call:: most_recent_three = fs.find().sort("uploadDate", -1).limit(3) would return a cursor to the three most recently uploaded files in GridFS. Follows a similar interface to :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. If a :class:`~pymongo.client_session.ClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. :Parameters: - `filter`: Search query. - `batch_size` (optional): The number of documents to return per batch. - `limit` (optional): The maximum number of documents to return. - `no_cursor_timeout` (optional): The server normally times out idle cursors after an inactivity period (10 minutes) to prevent excess memory use. Set this option to True prevent that. - `skip` (optional): The number of documents to skip before returning. - `sort` (optional): The order by which to sort results. Defaults to None. """ return GridOutCursor(self._collection, *args, **kwargs) def open_download_stream_by_name(self, filename, revision=-1, session=None): """Opens a Stream from which the application can read the contents of `filename` and optional `revision`. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) grid_out = fs.open_download_stream_by_name("test_file") contents = grid_out.read() Returns an instance of :class:`~gridfs.grid_file.GridOut`. Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` filename is not a string. :Parameters: - `filename`: The name of the file to read from. - `revision` (optional): Which revision (documents with the same filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). - `session` (optional): a :class:`~pymongo.client_session.ClientSession` :Note: Revision numbers are defined as follows: - 0 = the original stored file - 1 = the first revision - 2 = the second revision - etc... - -2 = the second most recent revision - -1 = the most recent revision .. versionchanged:: 3.6 Added ``session`` parameter. """ validate_string("filename", filename) query = {"filename": filename} _disallow_transactions(session) cursor = self._files.find(query, session=session) if revision < 0: skip = abs(revision) - 1 cursor.limit(-1).skip(skip).sort("uploadDate", DESCENDING) else: cursor.limit(-1).skip(revision).sort("uploadDate", ASCENDING) try: grid_file = next(cursor) return GridOut( self._collection, file_document=grid_file, session=session) except StopIteration: raise NoFile( "no version %d for filename %r" % (revision, filename)) def download_to_stream_by_name(self, filename, destination, revision=-1, session=None): """Write the contents of `filename` (with optional `revision`) to `destination`. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) # Get file to write to file = open('myfile','wb') fs.download_to_stream_by_name("test_file", file) Raises :exc:`~gridfs.errors.NoFile` if no such version of that file exists. Raises :exc:`~ValueError` if `filename` is not a string. :Parameters: - `filename`: The name of the file to read from. - `destination`: A file-like object that implements :meth:`write`. - `revision` (optional): Which revision (documents with the same filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). - `session` (optional): a :class:`~pymongo.client_session.ClientSession` :Note: Revision numbers are defined as follows: - 0 = the original stored file - 1 = the first revision - 2 = the second revision - etc... - -2 = the second most recent revision - -1 = the most recent revision .. versionchanged:: 3.6 Added ``session`` parameter. """ with self.open_download_stream_by_name( filename, revision, session=session) as gout: for chunk in gout: destination.write(chunk) def rename(self, file_id, new_filename, session=None): """Renames the stored file with the specified file_id. For example:: my_db = MongoClient().test fs = GridFSBucket(my_db) # Get _id of file to rename file_id = fs.upload_from_stream("test_file", "data I want to store!") fs.rename(file_id, "new_test_name") Raises :exc:`~gridfs.errors.NoFile` if no file with file_id exists. :Parameters: - `file_id`: The _id of the file to be renamed. - `new_filename`: The new name of the file. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. """ _disallow_transactions(session) result = self._files.update_one({"_id": file_id}, {"$set": {"filename": new_filename}}, session=session) if not result.matched_count: raise NoFile("no files could be renamed %r because none " "matched file_id %i" % (new_filename, file_id)) pymongo-3.11.0/gridfs/errors.py000066400000000000000000000020401374256237000164350ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Exceptions raised by the :mod:`gridfs` package""" from pymongo.errors import PyMongoError class GridFSError(PyMongoError): """Base class for all GridFS exceptions.""" class CorruptGridFile(GridFSError): """Raised when a file in :class:`~gridfs.GridFS` is malformed.""" class NoFile(GridFSError): """Raised when trying to read from a non-existent file.""" class FileExists(GridFSError): """Raised when trying to create a file that already exists.""" pymongo-3.11.0/gridfs/grid_file.py000066400000000000000000000744531374256237000170660ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for representing files stored in GridFS.""" import datetime import hashlib import io import math import os from bson.int64 import Int64 from bson.son import SON from bson.binary import Binary from bson.objectid import ObjectId from bson.py3compat import text_type, StringIO from gridfs.errors import CorruptGridFile, FileExists, NoFile from pymongo import ASCENDING from pymongo.collection import Collection from pymongo.cursor import Cursor from pymongo.errors import (ConfigurationError, CursorNotFound, DuplicateKeyError, InvalidOperation, OperationFailure) from pymongo.read_preferences import ReadPreference try: _SEEK_SET = os.SEEK_SET _SEEK_CUR = os.SEEK_CUR _SEEK_END = os.SEEK_END # before 2.5 except AttributeError: _SEEK_SET = 0 _SEEK_CUR = 1 _SEEK_END = 2 EMPTY = b"" NEWLN = b"\n" """Default chunk size, in bytes.""" # Slightly under a power of 2, to work well with server's record allocations. DEFAULT_CHUNK_SIZE = 255 * 1024 _C_INDEX = SON([("files_id", ASCENDING), ("n", ASCENDING)]) _F_INDEX = SON([("filename", ASCENDING), ("uploadDate", ASCENDING)]) def _grid_in_property(field_name, docstring, read_only=False, closed_only=False): """Create a GridIn property.""" def getter(self): if closed_only and not self._closed: raise AttributeError("can only get %r on a closed file" % field_name) # Protect against PHP-237 if field_name == 'length': return self._file.get(field_name, 0) return self._file.get(field_name, None) def setter(self, value): if self._closed: self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}}) self._file[field_name] = value if read_only: docstring += "\n\nThis attribute is read-only." elif closed_only: docstring = "%s\n\n%s" % (docstring, "This attribute is read-only and " "can only be read after :meth:`close` " "has been called.") if not read_only and not closed_only: return property(getter, setter, doc=docstring) return property(getter, doc=docstring) def _grid_out_property(field_name, docstring): """Create a GridOut property.""" def getter(self): self._ensure_file() # Protect against PHP-237 if field_name == 'length': return self._file.get(field_name, 0) return self._file.get(field_name, None) docstring += "\n\nThis attribute is read-only." return property(getter, doc=docstring) def _clear_entity_type_registry(entity, **kwargs): """Clear the given database/collection object's type registry.""" codecopts = entity.codec_options.with_options(type_registry=None) return entity.with_options(codec_options=codecopts, **kwargs) def _disallow_transactions(session): if session and session.in_transaction: raise InvalidOperation( 'GridFS does not support multi-document transactions') class GridIn(object): """Class to write data to GridFS. """ def __init__( self, root_collection, session=None, disable_md5=False, **kwargs): """Write a file to GridFS Application developers should generally not need to instantiate this class directly - instead see the methods provided by :class:`~gridfs.GridFS`. Raises :class:`TypeError` if `root_collection` is not an instance of :class:`~pymongo.collection.Collection`. Any of the file level options specified in the `GridFS Spec `_ may be passed as keyword arguments. Any additional keyword arguments will be set as additional fields on the file document. Valid keyword arguments include: - ``"_id"``: unique ID for this file (default: :class:`~bson.objectid.ObjectId`) - this ``"_id"`` must not have already been used for another file - ``"filename"``: human name for the file - ``"contentType"`` or ``"content_type"``: valid mime-type for the file - ``"chunkSize"`` or ``"chunk_size"``: size of each of the chunks, in bytes (default: 255 kb) - ``"encoding"``: encoding used for this file. In Python 2, any :class:`unicode` that is written to the file will be converted to a :class:`str`. In Python 3, any :class:`str` that is written to the file will be converted to :class:`bytes`. :Parameters: - `root_collection`: root collection to write to - `session` (optional): a :class:`~pymongo.client_session.ClientSession` to use for all commands - `disable_md5` (optional): When True, an MD5 checksum will not be computed for the uploaded file. Useful in environments where MD5 cannot be used for regulatory or other reasons. Defaults to False. - `**kwargs` (optional): file level options (see above) .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.0 `root_collection` must use an acknowledged :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): raise TypeError("root_collection must be an " "instance of Collection") if not root_collection.write_concern.acknowledged: raise ConfigurationError('root_collection must use ' 'acknowledged write_concern') _disallow_transactions(session) # Handle alternative naming if "content_type" in kwargs: kwargs["contentType"] = kwargs.pop("content_type") if "chunk_size" in kwargs: kwargs["chunkSize"] = kwargs.pop("chunk_size") coll = _clear_entity_type_registry( root_collection, read_preference=ReadPreference.PRIMARY) if not disable_md5: kwargs["md5"] = hashlib.md5() # Defaults kwargs["_id"] = kwargs.get("_id", ObjectId()) kwargs["chunkSize"] = kwargs.get("chunkSize", DEFAULT_CHUNK_SIZE) object.__setattr__(self, "_session", session) object.__setattr__(self, "_coll", coll) object.__setattr__(self, "_chunks", coll.chunks) object.__setattr__(self, "_file", kwargs) object.__setattr__(self, "_buffer", StringIO()) object.__setattr__(self, "_position", 0) object.__setattr__(self, "_chunk_number", 0) object.__setattr__(self, "_closed", False) object.__setattr__(self, "_ensured_index", False) def __create_index(self, collection, index_key, unique): doc = collection.find_one(projection={"_id": 1}, session=self._session) if doc is None: try: index_keys = [index_spec['key'] for index_spec in collection.list_indexes(session=self._session)] except OperationFailure: index_keys = [] if index_key not in index_keys: collection.create_index( index_key.items(), unique=unique, session=self._session) def __ensure_indexes(self): if not object.__getattribute__(self, "_ensured_index"): _disallow_transactions(self._session) self.__create_index(self._coll.files, _F_INDEX, False) self.__create_index(self._coll.chunks, _C_INDEX, True) object.__setattr__(self, "_ensured_index", True) def abort(self): """Remove all chunks/files that may have been uploaded and close. """ self._coll.chunks.delete_many( {"files_id": self._file['_id']}, session=self._session) self._coll.files.delete_one( {"_id": self._file['_id']}, session=self._session) object.__setattr__(self, "_closed", True) @property def closed(self): """Is this file closed? """ return self._closed _id = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True) filename = _grid_in_property("filename", "Name of this file.") name = _grid_in_property("filename", "Alias for `filename`.") content_type = _grid_in_property("contentType", "Mime-type for this file.") length = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True) chunk_size = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True) upload_date = _grid_in_property("uploadDate", "Date that this file was uploaded.", closed_only=True) md5 = _grid_in_property("md5", "MD5 of the contents of this file " "if an md5 sum was created.", closed_only=True) def __getattr__(self, name): if name in self._file: return self._file[name] raise AttributeError("GridIn object has no attribute '%s'" % name) def __setattr__(self, name, value): # For properties of this instance like _buffer, or descriptors set on # the class like filename, use regular __setattr__ if name in self.__dict__ or name in self.__class__.__dict__: object.__setattr__(self, name, value) else: # All other attributes are part of the document in db.fs.files. # Store them to be sent to server on close() or if closed, send # them now. self._file[name] = value if self._closed: self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) def __flush_data(self, data): """Flush `data` to a chunk. """ self.__ensure_indexes() if 'md5' in self._file: self._file['md5'].update(data) if not data: return assert(len(data) <= self.chunk_size) chunk = {"files_id": self._file["_id"], "n": self._chunk_number, "data": Binary(data)} try: self._chunks.insert_one(chunk, session=self._session) except DuplicateKeyError: self._raise_file_exists(self._file['_id']) self._chunk_number += 1 self._position += len(data) def __flush_buffer(self): """Flush the buffer contents out to a chunk. """ self.__flush_data(self._buffer.getvalue()) self._buffer.close() self._buffer = StringIO() def __flush(self): """Flush the file to the database. """ try: self.__flush_buffer() if "md5" in self._file: self._file["md5"] = self._file["md5"].hexdigest() # The GridFS spec says length SHOULD be an Int64. self._file["length"] = Int64(self._position) self._file["uploadDate"] = datetime.datetime.utcnow() return self._coll.files.insert_one( self._file, session=self._session) except DuplicateKeyError: self._raise_file_exists(self._id) def _raise_file_exists(self, file_id): """Raise a FileExists exception for the given file_id.""" raise FileExists("file with _id %r already exists" % file_id) def close(self): """Flush the file and close it. A closed file cannot be written any more. Calling :meth:`close` more than once is allowed. """ if not self._closed: self.__flush() object.__setattr__(self, "_closed", True) def read(self, size=-1): raise io.UnsupportedOperation('read') def readable(self): return False def seekable(self): return False def write(self, data): """Write data to the file. There is no return value. `data` can be either a string of bytes or a file-like object (implementing :meth:`read`). If the file has an :attr:`encoding` attribute, `data` can also be a :class:`unicode` (:class:`str` in python 3) instance, which will be encoded as :attr:`encoding` before being written. Due to buffering, the data may not actually be written to the database until the :meth:`close` method is called. Raises :class:`ValueError` if this file is already closed. Raises :class:`TypeError` if `data` is not an instance of :class:`str` (:class:`bytes` in python 3), a file-like object, or an instance of :class:`unicode` (:class:`str` in python 3). Unicode data is only allowed if the file has an :attr:`encoding` attribute. :Parameters: - `data`: string of bytes or file-like object to be written to the file """ if self._closed: raise ValueError("cannot write to a closed file") try: # file-like read = data.read except AttributeError: # string if not isinstance(data, (text_type, bytes)): raise TypeError("can only write strings or file-like objects") if isinstance(data, text_type): try: data = data.encode(self.encoding) except AttributeError: raise TypeError("must specify an encoding for file in " "order to write %s" % (text_type.__name__,)) read = StringIO(data).read if self._buffer.tell() > 0: # Make sure to flush only when _buffer is complete space = self.chunk_size - self._buffer.tell() if space: try: to_write = read(space) except: self.abort() raise self._buffer.write(to_write) if len(to_write) < space: return # EOF or incomplete self.__flush_buffer() to_write = read(self.chunk_size) while to_write and len(to_write) == self.chunk_size: self.__flush_data(to_write) to_write = read(self.chunk_size) self._buffer.write(to_write) def writelines(self, sequence): """Write a sequence of strings to the file. Does not add seperators. """ for line in sequence: self.write(line) def writeable(self): return True def __enter__(self): """Support for the context manager protocol. """ return self def __exit__(self, exc_type, exc_val, exc_tb): """Support for the context manager protocol. Close the file and allow exceptions to propagate. """ self.close() # propagate exceptions return False class GridOut(object): """Class to read data out of GridFS. """ def __init__(self, root_collection, file_id=None, file_document=None, session=None): """Read a file from GridFS Application developers should generally not need to instantiate this class directly - instead see the methods provided by :class:`~gridfs.GridFS`. Either `file_id` or `file_document` must be specified, `file_document` will be given priority if present. Raises :class:`TypeError` if `root_collection` is not an instance of :class:`~pymongo.collection.Collection`. :Parameters: - `root_collection`: root collection to read from - `file_id` (optional): value of ``"_id"`` for the file to read - `file_document` (optional): file document from `root_collection.files` - `session` (optional): a :class:`~pymongo.client_session.ClientSession` to use for all commands .. versionchanged:: 3.8 For better performance and to better follow the GridFS spec, :class:`GridOut` now uses a single cursor to read all the chunks in the file. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.0 Creating a GridOut does not immediately retrieve the file metadata from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): raise TypeError("root_collection must be an " "instance of Collection") _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) self.__chunks = root_collection.chunks self.__files = root_collection.files self.__file_id = file_id self.__buffer = EMPTY self.__chunk_iter = None self.__position = 0 self._file = file_document self._session = session _id = _grid_out_property("_id", "The ``'_id'`` value for this file.") filename = _grid_out_property("filename", "Name of this file.") name = _grid_out_property("filename", "Alias for `filename`.") content_type = _grid_out_property("contentType", "Mime-type for this file.") length = _grid_out_property("length", "Length (in bytes) of this file.") chunk_size = _grid_out_property("chunkSize", "Chunk size for this file.") upload_date = _grid_out_property("uploadDate", "Date that this file was first uploaded.") aliases = _grid_out_property("aliases", "List of aliases for this file.") metadata = _grid_out_property("metadata", "Metadata attached to this file.") md5 = _grid_out_property("md5", "MD5 of the contents of this file " "if an md5 sum was created.") def _ensure_file(self): if not self._file: _disallow_transactions(self._session) self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session) if not self._file: raise NoFile("no file in gridfs collection %r with _id %r" % (self.__files, self.__file_id)) def __getattr__(self, name): self._ensure_file() if name in self._file: return self._file[name] raise AttributeError("GridOut object has no attribute '%s'" % name) def readable(self): return True def readchunk(self): """Reads a chunk at a time. If the current position is within a chunk the remainder of the chunk is returned. """ received = len(self.__buffer) chunk_data = EMPTY chunk_size = int(self.chunk_size) if received > 0: chunk_data = self.__buffer elif self.__position < int(self.length): chunk_number = int((received + self.__position) / chunk_size) if self.__chunk_iter is None: self.__chunk_iter = _GridOutChunkIterator( self, self.__chunks, self._session, chunk_number) chunk = self.__chunk_iter.next() chunk_data = chunk["data"][self.__position % chunk_size:] if not chunk_data: raise CorruptGridFile("truncated chunk") self.__position += len(chunk_data) self.__buffer = EMPTY return chunk_data def read(self, size=-1): """Read at most `size` bytes from the file (less if there isn't enough data). The bytes are returned as an instance of :class:`str` (:class:`bytes` in python 3). If `size` is negative or omitted all data is read. :Parameters: - `size` (optional): the number of bytes to read .. versionchanged:: 3.8 This method now only checks for extra chunks after reading the entire file. Previously, this method would check for extra chunks on every call. """ self._ensure_file() remainder = int(self.length) - self.__position if size < 0 or size > remainder: size = remainder if size == 0: return EMPTY received = 0 data = StringIO() while received < size: chunk_data = self.readchunk() received += len(chunk_data) data.write(chunk_data) # Detect extra chunks after reading the entire file. if size == remainder and self.__chunk_iter: try: self.__chunk_iter.next() except StopIteration: pass self.__position -= received - size # Return 'size' bytes and store the rest. data.seek(size) self.__buffer = data.read() data.seek(0) return data.read(size) def readline(self, size=-1): """Read one line or up to `size` bytes from the file. :Parameters: - `size` (optional): the maximum number of bytes to read """ remainder = int(self.length) - self.__position if size < 0 or size > remainder: size = remainder if size == 0: return EMPTY received = 0 data = StringIO() while received < size: chunk_data = self.readchunk() pos = chunk_data.find(NEWLN, 0, size) if pos != -1: size = received + pos + 1 received += len(chunk_data) data.write(chunk_data) if pos != -1: break self.__position -= received - size # Return 'size' bytes and store the rest. data.seek(size) self.__buffer = data.read() data.seek(0) return data.read(size) def tell(self): """Return the current position of this file. """ return self.__position def seek(self, pos, whence=_SEEK_SET): """Set the current position of this file. :Parameters: - `pos`: the position (or offset if using relative positioning) to seek to - `whence` (optional): where to seek from. :attr:`os.SEEK_SET` (``0``) for absolute file positioning, :attr:`os.SEEK_CUR` (``1``) to seek relative to the current position, :attr:`os.SEEK_END` (``2``) to seek relative to the file's end. """ if whence == _SEEK_SET: new_pos = pos elif whence == _SEEK_CUR: new_pos = self.__position + pos elif whence == _SEEK_END: new_pos = int(self.length) + pos else: raise IOError(22, "Invalid value for `whence`") if new_pos < 0: raise IOError(22, "Invalid value for `pos` - must be positive") # Optimization, continue using the same buffer and chunk iterator. if new_pos == self.__position: return self.__position = new_pos self.__buffer = EMPTY if self.__chunk_iter: self.__chunk_iter.close() self.__chunk_iter = None def seekable(self): return True def __iter__(self): """Return an iterator over all of this file's data. The iterator will return chunk-sized instances of :class:`str` (:class:`bytes` in python 3). This can be useful when serving files using a webserver that handles such an iterator efficiently. .. note:: This is different from :py:class:`io.IOBase` which iterates over *lines* in the file. Use :meth:`GridOut.readline` to read line by line instead of chunk by chunk. .. versionchanged:: 3.8 The iterator now raises :class:`CorruptGridFile` when encountering any truncated, missing, or extra chunk in a file. The previous behavior was to only raise :class:`CorruptGridFile` on a missing chunk. """ return GridOutIterator(self, self.__chunks, self._session) def close(self): """Make GridOut more generically file-like.""" if self.__chunk_iter: self.__chunk_iter.close() self.__chunk_iter = None def write(self, value): raise io.UnsupportedOperation('write') def __enter__(self): """Makes it possible to use :class:`GridOut` files with the context manager protocol. """ return self def __exit__(self, exc_type, exc_val, exc_tb): """Makes it possible to use :class:`GridOut` files with the context manager protocol. """ self.close() return False class _GridOutChunkIterator(object): """Iterates over a file's chunks using a single cursor. Raises CorruptGridFile when encountering any truncated, missing, or extra chunk in a file. """ def __init__(self, grid_out, chunks, session, next_chunk): self._id = grid_out._id self._chunk_size = int(grid_out.chunk_size) self._length = int(grid_out.length) self._chunks = chunks self._session = session self._next_chunk = next_chunk self._num_chunks = math.ceil(float(self._length) / self._chunk_size) self._cursor = None def expected_chunk_length(self, chunk_n): if chunk_n < self._num_chunks - 1: return self._chunk_size return self._length - (self._chunk_size * (self._num_chunks - 1)) def __iter__(self): return self def _create_cursor(self): filter = {"files_id": self._id} if self._next_chunk > 0: filter["n"] = {"$gte": self._next_chunk} _disallow_transactions(self._session) self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session) def _next_with_retry(self): """Return the next chunk and retry once on CursorNotFound. We retry on CursorNotFound to maintain backwards compatibility in cases where two calls to read occur more than 10 minutes apart (the server's default cursor timeout). """ if self._cursor is None: self._create_cursor() try: return self._cursor.next() except CursorNotFound: self._cursor.close() self._create_cursor() return self._cursor.next() def next(self): try: chunk = self._next_with_retry() except StopIteration: if self._next_chunk >= self._num_chunks: raise raise CorruptGridFile("no chunk #%d" % self._next_chunk) if chunk["n"] != self._next_chunk: self.close() raise CorruptGridFile( "Missing chunk: expected chunk #%d but found " "chunk with n=%d" % (self._next_chunk, chunk["n"])) if chunk["n"] >= self._num_chunks: # According to spec, ignore extra chunks if they are empty. if len(chunk["data"]): self.close() raise CorruptGridFile( "Extra chunk found: expected %d chunks but found " "chunk with n=%d" % (self._num_chunks, chunk["n"])) expected_length = self.expected_chunk_length(chunk["n"]) if len(chunk["data"]) != expected_length: self.close() raise CorruptGridFile( "truncated chunk #%d: expected chunk length to be %d but " "found chunk with length %d" % ( chunk["n"], expected_length, len(chunk["data"]))) self._next_chunk += 1 return chunk __next__ = next def close(self): if self._cursor: self._cursor.close() self._cursor = None class GridOutIterator(object): def __init__(self, grid_out, chunks, session): self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0) def __iter__(self): return self def next(self): chunk = self.__chunk_iter.next() return bytes(chunk["data"]) __next__ = next class GridOutCursor(Cursor): """A cursor / iterator for returning GridOut objects as the result of an arbitrary query against the GridFS files collection. """ def __init__(self, collection, filter=None, skip=0, limit=0, no_cursor_timeout=False, sort=None, batch_size=0, session=None): """Create a new cursor, similar to the normal :class:`~pymongo.cursor.Cursor`. Should not be called directly by application developers - see the :class:`~gridfs.GridFS` method :meth:`~gridfs.GridFS.find` instead. .. versionadded 2.7 .. mongodoc:: cursors """ _disallow_transactions(session) collection = _clear_entity_type_registry(collection) # Hold on to the base "fs" collection to create GridOut objects later. self.__root_collection = collection super(GridOutCursor, self).__init__( collection.files, filter, skip=skip, limit=limit, no_cursor_timeout=no_cursor_timeout, sort=sort, batch_size=batch_size, session=session) def next(self): """Get next GridOut object from cursor. """ _disallow_transactions(self.session) # Work around "super is not iterable" issue in Python 3.x next_file = super(GridOutCursor, self).next() return GridOut(self.__root_collection, file_document=next_file, session=self.session) __next__ = next def add_option(self, *args, **kwargs): raise NotImplementedError("Method does not exist for GridOutCursor") def remove_option(self, *args, **kwargs): raise NotImplementedError("Method does not exist for GridOutCursor") def _clone_base(self, session): """Creates an empty GridOutCursor for information to be copied into. """ return GridOutCursor(self.__root_collection, session=session) pymongo-3.11.0/pymongo.egg-info/000077500000000000000000000000001374256237000164575ustar00rootroot00000000000000pymongo-3.11.0/pymongo.egg-info/PKG-INFO000066400000000000000000000251201374256237000175540ustar00rootroot00000000000000Metadata-Version: 2.1 Name: pymongo Version: 3.11.0 Summary: Python driver for MongoDB Home-page: http://github.com/mongodb/mongo-python-driver Author: Mike Dirolf Author-email: mongodb-user@googlegroups.com Maintainer: Bernie Hackett Maintainer-email: bernie@mongodb.com License: Apache License, Version 2.0 Description: ======= PyMongo ======= :Info: See `the mongo site `_ for more information. See `GitHub `_ for the latest source. :Documentation: Available at `pymongo.readthedocs.io `_ :Author: Mike Dirolf :Maintainer: Bernie Hackett About ===== The PyMongo distribution contains tools for interacting with MongoDB database from Python. The ``bson`` package is an implementation of the `BSON format `_ for Python. The ``pymongo`` package is a native Python driver for MongoDB. The ``gridfs`` package is a `gridfs `_ implementation on top of ``pymongo``. PyMongo supports MongoDB 2.6, 3.0, 3.2, 3.4, 3.6, 4.0, 4.2, and 4.4. Support / Feedback ================== For issues with, questions about, or feedback for PyMongo, please look into our `support channels `_. Please do not email any of the PyMongo developers directly with issues or questions - you're more likely to get an answer on the `MongoDB Community Forums `_. Bugs / Feature Requests ======================= Think you’ve found a bug? Want to see a new feature in PyMongo? Please open a case in our issue management tool, JIRA: - `Create an account and login `_. - Navigate to `the PYTHON project `_. - Click **Create Issue** - Please provide as much information as possible about the issue type and how to reproduce it. Bug reports in JIRA for all driver projects (i.e. PYTHON, CSHARP, JAVA) and the Core Server (i.e. SERVER) project are **public**. How To Ask For Help ------------------- Please include all of the following information when opening an issue: - Detailed steps to reproduce the problem, including full traceback, if possible. - The exact python version used, with patch level:: $ python -c "import sys; print(sys.version)" - The exact version of PyMongo used, with patch level:: $ python -c "import pymongo; print(pymongo.version); print(pymongo.has_c())" - The operating system and version (e.g. Windows 7, OSX 10.8, ...) - Web framework or asynchronous network library used, if any, with version (e.g. Django 1.7, mod_wsgi 4.3.0, gevent 1.0.1, Tornado 4.0.2, ...) Security Vulnerabilities ------------------------ If you’ve identified a security vulnerability in a driver or any other MongoDB project, please report it according to the `instructions here `_. Installation ============ PyMongo can be installed with `pip `_:: $ python -m pip install pymongo Or ``easy_install`` from `setuptools `_:: $ python -m easy_install pymongo You can also download the project source and do:: $ python setup.py install Do **not** install the "bson" package from pypi. PyMongo comes with its own bson package; doing "easy_install bson" installs a third-party package that is incompatible with PyMongo. Dependencies ============ PyMongo supports CPython 2.7, 3.4+, PyPy, and PyPy3.5+. Optional dependencies: GSSAPI authentication requires `pykerberos `_ on Unix or `WinKerberos `_ on Windows. The correct dependency can be installed automatically along with PyMongo:: $ python -m pip install pymongo[gssapi] MONGODB-AWS authentication requires `pymongo-auth-aws `_:: $ python -m pip install pymongo[aws] Support for mongodb+srv:// URIs requires `dnspython `_:: $ python -m pip install pymongo[srv] TLS / SSL support may require `ipaddress `_ and `certifi `_ or `wincertstore `_ depending on the Python version in use. The necessary dependencies can be installed along with PyMongo:: $ python -m pip install pymongo[tls] .. note:: Users of Python versions older than 2.7.9 will also receive the dependencies for OCSP when using the tls extra. OCSP (Online Certificate Status Protocol) requires `PyOpenSSL `_, `requests `_ and `service_identity `_:: $ python -m pip install pymongo[ocsp] Wire protocol compression with snappy requires `python-snappy `_:: $ python -m pip install pymongo[snappy] Wire protocol compression with zstandard requires `zstandard `_:: $ python -m pip install pymongo[zstd] Client-Side Field Level Encryption requires `pymongocrypt `_:: $ python -m pip install pymongo[encryption] You can install all dependencies automatically with the following command:: $ python -m pip install pymongo[gssapi,aws,ocsp,snappy,srv,tls,zstd,encryption] Other optional packages: - `backports.pbkdf2 `_, improves authentication performance with SCRAM-SHA-1 and SCRAM-SHA-256. It especially improves performance on Python versions older than 2.7.8. - `monotonic `_ adds support for a monotonic clock, which improves reliability in environments where clock adjustments are frequent. Not needed in Python 3. Additional dependencies are: - (to generate documentation) sphinx_ Examples ======== Here's a basic example (for more see the *examples* section of the docs): .. code-block:: python >>> import pymongo >>> client = pymongo.MongoClient("localhost", 27017) >>> db = client.test >>> db.name u'test' >>> db.my_collection Collection(Database(MongoClient('localhost', 27017), u'test'), u'my_collection') >>> db.my_collection.insert_one({"x": 10}).inserted_id ObjectId('4aba15ebe23f6b53b0000000') >>> db.my_collection.insert_one({"x": 8}).inserted_id ObjectId('4aba160ee23f6b543e000000') >>> db.my_collection.insert_one({"x": 11}).inserted_id ObjectId('4aba160ee23f6b543e000002') >>> db.my_collection.find_one() {u'x': 10, u'_id': ObjectId('4aba15ebe23f6b53b0000000')} >>> for item in db.my_collection.find(): ... print(item["x"]) ... 10 8 11 >>> db.my_collection.create_index("x") u'x_1' >>> for item in db.my_collection.find().sort("x", pymongo.ASCENDING): ... print(item["x"]) ... 8 10 11 >>> [item["x"] for item in db.my_collection.find().limit(2).skip(1)] [8, 11] Documentation ============= Documentation is available at `pymongo.readthedocs.io `_. To build the documentation, you will need to install sphinx_. Documentation can be generated by running **python setup.py doc**. Generated documentation can be found in the *doc/build/html/* directory. Testing ======= The easiest way to run the tests is to run **python setup.py test** in the root of the distribution. To verify that PyMongo works with Gevent's monkey-patching:: $ python green_framework_test.py gevent Or with Eventlet's:: $ python green_framework_test.py eventlet .. _sphinx: http://sphinx.pocoo.org/ Keywords: mongo,mongodb,pymongo,gridfs,bson Platform: UNKNOWN Classifier: Development Status :: 5 - Production/Stable Classifier: Intended Audience :: Developers Classifier: License :: OSI Approved :: Apache Software License Classifier: Operating System :: MacOS :: MacOS X Classifier: Operating System :: Microsoft :: Windows Classifier: Operating System :: POSIX Classifier: Programming Language :: Python :: 2 Classifier: Programming Language :: Python :: 2.7 Classifier: Programming Language :: Python :: 3 Classifier: Programming Language :: Python :: 3.4 Classifier: Programming Language :: Python :: 3.5 Classifier: Programming Language :: Python :: 3.6 Classifier: Programming Language :: Python :: 3.7 Classifier: Programming Language :: Python :: 3.8 Classifier: Programming Language :: Python :: Implementation :: CPython Classifier: Programming Language :: Python :: Implementation :: PyPy Classifier: Topic :: Database Requires-Python: >=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.* Provides-Extra: tls Provides-Extra: encryption Provides-Extra: aws Provides-Extra: gssapi Provides-Extra: snappy Provides-Extra: srv Provides-Extra: zstd Provides-Extra: ocsp pymongo-3.11.0/pymongo.egg-info/SOURCES.txt000066400000000000000000000157151374256237000203540ustar00rootroot00000000000000LICENSE MANIFEST.in README.rst THIRD-PARTY-NOTICES ez_setup.py setup.py bson/__init__.py bson/_cbsonmodule.c bson/_cbsonmodule.h bson/binary.py bson/bson-endian.h bson/bson-stdint-win32.h bson/buffer.c bson/buffer.h bson/code.py bson/codec_options.py bson/dbref.py bson/decimal128.py bson/encoding_helpers.c bson/encoding_helpers.h bson/errors.py bson/int64.py bson/json_util.py bson/max_key.py bson/min_key.py bson/objectid.py bson/py3compat.py bson/raw_bson.py bson/regex.py bson/son.py bson/time64.c bson/time64.h bson/time64_config.h bson/time64_limits.h bson/timestamp.py bson/tz_util.py doc/__init__.py doc/atlas.rst doc/changelog.rst doc/compatibility-policy.rst doc/conf.py doc/contributors.rst doc/faq.rst doc/index.rst doc/installation.rst doc/migrate-to-pymongo3.rst doc/mongo_extensions.py doc/python3.rst doc/tools.rst doc/tutorial.rst doc/api/index.rst doc/api/bson/binary.rst doc/api/bson/code.rst doc/api/bson/codec_options.rst doc/api/bson/dbref.rst doc/api/bson/decimal128.rst doc/api/bson/errors.rst doc/api/bson/index.rst doc/api/bson/int64.rst doc/api/bson/json_util.rst doc/api/bson/max_key.rst doc/api/bson/min_key.rst doc/api/bson/objectid.rst doc/api/bson/raw_bson.rst doc/api/bson/regex.rst doc/api/bson/son.rst doc/api/bson/timestamp.rst doc/api/bson/tz_util.rst doc/api/gridfs/errors.rst doc/api/gridfs/grid_file.rst doc/api/gridfs/index.rst doc/api/pymongo/bulk.rst doc/api/pymongo/change_stream.rst doc/api/pymongo/client_session.rst doc/api/pymongo/collation.rst doc/api/pymongo/collection.rst doc/api/pymongo/command_cursor.rst doc/api/pymongo/cursor.rst doc/api/pymongo/cursor_manager.rst doc/api/pymongo/database.rst doc/api/pymongo/driver_info.rst doc/api/pymongo/encryption.rst doc/api/pymongo/encryption_options.rst doc/api/pymongo/errors.rst doc/api/pymongo/event_loggers.rst doc/api/pymongo/index.rst doc/api/pymongo/ismaster.rst doc/api/pymongo/message.rst doc/api/pymongo/mongo_client.rst doc/api/pymongo/mongo_replica_set_client.rst doc/api/pymongo/monitoring.rst doc/api/pymongo/operations.rst doc/api/pymongo/pool.rst doc/api/pymongo/read_concern.rst doc/api/pymongo/read_preferences.rst doc/api/pymongo/results.rst doc/api/pymongo/server_description.rst doc/api/pymongo/son_manipulator.rst doc/api/pymongo/topology_description.rst doc/api/pymongo/uri_parser.rst doc/api/pymongo/write_concern.rst doc/developer/index.rst doc/developer/periodic_executor.rst doc/examples/aggregation.rst doc/examples/authentication.rst doc/examples/bulk.rst doc/examples/collations.rst doc/examples/copydb.rst doc/examples/custom_type.rst doc/examples/datetimes.rst doc/examples/encryption.rst doc/examples/geo.rst doc/examples/gevent.rst doc/examples/gridfs.rst doc/examples/high_availability.rst doc/examples/index.rst doc/examples/mod_wsgi.rst doc/examples/server_selection.rst doc/examples/tailable.rst doc/examples/tls.rst doc/examples/uuid.rst doc/pydoctheme/theme.conf doc/pydoctheme/static/pydoctheme.css doc/static/delighted.js doc/static/periodic-executor-refs.png doc/static/sidebar.js gridfs/__init__.py gridfs/errors.py gridfs/grid_file.py pymongo/__init__.py pymongo/_cmessagemodule.c pymongo/aggregation.py pymongo/auth.py pymongo/auth_aws.py pymongo/bulk.py pymongo/change_stream.py pymongo/client_options.py pymongo/client_session.py pymongo/collation.py pymongo/collection.py pymongo/command_cursor.py pymongo/common.py pymongo/compression_support.py pymongo/cursor.py pymongo/cursor_manager.py pymongo/daemon.py pymongo/database.py pymongo/driver_info.py pymongo/encryption.py pymongo/encryption_options.py pymongo/errors.py pymongo/event_loggers.py pymongo/helpers.py pymongo/ismaster.py pymongo/max_staleness_selectors.py pymongo/message.py pymongo/mongo_client.py pymongo/mongo_replica_set_client.py pymongo/monitor.py pymongo/monitoring.py pymongo/monotonic.py pymongo/network.py pymongo/ocsp_cache.py pymongo/ocsp_support.py pymongo/operations.py pymongo/periodic_executor.py pymongo/pool.py pymongo/pyopenssl_context.py pymongo/read_concern.py pymongo/read_preferences.py pymongo/response.py pymongo/results.py pymongo/saslprep.py pymongo/server.py pymongo/server_description.py pymongo/server_selectors.py pymongo/server_type.py pymongo/settings.py pymongo/socket_checker.py pymongo/son_manipulator.py pymongo/srv_resolver.py pymongo/ssl_context.py pymongo/ssl_match_hostname.py pymongo/ssl_support.py pymongo/thread_util.py pymongo/topology.py pymongo/topology_description.py pymongo/uri_parser.py pymongo/write_concern.py pymongo.egg-info/PKG-INFO pymongo.egg-info/SOURCES.txt pymongo.egg-info/dependency_links.txt pymongo.egg-info/requires.txt pymongo.egg-info/top_level.txt test/__init__.py test/barrier.py test/pymongo_mocks.py test/qcheck.py test/test_auth.py test/test_auth_spec.py test/test_binary.py test/test_bson.py test/test_bson_corpus.py test/test_bulk.py test/test_change_stream.py test/test_client.py test/test_client_context.py test/test_cmap.py test/test_code.py test/test_collation.py test/test_collection.py test/test_command_monitoring_spec.py test/test_common.py test/test_connections_survive_primary_stepdown_spec.py test/test_crud_v1.py test/test_crud_v2.py test/test_cursor.py test/test_cursor_manager.py test/test_custom_types.py test/test_database.py test/test_dbref.py test/test_decimal128.py test/test_discovery_and_monitoring.py test/test_dns.py test/test_encryption.py test/test_errors.py test/test_examples.py test/test_grid_file.py test/test_gridfs.py test/test_gridfs_bucket.py test/test_gridfs_spec.py test/test_heartbeat_monitoring.py test/test_json_util.py test/test_legacy_api.py test/test_max_staleness.py test/test_mongos_load_balancing.py test/test_monitor.py test/test_monitoring.py test/test_monotonic.py test/test_objectid.py test/test_ocsp_cache.py test/test_pooling.py test/test_pymongo.py test/test_raw_bson.py test/test_read_concern.py test/test_read_preferences.py test/test_read_write_concern_spec.py test/test_replica_set_client.py test/test_replica_set_reconfig.py test/test_retryable_reads.py test/test_retryable_writes.py test/test_saslprep.py test/test_sdam_monitoring_spec.py test/test_server.py test/test_server_description.py test/test_server_selection.py test/test_server_selection_rtt.py test/test_session.py test/test_son.py test/test_son_manipulator.py test/test_srv_polling.py test/test_ssl.py test/test_streaming_protocol.py test/test_threads.py test/test_timestamp.py test/test_topology.py test/test_transactions.py test/test_uri_parser.py test/test_uri_spec.py test/test_write_concern.py test/utils.py test/utils_selection_tests.py test/utils_spec_runner.py test/version.py test/atlas/test_connection.py test/auth_aws/test_auth_aws.py test/certificates/ca.pem test/certificates/client.pem test/certificates/crl.pem test/certificates/password_protected.pem test/certificates/server.pem test/certificates/trusted-ca.pem test/mod_wsgi_test/test_client.py test/ocsp/test_ocsp.py test/performance/perf_test.py test/unicode/test_utf8.py test/uri_options/ca.pem test/uri_options/cert.pem test/uri_options/client.pem tools/README.rst tools/benchmark.py tools/clean.py tools/fail_if_no_c.py tools/ocsptest.pypymongo-3.11.0/pymongo.egg-info/dependency_links.txt000066400000000000000000000000011374256237000225250ustar00rootroot00000000000000 pymongo-3.11.0/pymongo.egg-info/requires.txt000066400000000000000000000003621374256237000210600ustar00rootroot00000000000000 [aws] pymongo-auth-aws<2.0.0 [encryption] pymongocrypt<2.0.0 [gssapi] pykerberos [ocsp] pyopenssl>=17.2.0 requests<3.0.0 service_identity>=18.1.0 [snappy] python-snappy [srv] dnspython<1.17.0,>=1.16.0 [tls] ipaddress [zstd] zstandard pymongo-3.11.0/pymongo.egg-info/top_level.txt000066400000000000000000000000241374256237000212050ustar00rootroot00000000000000bson gridfs pymongo pymongo-3.11.0/pymongo/000077500000000000000000000000001374256237000147655ustar00rootroot00000000000000pymongo-3.11.0/pymongo/__init__.py000066400000000000000000000063361374256237000171060ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Python driver for MongoDB.""" ASCENDING = 1 """Ascending sort order.""" DESCENDING = -1 """Descending sort order.""" GEO2D = "2d" """Index specifier for a 2-dimensional `geospatial index`_. .. _geospatial index: http://docs.mongodb.org/manual/core/2d/ """ GEOHAYSTACK = "geoHaystack" """**DEPRECATED** - Index specifier for a 2-dimensional `haystack index`_. **DEPRECATED** - :attr:`GEOHAYSTACK` is deprecated and will be removed in PyMongo 4.0. geoHaystack indexes (and the geoSearch command) were deprecated in MongoDB 4.4. Instead, create a 2d index and use $geoNear or $geoWithin. See https://dochub.mongodb.org/core/4.4-deprecate-geoHaystack. .. versionchanged:: 3.11 Deprecated. .. _haystack index: http://docs.mongodb.org/manual/core/geohaystack/ """ GEOSPHERE = "2dsphere" """Index specifier for a `spherical geospatial index`_. .. versionadded:: 2.5 .. _spherical geospatial index: http://docs.mongodb.org/manual/core/2dsphere/ """ HASHED = "hashed" """Index specifier for a `hashed index`_. .. versionadded:: 2.5 .. _hashed index: http://docs.mongodb.org/manual/core/index-hashed/ """ TEXT = "text" """Index specifier for a `text index`_. .. seealso:: MongoDB's `Atlas Search `_ which offers more advanced text search functionality. .. versionadded:: 2.7.1 .. _text index: http://docs.mongodb.org/manual/core/index-text/ """ OFF = 0 """No database profiling.""" SLOW_ONLY = 1 """Only profile slow operations.""" ALL = 2 """Profile all operations.""" version_tuple = (3, 11, 0) def get_version_string(): if isinstance(version_tuple[-1], str): return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] return '.'.join(map(str, version_tuple)) __version__ = version = get_version_string() """Current version of PyMongo.""" from pymongo.collection import ReturnDocument from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION, MAX_SUPPORTED_WIRE_VERSION) from pymongo.cursor import CursorType from pymongo.mongo_client import MongoClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.operations import (IndexModel, InsertOne, DeleteOne, DeleteMany, UpdateOne, UpdateMany, ReplaceOne) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern def has_c(): """Is the C extension installed?""" try: from pymongo import _cmessage return True except ImportError: return False pymongo-3.11.0/pymongo/_cmessagemodule.c000066400000000000000000001625421374256237000202770ustar00rootroot00000000000000/* * Copyright 2009-present MongoDB, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ /* * This file contains C implementations of some of the functions * needed by the message module. If possible, these implementations * should be used to speed up message creation. */ #define PY_SSIZE_T_CLEAN #include "Python.h" #include "_cbsonmodule.h" #include "buffer.h" struct module_state { PyObject* _cbson; }; /* See comments about module initialization in _cbsonmodule.c */ #if PY_MAJOR_VERSION >= 3 #define GETSTATE(m) ((struct module_state*)PyModule_GetState(m)) #else #define GETSTATE(m) (&_state) static struct module_state _state; #endif #define DOC_TOO_LARGE_FMT "BSON document too large (%d bytes)" \ " - the connected server supports" \ " BSON document sizes up to %ld bytes." /* Get an error class from the pymongo.errors module. * * Returns a new ref */ static PyObject* _error(char* name) { PyObject* error; PyObject* errors = PyImport_ImportModule("pymongo.errors"); if (!errors) { return NULL; } error = PyObject_GetAttrString(errors, name); Py_DECREF(errors); return error; } /* The same as buffer_write_bytes except that it also validates * "size" will fit in an int. * Returns 0 on failure */ static int buffer_write_bytes_ssize_t(buffer_t buffer, const char* data, Py_ssize_t size) { int downsize = _downcast_and_check(size, 0); if (size == -1) { return 0; } return buffer_write_bytes(buffer, data, downsize); } /* add a lastError message on the end of the buffer. * returns 0 on failure */ static int add_last_error(PyObject* self, buffer_t buffer, int request_id, char* ns, Py_ssize_t nslen, codec_options_t* options, PyObject* args) { struct module_state *state = GETSTATE(self); int message_start; int document_start; int message_length; int document_length; PyObject* key = NULL; PyObject* value = NULL; Py_ssize_t pos = 0; PyObject* one; char *p = strchr(ns, '.'); /* Length of the database portion of ns. */ nslen = p ? (int)(p - ns) : nslen; message_start = buffer_save_space(buffer, 4); if (message_start == -1) { return 0; } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00" /* responseTo */ "\xd4\x07\x00\x00" /* opcode */ "\x00\x00\x00\x00", /* options */ 12) || !buffer_write_bytes_ssize_t(buffer, ns, nslen) || /* database */ !buffer_write_bytes(buffer, ".$cmd\x00" /* collection name */ "\x00\x00\x00\x00" /* skip */ "\xFF\xFF\xFF\xFF", /* limit (-1) */ 14)) { return 0; } /* save space for length */ document_start = buffer_save_space(buffer, 4); if (document_start == -1) { return 0; } /* getlasterror: 1 */ if (!(one = PyLong_FromLong(1))) return 0; if (!write_pair(state->_cbson, buffer, "getlasterror", 12, one, 0, options, 1)) { Py_DECREF(one); return 0; } Py_DECREF(one); /* getlasterror options */ while (PyDict_Next(args, &pos, &key, &value)) { if (!decode_and_write_pair(state->_cbson, buffer, key, value, 0, options, 0)) { return 0; } } /* EOD */ if (!buffer_write_bytes(buffer, "\x00", 1)) { return 0; } message_length = buffer_get_position(buffer) - message_start; document_length = buffer_get_position(buffer) - document_start; buffer_write_int32_at_position( buffer, message_start, (int32_t)message_length); buffer_write_int32_at_position( buffer, document_start, (int32_t)document_length); return 1; } static int init_insert_buffer(buffer_t buffer, int request_id, int options, const char* coll_name, Py_ssize_t coll_name_len, int compress) { int length_location = 0; if (!compress) { /* Save space for message length */ int length_location = buffer_save_space(buffer, 4); if (length_location == -1) { return length_location; } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00" "\xd2\x07\x00\x00", 8)) { return -1; } } if (!buffer_write_int32(buffer, (int32_t)options) || !buffer_write_bytes_ssize_t(buffer, coll_name, coll_name_len + 1)) { return -1; } return length_location; } static PyObject* _cbson_insert_message(PyObject* self, PyObject* args) { /* Used by the Bulk API to insert into pre-2.6 servers. Collection.insert * uses _cbson_do_batched_insert. */ struct module_state *state = GETSTATE(self); /* NOTE just using a random number as the request_id */ int request_id = rand(); char* collection_name = NULL; Py_ssize_t collection_name_length; PyObject* docs; PyObject* doc; PyObject* iterator; int before, cur_size, max_size = 0; int flags = 0; unsigned char check_keys; unsigned char safe; unsigned char continue_on_error; codec_options_t options; PyObject* last_error_args; buffer_t buffer = NULL; int length_location, message_length; PyObject* result = NULL; if (!PyArg_ParseTuple(args, "et#ObbObO&", "utf-8", &collection_name, &collection_name_length, &docs, &check_keys, &safe, &last_error_args, &continue_on_error, convert_codec_options, &options)) { return NULL; } if (continue_on_error) { flags += 1; } buffer = buffer_new(); if (!buffer) { goto fail; } length_location = init_insert_buffer(buffer, request_id, flags, collection_name, collection_name_length, 0); if (length_location == -1) { goto fail; } iterator = PyObject_GetIter(docs); if (iterator == NULL) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "input is not iterable"); Py_DECREF(InvalidOperation); } goto fail; } while ((doc = PyIter_Next(iterator)) != NULL) { before = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, doc, check_keys, &options, 1)) { Py_DECREF(doc); Py_DECREF(iterator); goto fail; } Py_DECREF(doc); cur_size = buffer_get_position(buffer) - before; max_size = (cur_size > max_size) ? cur_size : max_size; } Py_DECREF(iterator); if (PyErr_Occurred()) { goto fail; } if (!max_size) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "cannot do an empty bulk insert"); Py_DECREF(InvalidOperation); } goto fail; } message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); if (safe) { if (!add_last_error(self, buffer, request_id, collection_name, collection_name_length, &options, last_error_args)) { goto fail; } } /* objectify buffer */ result = Py_BuildValue("i" BYTES_FORMAT_STRING "i", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), max_size); fail: PyMem_Free(collection_name); destroy_codec_options(&options); if (buffer) { buffer_free(buffer); } return result; } static PyObject* _cbson_update_message(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ struct module_state *state = GETSTATE(self); int request_id = rand(); char* collection_name = NULL; Py_ssize_t collection_name_length; int before, cur_size, max_size = 0; PyObject* doc; PyObject* spec; unsigned char multi; unsigned char upsert; unsigned char safe; unsigned char check_keys; codec_options_t options; PyObject* last_error_args; int flags; buffer_t buffer = NULL; int length_location, message_length; PyObject* result = NULL; if (!PyArg_ParseTuple(args, "et#bbOObObO&", "utf-8", &collection_name, &collection_name_length, &upsert, &multi, &spec, &doc, &safe, &last_error_args, &check_keys, convert_codec_options, &options)) { return NULL; } flags = 0; if (upsert) { flags += 1; } if (multi) { flags += 2; } buffer = buffer_new(); if (!buffer) { goto fail; } // save space for message length length_location = buffer_save_space(buffer, 4); if (length_location == -1) { goto fail; } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00" "\xd1\x07\x00\x00" "\x00\x00\x00\x00", 12) || !buffer_write_bytes_ssize_t(buffer, collection_name, collection_name_length + 1) || !buffer_write_int32(buffer, (int32_t)flags)) { goto fail; } before = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, spec, 0, &options, 1)) { goto fail; } max_size = buffer_get_position(buffer) - before; before = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, doc, check_keys, &options, 1)) { goto fail; } cur_size = buffer_get_position(buffer) - before; max_size = (cur_size > max_size) ? cur_size : max_size; message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); if (safe) { if (!add_last_error(self, buffer, request_id, collection_name, collection_name_length, &options, last_error_args)) { goto fail; } } /* objectify buffer */ result = Py_BuildValue("i" BYTES_FORMAT_STRING "i", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), max_size); fail: PyMem_Free(collection_name); destroy_codec_options(&options); if (buffer) { buffer_free(buffer); } return result; } static PyObject* _cbson_query_message(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ struct module_state *state = GETSTATE(self); int request_id = rand(); PyObject* cluster_time = NULL; unsigned int flags; char* collection_name = NULL; Py_ssize_t collection_name_length; int begin, cur_size, max_size = 0; int num_to_skip; int num_to_return; PyObject* query; PyObject* field_selector; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; unsigned char check_keys = 0; PyObject* result = NULL; if (!PyArg_ParseTuple(args, "Iet#iiOOO&|b", &flags, "utf-8", &collection_name, &collection_name_length, &num_to_skip, &num_to_return, &query, &field_selector, convert_codec_options, &options, &check_keys)) { return NULL; } buffer = buffer_new(); if (!buffer) { goto fail; } // save space for message length length_location = buffer_save_space(buffer, 4); if (length_location == -1) { goto fail; } /* Pop $clusterTime from dict and write it at the end, avoiding an error * from the $-prefix and check_keys. * * If "dict" is a defaultdict we don't want to call PyMapping_GetItemString * on it. That would **create** an _id where one didn't previously exist * (PYTHON-871). */ if (PyDict_Check(query)) { cluster_time = PyDict_GetItemString(query, "$clusterTime"); if (cluster_time) { /* PyDict_GetItemString returns a borrowed reference. */ Py_INCREF(cluster_time); if (-1 == PyMapping_DelItemString(query, "$clusterTime")) { goto fail; } } } else if (PyMapping_HasKeyString(query, "$clusterTime")) { cluster_time = PyMapping_GetItemString(query, "$clusterTime"); if (!cluster_time || -1 == PyMapping_DelItemString(query, "$clusterTime")) { goto fail; } } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00\xd4\x07\x00\x00", 8) || !buffer_write_int32(buffer, (int32_t)flags) || !buffer_write_bytes_ssize_t(buffer, collection_name, collection_name_length + 1) || !buffer_write_int32(buffer, (int32_t)num_to_skip) || !buffer_write_int32(buffer, (int32_t)num_to_return)) { goto fail; } begin = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, query, check_keys, &options, 1)) { goto fail; } /* back up a byte and write $clusterTime */ if (cluster_time) { int length; char zero = 0; buffer_update_position(buffer, buffer_get_position(buffer) - 1); if (!write_pair(state->_cbson, buffer, "$clusterTime", 12, cluster_time, 0, &options, 1)) { goto fail; } if (!buffer_write_bytes(buffer, &zero, 1)) { goto fail; } length = buffer_get_position(buffer) - begin; buffer_write_int32_at_position(buffer, begin, (int32_t)length); /* undo popping $clusterTime */ if (-1 == PyMapping_SetItemString( query, "$clusterTime", cluster_time)) { goto fail; } Py_CLEAR(cluster_time); } max_size = buffer_get_position(buffer) - begin; if (field_selector != Py_None) { begin = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, field_selector, 0, &options, 1)) { goto fail; } cur_size = buffer_get_position(buffer) - begin; max_size = (cur_size > max_size) ? cur_size : max_size; } message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); /* objectify buffer */ result = Py_BuildValue("i" BYTES_FORMAT_STRING "i", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), max_size); fail: PyMem_Free(collection_name); destroy_codec_options(&options); if (buffer) { buffer_free(buffer); } Py_XDECREF(cluster_time); return result; } static PyObject* _cbson_get_more_message(PyObject* self, PyObject* args) { /* NOTE just using a random number as the request_id */ int request_id = rand(); char* collection_name = NULL; Py_ssize_t collection_name_length; int num_to_return; long long cursor_id; buffer_t buffer = NULL; int length_location, message_length; PyObject* result = NULL; if (!PyArg_ParseTuple(args, "et#iL", "utf-8", &collection_name, &collection_name_length, &num_to_return, &cursor_id)) { return NULL; } buffer = buffer_new(); if (!buffer) { goto fail; } // save space for message length length_location = buffer_save_space(buffer, 4); if (length_location == -1) { goto fail; } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00" "\xd5\x07\x00\x00" "\x00\x00\x00\x00", 12) || !buffer_write_bytes_ssize_t(buffer, collection_name, collection_name_length + 1) || !buffer_write_int32(buffer, (int32_t)num_to_return) || !buffer_write_int64(buffer, (int64_t)cursor_id)) { goto fail; } message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); /* objectify buffer */ result = Py_BuildValue("i" BYTES_FORMAT_STRING, request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer)); fail: PyMem_Free(collection_name); if (buffer) { buffer_free(buffer); } return result; } /* * NOTE this method handles multiple documents in a type one payload but * it does not perform batch splitting and the total message size is * only checked *after* generating the entire message. */ static PyObject* _cbson_op_msg(PyObject* self, PyObject* args) { struct module_state *state = GETSTATE(self); /* NOTE just using a random number as the request_id */ int request_id = rand(); unsigned int flags; PyObject* command; char* identifier = NULL; Py_ssize_t identifier_length = 0; PyObject* docs; PyObject* doc; unsigned char check_keys = 0; codec_options_t options; buffer_t buffer = NULL; int length_location, message_length; int total_size = 0; int max_doc_size = 0; PyObject* result = NULL; PyObject* iterator = NULL; /*flags, command, identifier, docs, check_keys, opts*/ if (!PyArg_ParseTuple(args, "IOet#ObO&", &flags, &command, "utf-8", &identifier, &identifier_length, &docs, &check_keys, convert_codec_options, &options)) { return NULL; } buffer = buffer_new(); if (!buffer) { goto fail; } // save space for message length length_location = buffer_save_space(buffer, 4); if (length_location == -1) { goto fail; } if (!buffer_write_int32(buffer, (int32_t)request_id) || !buffer_write_bytes(buffer, "\x00\x00\x00\x00" /* responseTo */ "\xdd\x07\x00\x00" /* 2013 */, 8)) { goto fail; } if (!buffer_write_int32(buffer, (int32_t)flags) || !buffer_write_bytes(buffer, "\x00", 1) /* Payload type 0 */) { goto fail; } total_size = write_dict(state->_cbson, buffer, command, 0, &options, 1); if (!total_size) { goto fail; } if (identifier_length) { int payload_one_length_location, payload_length; /* Payload type 1 */ if (!buffer_write_bytes(buffer, "\x01", 1)) { goto fail; } /* save space for payload 0 length */ payload_one_length_location = buffer_save_space(buffer, 4); /* C string identifier */ if (!buffer_write_bytes_ssize_t(buffer, identifier, identifier_length + 1)) { goto fail; } iterator = PyObject_GetIter(docs); if (iterator == NULL) { goto fail; } while ((doc = PyIter_Next(iterator)) != NULL) { int encoded_doc_size = write_dict( state->_cbson, buffer, doc, check_keys, &options, 1); if (!encoded_doc_size) { Py_CLEAR(doc); goto fail; } if (encoded_doc_size > max_doc_size) { max_doc_size = encoded_doc_size; } Py_CLEAR(doc); } payload_length = buffer_get_position(buffer) - payload_one_length_location; buffer_write_int32_at_position( buffer, payload_one_length_location, (int32_t)payload_length); total_size += payload_length; } message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); /* objectify buffer */ result = Py_BuildValue("i" BYTES_FORMAT_STRING "ii", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), total_size, max_doc_size); fail: Py_XDECREF(iterator); if (buffer) { buffer_free(buffer); } PyMem_Free(identifier); destroy_codec_options(&options); return result; } static void _set_document_too_large(int size, long max) { PyObject* DocumentTooLarge = _error("DocumentTooLarge"); if (DocumentTooLarge) { #if PY_MAJOR_VERSION >= 3 PyObject* error = PyUnicode_FromFormat(DOC_TOO_LARGE_FMT, size, max); #else PyObject* error = PyString_FromFormat(DOC_TOO_LARGE_FMT, size, max); #endif if (error) { PyErr_SetObject(DocumentTooLarge, error); Py_DECREF(error); } Py_DECREF(DocumentTooLarge); } } static PyObject* _send_insert(PyObject* self, PyObject* ctx, PyObject* gle_args, buffer_t buffer, char* coll_name, Py_ssize_t coll_len, int request_id, int safe, codec_options_t* options, PyObject* to_publish, int compress) { if (safe) { if (!add_last_error(self, buffer, request_id, coll_name, coll_len, options, gle_args)) { return NULL; } } /* The max_doc_size parameter for legacy_bulk_insert is the max size of * any document in buffer. We enforced max size already, pass 0 here. */ return PyObject_CallMethod(ctx, "legacy_bulk_insert", "i" BYTES_FORMAT_STRING "iNOi", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), 0, PyBool_FromLong((long)safe), to_publish, compress); } static PyObject* _cbson_do_batched_insert(PyObject* self, PyObject* args) { struct module_state *state = GETSTATE(self); /* NOTE just using a random number as the request_id */ int request_id = rand(); int send_safe, flags = 0; int length_location, message_length; Py_ssize_t collection_name_length; int compress; char* collection_name = NULL; PyObject* docs; PyObject* doc; PyObject* iterator; PyObject* ctx; PyObject* last_error_args; PyObject* result; PyObject* max_bson_size_obj; PyObject* max_message_size_obj; PyObject* compress_obj; PyObject* to_publish = NULL; unsigned char check_keys; unsigned char safe; unsigned char continue_on_error; codec_options_t options; unsigned char empty = 1; long max_bson_size; long max_message_size; buffer_t buffer; PyObject *exc_type = NULL, *exc_value = NULL, *exc_trace = NULL; if (!PyArg_ParseTuple(args, "et#ObbObO&O", "utf-8", &collection_name, &collection_name_length, &docs, &check_keys, &safe, &last_error_args, &continue_on_error, convert_codec_options, &options, &ctx)) { return NULL; } if (continue_on_error) { flags += 1; } /* * If we are doing unacknowledged writes *and* continue_on_error * is True it's pointless (and slower) to send GLE. */ send_safe = (safe || !continue_on_error); max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); #if PY_MAJOR_VERSION >= 3 max_bson_size = PyLong_AsLong(max_bson_size_obj); #else max_bson_size = PyInt_AsLong(max_bson_size_obj); #endif Py_XDECREF(max_bson_size_obj); if (max_bson_size == -1) { destroy_codec_options(&options); PyMem_Free(collection_name); return NULL; } max_message_size_obj = PyObject_GetAttrString(ctx, "max_message_size"); #if PY_MAJOR_VERSION >= 3 max_message_size = PyLong_AsLong(max_message_size_obj); #else max_message_size = PyInt_AsLong(max_message_size_obj); #endif Py_XDECREF(max_message_size_obj); if (max_message_size == -1) { destroy_codec_options(&options); PyMem_Free(collection_name); return NULL; } compress_obj = PyObject_GetAttrString(ctx, "compress"); compress = PyObject_IsTrue(compress_obj); Py_XDECREF(compress_obj); if (compress == -1) { destroy_codec_options(&options); PyMem_Free(collection_name); return NULL; } compress = compress && !(safe || send_safe); buffer = buffer_new(); if (!buffer) { destroy_codec_options(&options); PyMem_Free(collection_name); return NULL; } length_location = init_insert_buffer(buffer, request_id, flags, collection_name, collection_name_length, compress); if (length_location == -1) { goto insertfail; } if (!(to_publish = PyList_New(0))) { goto insertfail; } iterator = PyObject_GetIter(docs); if (iterator == NULL) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "input is not iterable"); Py_DECREF(InvalidOperation); } goto insertfail; } while ((doc = PyIter_Next(iterator)) != NULL) { int before = buffer_get_position(buffer); int cur_size; if (!write_dict(state->_cbson, buffer, doc, check_keys, &options, 1)) { goto iterfail; } cur_size = buffer_get_position(buffer) - before; if (cur_size > max_bson_size) { /* If we've encoded anything send it before raising. */ if (!empty) { buffer_update_position(buffer, before); if (!compress) { message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); } result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, request_id, send_safe, &options, to_publish, compress); if (!result) goto iterfail; Py_DECREF(result); } _set_document_too_large(cur_size, max_bson_size); goto iterfail; } empty = 0; /* We have enough data, send this batch. */ if (buffer_get_position(buffer) > max_message_size) { int new_request_id = rand(); int message_start; buffer_t new_buffer = buffer_new(); if (!new_buffer) { goto iterfail; } message_start = init_insert_buffer(new_buffer, new_request_id, flags, collection_name, collection_name_length, compress); if (message_start == -1) { buffer_free(new_buffer); goto iterfail; } /* Copy the overflow encoded document into the new buffer. */ if (!buffer_write_bytes(new_buffer, (const char*)buffer_get_buffer(buffer) + before, cur_size)) { buffer_free(new_buffer); goto iterfail; } /* Roll back to the beginning of this document. */ buffer_update_position(buffer, before); if (!compress) { message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); } result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, request_id, send_safe, &options, to_publish, compress); buffer_free(buffer); buffer = new_buffer; request_id = new_request_id; length_location = message_start; Py_DECREF(to_publish); if (!(to_publish = PyList_New(0))) { goto insertfail; } if (!result) { PyObject *etype = NULL, *evalue = NULL, *etrace = NULL; PyObject* OperationFailure; PyErr_Fetch(&etype, &evalue, &etrace); OperationFailure = _error("OperationFailure"); if (OperationFailure) { if (PyErr_GivenExceptionMatches(etype, OperationFailure)) { if (!safe || continue_on_error) { Py_DECREF(OperationFailure); if (!safe) { /* We're doing unacknowledged writes and * continue_on_error is False. Just return. */ Py_DECREF(etype); Py_XDECREF(evalue); Py_XDECREF(etrace); Py_DECREF(to_publish); Py_DECREF(iterator); Py_DECREF(doc); buffer_free(buffer); PyMem_Free(collection_name); Py_RETURN_NONE; } /* continue_on_error is True, store the error * details to re-raise after the final batch */ Py_XDECREF(exc_type); Py_XDECREF(exc_value); Py_XDECREF(exc_trace); exc_type = etype; exc_value = evalue; exc_trace = etrace; if (PyList_Append(to_publish, doc) < 0) { goto iterfail; } Py_CLEAR(doc); continue; } } Py_DECREF(OperationFailure); } /* This isn't OperationFailure, we couldn't * import OperationFailure, or we are doing * acknowledged writes. Re-raise immediately. */ PyErr_Restore(etype, evalue, etrace); goto iterfail; } else { Py_DECREF(result); } } if (PyList_Append(to_publish, doc) < 0) { goto iterfail; } Py_CLEAR(doc); } Py_DECREF(iterator); if (PyErr_Occurred()) { goto insertfail; } if (empty) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "cannot do an empty bulk insert"); Py_DECREF(InvalidOperation); } goto insertfail; } if (!compress) { message_length = buffer_get_position(buffer) - length_location; buffer_write_int32_at_position( buffer, length_location, (int32_t)message_length); } /* Send the last (or only) batch */ result = _send_insert(self, ctx, last_error_args, buffer, collection_name, collection_name_length, request_id, safe, &options, to_publish, compress); Py_DECREF(to_publish); PyMem_Free(collection_name); buffer_free(buffer); if (!result) { Py_XDECREF(exc_type); Py_XDECREF(exc_value); Py_XDECREF(exc_trace); return NULL; } else { Py_DECREF(result); } if (exc_type) { /* Re-raise any previously stored exception * due to continue_on_error being True */ PyErr_Restore(exc_type, exc_value, exc_trace); return NULL; } Py_RETURN_NONE; iterfail: Py_XDECREF(doc); Py_DECREF(iterator); insertfail: Py_XDECREF(exc_type); Py_XDECREF(exc_value); Py_XDECREF(exc_trace); Py_XDECREF(to_publish); buffer_free(buffer); PyMem_Free(collection_name); return NULL; } #define _INSERT 0 #define _UPDATE 1 #define _DELETE 2 /* OP_MSG ----------------------------------------------- */ static int _batched_op_msg( unsigned char op, unsigned char check_keys, unsigned char ack, PyObject* command, PyObject* docs, PyObject* ctx, PyObject* to_publish, codec_options_t options, buffer_t buffer, struct module_state *state) { long max_bson_size; long max_write_batch_size; long max_message_size; int idx = 0; int size_location; int position; int length; PyObject* max_bson_size_obj = NULL; PyObject* max_write_batch_size_obj = NULL; PyObject* max_message_size_obj = NULL; PyObject* doc = NULL; PyObject* iterator = NULL; char* flags = ack ? "\x00\x00\x00\x00" : "\x02\x00\x00\x00"; max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); #if PY_MAJOR_VERSION >= 3 max_bson_size = PyLong_AsLong(max_bson_size_obj); #else max_bson_size = PyInt_AsLong(max_bson_size_obj); #endif Py_XDECREF(max_bson_size_obj); if (max_bson_size == -1) { return 0; } max_write_batch_size_obj = PyObject_GetAttrString(ctx, "max_write_batch_size"); #if PY_MAJOR_VERSION >= 3 max_write_batch_size = PyLong_AsLong(max_write_batch_size_obj); #else max_write_batch_size = PyInt_AsLong(max_write_batch_size_obj); #endif Py_XDECREF(max_write_batch_size_obj); if (max_write_batch_size == -1) { return 0; } max_message_size_obj = PyObject_GetAttrString(ctx, "max_message_size"); #if PY_MAJOR_VERSION >= 3 max_message_size = PyLong_AsLong(max_message_size_obj); #else max_message_size = PyInt_AsLong(max_message_size_obj); #endif Py_XDECREF(max_message_size_obj); if (max_message_size == -1) { return 0; } if (!buffer_write_bytes(buffer, flags, 4)) { return 0; } /* Type 0 Section */ if (!buffer_write_bytes(buffer, "\x00", 1)) { return 0; } if (!write_dict(state->_cbson, buffer, command, 0, &options, 0)) { return 0; } /* Type 1 Section */ if (!buffer_write_bytes(buffer, "\x01", 1)) { return 0; } /* Save space for size */ size_location = buffer_save_space(buffer, 4); if (size_location == -1) { return 0; } switch (op) { case _INSERT: { if (!buffer_write_bytes(buffer, "documents\x00", 10)) goto fail; break; } case _UPDATE: { /* MongoDB does key validation for update. */ check_keys = 0; if (!buffer_write_bytes(buffer, "updates\x00", 8)) goto fail; break; } case _DELETE: { /* Never check keys in a delete command. */ check_keys = 0; if (!buffer_write_bytes(buffer, "deletes\x00", 8)) goto fail; break; } default: { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "Unknown command"); Py_DECREF(InvalidOperation); } return 0; } } iterator = PyObject_GetIter(docs); if (iterator == NULL) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "input is not iterable"); Py_DECREF(InvalidOperation); } return 0; } while ((doc = PyIter_Next(iterator)) != NULL) { int cur_doc_begin = buffer_get_position(buffer); int cur_size; int doc_too_large = 0; int unacked_doc_too_large = 0; if (!write_dict(state->_cbson, buffer, doc, check_keys, &options, 1)) { goto fail; } cur_size = buffer_get_position(buffer) - cur_doc_begin; /* Does the first document exceed max_message_size? */ doc_too_large = (idx == 0 && (buffer_get_position(buffer) > max_message_size)); /* When OP_MSG is used unacknowledged we have to check * document size client side or applications won't be notified. * Otherwise we let the server deal with documents that are too large * since ordered=False causes those documents to be skipped instead of * halting the bulk write operation. * */ unacked_doc_too_large = (!ack && cur_size > max_bson_size); if (doc_too_large || unacked_doc_too_large) { if (op == _INSERT) { _set_document_too_large(cur_size, max_bson_size); } else { PyObject* DocumentTooLarge = _error("DocumentTooLarge"); if (DocumentTooLarge) { /* * There's nothing intelligent we can say * about size for update and delete. */ PyErr_Format( DocumentTooLarge, "%s command document too large", (op == _UPDATE) ? "update": "delete"); Py_DECREF(DocumentTooLarge); } } goto fail; } /* We have enough data, return this batch. */ if (buffer_get_position(buffer) > max_message_size) { /* * Roll the existing buffer back to the beginning * of the last document encoded. */ buffer_update_position(buffer, cur_doc_begin); Py_CLEAR(doc); break; } if (PyList_Append(to_publish, doc) < 0) { goto fail; } Py_CLEAR(doc); idx += 1; /* We have enough documents, return this batch. */ if (idx == max_write_batch_size) { break; } } Py_CLEAR(iterator); if (PyErr_Occurred()) { goto fail; } position = buffer_get_position(buffer); length = position - size_location; buffer_write_int32_at_position(buffer, size_location, (int32_t)length); return 1; fail: Py_XDECREF(doc); Py_XDECREF(iterator); return 0; } static PyObject* _cbson_encode_batched_op_msg(PyObject* self, PyObject* args) { unsigned char op; unsigned char check_keys; unsigned char ack; PyObject* command; PyObject* docs; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); if (!PyArg_ParseTuple(args, "bOObbO&O", &op, &command, &docs, &check_keys, &ack, convert_codec_options, &options, &ctx)) { return NULL; } if (!(buffer = buffer_new())) { destroy_codec_options(&options); return NULL; } if (!(to_publish = PyList_New(0))) { goto fail; } if (!_batched_op_msg( op, check_keys, ack, command, docs, ctx, to_publish, options, buffer, state)) { goto fail; } result = Py_BuildValue(BYTES_FORMAT_STRING "O", buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), to_publish); fail: destroy_codec_options(&options); buffer_free(buffer); Py_XDECREF(to_publish); return result; } static PyObject* _cbson_batched_op_msg(PyObject* self, PyObject* args) { unsigned char op; unsigned char check_keys; unsigned char ack; int request_id; int position; PyObject* command; PyObject* docs; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); if (!PyArg_ParseTuple(args, "bOObbO&O", &op, &command, &docs, &check_keys, &ack, convert_codec_options, &options, &ctx)) { return NULL; } if (!(buffer = buffer_new())) { destroy_codec_options(&options); return NULL; } /* Save space for message length and request id */ if ((buffer_save_space(buffer, 8)) == -1) { goto fail; } if (!buffer_write_bytes(buffer, "\x00\x00\x00\x00" /* responseTo */ "\xdd\x07\x00\x00", /* opcode */ 8)) { goto fail; } if (!(to_publish = PyList_New(0))) { goto fail; } if (!_batched_op_msg( op, check_keys, ack, command, docs, ctx, to_publish, options, buffer, state)) { goto fail; } request_id = rand(); position = buffer_get_position(buffer); buffer_write_int32_at_position(buffer, 0, (int32_t)position); buffer_write_int32_at_position(buffer, 4, (int32_t)request_id); result = Py_BuildValue("i" BYTES_FORMAT_STRING "O", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), to_publish); fail: destroy_codec_options(&options); buffer_free(buffer); Py_XDECREF(to_publish); return result; } /* End OP_MSG -------------------------------------------- */ static int _batched_write_command( char* ns, Py_ssize_t ns_len, unsigned char op, int check_keys, PyObject* command, PyObject* docs, PyObject* ctx, PyObject* to_publish, codec_options_t options, buffer_t buffer, struct module_state *state) { long max_bson_size; long max_cmd_size; long max_write_batch_size; long max_split_size; int idx = 0; int cmd_len_loc; int lst_len_loc; int position; int length; PyObject* max_bson_size_obj = NULL; PyObject* max_write_batch_size_obj = NULL; PyObject* max_split_size_obj = NULL; PyObject* doc = NULL; PyObject* iterator = NULL; max_bson_size_obj = PyObject_GetAttrString(ctx, "max_bson_size"); #if PY_MAJOR_VERSION >= 3 max_bson_size = PyLong_AsLong(max_bson_size_obj); #else max_bson_size = PyInt_AsLong(max_bson_size_obj); #endif Py_XDECREF(max_bson_size_obj); if (max_bson_size == -1) { return 0; } /* * Max BSON object size + 16k - 2 bytes for ending NUL bytes * XXX: This should come from the server - SERVER-10643 */ max_cmd_size = max_bson_size + 16382; max_write_batch_size_obj = PyObject_GetAttrString(ctx, "max_write_batch_size"); #if PY_MAJOR_VERSION >= 3 max_write_batch_size = PyLong_AsLong(max_write_batch_size_obj); #else max_write_batch_size = PyInt_AsLong(max_write_batch_size_obj); #endif Py_XDECREF(max_write_batch_size_obj); if (max_write_batch_size == -1) { return 0; } // max_split_size is the size at which to perform a batch split. // Normally this this value is equal to max_bson_size (16MiB). However, // when auto encryption is enabled max_split_size is reduced to 2MiB. max_split_size_obj = PyObject_GetAttrString(ctx, "max_split_size"); #if PY_MAJOR_VERSION >= 3 max_split_size = PyLong_AsLong(max_split_size_obj); #else max_split_size = PyInt_AsLong(max_split_size_obj); #endif Py_XDECREF(max_split_size_obj); if (max_split_size == -1) { return 0; } if (!buffer_write_bytes(buffer, "\x00\x00\x00\x00", /* flags */ 4) || !buffer_write_bytes_ssize_t(buffer, ns, ns_len + 1) || /* namespace */ !buffer_write_bytes(buffer, "\x00\x00\x00\x00" /* skip */ "\xFF\xFF\xFF\xFF", /* limit (-1) */ 8)) { return 0; } /* Position of command document length */ cmd_len_loc = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, command, 0, &options, 0)) { return 0; } /* Write type byte for array */ *(buffer_get_buffer(buffer) + (buffer_get_position(buffer) - 1)) = 0x4; switch (op) { case _INSERT: { if (!buffer_write_bytes(buffer, "documents\x00", 10)) goto fail; break; } case _UPDATE: { /* MongoDB does key validation for update. */ check_keys = 0; if (!buffer_write_bytes(buffer, "updates\x00", 8)) goto fail; break; } case _DELETE: { /* Never check keys in a delete command. */ check_keys = 0; if (!buffer_write_bytes(buffer, "deletes\x00", 8)) goto fail; break; } default: { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "Unknown command"); Py_DECREF(InvalidOperation); } return 0; } } /* Save space for list document */ lst_len_loc = buffer_save_space(buffer, 4); if (lst_len_loc == -1) { return 0; } iterator = PyObject_GetIter(docs); if (iterator == NULL) { PyObject* InvalidOperation = _error("InvalidOperation"); if (InvalidOperation) { PyErr_SetString(InvalidOperation, "input is not iterable"); Py_DECREF(InvalidOperation); } return 0; } while ((doc = PyIter_Next(iterator)) != NULL) { int sub_doc_begin = buffer_get_position(buffer); int cur_doc_begin; int cur_size; int enough_data = 0; char key[16]; INT2STRING(key, idx); if (!buffer_write_bytes(buffer, "\x03", 1) || !buffer_write_bytes(buffer, key, (int)strlen(key) + 1)) { goto fail; } cur_doc_begin = buffer_get_position(buffer); if (!write_dict(state->_cbson, buffer, doc, check_keys, &options, 1)) { goto fail; } /* We have enough data, return this batch. * max_cmd_size accounts for the two trailing null bytes. */ cur_size = buffer_get_position(buffer) - cur_doc_begin; /* This single document is too large for the command. */ if (cur_size > max_cmd_size) { if (op == _INSERT) { _set_document_too_large(cur_size, max_bson_size); } else { PyObject* DocumentTooLarge = _error("DocumentTooLarge"); if (DocumentTooLarge) { /* * There's nothing intelligent we can say * about size for update and delete. */ PyErr_Format( DocumentTooLarge, "%s command document too large", (op == _UPDATE) ? "update": "delete"); Py_DECREF(DocumentTooLarge); } } goto fail; } enough_data = (idx >= 1 && (buffer_get_position(buffer) > max_split_size)); if (enough_data) { /* * Roll the existing buffer back to the beginning * of the last document encoded. */ buffer_update_position(buffer, sub_doc_begin); Py_CLEAR(doc); break; } if (PyList_Append(to_publish, doc) < 0) { goto fail; } Py_CLEAR(doc); idx += 1; /* We have enough documents, return this batch. */ if (idx == max_write_batch_size) { break; } } Py_CLEAR(iterator); if (PyErr_Occurred()) { goto fail; } if (!buffer_write_bytes(buffer, "\x00\x00", 2)) { goto fail; } position = buffer_get_position(buffer); length = position - lst_len_loc - 1; buffer_write_int32_at_position(buffer, lst_len_loc, (int32_t)length); length = position - cmd_len_loc; buffer_write_int32_at_position(buffer, cmd_len_loc, (int32_t)length); return 1; fail: Py_XDECREF(doc); Py_XDECREF(iterator); return 0; } static PyObject* _cbson_encode_batched_write_command(PyObject* self, PyObject* args) { char *ns = NULL; unsigned char op; unsigned char check_keys; Py_ssize_t ns_len; PyObject* command; PyObject* docs; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8", &ns, &ns_len, &op, &command, &docs, &check_keys, convert_codec_options, &options, &ctx)) { return NULL; } if (!(buffer = buffer_new())) { PyMem_Free(ns); destroy_codec_options(&options); return NULL; } if (!(to_publish = PyList_New(0))) { goto fail; } if (!_batched_write_command( ns, ns_len, op, check_keys, command, docs, ctx, to_publish, options, buffer, state)) { goto fail; } result = Py_BuildValue(BYTES_FORMAT_STRING "O", buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), to_publish); fail: PyMem_Free(ns); destroy_codec_options(&options); buffer_free(buffer); Py_XDECREF(to_publish); return result; } static PyObject* _cbson_batched_write_command(PyObject* self, PyObject* args) { char *ns = NULL; unsigned char op; unsigned char check_keys; Py_ssize_t ns_len; int request_id; int position; PyObject* command; PyObject* docs; PyObject* ctx = NULL; PyObject* to_publish = NULL; PyObject* result = NULL; codec_options_t options; buffer_t buffer; struct module_state *state = GETSTATE(self); if (!PyArg_ParseTuple(args, "et#bOObO&O", "utf-8", &ns, &ns_len, &op, &command, &docs, &check_keys, convert_codec_options, &options, &ctx)) { return NULL; } if (!(buffer = buffer_new())) { PyMem_Free(ns); destroy_codec_options(&options); return NULL; } /* Save space for message length and request id */ if ((buffer_save_space(buffer, 8)) == -1) { goto fail; } if (!buffer_write_bytes(buffer, "\x00\x00\x00\x00" /* responseTo */ "\xd4\x07\x00\x00", /* opcode */ 8)) { goto fail; } if (!(to_publish = PyList_New(0))) { goto fail; } if (!_batched_write_command( ns, ns_len, op, check_keys, command, docs, ctx, to_publish, options, buffer, state)) { goto fail; } request_id = rand(); position = buffer_get_position(buffer); buffer_write_int32_at_position(buffer, 0, (int32_t)position); buffer_write_int32_at_position(buffer, 4, (int32_t)request_id); result = Py_BuildValue("i" BYTES_FORMAT_STRING "O", request_id, buffer_get_buffer(buffer), (Py_ssize_t)buffer_get_position(buffer), to_publish); fail: PyMem_Free(ns); destroy_codec_options(&options); buffer_free(buffer); Py_XDECREF(to_publish); return result; } static PyMethodDef _CMessageMethods[] = { {"_insert_message", _cbson_insert_message, METH_VARARGS, "Create an insert message to be sent to MongoDB"}, {"_update_message", _cbson_update_message, METH_VARARGS, "create an update message to be sent to MongoDB"}, {"_query_message", _cbson_query_message, METH_VARARGS, "create a query message to be sent to MongoDB"}, {"_get_more_message", _cbson_get_more_message, METH_VARARGS, "create a get more message to be sent to MongoDB"}, {"_op_msg", _cbson_op_msg, METH_VARARGS, "create an OP_MSG message to be sent to MongoDB"}, {"_do_batched_insert", _cbson_do_batched_insert, METH_VARARGS, "insert a batch of documents, splitting the batch as needed"}, {"_batched_write_command", _cbson_batched_write_command, METH_VARARGS, "Create the next batched insert, update, or delete command"}, {"_encode_batched_write_command", _cbson_encode_batched_write_command, METH_VARARGS, "Encode the next batched insert, update, or delete command"}, {"_batched_op_msg", _cbson_batched_op_msg, METH_VARARGS, "Create the next batched insert, update, or delete using OP_MSG"}, {"_encode_batched_op_msg", _cbson_encode_batched_op_msg, METH_VARARGS, "Encode the next batched insert, update, or delete using OP_MSG"}, {NULL, NULL, 0, NULL} }; #if PY_MAJOR_VERSION >= 3 #define INITERROR return NULL static int _cmessage_traverse(PyObject *m, visitproc visit, void *arg) { Py_VISIT(GETSTATE(m)->_cbson); return 0; } static int _cmessage_clear(PyObject *m) { Py_CLEAR(GETSTATE(m)->_cbson); return 0; } static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "_cmessage", NULL, sizeof(struct module_state), _CMessageMethods, NULL, _cmessage_traverse, _cmessage_clear, NULL }; PyMODINIT_FUNC PyInit__cmessage(void) #else #define INITERROR return PyMODINIT_FUNC init_cmessage(void) #endif { PyObject *_cbson = NULL; PyObject *c_api_object = NULL; PyObject *m = NULL; /* Store a reference to the _cbson module since it's needed to call some * of its functions */ _cbson = PyImport_ImportModule("bson._cbson"); if (_cbson == NULL) { goto fail; } /* Import C API of _cbson * The header file accesses _cbson_API to call the functions */ c_api_object = PyObject_GetAttrString(_cbson, "_C_API"); if (c_api_object == NULL) { goto fail; } #if PY_VERSION_HEX >= 0x03010000 _cbson_API = (void **)PyCapsule_GetPointer(c_api_object, "_cbson._C_API"); #else _cbson_API = (void **)PyCObject_AsVoidPtr(c_api_object); #endif if (_cbson_API == NULL) { goto fail; } #if PY_MAJOR_VERSION >= 3 /* Returns a new reference. */ m = PyModule_Create(&moduledef); #else /* Returns a borrowed reference. */ m = Py_InitModule("_cmessage", _CMessageMethods); #endif if (m == NULL) { goto fail; } GETSTATE(m)->_cbson = _cbson; Py_DECREF(c_api_object); #if PY_MAJOR_VERSION >= 3 return m; #else return; #endif fail: #if PY_MAJOR_VERSION >= 3 Py_XDECREF(m); #endif Py_XDECREF(c_api_object); Py_XDECREF(_cbson); INITERROR; } pymongo-3.11.0/pymongo/aggregation.py000066400000000000000000000211341374256237000176270ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Perform aggregation operations on a collection or database.""" from bson.son import SON from pymongo import common from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference class _AggregationCommand(object): """The internal abstract base class for aggregation cursors. Should not be called directly by application developers. Use :meth:`pymongo.collection.Collection.aggregate`, or :meth:`pymongo.database.Database.aggregate` instead. """ def __init__(self, target, cursor_class, pipeline, options, explicit_session, user_fields=None, result_processor=None): if "explain" in options: raise ConfigurationError("The explain option is not supported. " "Use Database.command instead.") self._target = target common.validate_list('pipeline', pipeline) self._pipeline = pipeline self._performs_write = False if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]): self._performs_write = True common.validate_is_mapping('options', options) self._options = options # This is the batchSize that will be used for setting the initial # batchSize for the cursor, as well as the subsequent getMores. self._batch_size = common.validate_non_negative_integer_or_none( "batchSize", self._options.pop("batchSize", None)) # If the cursor option is already specified, avoid overriding it. self._options.setdefault("cursor", {}) # If the pipeline performs a write, we ignore the initial batchSize # since the server doesn't return results in this case. if self._batch_size is not None and not self._performs_write: self._options["cursor"]["batchSize"] = self._batch_size self._cursor_class = cursor_class self._explicit_session = explicit_session self._user_fields = user_fields self._result_processor = result_processor self._collation = validate_collation_or_none( options.pop('collation', None)) self._max_await_time_ms = options.pop('maxAwaitTimeMS', None) @property def _aggregation_target(self): """The argument to pass to the aggregate command.""" raise NotImplementedError @property def _cursor_namespace(self): """The namespace in which the aggregate command is run.""" raise NotImplementedError @property def _cursor_collection(self, cursor_doc): """The Collection used for the aggregate command cursor.""" raise NotImplementedError @property def _database(self): """The database against which the aggregation command is run.""" raise NotImplementedError @staticmethod def _check_compat(sock_info): """Check whether the server version in-use supports aggregation.""" pass def _process_result(self, result, session, server, sock_info, slave_ok): if self._result_processor: self._result_processor( result, session, server, sock_info, slave_ok) def get_read_preference(self, session): if self._performs_write: return ReadPreference.PRIMARY return self._target._read_preference_for(session) def get_cursor(self, session, server, sock_info, slave_ok): # Ensure command compatibility. self._check_compat(sock_info) # Serialize command. cmd = SON([("aggregate", self._aggregation_target), ("pipeline", self._pipeline)]) cmd.update(self._options) # Apply this target's read concern if: # readConcern has not been specified as a kwarg and either # - server version is >= 4.2 or # - server version is >= 3.2 and pipeline doesn't use $out if (('readConcern' not in cmd) and ((sock_info.max_wire_version >= 4 and not self._performs_write) or (sock_info.max_wire_version >= 8))): read_concern = self._target.read_concern else: read_concern = None # Apply this target's write concern if: # writeConcern has not been specified as a kwarg and pipeline doesn't # perform a write operation if 'writeConcern' not in cmd and self._performs_write: write_concern = self._target._write_concern_for(session) else: write_concern = None # Run command. result = sock_info.command( self._database.name, cmd, slave_ok, self.get_read_preference(session), self._target.codec_options, parse_write_concern_error=True, read_concern=read_concern, write_concern=write_concern, collation=self._collation, session=session, client=self._database.client, user_fields=self._user_fields) self._process_result(result, session, server, sock_info, slave_ok) # Extract cursor from result or mock/fake one if necessary. if 'cursor' in result: cursor = result['cursor'] else: # Pre-MongoDB 2.6 or unacknowledged write. Fake a cursor. cursor = { "id": 0, "firstBatch": result.get("result", []), "ns": self._cursor_namespace, } # Create and return cursor instance. return self._cursor_class( self._cursor_collection(cursor), cursor, sock_info.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, session=session, explicit_session=self._explicit_session) class _CollectionAggregationCommand(_AggregationCommand): def __init__(self, *args, **kwargs): # Pop additional option and initialize parent class. use_cursor = kwargs.pop("use_cursor", True) super(_CollectionAggregationCommand, self).__init__(*args, **kwargs) # Remove the cursor document if the user has set use_cursor to False. self._use_cursor = use_cursor if not self._use_cursor: self._options.pop("cursor", None) @property def _aggregation_target(self): return self._target.name @property def _cursor_namespace(self): return self._target.full_name def _cursor_collection(self, cursor): """The Collection used for the aggregate command cursor.""" return self._target @property def _database(self): return self._target.database class _CollectionRawAggregationCommand(_CollectionAggregationCommand): def __init__(self, *args, **kwargs): super(_CollectionRawAggregationCommand, self).__init__(*args, **kwargs) # For raw-batches, we set the initial batchSize for the cursor to 0. if self._use_cursor and not self._performs_write: self._options["cursor"]["batchSize"] = 0 class _DatabaseAggregationCommand(_AggregationCommand): @property def _aggregation_target(self): return 1 @property def _cursor_namespace(self): return "%s.$cmd.aggregate" % (self._target.name,) @property def _database(self): return self._target def _cursor_collection(self, cursor): """The Collection used for the aggregate command cursor.""" # Collection level aggregate may not always return the "ns" field # according to our MockupDB tests. Let's handle that case for db level # aggregate too by defaulting to the .$cmd.aggregate namespace. _, collname = cursor.get("ns", self._cursor_namespace).split(".", 1) return self._database[collname] @staticmethod def _check_compat(sock_info): # Older server version don't raise a descriptive error, so we raise # one instead. if not sock_info.max_wire_version >= 6: err_msg = "Database.aggregate() is only supported on MongoDB 3.6+." raise ConfigurationError(err_msg) pymongo-3.11.0/pymongo/auth.py000066400000000000000000000600331374256237000163020ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Authentication helpers.""" import functools import hashlib import hmac import os import socket try: from urllib import quote except ImportError: from urllib.parse import quote HAVE_KERBEROS = True _USE_PRINCIPAL = False try: import winkerberos as kerberos if tuple(map(int, kerberos.__version__.split('.')[:2])) >= (0, 5): _USE_PRINCIPAL = True except ImportError: try: import kerberos except ImportError: HAVE_KERBEROS = False from base64 import standard_b64decode, standard_b64encode from collections import namedtuple from bson.binary import Binary from bson.py3compat import string_type, _unicode, PY3 from bson.son import SON from pymongo.auth_aws import _authenticate_aws from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep MECHANISMS = frozenset( ['GSSAPI', 'MONGODB-CR', 'MONGODB-X509', 'MONGODB-AWS', 'PLAIN', 'SCRAM-SHA-1', 'SCRAM-SHA-256', 'DEFAULT']) """The authentication mechanisms supported by PyMongo.""" class _Cache(object): __slots__ = ("data",) _hash_val = hash('_Cache') def __init__(self): self.data = None def __eq__(self, other): # Two instances must always compare equal. if isinstance(other, _Cache): return True return NotImplemented def __ne__(self, other): if isinstance(other, _Cache): return False return NotImplemented def __hash__(self): return self._hash_val MongoCredential = namedtuple( 'MongoCredential', ['mechanism', 'source', 'username', 'password', 'mechanism_properties', 'cache']) """A hashable namedtuple of values used for authentication.""" GSSAPIProperties = namedtuple('GSSAPIProperties', ['service_name', 'canonicalize_host_name', 'service_realm']) """Mechanism properties for GSSAPI authentication.""" _AWSProperties = namedtuple('AWSProperties', ['aws_session_token']) """Mechanism properties for MONGODB-AWS authentication.""" def _build_credentials_tuple(mech, source, user, passwd, extra, database): """Build and return a mechanism specific credentials tuple. """ if mech not in ('MONGODB-X509', 'MONGODB-AWS') and user is None: raise ConfigurationError("%s requires a username." % (mech,)) if mech == 'GSSAPI': if source is not None and source != '$external': raise ValueError( "authentication source must be $external or None for GSSAPI") properties = extra.get('authmechanismproperties', {}) service_name = properties.get('SERVICE_NAME', 'mongodb') canonicalize = properties.get('CANONICALIZE_HOST_NAME', False) service_realm = properties.get('SERVICE_REALM') props = GSSAPIProperties(service_name=service_name, canonicalize_host_name=canonicalize, service_realm=service_realm) # Source is always $external. return MongoCredential(mech, '$external', user, passwd, props, None) elif mech == 'MONGODB-X509': if passwd is not None: raise ConfigurationError( "Passwords are not supported by MONGODB-X509") if source is not None and source != '$external': raise ValueError( "authentication source must be " "$external or None for MONGODB-X509") # Source is always $external, user can be None. return MongoCredential(mech, '$external', user, None, None, None) elif mech == 'MONGODB-AWS': if user is not None and passwd is None: raise ConfigurationError( "username without a password is not supported by MONGODB-AWS") if source is not None and source != '$external': raise ConfigurationError( "authentication source must be " "$external or None for MONGODB-AWS") properties = extra.get('authmechanismproperties', {}) aws_session_token = properties.get('AWS_SESSION_TOKEN') props = _AWSProperties(aws_session_token=aws_session_token) # user can be None for temporary link-local EC2 credentials. return MongoCredential(mech, '$external', user, passwd, props, None) elif mech == 'PLAIN': source_database = source or database or '$external' return MongoCredential(mech, source_database, user, passwd, None, None) else: source_database = source or database or 'admin' if passwd is None: raise ConfigurationError("A password is required.") return MongoCredential( mech, source_database, user, passwd, None, _Cache()) if PY3: def _xor(fir, sec): """XOR two byte strings together (python 3.x).""" return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) _from_bytes = int.from_bytes _to_bytes = int.to_bytes else: from binascii import (hexlify as _hexlify, unhexlify as _unhexlify) def _xor(fir, sec): """XOR two byte strings together (python 2.x).""" return b"".join([chr(ord(x) ^ ord(y)) for x, y in zip(fir, sec)]) def _from_bytes(value, dummy, _int=int, _hexlify=_hexlify): """An implementation of int.from_bytes for python 2.x.""" return _int(_hexlify(value), 16) def _to_bytes(value, length, dummy, _unhexlify=_unhexlify): """An implementation of int.to_bytes for python 2.x.""" fmt = '%%0%dx' % (2 * length,) return _unhexlify(fmt % value) try: # The fastest option, if it's been compiled to use OpenSSL's HMAC. from backports.pbkdf2 import pbkdf2_hmac as _hi except ImportError: try: # Python 2.7.8+, or Python 3.4+. from hashlib import pbkdf2_hmac as _hi except ImportError: def _hi(hash_name, data, salt, iterations): """A simple implementation of PBKDF2-HMAC.""" mac = hmac.HMAC(data, None, getattr(hashlib, hash_name)) def _digest(msg, mac=mac): """Get a digest for msg.""" _mac = mac.copy() _mac.update(msg) return _mac.digest() from_bytes = _from_bytes to_bytes = _to_bytes _u1 = _digest(salt + b'\x00\x00\x00\x01') _ui = from_bytes(_u1, 'big') for _ in range(iterations - 1): _u1 = _digest(_u1) _ui ^= from_bytes(_u1, 'big') return to_bytes(_ui, mac.digest_size, 'big') try: from hmac import compare_digest except ImportError: if PY3: def _xor_bytes(a, b): return a ^ b else: def _xor_bytes(a, b, _ord=ord): return _ord(a) ^ _ord(b) # Python 2.x < 2.7.7 # Note: This method is intentionally obtuse to prevent timing attacks. Do # not refactor it! # References: # - http://bugs.python.org/issue14532 # - http://bugs.python.org/issue14955 # - http://bugs.python.org/issue15061 def compare_digest(a, b, _xor_bytes=_xor_bytes): left = None right = b if len(a) == len(b): left = a result = 0 if len(a) != len(b): left = b result = 1 for x, y in zip(left, right): result |= _xor_bytes(x, y) return result == 0 def _parse_scram_response(response): """Split a scram response into key, value pairs.""" return dict(item.split(b"=", 1) for item in response.split(b",")) def _authenticate_scram_start(credentials, mechanism): username = credentials.username user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") nonce = standard_b64encode(os.urandom(32)) first_bare = b"n=" + user + b",r=" + nonce cmd = SON([('saslStart', 1), ('mechanism', mechanism), ('payload', Binary(b"n,," + first_bare)), ('autoAuthorize', 1), ('options', {'skipEmptyExchange': True})]) return nonce, first_bare, cmd def _authenticate_scram(credentials, sock_info, mechanism): """Authenticate using SCRAM.""" username = credentials.username if mechanism == 'SCRAM-SHA-256': digest = "sha256" digestmod = hashlib.sha256 data = saslprep(credentials.password).encode("utf-8") else: digest = "sha1" digestmod = hashlib.sha1 data = _password_digest(username, credentials.password).encode("utf-8") source = credentials.source cache = credentials.cache # Make local _hmac = hmac.HMAC ctx = sock_info.auth_ctx.get(credentials) if ctx and ctx.speculate_succeeded(): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate else: nonce, first_bare, cmd = _authenticate_scram_start( credentials, mechanism) res = sock_info.command(source, cmd) server_first = res['payload'] parsed = _parse_scram_response(server_first) iterations = int(parsed[b'i']) if iterations < 4096: raise OperationFailure("Server returned an invalid iteration count.") salt = parsed[b's'] rnonce = parsed[b'r'] if not rnonce.startswith(nonce): raise OperationFailure("Server returned an invalid nonce.") without_proof = b"c=biws,r=" + rnonce if cache.data: client_key, server_key, csalt, citerations = cache.data else: client_key, server_key, csalt, citerations = None, None, None, None # Salt and / or iterations could change for a number of different # reasons. Either changing invalidates the cache. if not client_key or salt != csalt or iterations != citerations: salted_pass = _hi( digest, data, standard_b64decode(salt), iterations) client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() cache.data = (client_key, server_key, salt, iterations) stored_key = digestmod(client_key).digest() auth_msg = b",".join((first_bare, server_first, without_proof)) client_sig = _hmac(stored_key, auth_msg, digestmod).digest() client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig)) client_final = b",".join((without_proof, client_proof)) server_sig = standard_b64encode( _hmac(server_key, auth_msg, digestmod).digest()) cmd = SON([('saslContinue', 1), ('conversationId', res['conversationId']), ('payload', Binary(client_final))]) res = sock_info.command(source, cmd) parsed = _parse_scram_response(res['payload']) if not compare_digest(parsed[b'v'], server_sig): raise OperationFailure("Server returned an invalid signature.") # A third empty challenge may be required if the server does not support # skipEmptyExchange: SERVER-44857. if not res['done']: cmd = SON([('saslContinue', 1), ('conversationId', res['conversationId']), ('payload', Binary(b''))]) res = sock_info.command(source, cmd) if not res['done']: raise OperationFailure('SASL conversation failed to complete.') def _password_digest(username, password): """Get a password digest to use for authentication. """ if not isinstance(password, string_type): raise TypeError("password must be an " "instance of %s" % (string_type.__name__,)) if len(password) == 0: raise ValueError("password can't be empty") if not isinstance(username, string_type): raise TypeError("password must be an " "instance of %s" % (string_type.__name__,)) md5hash = hashlib.md5() data = "%s:mongo:%s" % (username, password) md5hash.update(data.encode('utf-8')) return _unicode(md5hash.hexdigest()) def _auth_key(nonce, username, password): """Get an auth key to use for authentication. """ digest = _password_digest(username, password) md5hash = hashlib.md5() data = "%s%s%s" % (nonce, username, digest) md5hash.update(data.encode('utf-8')) return _unicode(md5hash.hexdigest()) def _canonicalize_hostname(hostname): """Canonicalize hostname following MIT-krb5 behavior.""" # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME)[0] try: name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD) except socket.gaierror: return canonname.lower() return name[0].lower() def _authenticate_gssapi(credentials, sock_info): """Authenticate using GSSAPI. """ if not HAVE_KERBEROS: raise ConfigurationError('The "kerberos" module must be ' 'installed to use GSSAPI authentication.') try: username = credentials.username password = credentials.password props = credentials.mechanism_properties # Starting here and continuing through the while loop below - establish # the security context. See RFC 4752, Section 3.1, first paragraph. host = sock_info.address[0] if props.canonicalize_host_name: host = _canonicalize_hostname(host) service = props.service_name + '@' + host if props.service_realm is not None: service = service + '@' + props.service_realm if password is not None: if _USE_PRINCIPAL: # Note that, though we use unquote_plus for unquoting URI # options, we use quote here. Microsoft's UrlUnescape (used # by WinKerberos) doesn't support +. principal = ":".join((quote(username), quote(password))) result, ctx = kerberos.authGSSClientInit( service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG) else: if '@' in username: user, domain = username.split('@', 1) else: user, domain = username, None result, ctx = kerberos.authGSSClientInit( service, gssflags=kerberos.GSS_C_MUTUAL_FLAG, user=user, domain=domain, password=password) else: result, ctx = kerberos.authGSSClientInit( service, gssflags=kerberos.GSS_C_MUTUAL_FLAG) if result != kerberos.AUTH_GSS_COMPLETE: raise OperationFailure('Kerberos context failed to initialize.') try: # pykerberos uses a weird mix of exceptions and return values # to indicate errors. # 0 == continue, 1 == complete, -1 == error # Only authGSSClientStep can return 0. if kerberos.authGSSClientStep(ctx, '') != 0: raise OperationFailure('Unknown kerberos ' 'failure in step function.') # Start a SASL conversation with mongod/s # Note: pykerberos deals with base64 encoded byte strings. # Since mongo accepts base64 strings as the payload we don't # have to use bson.binary.Binary. payload = kerberos.authGSSClientResponse(ctx) cmd = SON([('saslStart', 1), ('mechanism', 'GSSAPI'), ('payload', payload), ('autoAuthorize', 1)]) response = sock_info.command('$external', cmd) # Limit how many times we loop to catch protocol / library issues for _ in range(10): result = kerberos.authGSSClientStep(ctx, str(response['payload'])) if result == -1: raise OperationFailure('Unknown kerberos ' 'failure in step function.') payload = kerberos.authGSSClientResponse(ctx) or '' cmd = SON([('saslContinue', 1), ('conversationId', response['conversationId']), ('payload', payload)]) response = sock_info.command('$external', cmd) if result == kerberos.AUTH_GSS_COMPLETE: break else: raise OperationFailure('Kerberos ' 'authentication failed to complete.') # Once the security context is established actually authenticate. # See RFC 4752, Section 3.1, last two paragraphs. if kerberos.authGSSClientUnwrap(ctx, str(response['payload'])) != 1: raise OperationFailure('Unknown kerberos ' 'failure during GSS_Unwrap step.') if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1: raise OperationFailure('Unknown kerberos ' 'failure during GSS_Wrap step.') payload = kerberos.authGSSClientResponse(ctx) cmd = SON([('saslContinue', 1), ('conversationId', response['conversationId']), ('payload', payload)]) sock_info.command('$external', cmd) finally: kerberos.authGSSClientClean(ctx) except kerberos.KrbError as exc: raise OperationFailure(str(exc)) def _authenticate_plain(credentials, sock_info): """Authenticate using SASL PLAIN (RFC 4616) """ source = credentials.source username = credentials.username password = credentials.password payload = ('\x00%s\x00%s' % (username, password)).encode('utf-8') cmd = SON([('saslStart', 1), ('mechanism', 'PLAIN'), ('payload', Binary(payload)), ('autoAuthorize', 1)]) sock_info.command(source, cmd) def _authenticate_cram_md5(credentials, sock_info): """Authenticate using CRAM-MD5 (RFC 2195) """ source = credentials.source username = credentials.username password = credentials.password # The password used as the mac key is the # same as what we use for MONGODB-CR passwd = _password_digest(username, password) cmd = SON([('saslStart', 1), ('mechanism', 'CRAM-MD5'), ('payload', Binary(b'')), ('autoAuthorize', 1)]) response = sock_info.command(source, cmd) # MD5 as implicit default digest for digestmod is deprecated # in python 3.4 mac = hmac.HMAC(key=passwd.encode('utf-8'), digestmod=hashlib.md5) mac.update(response['payload']) challenge = username.encode('utf-8') + b' ' + mac.hexdigest().encode('utf-8') cmd = SON([('saslContinue', 1), ('conversationId', response['conversationId']), ('payload', Binary(challenge))]) sock_info.command(source, cmd) def _authenticate_x509(credentials, sock_info): """Authenticate using MONGODB-X509. """ ctx = sock_info.auth_ctx.get(credentials) if ctx and ctx.speculate_succeeded(): # MONGODB-X509 is done after the speculative auth step. return cmd = _X509Context(credentials).speculate_command() if credentials.username is None and sock_info.max_wire_version < 5: raise ConfigurationError( "A username is required for MONGODB-X509 authentication " "when connected to MongoDB versions older than 3.4.") sock_info.command('$external', cmd) def _authenticate_mongo_cr(credentials, sock_info): """Authenticate using MONGODB-CR. """ source = credentials.source username = credentials.username password = credentials.password # Get a nonce response = sock_info.command(source, {'getnonce': 1}) nonce = response['nonce'] key = _auth_key(nonce, username, password) # Actually authenticate query = SON([('authenticate', 1), ('user', username), ('nonce', nonce), ('key', key)]) sock_info.command(source, query) def _authenticate_default(credentials, sock_info): if sock_info.max_wire_version >= 7: if credentials in sock_info.negotiated_mechanisms: mechs = sock_info.negotiated_mechanisms[credentials] else: source = credentials.source cmd = SON([ ('ismaster', 1), ('saslSupportedMechs', source + '.' + credentials.username)]) mechs = sock_info.command( source, cmd, publish_events=False).get( 'saslSupportedMechs', []) if 'SCRAM-SHA-256' in mechs: return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256') else: return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1') elif sock_info.max_wire_version >= 3: return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1') else: return _authenticate_mongo_cr(credentials, sock_info) _AUTH_MAP = { 'CRAM-MD5': _authenticate_cram_md5, 'GSSAPI': _authenticate_gssapi, 'MONGODB-CR': _authenticate_mongo_cr, 'MONGODB-X509': _authenticate_x509, 'MONGODB-AWS': _authenticate_aws, 'PLAIN': _authenticate_plain, 'SCRAM-SHA-1': functools.partial( _authenticate_scram, mechanism='SCRAM-SHA-1'), 'SCRAM-SHA-256': functools.partial( _authenticate_scram, mechanism='SCRAM-SHA-256'), 'DEFAULT': _authenticate_default, } class _AuthContext(object): def __init__(self, credentials): self.credentials = credentials self.speculative_authenticate = None @staticmethod def from_credentials(creds): spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism) if spec_cls: return spec_cls(creds) return None def speculate_command(self): raise NotImplementedError def parse_response(self, ismaster): self.speculative_authenticate = ismaster.speculative_authenticate def speculate_succeeded(self): return bool(self.speculative_authenticate) class _ScramContext(_AuthContext): def __init__(self, credentials, mechanism): super(_ScramContext, self).__init__(credentials) self.scram_data = None self.mechanism = mechanism def speculate_command(self): nonce, first_bare, cmd = _authenticate_scram_start( self.credentials, self.mechanism) # The 'db' field is included only on the speculative command. cmd['db'] = self.credentials.source # Save for later use. self.scram_data = (nonce, first_bare) return cmd class _X509Context(_AuthContext): def speculate_command(self): cmd = SON([('authenticate', 1), ('mechanism', 'MONGODB-X509')]) if self.credentials.username is not None: cmd['user'] = self.credentials.username return cmd _SPECULATIVE_AUTH_MAP = { 'MONGODB-X509': _X509Context, 'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'), 'SCRAM-SHA-256': functools.partial(_ScramContext, mechanism='SCRAM-SHA-256'), 'DEFAULT': functools.partial(_ScramContext, mechanism='SCRAM-SHA-256'), } def authenticate(credentials, sock_info): """Authenticate sock_info.""" mechanism = credentials.mechanism auth_func = _AUTH_MAP.get(mechanism) auth_func(credentials, sock_info) def logout(source, sock_info): """Log out from a database.""" sock_info.command(source, {'logout': 1}) pymongo-3.11.0/pymongo/auth_aws.py000066400000000000000000000060621374256237000171560ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MONGODB-AWS Authentication helpers.""" try: import pymongo_auth_aws from pymongo_auth_aws import (AwsCredential, AwsSaslContext, PyMongoAuthAwsError) _HAVE_MONGODB_AWS = True except ImportError: class AwsSaslContext(object): def __init__(self, credentials): pass _HAVE_MONGODB_AWS = False import bson from bson.binary import Binary from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure class _AwsSaslContext(AwsSaslContext): # Dependency injection: def binary_type(self): """Return the bson.binary.Binary type.""" return Binary def bson_encode(self, doc): """Encode a dictionary to BSON.""" return bson.encode(doc) def bson_decode(self, data): """Decode BSON to a dictionary.""" return bson.decode(data) def _authenticate_aws(credentials, sock_info): """Authenticate using MONGODB-AWS. """ if not _HAVE_MONGODB_AWS: raise ConfigurationError( "MONGODB-AWS authentication requires pymongo-auth-aws: " "install with: python -m pip install 'pymongo[aws]'") if sock_info.max_wire_version < 9: raise ConfigurationError( "MONGODB-AWS authentication requires MongoDB version 4.4 or later") try: ctx = _AwsSaslContext(AwsCredential( credentials.username, credentials.password, credentials.mechanism_properties.aws_session_token)) client_payload = ctx.step(None) client_first = SON([('saslStart', 1), ('mechanism', 'MONGODB-AWS'), ('payload', client_payload)]) server_first = sock_info.command('$external', client_first) res = server_first # Limit how many times we loop to catch protocol / library issues for _ in range(10): client_payload = ctx.step(res['payload']) cmd = SON([('saslContinue', 1), ('conversationId', server_first['conversationId']), ('payload', client_payload)]) res = sock_info.command('$external', cmd) if res['done']: # SASL complete. break except PyMongoAuthAwsError as exc: # Convert to OperationFailure and include pymongo-auth-aws version. raise OperationFailure('%s (pymongo-auth-aws version %s)' % ( exc, pymongo_auth_aws.__version__)) pymongo-3.11.0/pymongo/bulk.py000066400000000000000000000645351374256237000163110ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The bulk write operations interface. .. versionadded:: 2.7 """ import copy from itertools import islice from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo.client_session import _validate_session_write_concern from pymongo.common import (validate_is_mapping, validate_is_document_type, validate_ok_for_replace, validate_ok_for_update) from pymongo.helpers import _RETRYABLE_ERROR_CODES from pymongo.collation import validate_collation_or_none from pymongo.errors import (BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure) from pymongo.message import (_INSERT, _UPDATE, _DELETE, _do_batched_insert, _randint, _BulkWriteContext, _EncryptedBulkWriteContext) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern _DELETE_ALL = 0 _DELETE_ONE = 1 # For backwards compatibility. See MongoDB src/mongo/base/error_codes.err _BAD_VALUE = 2 _UNKNOWN_ERROR = 8 _WRITE_CONCERN_ERROR = 64 _COMMANDS = ('insert', 'update', 'delete') # These string literals are used when we create fake server return # documents client side. We use unicode literals in python 2.x to # match the actual return values from the server. _UOP = u"op" class _Run(object): """Represents a batch of write operations. """ def __init__(self, op_type): """Initialize a new Run object. """ self.op_type = op_type self.index_map = [] self.ops = [] self.idx_offset = 0 def index(self, idx): """Get the original index of an operation in this run. :Parameters: - `idx`: The Run index that maps to the original index. """ return self.index_map[idx] def add(self, original_index, operation): """Add an operation to this Run instance. :Parameters: - `original_index`: The original index of this operation within a larger bulk operation. - `operation`: The operation document. """ self.index_map.append(original_index) self.ops.append(operation) def _merge_command(run, full_result, offset, result): """Merge a write command result into the full bulk result. """ affected = result.get("n", 0) if run.op_type == _INSERT: full_result["nInserted"] += affected elif run.op_type == _DELETE: full_result["nRemoved"] += affected elif run.op_type == _UPDATE: upserted = result.get("upserted") if upserted: n_upserted = len(upserted) for doc in upserted: doc["index"] = run.index(doc["index"] + offset) full_result["upserted"].extend(upserted) full_result["nUpserted"] += n_upserted full_result["nMatched"] += (affected - n_upserted) else: full_result["nMatched"] += affected full_result["nModified"] += result["nModified"] write_errors = result.get("writeErrors") if write_errors: for doc in write_errors: # Leave the server response intact for APM. replacement = doc.copy() idx = doc["index"] + offset replacement["index"] = run.index(idx) # Add the failed operation to the error document. replacement[_UOP] = run.ops[idx] full_result["writeErrors"].append(replacement) wc_error = result.get("writeConcernError") if wc_error: full_result["writeConcernErrors"].append(wc_error) def _raise_bulk_write_error(full_result): """Raise a BulkWriteError from the full bulk api result. """ if full_result["writeErrors"]: full_result["writeErrors"].sort( key=lambda error: error["index"]) raise BulkWriteError(full_result) class _Bulk(object): """The private guts of the bulk write API. """ def __init__(self, collection, ordered, bypass_document_validation): """Initialize a _Bulk instance. """ self.collection = collection.with_options( codec_options=collection.codec_options._replace( unicode_decode_error_handler='replace', document_class=dict)) self.ordered = ordered self.ops = [] self.executed = False self.bypass_doc_val = bypass_document_validation self.uses_collation = False self.uses_array_filters = False self.uses_hint = False self.is_retryable = True self.retrying = False self.started_retryable_write = False # Extra state so that we know where to pick up on a retry attempt. self.current_run = None @property def bulk_ctx_class(self): encrypter = self.collection.database.client._encrypter if encrypter and not encrypter._bypass_auto_encryption: return _EncryptedBulkWriteContext else: return _BulkWriteContext def add_insert(self, document): """Add an insert document to the list of ops. """ validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or '_id' in document): document['_id'] = ObjectId() self.ops.append((_INSERT, document)) def add_update(self, selector, update, multi=False, upsert=False, collation=None, array_filters=None, hint=None): """Create an update document and add it to the list of ops. """ validate_ok_for_update(update) cmd = SON([('q', selector), ('u', update), ('multi', multi), ('upsert', upsert)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True cmd['collation'] = collation if array_filters is not None: self.uses_array_filters = True cmd['arrayFilters'] = array_filters if hint is not None: self.uses_hint = True cmd['hint'] = hint if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False self.ops.append((_UPDATE, cmd)) def add_replace(self, selector, replacement, upsert=False, collation=None, hint=None): """Create a replace document and add it to the list of ops. """ validate_ok_for_replace(replacement) cmd = SON([('q', selector), ('u', replacement), ('multi', False), ('upsert', upsert)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True cmd['collation'] = collation if hint is not None: self.uses_hint = True cmd['hint'] = hint self.ops.append((_UPDATE, cmd)) def add_delete(self, selector, limit, collation=None, hint=None): """Create a delete document and add it to the list of ops. """ cmd = SON([('q', selector), ('limit', limit)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True cmd['collation'] = collation if hint is not None: self.uses_hint = True cmd['hint'] = hint if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. self.is_retryable = False self.ops.append((_DELETE, cmd)) def gen_ordered(self): """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None for idx, (op_type, operation) in enumerate(self.ops): if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) yield run def gen_unordered(self): """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] for idx, (op_type, operation) in enumerate(self.ops): operations[op_type].add(idx, operation) for run in operations: if run.ops: yield run def _execute_command(self, generator, write_concern, session, sock_info, op_id, retryable, full_result): if sock_info.max_wire_version < 5: if self.uses_collation: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use a collation.') if self.uses_hint: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use hint.') if sock_info.max_wire_version < 6 and self.uses_array_filters: raise ConfigurationError( 'Must be connected to MongoDB 3.6+ to use arrayFilters.') db_name = self.collection.database.name client = self.collection.database.client listeners = client._event_listeners if not self.current_run: self.current_run = next(generator) run = self.current_run # sock_info.command validates the session, but we use # sock_info.write_command. sock_info.validate_session(client, session) while run: cmd = SON([(_COMMANDS[run.op_type], self.collection.name), ('ordered', self.ordered)]) if not write_concern.is_server_default: cmd['writeConcern'] = write_concern.document if self.bypass_doc_val and sock_info.max_wire_version >= 4: cmd['bypassDocumentValidation'] = True bwc = self.bulk_ctx_class( db_name, cmd, sock_info, op_id, listeners, session, run.op_type, self.collection.codec_options) while run.idx_offset < len(run.ops): if session: # Start a new retryable write unless one was already # started for this command. if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True session._apply_to(cmd, retryable, ReadPreference.PRIMARY) sock_info.send_cluster_time(cmd, session, client) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. result, to_send = bwc.execute(ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get('writeConcernError', {}) if wce.get('code', 0) in _RETRYABLE_ERROR_CODES: # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. full = copy.deepcopy(full_result) _merge_command(run, full, run.idx_offset, result) _raise_bulk_write_error(full) _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. self.retrying = False self.started_retryable_write = False if self.ordered and "writeErrors" in result: break run.idx_offset += len(to_send) # We're supposed to continue if errors are # at the write concern level (e.g. wtimeout) if self.ordered and full_result['writeErrors']: break # Reset our state self.current_run = run = next(generator, None) def execute_command(self, generator, write_concern, session): """Execute using write commands. """ # nModified is only reported for write commands, not legacy ops. full_result = { "writeErrors": [], "writeConcernErrors": [], "nInserted": 0, "nUpserted": 0, "nMatched": 0, "nModified": 0, "nRemoved": 0, "upserted": [], } op_id = _randint() def retryable_bulk(session, sock_info, retryable): self._execute_command( generator, write_concern, session, sock_info, op_id, retryable, full_result) client = self.collection.database.client with client._tmp_session(session) as s: client._retry_with_session( self.is_retryable, retryable_bulk, s, self) if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result def execute_insert_no_results(self, sock_info, run, op_id, acknowledged): """Execute insert, returning no results. """ command = SON([('insert', self.collection.name), ('ordered', self.ordered)]) concern = {'w': int(self.ordered)} command['writeConcern'] = concern if self.bypass_doc_val and sock_info.max_wire_version >= 4: command['bypassDocumentValidation'] = True db = self.collection.database bwc = _BulkWriteContext( db.name, command, sock_info, op_id, db.client._event_listeners, None, _INSERT, self.collection.codec_options) # Legacy batched OP_INSERT. _do_batched_insert( self.collection.full_name, run.ops, True, acknowledged, concern, not self.ordered, self.collection.codec_options, bwc) def execute_op_msg_no_results(self, sock_info, generator): """Execute write commands with OP_MSG and w=0 writeConcern, unordered. """ db_name = self.collection.database.name client = self.collection.database.client listeners = client._event_listeners op_id = _randint() if not self.current_run: self.current_run = next(generator) run = self.current_run while run: cmd = SON([(_COMMANDS[run.op_type], self.collection.name), ('ordered', False), ('writeConcern', {'w': 0})]) bwc = self.bulk_ctx_class( db_name, cmd, sock_info, op_id, listeners, None, run.op_type, self.collection.codec_options) while run.idx_offset < len(run.ops): ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. to_send = bwc.execute_unack(ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) def execute_command_no_results(self, sock_info, generator): """Execute write commands with OP_MSG and w=0 WriteConcern, ordered. """ full_result = { "writeErrors": [], "writeConcernErrors": [], "nInserted": 0, "nUpserted": 0, "nMatched": 0, "nModified": 0, "nRemoved": 0, "upserted": [], } # Ordered bulk writes have to be acknowledged so that we stop # processing at the first error, even when the application # specified unacknowledged writeConcern. write_concern = WriteConcern() op_id = _randint() try: self._execute_command( generator, write_concern, None, sock_info, op_id, False, full_result) except OperationFailure: pass def execute_no_results(self, sock_info, generator): """Execute all operations, returning no results (w=0). """ if self.uses_collation: raise ConfigurationError( 'Collation is unsupported for unacknowledged writes.') if self.uses_array_filters: raise ConfigurationError( 'arrayFilters is unsupported for unacknowledged writes.') if self.uses_hint: raise ConfigurationError( 'hint is unsupported for unacknowledged writes.') # Cannot have both unacknowledged writes and bypass document validation. if self.bypass_doc_val and sock_info.max_wire_version >= 4: raise OperationFailure("Cannot set bypass_document_validation with" " unacknowledged write concern") # OP_MSG if sock_info.max_wire_version > 5: if self.ordered: return self.execute_command_no_results(sock_info, generator) return self.execute_op_msg_no_results(sock_info, generator) coll = self.collection # If ordered is True we have to send GLE or use write # commands so we can abort on the first error. write_concern = WriteConcern(w=int(self.ordered)) op_id = _randint() next_run = next(generator) while next_run: # An ordered bulk write needs to send acknowledged writes to short # circuit the next run. However, the final message on the final # run can be unacknowledged. run = next_run next_run = next(generator, None) needs_ack = self.ordered and next_run is not None try: if run.op_type == _INSERT: self.execute_insert_no_results( sock_info, run, op_id, needs_ack) elif run.op_type == _UPDATE: for operation in run.ops: doc = operation['u'] check_keys = True if doc and next(iter(doc)).startswith('$'): check_keys = False coll._update( sock_info, operation['q'], doc, operation['upsert'], check_keys, operation['multi'], write_concern=write_concern, op_id=op_id, ordered=self.ordered, bypass_doc_val=self.bypass_doc_val) else: for operation in run.ops: coll._delete(sock_info, operation['q'], not operation['limit'], write_concern, op_id, self.ordered) except OperationFailure: if self.ordered: break def execute(self, write_concern, session): """Execute operations. """ if not self.ops: raise InvalidOperation('No operations to execute') if self.executed: raise InvalidOperation('Bulk operations can ' 'only be executed once.') self.executed = True write_concern = write_concern or self.collection.write_concern session = _validate_session_write_concern(session, write_concern) if self.ordered: generator = self.gen_ordered() else: generator = self.gen_unordered() client = self.collection.database.client if not write_concern.acknowledged: with client._socket_for_writes(session) as sock_info: self.execute_no_results(sock_info, generator) else: return self.execute_command(generator, write_concern, session) class BulkUpsertOperation(object): """An interface for adding upsert operations. """ __slots__ = ('__selector', '__bulk', '__collation') def __init__(self, selector, bulk, collation): self.__selector = selector self.__bulk = bulk self.__collation = collation def update_one(self, update): """Update one document matching the selector. :Parameters: - `update` (dict): the update operations to apply """ self.__bulk.add_update(self.__selector, update, multi=False, upsert=True, collation=self.__collation) def update(self, update): """Update all documents matching the selector. :Parameters: - `update` (dict): the update operations to apply """ self.__bulk.add_update(self.__selector, update, multi=True, upsert=True, collation=self.__collation) def replace_one(self, replacement): """Replace one entire document matching the selector criteria. :Parameters: - `replacement` (dict): the replacement document """ self.__bulk.add_replace(self.__selector, replacement, upsert=True, collation=self.__collation) class BulkWriteOperation(object): """An interface for adding update or remove operations. """ __slots__ = ('__selector', '__bulk', '__collation') def __init__(self, selector, bulk, collation): self.__selector = selector self.__bulk = bulk self.__collation = collation def update_one(self, update): """Update one document matching the selector criteria. :Parameters: - `update` (dict): the update operations to apply """ self.__bulk.add_update(self.__selector, update, multi=False, collation=self.__collation) def update(self, update): """Update all documents matching the selector criteria. :Parameters: - `update` (dict): the update operations to apply """ self.__bulk.add_update(self.__selector, update, multi=True, collation=self.__collation) def replace_one(self, replacement): """Replace one entire document matching the selector criteria. :Parameters: - `replacement` (dict): the replacement document """ self.__bulk.add_replace(self.__selector, replacement, collation=self.__collation) def remove_one(self): """Remove a single document matching the selector criteria. """ self.__bulk.add_delete(self.__selector, _DELETE_ONE, collation=self.__collation) def remove(self): """Remove all documents matching the selector criteria. """ self.__bulk.add_delete(self.__selector, _DELETE_ALL, collation=self.__collation) def upsert(self): """Specify that all chained update operations should be upserts. :Returns: - A :class:`BulkUpsertOperation` instance, used to add update operations to this bulk operation. """ return BulkUpsertOperation(self.__selector, self.__bulk, self.__collation) class BulkOperationBuilder(object): """**DEPRECATED**: An interface for executing a batch of write operations. """ __slots__ = '__bulk' def __init__(self, collection, ordered=True, bypass_document_validation=False): """**DEPRECATED**: Initialize a new BulkOperationBuilder instance. :Parameters: - `collection`: A :class:`~pymongo.collection.Collection` instance. - `ordered` (optional): If ``True`` all operations will be executed serially, in the order provided, and the entire execution will abort on the first error. If ``False`` operations will be executed in arbitrary order (possibly in parallel on the server), reporting any errors that occurred after attempting all operations. Defaults to ``True``. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.5 Deprecated. Use :meth:`~pymongo.collection.Collection.bulk_write` instead. .. versionchanged:: 3.2 Added bypass_document_validation support """ self.__bulk = _Bulk(collection, ordered, bypass_document_validation) def find(self, selector, collation=None): """Specify selection criteria for bulk operations. :Parameters: - `selector` (dict): the selection criteria for update and remove operations. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. :Returns: - A :class:`BulkWriteOperation` instance, used to add update and remove operations to this bulk operation. .. versionchanged:: 3.4 Added the `collation` option. """ validate_is_mapping("selector", selector) return BulkWriteOperation(selector, self.__bulk, collation) def insert(self, document): """Insert a single document. :Parameters: - `document` (dict): the document to insert .. seealso:: :ref:`writes-and-ids` """ self.__bulk.add_insert(document) def execute(self, write_concern=None): """Execute all provided operations. :Parameters: - write_concern (optional): the write concern for this bulk execution. """ if write_concern is not None: write_concern = WriteConcern(**write_concern) return self.__bulk.execute(write_concern, session=None) pymongo-3.11.0/pymongo/change_stream.py000066400000000000000000000366001374256237000201440ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Watch changes on a collection, a database, or the entire cluster.""" import copy from bson import _bson_to_dict from bson.raw_bson import RawBSONDocument from pymongo import common from pymongo.aggregation import (_CollectionAggregationCommand, _DatabaseAggregationCommand) from pymongo.collation import validate_collation_or_none from pymongo.command_cursor import CommandCursor from pymongo.errors import (ConnectionFailure, CursorNotFound, InvalidOperation, OperationFailure, PyMongoError) # The change streams spec considers the following server errors from the # getMore command non-resumable. All other getMore errors are resumable. _RESUMABLE_GETMORE_ERRORS = frozenset([ 6, # HostUnreachable 7, # HostNotFound 89, # NetworkTimeout 91, # ShutdownInProgress 189, # PrimarySteppedDown 262, # ExceededTimeLimit 9001, # SocketException 10107, # NotMaster 11600, # InterruptedAtShutdown 11602, # InterruptedDueToReplStateChange 13435, # NotMasterNoSlaveOk 13436, # NotMasterOrSecondary 63, # StaleShardVersion 150, # StaleEpoch 13388, # StaleConfig 234, # RetryChangeStream 133, # FailedToSatisfyReadPreference 216, # ElectionInProgress ]) class ChangeStream(object): """The internal abstract base class for change stream cursors. Should not be called directly by application developers. Use :meth:`pymongo.collection.Collection.watch`, :meth:`pymongo.database.Database.watch`, or :meth:`pymongo.mongo_client.MongoClient.watch` instead. .. versionadded:: 3.6 .. mongodoc:: changeStreams """ def __init__(self, target, pipeline, full_document, resume_after, max_await_time_ms, batch_size, collation, start_at_operation_time, session, start_after): if pipeline is None: pipeline = [] elif not isinstance(pipeline, list): raise TypeError("pipeline must be a list") common.validate_string_or_none('full_document', full_document) validate_collation_or_none(collation) common.validate_non_negative_integer_or_none("batchSize", batch_size) self._decode_custom = False self._orig_codec_options = target.codec_options if target.codec_options.type_registry._decoder_map: self._decode_custom = True # Keep the type registry so that we support encoding custom types # in the pipeline. self._target = target.with_options( codec_options=target.codec_options.with_options( document_class=RawBSONDocument)) else: self._target = target self._pipeline = copy.deepcopy(pipeline) self._full_document = full_document self._uses_start_after = start_after is not None self._uses_resume_after = resume_after is not None self._resume_token = copy.deepcopy(start_after or resume_after) self._max_await_time_ms = max_await_time_ms self._batch_size = batch_size self._collation = collation self._start_at_operation_time = start_at_operation_time self._session = session # Initialize cursor. self._cursor = self._create_cursor() @property def _aggregation_command_class(self): """The aggregation command class to be used.""" raise NotImplementedError @property def _client(self): """The client against which the aggregation commands for this ChangeStream will be run. """ raise NotImplementedError def _change_stream_options(self): """Return the options dict for the $changeStream pipeline stage.""" options = {} if self._full_document is not None: options['fullDocument'] = self._full_document resume_token = self.resume_token if resume_token is not None: if self._uses_start_after: options['startAfter'] = resume_token else: options['resumeAfter'] = resume_token if self._start_at_operation_time is not None: options['startAtOperationTime'] = self._start_at_operation_time return options def _command_options(self): """Return the options dict for the aggregation command.""" options = {} if self._max_await_time_ms is not None: options["maxAwaitTimeMS"] = self._max_await_time_ms if self._batch_size is not None: options["batchSize"] = self._batch_size return options def _aggregation_pipeline(self): """Return the full aggregation pipeline for this ChangeStream.""" options = self._change_stream_options() full_pipeline = [{'$changeStream': options}] full_pipeline.extend(self._pipeline) return full_pipeline def _process_result(self, result, session, server, sock_info, slave_ok): """Callback that caches the postBatchResumeToken or startAtOperationTime from a changeStream aggregate command response containing an empty batch of change documents. This is implemented as a callback because we need access to the wire version in order to determine whether to cache this value. """ if not result['cursor']['firstBatch']: if 'postBatchResumeToken' in result['cursor']: self._resume_token = result['cursor']['postBatchResumeToken'] elif (self._start_at_operation_time is None and self._uses_resume_after is False and self._uses_start_after is False and sock_info.max_wire_version >= 7): self._start_at_operation_time = result.get("operationTime") # PYTHON-2181: informative error on missing operationTime. if self._start_at_operation_time is None: raise OperationFailure( "Expected field 'operationTime' missing from command " "response : %r" % (result, )) def _run_aggregation_cmd(self, session, explicit_session): """Run the full aggregation pipeline for this ChangeStream and return the corresponding CommandCursor. """ cmd = self._aggregation_command_class( self._target, CommandCursor, self._aggregation_pipeline(), self._command_options(), explicit_session, result_processor=self._process_result) return self._client._retryable_read( cmd.get_cursor, self._target._read_preference_for(session), session) def _create_cursor(self): with self._client._tmp_session(self._session, close=False) as s: return self._run_aggregation_cmd( session=s, explicit_session=self._session is not None) def _resume(self): """Reestablish this change stream after a resumable error.""" try: self._cursor.close() except PyMongoError: pass self._cursor = self._create_cursor() def close(self): """Close this ChangeStream.""" self._cursor.close() def __iter__(self): return self @property def resume_token(self): """The cached resume token that will be used to resume after the most recently returned change. .. versionadded:: 3.9 """ return copy.deepcopy(self._resume_token) def next(self): """Advance the cursor. This method blocks until the next change document is returned or an unrecoverable error is raised. This method is used when iterating over all changes in the cursor. For example:: try: resume_token = None pipeline = [{'$match': {'operationType': 'insert'}}] with db.collection.watch(pipeline) as stream: for insert_change in stream: print(insert_change) resume_token = stream.resume_token except pymongo.errors.PyMongoError: # The ChangeStream encountered an unrecoverable error or the # resume attempt failed to recreate the cursor. if resume_token is None: # There is no usable resume token because there was a # failure during ChangeStream initialization. logging.error('...') else: # Use the interrupted ChangeStream's resume token to create # a new ChangeStream. The new stream will continue from the # last seen insert change without missing any events. with db.collection.watch( pipeline, resume_after=resume_token) as stream: for insert_change in stream: print(insert_change) Raises :exc:`StopIteration` if this ChangeStream is closed. """ while self.alive: doc = self.try_next() if doc is not None: return doc raise StopIteration __next__ = next @property def alive(self): """Does this cursor have the potential to return more data? .. note:: Even if :attr:`alive` is ``True``, :meth:`next` can raise :exc:`StopIteration` and :meth:`try_next` can return ``None``. .. versionadded:: 3.8 """ return self._cursor.alive def try_next(self): """Advance the cursor without blocking indefinitely. This method returns the next change document without waiting indefinitely for the next change. For example:: with db.collection.watch() as stream: while stream.alive: change = stream.try_next() # Note that the ChangeStream's resume token may be updated # even when no changes are returned. print("Current resume token: %r" % (stream.resume_token,)) if change is not None: print("Change document: %r" % (change,)) continue # We end up here when there are no recent changes. # Sleep for a while before trying again to avoid flooding # the server with getMore requests when no changes are # available. time.sleep(10) If no change document is cached locally then this method runs a single getMore command. If the getMore yields any documents, the next document is returned, otherwise, if the getMore returns no documents (because there have been no changes) then ``None`` is returned. :Returns: The next change document or ``None`` when no document is available after running a single getMore or when the cursor is closed. .. versionadded:: 3.8 """ # Attempt to get the next change with at most one getMore and at most # one resume attempt. try: change = self._cursor._try_next(True) except (ConnectionFailure, CursorNotFound): self._resume() change = self._cursor._try_next(False) except OperationFailure as exc: if exc._max_wire_version is None: raise is_resumable = ((exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError")) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS)) if not is_resumable: raise self._resume() change = self._cursor._try_next(False) # If no changes are available. if change is None: # We have either iterated over all documents in the cursor, # OR the most-recently returned batch is empty. In either case, # update the cached resume token with the postBatchResumeToken if # one was returned. We also clear the startAtOperationTime. if self._cursor._post_batch_resume_token is not None: self._resume_token = self._cursor._post_batch_resume_token self._start_at_operation_time = None return change # Else, changes are available. try: resume_token = change['_id'] except KeyError: self.close() raise InvalidOperation( "Cannot provide resume functionality when the resume " "token is missing.") # If this is the last change document from the current batch, cache the # postBatchResumeToken. if (not self._cursor._has_next() and self._cursor._post_batch_resume_token): resume_token = self._cursor._post_batch_resume_token # Hereafter, don't use startAfter; instead use resumeAfter. self._uses_start_after = False self._uses_resume_after = True # Cache the resume token and clear startAtOperationTime. self._resume_token = resume_token self._start_at_operation_time = None if self._decode_custom: return _bson_to_dict(change.raw, self._orig_codec_options) return change def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() class CollectionChangeStream(ChangeStream): """A change stream that watches changes on a single collection. Should not be called directly by application developers. Use helper method :meth:`pymongo.collection.Collection.watch` instead. .. versionadded:: 3.7 """ @property def _aggregation_command_class(self): return _CollectionAggregationCommand @property def _client(self): return self._target.database.client class DatabaseChangeStream(ChangeStream): """A change stream that watches changes on all collections in a database. Should not be called directly by application developers. Use helper method :meth:`pymongo.database.Database.watch` instead. .. versionadded:: 3.7 """ @property def _aggregation_command_class(self): return _DatabaseAggregationCommand @property def _client(self): return self._target.client class ClusterChangeStream(DatabaseChangeStream): """A change stream that watches changes on all collections in the cluster. Should not be called directly by application developers. Use helper method :meth:`pymongo.mongo_client.MongoClient.watch` instead. .. versionadded:: 3.7 """ def _change_stream_options(self): options = super(ClusterChangeStream, self)._change_stream_options() options["allChangesForCluster"] = True return options pymongo-3.11.0/pymongo/client_options.py000066400000000000000000000225271374256237000204000ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Tools to parse mongo client options.""" from bson.codec_options import _parse_codec_options from pymongo.auth import _build_credentials_tuple from pymongo.common import validate_boolean from pymongo import common from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListeners from pymongo.pool import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.read_preferences import (make_read_preference, read_pref_mode_from_name) from pymongo.server_selectors import any_server_selector from pymongo.ssl_support import get_ssl_context from pymongo.write_concern import WriteConcern def _parse_credentials(username, password, database, options): """Parse authentication credentials.""" mechanism = options.get('authmechanism', 'DEFAULT' if username else None) source = options.get('authsource') if username or mechanism: return _build_credentials_tuple( mechanism, source, username, password, options, database) return None def _parse_read_preference(options): """Parse read preference options.""" if 'read_preference' in options: return options['read_preference'] name = options.get('readpreference', 'primary') mode = read_pref_mode_from_name(name) tags = options.get('readpreferencetags') max_staleness = options.get('maxstalenessseconds', -1) return make_read_preference(mode, tags, max_staleness) def _parse_write_concern(options): """Parse write concern options.""" concern = options.get('w') wtimeout = options.get('wtimeoutms') j = options.get('journal') fsync = options.get('fsync') return WriteConcern(concern, wtimeout, j, fsync) def _parse_read_concern(options): """Parse read concern options.""" concern = options.get('readconcernlevel') return ReadConcern(concern) def _parse_ssl_options(options): """Parse ssl options.""" use_ssl = options.get('ssl') if use_ssl is not None: validate_boolean('ssl', use_ssl) certfile = options.get('ssl_certfile') keyfile = options.get('ssl_keyfile') passphrase = options.get('ssl_pem_passphrase') ca_certs = options.get('ssl_ca_certs') cert_reqs = options.get('ssl_cert_reqs') match_hostname = options.get('ssl_match_hostname', True) crlfile = options.get('ssl_crlfile') check_ocsp_endpoint = options.get('ssl_check_ocsp_endpoint', True) ssl_kwarg_keys = [k for k in options if k.startswith('ssl_') and options[k]] if use_ssl is False and ssl_kwarg_keys: raise ConfigurationError("ssl has not been enabled but the " "following ssl parameters have been set: " "%s. Please set `ssl=True` or remove." % ', '.join(ssl_kwarg_keys)) if ssl_kwarg_keys and use_ssl is None: # ssl options imply ssl = True use_ssl = True if use_ssl is True: ctx = get_ssl_context( certfile, keyfile, passphrase, ca_certs, cert_reqs, crlfile, match_hostname, check_ocsp_endpoint) return ctx, match_hostname return None, match_hostname def _parse_pool_options(options): """Parse connection pool options.""" max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE) min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE) max_idle_time_seconds = options.get( 'maxidletimems', common.MAX_IDLE_TIME_SEC) if max_pool_size is not None and min_pool_size > max_pool_size: raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") connect_timeout = options.get('connecttimeoutms', common.CONNECT_TIMEOUT) socket_keepalive = options.get('socketkeepalive', True) socket_timeout = options.get('sockettimeoutms') wait_queue_timeout = options.get( 'waitqueuetimeoutms', common.WAIT_QUEUE_TIMEOUT) wait_queue_multiple = options.get('waitqueuemultiple') event_listeners = options.get('event_listeners') appname = options.get('appname') driver = options.get('driver') compression_settings = CompressionSettings( options.get('compressors', []), options.get('zlibcompressionlevel', -1)) ssl_context, ssl_match_hostname = _parse_ssl_options(options) return PoolOptions(max_pool_size, min_pool_size, max_idle_time_seconds, connect_timeout, socket_timeout, wait_queue_timeout, wait_queue_multiple, ssl_context, ssl_match_hostname, socket_keepalive, _EventListeners(event_listeners), appname, driver, compression_settings) class ClientOptions(object): """ClientOptions""" def __init__(self, username, password, database, options): self.__options = options self.__codec_options = _parse_codec_options(options) self.__credentials = _parse_credentials( username, password, database, options) self.__direct_connection = options.get('directconnection') self.__local_threshold_ms = options.get( 'localthresholdms', common.LOCAL_THRESHOLD_MS) # self.__server_selection_timeout is in seconds. Must use full name for # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. self.__server_selection_timeout = options.get( 'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT) self.__pool_options = _parse_pool_options(options) self.__read_preference = _parse_read_preference(options) self.__replica_set_name = options.get('replicaset') self.__write_concern = _parse_write_concern(options) self.__read_concern = _parse_read_concern(options) self.__connect = options.get('connect') self.__heartbeat_frequency = options.get( 'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY) self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES) self.__retry_reads = options.get('retryreads', common.RETRY_READS) self.__server_selector = options.get( 'server_selector', any_server_selector) self.__auto_encryption_opts = options.get('auto_encryption_opts') @property def _options(self): """The original options used to create this ClientOptions.""" return self.__options @property def connect(self): """Whether to begin discovering a MongoDB topology automatically.""" return self.__connect @property def codec_options(self): """A :class:`~bson.codec_options.CodecOptions` instance.""" return self.__codec_options @property def credentials(self): """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" return self.__credentials @property def direct_connection(self): """Whether to connect to the deployment in 'Single' topology.""" return self.__direct_connection @property def local_threshold_ms(self): """The local threshold for this instance.""" return self.__local_threshold_ms @property def server_selection_timeout(self): """The server selection timeout for this instance in seconds.""" return self.__server_selection_timeout @property def server_selector(self): return self.__server_selector @property def heartbeat_frequency(self): """The monitoring frequency in seconds.""" return self.__heartbeat_frequency @property def pool_options(self): """A :class:`~pymongo.pool.PoolOptions` instance.""" return self.__pool_options @property def read_preference(self): """A read preference instance.""" return self.__read_preference @property def replica_set_name(self): """Replica set name or None.""" return self.__replica_set_name @property def write_concern(self): """A :class:`~pymongo.write_concern.WriteConcern` instance.""" return self.__write_concern @property def read_concern(self): """A :class:`~pymongo.read_concern.ReadConcern` instance.""" return self.__read_concern @property def retry_writes(self): """If this instance should retry supported write operations.""" return self.__retry_writes @property def retry_reads(self): """If this instance should retry supported read operations.""" return self.__retry_reads @property def auto_encryption_opts(self): """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" return self.__auto_encryption_opts pymongo-3.11.0/pymongo/client_session.py000066400000000000000000001054241374256237000203660ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Logical sessions for ordering sequential operations. Requires MongoDB 3.6. .. versionadded:: 3.6 Causally Consistent Reads ========================= .. code-block:: python with client.start_session(causal_consistency=True) as session: collection = client.db.collection collection.update_one({'_id': 1}, {'$set': {'x': 10}}, session=session) secondary_c = collection.with_options( read_preference=ReadPreference.SECONDARY) # A secondary read waits for replication of the write. secondary_c.find_one({'_id': 1}, session=session) If `causal_consistency` is True (the default), read operations that use the session are causally after previous read and write operations. Using a causally consistent session, an application can read its own writes and is guaranteed monotonic reads, even when reading from replica set secondaries. .. mongodoc:: causal-consistency .. _transactions-ref: Transactions ============ MongoDB 4.0 adds support for transactions on replica set primaries. A transaction is associated with a :class:`ClientSession`. To start a transaction on a session, use :meth:`ClientSession.start_transaction` in a with-statement. Then, execute an operation within the transaction by passing the session to the operation: .. code-block:: python orders = client.db.orders inventory = client.db.inventory with client.start_session() as session: with session.start_transaction(): orders.insert_one({"sku": "abc123", "qty": 100}, session=session) inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}}, {"$inc": {"qty": -100}}, session=session) Upon normal completion of ``with session.start_transaction()`` block, the transaction automatically calls :meth:`ClientSession.commit_transaction`. If the block exits with an exception, the transaction automatically calls :meth:`ClientSession.abort_transaction`. In general, multi-document transactions only support read/write (CRUD) operations on existing collections. However, MongoDB 4.4 adds support for creating collections and indexes with some limitations, including an insert operation that would result in the creation of a new collection. For a complete description of all the supported and unsupported operations see the `MongoDB server's documentation for transactions `_. A session may only have a single active transaction at a time, multiple transactions on the same session can be executed in sequence. .. versionadded:: 3.7 Sharded Transactions ^^^^^^^^^^^^^^^^^^^^ PyMongo 3.9 adds support for transactions on sharded clusters running MongoDB 4.2. Sharded transactions have the same API as replica set transactions. When running a transaction against a sharded cluster, the session is pinned to the mongos server selected for the first operation in the transaction. All subsequent operations that are part of the same transaction are routed to the same mongos server. When the transaction is completed, by running either commitTransaction or abortTransaction, the session is unpinned. .. versionadded:: 3.9 .. mongodoc:: transactions Classes ======= """ import collections import uuid from bson.binary import Binary from bson.int64 import Int64 from bson.py3compat import abc, integer_types from bson.son import SON from bson.timestamp import Timestamp from pymongo import monotonic from pymongo.errors import (ConfigurationError, ConnectionFailure, InvalidOperation, OperationFailure, PyMongoError, WTimeoutError) from pymongo.helpers import _RETRYABLE_ERROR_CODES from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.write_concern import WriteConcern class SessionOptions(object): """Options for a new :class:`ClientSession`. :Parameters: - `causal_consistency` (optional): If True (the default), read operations are causally ordered within the session. - `default_transaction_options` (optional): The default TransactionOptions to use for transactions started on this session. """ def __init__(self, causal_consistency=True, default_transaction_options=None): self._causal_consistency = causal_consistency if default_transaction_options is not None: if not isinstance(default_transaction_options, TransactionOptions): raise TypeError( "default_transaction_options must be an instance of " "pymongo.client_session.TransactionOptions, not: %r" % (default_transaction_options,)) self._default_transaction_options = default_transaction_options @property def causal_consistency(self): """Whether causal consistency is configured.""" return self._causal_consistency @property def default_transaction_options(self): """The default TransactionOptions to use for transactions started on this session. .. versionadded:: 3.7 """ return self._default_transaction_options class TransactionOptions(object): """Options for :meth:`ClientSession.start_transaction`. :Parameters: - `read_concern` (optional): The :class:`~pymongo.read_concern.ReadConcern` to use for this transaction. If ``None`` (the default) the :attr:`read_preference` of the :class:`MongoClient` is used. - `write_concern` (optional): The :class:`~pymongo.write_concern.WriteConcern` to use for this transaction. If ``None`` (the default) the :attr:`read_preference` of the :class:`MongoClient` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` for options. Transactions which read must use :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - `max_commit_time_ms` (optional): The maximum amount of time to allow a single commitTransaction command to run. This option is an alias for maxTimeMS option on the commitTransaction command. If ``None`` (the default) maxTimeMS is not used. .. versionchanged:: 3.9 Added the ``max_commit_time_ms`` option. .. versionadded:: 3.7 """ def __init__(self, read_concern=None, write_concern=None, read_preference=None, max_commit_time_ms=None): self._read_concern = read_concern self._write_concern = write_concern self._read_preference = read_preference self._max_commit_time_ms = max_commit_time_ms if read_concern is not None: if not isinstance(read_concern, ReadConcern): raise TypeError("read_concern must be an instance of " "pymongo.read_concern.ReadConcern, not: %r" % (read_concern,)) if write_concern is not None: if not isinstance(write_concern, WriteConcern): raise TypeError("write_concern must be an instance of " "pymongo.write_concern.WriteConcern, not: %r" % (write_concern,)) if not write_concern.acknowledged: raise ConfigurationError( "transactions do not support unacknowledged write concern" ": %r" % (write_concern,)) if read_preference is not None: if not isinstance(read_preference, _ServerMode): raise TypeError("%r is not valid for read_preference. See " "pymongo.read_preferences for valid " "options." % (read_preference,)) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, integer_types): raise TypeError( "max_commit_time_ms must be an integer or None") @property def read_concern(self): """This transaction's :class:`~pymongo.read_concern.ReadConcern`.""" return self._read_concern @property def write_concern(self): """This transaction's :class:`~pymongo.write_concern.WriteConcern`.""" return self._write_concern @property def read_preference(self): """This transaction's :class:`~pymongo.read_preferences.ReadPreference`. """ return self._read_preference @property def max_commit_time_ms(self): """The maxTimeMS to use when running a commitTransaction command. .. versionadded:: 3.9 """ return self._max_commit_time_ms def _validate_session_write_concern(session, write_concern): """Validate that an explicit session is not used with an unack'ed write. Returns the session to use for the next operation. """ if session: if write_concern is not None and not write_concern.acknowledged: # For unacknowledged writes without an explicit session, # drivers SHOULD NOT use an implicit session. If a driver # creates an implicit session for unacknowledged writes # without an explicit session, the driver MUST NOT send the # session ID. if session._implicit: return None else: raise ConfigurationError( 'Explicit sessions are incompatible with ' 'unacknowledged write concern: %r' % ( write_concern,)) return session class _TransactionContext(object): """Internal transaction context manager for start_transaction.""" def __init__(self, session): self.__session = session def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if self.__session.in_transaction: if exc_val is None: self.__session.commit_transaction() else: self.__session.abort_transaction() class _TxnState(object): NONE = 1 STARTING = 2 IN_PROGRESS = 3 COMMITTED = 4 COMMITTED_EMPTY = 5 ABORTED = 6 class _Transaction(object): """Internal class to hold transaction information in a ClientSession.""" def __init__(self, opts): self.opts = opts self.state = _TxnState.NONE self.sharded = False self.pinned_address = None self.recovery_token = None self.attempt = 0 def active(self): return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS) def reset(self): self.state = _TxnState.NONE self.sharded = False self.pinned_address = None self.recovery_token = None self.attempt = 0 def _reraise_with_unknown_commit(exc): """Re-raise an exception with the UnknownTransactionCommitResult label.""" exc._add_error_label("UnknownTransactionCommitResult") raise def _max_time_expired_error(exc): """Return true if exc is a MaxTimeMSExpired error.""" return isinstance(exc, OperationFailure) and exc.code == 50 # From the transactions spec, all the retryable writes errors plus # WriteConcernFailed. _UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset([ 64, # WriteConcernFailed 50, # MaxTimeMSExpired ]) # From the Convenient API for Transactions spec, with_transaction must # halt retries after 120 seconds. # This limit is non-configurable and was chosen to be twice the 60 second # default value of MongoDB's `transactionLifetimeLimitSeconds` parameter. _WITH_TRANSACTION_RETRY_TIME_LIMIT = 120 def _within_time_limit(start_time): """Are we within the with_transaction retry limit?""" return monotonic.time() - start_time < _WITH_TRANSACTION_RETRY_TIME_LIMIT class ClientSession(object): """A session for ordering sequential operations. :class:`ClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread or process at a time. A single :class:`ClientSession` cannot be used to run multiple operations concurrently. Should not be initialized directly by application developers - to create a :class:`ClientSession`, call :meth:`~pymongo.mongo_client.MongoClient.start_session`. """ def __init__(self, client, server_session, options, authset, implicit): # A MongoClient, a _ServerSession, a SessionOptions, and a set. self._client = client self._server_session = server_session self._options = options self._authset = authset self._cluster_time = None self._operation_time = None # Is this an implicitly created session? self._implicit = implicit self._transaction = _Transaction(None) def end_session(self): """Finish this session. If a transaction has started, abort it. It is an error to use the session after the session has ended. """ self._end_session(lock=True) def _end_session(self, lock): if self._server_session is not None: try: if self.in_transaction: self.abort_transaction() finally: self._client._return_server_session(self._server_session, lock) self._server_session = None def _check_ended(self): if self._server_session is None: raise InvalidOperation("Cannot use ended session") def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self._end_session(lock=True) @property def client(self): """The :class:`~pymongo.mongo_client.MongoClient` this session was created from. """ return self._client @property def options(self): """The :class:`SessionOptions` this session was created with.""" return self._options @property def session_id(self): """A BSON document, the opaque server session identifier.""" self._check_ended() return self._server_session.session_id @property def cluster_time(self): """The cluster time returned by the last operation executed in this session. """ return self._cluster_time @property def operation_time(self): """The operation time returned by the last operation executed in this session. """ return self._operation_time def _inherit_option(self, name, val): """Return the inherited TransactionOption value.""" if val: return val txn_opts = self.options.default_transaction_options val = txn_opts and getattr(txn_opts, name) if val: return val return getattr(self.client, name) def with_transaction(self, callback, read_concern=None, write_concern=None, read_preference=None, max_commit_time_ms=None): """Execute a callback in a transaction. This method starts a transaction on this session, executes ``callback`` once, and then commits the transaction. For example:: def callback(session): orders = session.client.db.orders inventory = session.client.db.inventory orders.insert_one({"sku": "abc123", "qty": 100}, session=session) inventory.update_one({"sku": "abc123", "qty": {"$gte": 100}}, {"$inc": {"qty": -100}}, session=session) with client.start_session() as session: session.with_transaction(callback) To pass arbitrary arguments to the ``callback``, wrap your callable with a ``lambda`` like this:: def callback(session, custom_arg, custom_kwarg=None): # Transaction operations... with client.start_session() as session: session.with_transaction( lambda s: callback(s, "custom_arg", custom_kwarg=1)) In the event of an exception, ``with_transaction`` may retry the commit or the entire transaction, therefore ``callback`` may be invoked multiple times by a single call to ``with_transaction``. Developers should be mindful of this possiblity when writing a ``callback`` that modifies application state or has any other side-effects. Note that even when the ``callback`` is invoked multiple times, ``with_transaction`` ensures that the transaction will be committed at-most-once on the server. The ``callback`` should not attempt to start new transactions, but should simply run operations meant to be contained within a transaction. The ``callback`` should also not commit the transaction; this is handled automatically by ``with_transaction``. If the ``callback`` does commit or abort the transaction without error, however, ``with_transaction`` will return without taking further action. :class:`ClientSession` instances are **not thread-safe or fork-safe**. Consequently, the ``callback`` must not attempt to execute multiple operations concurrently. When ``callback`` raises an exception, ``with_transaction`` automatically aborts the current transaction. When ``callback`` or :meth:`~ClientSession.commit_transaction` raises an exception that includes the ``"TransientTransactionError"`` error label, ``with_transaction`` starts a new transaction and re-executes the ``callback``. When :meth:`~ClientSession.commit_transaction` raises an exception with the ``"UnknownTransactionCommitResult"`` error label, ``with_transaction`` retries the commit until the result of the transaction is known. This method will cease retrying after 120 seconds has elapsed. This timeout is not configurable and any exception raised by the ``callback`` or by :meth:`ClientSession.commit_transaction` after the timeout is reached will be re-raised. Applications that desire a different timeout duration should not use this method. :Parameters: - `callback`: The callable ``callback`` to run inside a transaction. The callable must accept a single argument, this session. Note, under certain error conditions the callback may be run multiple times. - `read_concern` (optional): The :class:`~pymongo.read_concern.ReadConcern` to use for this transaction. - `write_concern` (optional): The :class:`~pymongo.write_concern.WriteConcern` to use for this transaction. - `read_preference` (optional): The read preference to use for this transaction. If ``None`` (the default) the :attr:`read_preference` of this :class:`Database` is used. See :mod:`~pymongo.read_preferences` for options. :Returns: The return value of the ``callback``. .. versionadded:: 3.9 """ start_time = monotonic.time() while True: self.start_transaction( read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) except Exception as exc: if self.in_transaction: self.abort_transaction() if (isinstance(exc, PyMongoError) and exc.has_error_label("TransientTransactionError") and _within_time_limit(start_time)): # Retry the entire transaction. continue raise if not self.in_transaction: # Assume callback intentionally ended the transaction. return ret while True: try: self.commit_transaction() except PyMongoError as exc: if (exc.has_error_label("UnknownTransactionCommitResult") and _within_time_limit(start_time) and not _max_time_expired_error(exc)): # Retry the commit. continue if (exc.has_error_label("TransientTransactionError") and _within_time_limit(start_time)): # Retry the entire transaction. break raise # Commit succeeded. return ret def start_transaction(self, read_concern=None, write_concern=None, read_preference=None, max_commit_time_ms=None): """Start a multi-statement transaction. Takes the same arguments as :class:`TransactionOptions`. .. versionchanged:: 3.9 Added the ``max_commit_time_ms`` option. .. versionadded:: 3.7 """ self._check_ended() if self.in_transaction: raise InvalidOperation("Transaction already in progress") read_concern = self._inherit_option("read_concern", read_concern) write_concern = self._inherit_option("write_concern", write_concern) read_preference = self._inherit_option( "read_preference", read_preference) if max_commit_time_ms is None: opts = self.options.default_transaction_options if opts: max_commit_time_ms = opts.max_commit_time_ms self._transaction.opts = TransactionOptions( read_concern, write_concern, read_preference, max_commit_time_ms) self._transaction.reset() self._transaction.state = _TxnState.STARTING self._start_retryable_write() return _TransactionContext(self) def commit_transaction(self): """Commit a multi-statement transaction. .. versionadded:: 3.7 """ self._check_ended() state = self._transaction.state if state is _TxnState.NONE: raise InvalidOperation("No transaction started") elif state in (_TxnState.STARTING, _TxnState.COMMITTED_EMPTY): # Server transaction was never started, no need to send a command. self._transaction.state = _TxnState.COMMITTED_EMPTY return elif state is _TxnState.ABORTED: raise InvalidOperation( "Cannot call commitTransaction after calling abortTransaction") elif state is _TxnState.COMMITTED: # We're explicitly retrying the commit, move the state back to # "in progress" so that in_transaction returns true. self._transaction.state = _TxnState.IN_PROGRESS try: self._finish_transaction_with_retry("commitTransaction") except ConnectionFailure as exc: # We do not know if the commit was successfully applied on the # server or if it satisfied the provided write concern, set the # unknown commit error label. exc._remove_error_label("TransientTransactionError") _reraise_with_unknown_commit(exc) except WTimeoutError as exc: # We do not know if the commit has satisfied the provided write # concern, add the unknown commit error label. _reraise_with_unknown_commit(exc) except OperationFailure as exc: if exc.code not in _UNKNOWN_COMMIT_ERROR_CODES: # The server reports errorLabels in the case. raise # We do not know if the commit was successfully applied on the # server or if it satisfied the provided write concern, set the # unknown commit error label. _reraise_with_unknown_commit(exc) finally: self._transaction.state = _TxnState.COMMITTED def abort_transaction(self): """Abort a multi-statement transaction. .. versionadded:: 3.7 """ self._check_ended() state = self._transaction.state if state is _TxnState.NONE: raise InvalidOperation("No transaction started") elif state is _TxnState.STARTING: # Server transaction was never started, no need to send a command. self._transaction.state = _TxnState.ABORTED return elif state is _TxnState.ABORTED: raise InvalidOperation("Cannot call abortTransaction twice") elif state in (_TxnState.COMMITTED, _TxnState.COMMITTED_EMPTY): raise InvalidOperation( "Cannot call abortTransaction after calling commitTransaction") try: self._finish_transaction_with_retry("abortTransaction") except (OperationFailure, ConnectionFailure): # The transactions spec says to ignore abortTransaction errors. pass finally: self._transaction.state = _TxnState.ABORTED def _finish_transaction_with_retry(self, command_name): """Run commit or abort with one retry after any retryable error. :Parameters: - `command_name`: Either "commitTransaction" or "abortTransaction". """ def func(session, sock_info, retryable): return self._finish_transaction(sock_info, command_name) return self._client._retry_internal(True, func, self, None) def _finish_transaction(self, sock_info, command_name): self._transaction.attempt += 1 opts = self._transaction.opts wc = opts.write_concern cmd = SON([(command_name, 1)]) if command_name == "commitTransaction": if opts.max_commit_time_ms: cmd['maxTimeMS'] = opts.max_commit_time_ms # Transaction spec says that after the initial commit attempt, # subsequent commitTransaction commands should be upgraded to use # w:"majority" and set a default value of 10 seconds for wtimeout. if self._transaction.attempt > 1: wc_doc = wc.document wc_doc["w"] = "majority" wc_doc.setdefault("wtimeout", 10000) wc = WriteConcern(**wc_doc) if self._transaction.recovery_token: cmd['recoveryToken'] = self._transaction.recovery_token return self._client.admin._command( sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True) def _advance_cluster_time(self, cluster_time): """Internal cluster time helper.""" if self._cluster_time is None: self._cluster_time = cluster_time elif cluster_time is not None: if cluster_time["clusterTime"] > self._cluster_time["clusterTime"]: self._cluster_time = cluster_time def advance_cluster_time(self, cluster_time): """Update the cluster time for this session. :Parameters: - `cluster_time`: The :data:`~pymongo.client_session.ClientSession.cluster_time` from another `ClientSession` instance. """ if not isinstance(cluster_time, abc.Mapping): raise TypeError( "cluster_time must be a subclass of collections.Mapping") if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) def _advance_operation_time(self, operation_time): """Internal operation time helper.""" if self._operation_time is None: self._operation_time = operation_time elif operation_time is not None: if operation_time > self._operation_time: self._operation_time = operation_time def advance_operation_time(self, operation_time): """Update the operation time for this session. :Parameters: - `operation_time`: The :data:`~pymongo.client_session.ClientSession.operation_time` from another `ClientSession` instance. """ if not isinstance(operation_time, Timestamp): raise TypeError("operation_time must be an instance " "of bson.timestamp.Timestamp") self._advance_operation_time(operation_time) def _process_response(self, reply): """Process a response to a command that was run with this session.""" self._advance_cluster_time(reply.get('$clusterTime')) self._advance_operation_time(reply.get('operationTime')) if self.in_transaction and self._transaction.sharded: recovery_token = reply.get('recoveryToken') if recovery_token: self._transaction.recovery_token = recovery_token @property def has_ended(self): """True if this session is finished.""" return self._server_session is None @property def in_transaction(self): """True if this session has an active multi-statement transaction. .. versionadded:: 3.10 """ return self._transaction.active() @property def _pinned_address(self): """The mongos address this transaction was created on.""" if self._transaction.active(): return self._transaction.pinned_address return None def _pin_mongos(self, server): """Pin this session to the given mongos Server.""" self._transaction.sharded = True self._transaction.pinned_address = server.description.address def _unpin_mongos(self): """Unpin this session from any pinned mongos address.""" self._transaction.pinned_address = None def _txn_read_preference(self): """Return read preference of this transaction or None.""" if self.in_transaction: return self._transaction.opts.read_preference return None def _apply_to(self, command, is_retryable, read_preference): self._check_ended() self._server_session.last_use = monotonic.time() command['lsid'] = self._server_session.session_id if not self.in_transaction: self._transaction.reset() if is_retryable: command['txnNumber'] = self._server_session.transaction_id return if self.in_transaction: if read_preference != ReadPreference.PRIMARY: raise InvalidOperation( 'read preference in a transaction must be primary, not: ' '%r' % (read_preference,)) if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. self._transaction.state = _TxnState.IN_PROGRESS command['startTransaction'] = True if self._transaction.opts.read_concern: rc = self._transaction.opts.read_concern.document else: rc = {} if (self.options.causal_consistency and self.operation_time is not None): rc['afterClusterTime'] = self.operation_time if rc: command['readConcern'] = rc command['txnNumber'] = self._server_session.transaction_id command['autocommit'] = False def _start_retryable_write(self): self._check_ended() self._server_session.inc_transaction_id() class _ServerSession(object): def __init__(self, generation): # Ensure id is type 4, regardless of CodecOptions.uuid_representation. self.session_id = {'id': Binary(uuid.uuid4().bytes, 4)} self.last_use = monotonic.time() self._transaction_id = 0 self.dirty = False self.generation = generation def mark_dirty(self): """Mark this session as dirty. A server session is marked dirty when a command fails with a network error. Dirty sessions are later discarded from the server session pool. """ self.dirty = True def timed_out(self, session_timeout_minutes): idle_seconds = monotonic.time() - self.last_use # Timed out if we have less than a minute to live. return idle_seconds > (session_timeout_minutes - 1) * 60 @property def transaction_id(self): """Positive 64-bit integer.""" return Int64(self._transaction_id) def inc_transaction_id(self): self._transaction_id += 1 class _ServerSessionPool(collections.deque): """Pool of _ServerSession objects. This class is not thread-safe, access it while holding the Topology lock. """ def __init__(self, *args, **kwargs): super(_ServerSessionPool, self).__init__(*args, **kwargs) self.generation = 0 def reset(self): self.generation += 1 self.clear() def pop_all(self): ids = [] while self: ids.append(self.pop().session_id) return ids def get_server_session(self, session_timeout_minutes): # Although the Driver Sessions Spec says we only clear stale sessions # in return_server_session, PyMongo can't take a lock when returning # sessions from a __del__ method (like in Cursor.__die), so it can't # clear stale sessions there. In case many sessions were returned via # __del__, check for stale sessions here too. self._clear_stale(session_timeout_minutes) # The most recently used sessions are on the left. while self: s = self.popleft() if not s.timed_out(session_timeout_minutes): return s return _ServerSession(self.generation) def return_server_session(self, server_session, session_timeout_minutes): self._clear_stale(session_timeout_minutes) if not server_session.timed_out(session_timeout_minutes): self.return_server_session_no_lock(server_session) def return_server_session_no_lock(self, server_session): # Discard sessions from an old pool to avoid duplicate sessions in the # child process after a fork. if (server_session.generation == self.generation and not server_session.dirty): self.appendleft(server_session) def _clear_stale(self, session_timeout_minutes): # Clear stale sessions. The least recently used are on the right. while self: if self[-1].timed_out(session_timeout_minutes): self.pop() else: # The remaining sessions also haven't timed out. break pymongo-3.11.0/pymongo/collation.py000066400000000000000000000172001374256237000173230ustar00rootroot00000000000000# Copyright 2016 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for working with `collations`_. .. _collations: http://userguide.icu-project.org/collation/concepts """ from pymongo import common class CollationStrength(object): """ An enum that defines values for `strength` on a :class:`~pymongo.collation.Collation`. """ PRIMARY = 1 """Differentiate base (unadorned) characters.""" SECONDARY = 2 """Differentiate character accents.""" TERTIARY = 3 """Differentiate character case.""" QUATERNARY = 4 """Differentiate words with and without punctuation.""" IDENTICAL = 5 """Differentiate unicode code point (characters are exactly identical).""" class CollationAlternate(object): """ An enum that defines values for `alternate` on a :class:`~pymongo.collation.Collation`. """ NON_IGNORABLE = 'non-ignorable' """Spaces and punctuation are treated as base characters.""" SHIFTED = 'shifted' """Spaces and punctuation are *not* considered base characters. Spaces and punctuation are distinguished regardless when the :class:`~pymongo.collation.Collation` strength is at least :data:`~pymongo.collation.CollationStrength.QUATERNARY`. """ class CollationMaxVariable(object): """ An enum that defines values for `max_variable` on a :class:`~pymongo.collation.Collation`. """ PUNCT = 'punct' """Both punctuation and spaces are ignored.""" SPACE = 'space' """Spaces alone are ignored.""" class CollationCaseFirst(object): """ An enum that defines values for `case_first` on a :class:`~pymongo.collation.Collation`. """ UPPER = 'upper' """Sort uppercase characters first.""" LOWER = 'lower' """Sort lowercase characters first.""" OFF = 'off' """Default for locale or collation strength.""" class Collation(object): """Collation :Parameters: - `locale`: (string) The locale of the collation. This should be a string that identifies an `ICU locale ID` exactly. For example, ``en_US`` is valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB documentation for a list of supported locales. - `caseLevel`: (optional) If ``True``, turn on case sensitivity if `strength` is 1 or 2 (case sensitivity is implied if `strength` is greater than 2). Defaults to ``False``. - `caseFirst`: (optional) Specify that either uppercase or lowercase characters take precedence. Must be one of the following values: * :data:`~CollationCaseFirst.UPPER` * :data:`~CollationCaseFirst.LOWER` * :data:`~CollationCaseFirst.OFF` (the default) - `strength`: (optional) Specify the comparison strength. This is also known as the ICU comparison level. This must be one of the following values: * :data:`~CollationStrength.PRIMARY` * :data:`~CollationStrength.SECONDARY` * :data:`~CollationStrength.TERTIARY` (the default) * :data:`~CollationStrength.QUATERNARY` * :data:`~CollationStrength.IDENTICAL` Each successive level builds upon the previous. For example, a `strength` of :data:`~CollationStrength.SECONDARY` differentiates characters based both on the unadorned base character and its accents. - `numericOrdering`: (optional) If ``True``, order numbers numerically instead of in collation order (defaults to ``False``). - `alternate`: (optional) Specify whether spaces and punctuation are considered base characters. This must be one of the following values: * :data:`~CollationAlternate.NON_IGNORABLE` (the default) * :data:`~CollationAlternate.SHIFTED` - `maxVariable`: (optional) When `alternate` is :data:`~CollationAlternate.SHIFTED`, this option specifies what characters may be ignored. This must be one of the following values: * :data:`~CollationMaxVariable.PUNCT` (the default) * :data:`~CollationMaxVariable.SPACE` - `normalization`: (optional) If ``True``, normalizes text into Unicode NFD. Defaults to ``False``. - `backwards`: (optional) If ``True``, accents on characters are considered from the back of the word to the front, as it is done in some French dictionary ordering traditions. Defaults to ``False``. - `kwargs`: (optional) Keyword arguments supplying any additional options to be sent with this Collation object. .. versionadded: 3.4 """ __slots__ = ("__document",) def __init__(self, locale, caseLevel=None, caseFirst=None, strength=None, numericOrdering=None, alternate=None, maxVariable=None, normalization=None, backwards=None, **kwargs): locale = common.validate_string('locale', locale) self.__document = {'locale': locale} if caseLevel is not None: self.__document['caseLevel'] = common.validate_boolean( 'caseLevel', caseLevel) if caseFirst is not None: self.__document['caseFirst'] = common.validate_string( 'caseFirst', caseFirst) if strength is not None: self.__document['strength'] = common.validate_integer( 'strength', strength) if numericOrdering is not None: self.__document['numericOrdering'] = common.validate_boolean( 'numericOrdering', numericOrdering) if alternate is not None: self.__document['alternate'] = common.validate_string( 'alternate', alternate) if maxVariable is not None: self.__document['maxVariable'] = common.validate_string( 'maxVariable', maxVariable) if normalization is not None: self.__document['normalization'] = common.validate_boolean( 'normalization', normalization) if backwards is not None: self.__document['backwards'] = common.validate_boolean( 'backwards', backwards) self.__document.update(kwargs) @property def document(self): """The document representation of this collation. .. note:: :class:`Collation` is immutable. Mutating the value of :attr:`document` does not mutate this :class:`Collation`. """ return self.__document.copy() def __repr__(self): document = self.document return 'Collation(%s)' % ( ', '.join('%s=%r' % (key, document[key]) for key in document),) def __eq__(self, other): if isinstance(other, Collation): return self.document == other.document return NotImplemented def __ne__(self, other): return not self == other def validate_collation_or_none(value): if value is None: return None if isinstance(value, Collation): return value.document if isinstance(value, dict): return value raise TypeError( 'collation must be a dict, an instance of collation.Collation, ' 'or None.') pymongo-3.11.0/pymongo/collection.py000066400000000000000000004465401374256237000175070ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Collection level utilities for Mongo.""" import datetime import warnings from bson.code import Code from bson.objectid import ObjectId from bson.py3compat import (_unicode, abc, integer_types, string_type) from bson.raw_bson import RawBSONDocument from bson.codec_options import CodecOptions from bson.son import SON from pymongo import (common, helpers, message) from pymongo.aggregation import (_CollectionAggregationCommand, _CollectionRawAggregationCommand) from pymongo.bulk import BulkOperationBuilder, _Bulk from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor from pymongo.common import ORDERED_TYPES from pymongo.collation import validate_collation_or_none from pymongo.change_stream import CollectionChangeStream from pymongo.cursor import Cursor, RawBatchCursor from pymongo.errors import (BulkWriteError, ConfigurationError, InvalidName, InvalidOperation, OperationFailure) from pymongo.helpers import (_check_write_command_response, _raise_last_error) from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.operations import IndexModel from pymongo.read_preferences import ReadPreference from pymongo.results import (BulkWriteResult, DeleteResult, InsertOneResult, InsertManyResult, UpdateResult) from pymongo.write_concern import WriteConcern _UJOIN = u"%s.%s" _FIND_AND_MODIFY_DOC_FIELDS = {'value': 1} _HAYSTACK_MSG = ( "geoHaystack indexes are deprecated as of MongoDB 4.4." " Instead, create a 2d index and use $geoNear or $geoWithin." " See https://dochub.mongodb.org/core/4.4-deprecate-geoHaystack") class ReturnDocument(object): """An enum used with :meth:`~pymongo.collection.Collection.find_one_and_replace` and :meth:`~pymongo.collection.Collection.find_one_and_update`. """ BEFORE = False """Return the original document before it was updated/replaced, or ``None`` if no document matches the query. """ AFTER = True """Return the updated/replaced or inserted document.""" class Collection(common.BaseObject): """A Mongo collection. """ def __init__(self, database, name, create=False, codec_options=None, read_preference=None, write_concern=None, read_concern=None, session=None, **kwargs): """Get / create a Mongo collection. Raises :class:`TypeError` if `name` is not an instance of :class:`basestring` (:class:`str` in python 3). Raises :class:`~pymongo.errors.InvalidName` if `name` is not a valid collection name. Any additional keyword arguments will be used as options passed to the create command. See :meth:`~pymongo.database.Database.create_collection` for valid options. If `create` is ``True``, `collation` is specified, or any additional keyword arguments are present, a ``create`` command will be sent, using ``session`` if specified. Otherwise, a ``create`` command will not be sent and the collection will be created implicitly on first use. The optional ``session`` argument is *only* used for the ``create`` command, it is not associated with the collection afterward. :Parameters: - `database`: the database to get a collection from - `name`: the name of the collection to get - `create` (optional): if ``True``, force collection creation even without options being set - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) database.codec_options is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) database.read_preference is used. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) database.write_concern is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) database.read_concern is used. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. If a collation is provided, it will be passed to the create collection command. This option is only supported on MongoDB 3.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession` that is used with the create collection command - `**kwargs` (optional): additional keyword arguments will be passed as options for the create collection command .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Support the `collation` option. .. versionchanged:: 3.2 Added the read_concern option. .. versionchanged:: 3.0 Added the codec_options, read_preference, and write_concern options. Removed the uuid_subtype attribute. :class:`~pymongo.collection.Collection` no longer returns an instance of :class:`~pymongo.collection.Collection` for attribute names with leading underscores. You must use dict-style lookups instead:: collection['__my_collection__'] Not: collection.__my_collection__ .. versionchanged:: 2.2 Removed deprecated argument: options .. versionadded:: 2.1 uuid_subtype attribute .. mongodoc:: collections """ super(Collection, self).__init__( codec_options or database.codec_options, read_preference or database.read_preference, write_concern or database.write_concern, read_concern or database.read_concern) if not isinstance(name, string_type): raise TypeError("name must be an instance " "of %s" % (string_type.__name__,)) if not name or ".." in name: raise InvalidName("collection names cannot be empty") if "$" in name and not (name.startswith("oplog.$main") or name.startswith("$cmd")): raise InvalidName("collection names must not " "contain '$': %r" % name) if name[0] == "." or name[-1] == ".": raise InvalidName("collection names must not start " "or end with '.': %r" % name) if "\x00" in name: raise InvalidName("collection names must not contain the " "null character") collation = validate_collation_or_none(kwargs.pop('collation', None)) self.__database = database self.__name = _unicode(name) self.__full_name = _UJOIN % (self.__database.name, self.__name) if create or kwargs or collation: self.__create(kwargs, collation, session) self.__write_response_codec_options = self.codec_options._replace( unicode_decode_error_handler='replace', document_class=dict) def _socket_for_reads(self, session): return self.__database.client._socket_for_reads( self._read_preference_for(session), session) def _socket_for_writes(self, session): return self.__database.client._socket_for_writes(session) def _command(self, sock_info, command, slave_ok=False, read_preference=None, codec_options=None, check=True, allowable_errors=None, read_concern=None, write_concern=None, collation=None, session=None, retryable_write=False, user_fields=None): """Internal command helper. :Parameters: - `sock_info` - A SocketInfo instance. - `command` - The command itself, as a SON instance. - `slave_ok`: whether to set the SlaveOkay wire protocol bit. - `codec_options` (optional) - An instance of :class:`~bson.codec_options.CodecOptions`. - `check`: raise OperationFailure if there are errors - `allowable_errors`: errors to ignore if `check` is True - `read_concern` (optional) - An instance of :class:`~pymongo.read_concern.ReadConcern`. - `write_concern`: An instance of :class:`~pymongo.write_concern.WriteConcern`. This option is only valid for MongoDB 3.4 and above. - `collation` (optional) - An instance of :class:`~pymongo.collation.Collation`. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `retryable_write` (optional): True if this command is a retryable write. - `user_fields` (optional): Response fields that should be decoded using the TypeDecoders from codec_options, passed to bson._decode_all_selective. :Returns: The result document. """ with self.__database.client._tmp_session(session) as s: return sock_info.command( self.__database.name, command, slave_ok, read_preference or self._read_preference_for(session), codec_options or self.codec_options, check, allowable_errors, read_concern=read_concern, write_concern=write_concern, parse_write_concern_error=True, collation=collation, session=s, client=self.__database.client, retryable_write=retryable_write, user_fields=user_fields) def __create(self, options, collation, session): """Sends a create command with the given options. """ cmd = SON([("create", self.__name)]) if options: if "size" in options: options["size"] = float(options["size"]) cmd.update(options) with self._socket_for_writes(session) as sock_info: self._command( sock_info, cmd, read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), collation=collation, session=session) def __getattr__(self, name): """Get a sub-collection of this collection by name. Raises InvalidName if an invalid collection name is used. :Parameters: - `name`: the name of the collection to get """ if name.startswith('_'): full_name = _UJOIN % (self.__name, name) raise AttributeError( "Collection has no attribute %r. To access the %s" " collection, use database['%s']." % ( name, full_name, full_name)) return self.__getitem__(name) def __getitem__(self, name): return Collection(self.__database, _UJOIN % (self.__name, name), False, self.codec_options, self.read_preference, self.write_concern, self.read_concern) def __repr__(self): return "Collection(%r, %r)" % (self.__database, self.__name) def __eq__(self, other): if isinstance(other, Collection): return (self.__database == other.database and self.__name == other.name) return NotImplemented def __ne__(self, other): return not self == other @property def full_name(self): """The full name of this :class:`Collection`. The full name is of the form `database_name.collection_name`. """ return self.__full_name @property def name(self): """The name of this :class:`Collection`.""" return self.__name @property def database(self): """The :class:`~pymongo.database.Database` that this :class:`Collection` is a part of. """ return self.__database def with_options(self, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get a clone of this collection changing the specified settings. >>> coll1.read_preference Primary() >>> from pymongo import ReadPreference >>> coll2 = coll1.with_options(read_preference=ReadPreference.SECONDARY) >>> coll1.read_preference Primary() >>> coll2.read_preference Secondary(tag_sets=None) :Parameters: - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`Collection` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`Collection` is used. See :mod:`~pymongo.read_preferences` for options. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`Collection` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`Collection` is used. """ return Collection(self.__database, self.__name, False, codec_options or self.codec_options, read_preference or self.read_preference, write_concern or self.write_concern, read_concern or self.read_concern) def initialize_unordered_bulk_op(self, bypass_document_validation=False): """**DEPRECATED** - Initialize an unordered batch of write operations. Operations will be performed on the server in arbitrary order, possibly in parallel. All operations will be attempted. :Parameters: - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. Returns a :class:`~pymongo.bulk.BulkOperationBuilder` instance. See :ref:`unordered_bulk` for examples. .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.5 Deprecated. Use :meth:`~pymongo.collection.Collection.bulk_write` instead. .. versionchanged:: 3.2 Added bypass_document_validation support .. versionadded:: 2.7 """ warnings.warn("initialize_unordered_bulk_op is deprecated", DeprecationWarning, stacklevel=2) return BulkOperationBuilder(self, False, bypass_document_validation) def initialize_ordered_bulk_op(self, bypass_document_validation=False): """**DEPRECATED** - Initialize an ordered batch of write operations. Operations will be performed on the server serially, in the order provided. If an error occurs all remaining operations are aborted. :Parameters: - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. Returns a :class:`~pymongo.bulk.BulkOperationBuilder` instance. See :ref:`ordered_bulk` for examples. .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.5 Deprecated. Use :meth:`~pymongo.collection.Collection.bulk_write` instead. .. versionchanged:: 3.2 Added bypass_document_validation support .. versionadded:: 2.7 """ warnings.warn("initialize_ordered_bulk_op is deprecated", DeprecationWarning, stacklevel=2) return BulkOperationBuilder(self, True, bypass_document_validation) def bulk_write(self, requests, ordered=True, bypass_document_validation=False, session=None): """Send a batch of write operations to the server. Requests are passed as a list of write operation instances ( :class:`~pymongo.operations.InsertOne`, :class:`~pymongo.operations.UpdateOne`, :class:`~pymongo.operations.UpdateMany`, :class:`~pymongo.operations.ReplaceOne`, :class:`~pymongo.operations.DeleteOne`, or :class:`~pymongo.operations.DeleteMany`). >>> for doc in db.test.find({}): ... print(doc) ... {u'x': 1, u'_id': ObjectId('54f62e60fba5226811f634ef')} {u'x': 1, u'_id': ObjectId('54f62e60fba5226811f634f0')} >>> # DeleteMany, UpdateOne, and UpdateMany are also available. ... >>> from pymongo import InsertOne, DeleteOne, ReplaceOne >>> requests = [InsertOne({'y': 1}), DeleteOne({'x': 1}), ... ReplaceOne({'w': 1}, {'z': 1}, upsert=True)] >>> result = db.test.bulk_write(requests) >>> result.inserted_count 1 >>> result.deleted_count 1 >>> result.modified_count 0 >>> result.upserted_ids {2: ObjectId('54f62ee28891e756a6e1abd5')} >>> for doc in db.test.find({}): ... print(doc) ... {u'x': 1, u'_id': ObjectId('54f62e60fba5226811f634f0')} {u'y': 1, u'_id': ObjectId('54f62ee2fba5226811f634f1')} {u'z': 1, u'_id': ObjectId('54f62ee28891e756a6e1abd5')} :Parameters: - `requests`: A list of write operations (see examples above). - `ordered` (optional): If ``True`` (the default) requests will be performed on the server serially, in the order provided. If an error occurs all remaining operations are aborted. If ``False`` requests will be performed on the server in arbitrary order, possibly in parallel, and all operations will be attempted. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: An instance of :class:`~pymongo.results.BulkWriteResult`. .. seealso:: :ref:`writes-and-ids` .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.2 Added bypass_document_validation support .. versionadded:: 3.0 """ common.validate_list("requests", requests) blk = _Bulk(self, ordered, bypass_document_validation) for request in requests: try: request._add_to_bulk(blk) except AttributeError: raise TypeError("%r is not a valid request" % (request,)) write_concern = self._write_concern_for(session) bulk_api_result = blk.execute(write_concern, session) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) def _legacy_write(self, sock_info, name, cmd, op_id, bypass_doc_val, func, *args): """Internal legacy unacknowledged write helper.""" # Cannot have both unacknowledged write and bypass document validation. if bypass_doc_val and sock_info.max_wire_version >= 4: raise OperationFailure("Cannot set bypass_document_validation with" " unacknowledged write concern") listeners = self.database.client._event_listeners publish = listeners.enabled_for_commands if publish: start = datetime.datetime.now() args = args + (sock_info.compression_context,) rqst_id, msg, max_size = func(*args) if publish: duration = datetime.datetime.now() - start listeners.publish_command_start( cmd, self.__database.name, rqst_id, sock_info.address, op_id) start = datetime.datetime.now() try: result = sock_info.legacy_write(rqst_id, msg, max_size, False) except Exception as exc: if publish: dur = (datetime.datetime.now() - start) + duration if isinstance(exc, OperationFailure): details = exc.details # Succeed if GLE was successful and this is a write error. if details.get("ok") and "n" in details: reply = message._convert_write_result( name, cmd, details) listeners.publish_command_success( dur, reply, name, rqst_id, sock_info.address, op_id) raise else: details = message._convert_exception(exc) listeners.publish_command_failure( dur, details, name, rqst_id, sock_info.address, op_id) raise if publish: if result is not None: reply = message._convert_write_result(name, cmd, result) else: # Comply with APM spec. reply = {'ok': 1} duration = (datetime.datetime.now() - start) + duration listeners.publish_command_success( duration, reply, name, rqst_id, sock_info.address, op_id) return result def _insert_one( self, doc, ordered, check_keys, manipulate, write_concern, op_id, bypass_doc_val, session): """Internal helper for inserting a single document.""" if manipulate: doc = self.__database._apply_incoming_manipulators(doc, self) if not isinstance(doc, RawBSONDocument) and '_id' not in doc: doc['_id'] = ObjectId() doc = self.__database._apply_incoming_copying_manipulators(doc, self) write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged command = SON([('insert', self.name), ('ordered', ordered), ('documents', [doc])]) if not write_concern.is_server_default: command['writeConcern'] = write_concern.document def _insert_command(session, sock_info, retryable_write): if not sock_info.op_msg_enabled and not acknowledged: # Legacy OP_INSERT. return self._legacy_write( sock_info, 'insert', command, op_id, bypass_doc_val, message.insert, self.__full_name, [doc], check_keys, False, write_concern.document, False, self.__write_response_codec_options) if bypass_doc_val and sock_info.max_wire_version >= 4: command['bypassDocumentValidation'] = True result = sock_info.command( self.__database.name, command, write_concern=write_concern, codec_options=self.__write_response_codec_options, check_keys=check_keys, session=session, client=self.__database.client, retryable_write=retryable_write) _check_write_command_response(result) self.__database.client._retryable_write( acknowledged, _insert_command, session) if not isinstance(doc, RawBSONDocument): return doc.get('_id') def _insert(self, docs, ordered=True, check_keys=True, manipulate=False, write_concern=None, op_id=None, bypass_doc_val=False, session=None): """Internal insert helper.""" if isinstance(docs, abc.Mapping): return self._insert_one( docs, ordered, check_keys, manipulate, write_concern, op_id, bypass_doc_val, session) ids = [] if manipulate: def gen(): """Generator that applies SON manipulators to each document and adds _id if necessary. """ _db = self.__database for doc in docs: # Apply user-configured SON manipulators. This order of # operations is required for backwards compatibility, # see PYTHON-709. doc = _db._apply_incoming_manipulators(doc, self) if not (isinstance(doc, RawBSONDocument) or '_id' in doc): doc['_id'] = ObjectId() doc = _db._apply_incoming_copying_manipulators(doc, self) ids.append(doc['_id']) yield doc else: def gen(): """Generator that only tracks existing _ids.""" for doc in docs: # Don't inflate RawBSONDocument by touching fields. if not isinstance(doc, RawBSONDocument): ids.append(doc.get('_id')) yield doc write_concern = write_concern or self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_doc_val) blk.ops = [(message._INSERT, doc) for doc in gen()] try: blk.execute(write_concern, session=session) except BulkWriteError as bwe: _raise_last_error(bwe.details) return ids def insert_one(self, document, bypass_document_validation=False, session=None): """Insert a single document. >>> db.test.count_documents({'x': 1}) 0 >>> result = db.test.insert_one({'x': 1}) >>> result.inserted_id ObjectId('54f112defba522406c9cc208') >>> db.test.find_one({'x': 1}) {u'x': 1, u'_id': ObjectId('54f112defba522406c9cc208')} :Parameters: - `document`: The document to insert. Must be a mutable mapping type. If the document does not have an _id field one will be added automatically. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.InsertOneResult`. .. seealso:: :ref:`writes-and-ids` .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.2 Added bypass_document_validation support .. versionadded:: 3.0 """ common.validate_is_document_type("document", document) if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() write_concern = self._write_concern_for(session) return InsertOneResult( self._insert(document, write_concern=write_concern, bypass_doc_val=bypass_document_validation, session=session), write_concern.acknowledged) def insert_many(self, documents, ordered=True, bypass_document_validation=False, session=None): """Insert an iterable of documents. >>> db.test.count_documents({}) 0 >>> result = db.test.insert_many([{'x': i} for i in range(2)]) >>> result.inserted_ids [ObjectId('54f113fffba522406c9cc20e'), ObjectId('54f113fffba522406c9cc20f')] >>> db.test.count_documents({}) 2 :Parameters: - `documents`: A iterable of documents to insert. - `ordered` (optional): If ``True`` (the default) documents will be inserted on the server serially, in the order provided. If an error occurs all remaining inserts are aborted. If ``False``, documents will be inserted on the server in arbitrary order, possibly in parallel, and all document inserts will be attempted. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: An instance of :class:`~pymongo.results.InsertManyResult`. .. seealso:: :ref:`writes-and-ids` .. note:: `bypass_document_validation` requires server version **>= 3.2** .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.2 Added bypass_document_validation support .. versionadded:: 3.0 """ if not isinstance(documents, abc.Iterable) or not documents: raise TypeError("documents must be a non-empty list") inserted_ids = [] def gen(): """A generator that validates documents and handles _ids.""" for document in documents: common.validate_is_document_type("document", document) if not isinstance(document, RawBSONDocument): if "_id" not in document: document["_id"] = ObjectId() inserted_ids.append(document["_id"]) yield (message._INSERT, document) write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation) blk.ops = [doc for doc in gen()] blk.execute(write_concern, session=session) return InsertManyResult(inserted_ids, write_concern.acknowledged) def _update(self, sock_info, criteria, document, upsert=False, check_keys=True, multi=False, manipulate=False, write_concern=None, op_id=None, ordered=True, bypass_doc_val=False, collation=None, array_filters=None, hint=None, session=None, retryable_write=False): """Internal update / replace helper.""" common.validate_boolean("upsert", upsert) if manipulate: document = self.__database._fix_incoming(document, self) collation = validate_collation_or_none(collation) write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged update_doc = SON([('q', criteria), ('u', document), ('multi', multi), ('upsert', upsert)]) if collation is not None: if sock_info.max_wire_version < 5: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use collations.') elif not acknowledged: raise ConfigurationError( 'Collation is unsupported for unacknowledged writes.') else: update_doc['collation'] = collation if array_filters is not None: if sock_info.max_wire_version < 6: raise ConfigurationError( 'Must be connected to MongoDB 3.6+ to use array_filters.') elif not acknowledged: raise ConfigurationError( 'arrayFilters is unsupported for unacknowledged writes.') else: update_doc['arrayFilters'] = array_filters if hint is not None: if sock_info.max_wire_version < 5: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use hint.') elif not acknowledged: raise ConfigurationError( 'hint is unsupported for unacknowledged writes.') if not isinstance(hint, string_type): hint = helpers._index_document(hint) update_doc['hint'] = hint command = SON([('update', self.name), ('ordered', ordered), ('updates', [update_doc])]) if not write_concern.is_server_default: command['writeConcern'] = write_concern.document if not sock_info.op_msg_enabled and not acknowledged: # Legacy OP_UPDATE. return self._legacy_write( sock_info, 'update', command, op_id, bypass_doc_val, message.update, self.__full_name, upsert, multi, criteria, document, False, write_concern.document, check_keys, self.__write_response_codec_options) # Update command. if bypass_doc_val and sock_info.max_wire_version >= 4: command['bypassDocumentValidation'] = True # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. result = sock_info.command( self.__database.name, command, write_concern=write_concern, codec_options=self.__write_response_codec_options, session=session, client=self.__database.client, retryable_write=retryable_write).copy() _check_write_command_response(result) # Add the updatedExisting field for compatibility. if result.get('n') and 'upserted' not in result: result['updatedExisting'] = True else: result['updatedExisting'] = False # MongoDB >= 2.6.0 returns the upsert _id in an array # element. Break it out for backward compatibility. if 'upserted' in result: result['upserted'] = result['upserted'][0]['_id'] if not acknowledged: return None return result def _update_retryable( self, criteria, document, upsert=False, check_keys=True, multi=False, manipulate=False, write_concern=None, op_id=None, ordered=True, bypass_doc_val=False, collation=None, array_filters=None, hint=None, session=None): """Internal update / replace helper.""" def _update(session, sock_info, retryable_write): return self._update( sock_info, criteria, document, upsert=upsert, check_keys=check_keys, multi=multi, manipulate=manipulate, write_concern=write_concern, op_id=op_id, ordered=ordered, bypass_doc_val=bypass_doc_val, collation=collation, array_filters=array_filters, hint=hint, session=session, retryable_write=retryable_write) return self.__database.client._retryable_write( (write_concern or self.write_concern).acknowledged and not multi, _update, session) def replace_one(self, filter, replacement, upsert=False, bypass_document_validation=False, collation=None, hint=None, session=None): """Replace a single document matching the filter. >>> for doc in db.test.find({}): ... print(doc) ... {u'x': 1, u'_id': ObjectId('54f4c5befba5220aa4d6dee7')} >>> result = db.test.replace_one({'x': 1}, {'y': 1}) >>> result.matched_count 1 >>> result.modified_count 1 >>> for doc in db.test.find({}): ... print(doc) ... {u'y': 1, u'_id': ObjectId('54f4c5befba5220aa4d6dee7')} The *upsert* option can be used to insert a new document if a matching document does not exist. >>> result = db.test.replace_one({'x': 1}, {'x': 1}, True) >>> result.matched_count 0 >>> result.modified_count 0 >>> result.upserted_id ObjectId('54f11e5c8891e756a6e1abd4') >>> db.test.find_one({'x': 1}) {u'x': 1, u'_id': ObjectId('54f11e5c8891e756a6e1abd4')} :Parameters: - `filter`: A query that matches the document to replace. - `replacement`: The new document. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. This option is only supported on MongoDB 3.2 and above. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.UpdateResult`. .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the `collation` option. .. versionchanged:: 3.2 Added bypass_document_validation support. .. versionadded:: 3.0 """ common.validate_is_mapping("filter", filter) common.validate_ok_for_replace(replacement) write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( filter, replacement, upsert, write_concern=write_concern, bypass_doc_val=bypass_document_validation, collation=collation, hint=hint, session=session), write_concern.acknowledged) def update_one(self, filter, update, upsert=False, bypass_document_validation=False, collation=None, array_filters=None, hint=None, session=None): """Update a single document matching the filter. >>> for doc in db.test.find(): ... print(doc) ... {u'x': 1, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} >>> result = db.test.update_one({'x': 1}, {'$inc': {'x': 3}}) >>> result.matched_count 1 >>> result.modified_count 1 >>> for doc in db.test.find(): ... print(doc) ... {u'x': 4, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} :Parameters: - `filter`: A query that matches the document to update. - `update`: The modifications to apply. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `bypass_document_validation`: (optional) If ``True``, allows the write to opt-out of document level validation. Default is ``False``. This option is only supported on MongoDB 3.2 and above. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `array_filters` (optional): A list of filters specifying which array elements an update should apply. This option is only supported on MongoDB 3.6 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.UpdateResult`. .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.9 Added the ability to accept a pipeline as the ``update``. .. versionchanged:: 3.6 Added the ``array_filters`` and ``session`` parameters. .. versionchanged:: 3.4 Added the ``collation`` option. .. versionchanged:: 3.2 Added ``bypass_document_validation`` support. .. versionadded:: 3.0 """ common.validate_is_mapping("filter", filter) common.validate_ok_for_update(update) common.validate_list_or_none('array_filters', array_filters) write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( filter, update, upsert, check_keys=False, write_concern=write_concern, bypass_doc_val=bypass_document_validation, collation=collation, array_filters=array_filters, hint=hint, session=session), write_concern.acknowledged) def update_many(self, filter, update, upsert=False, array_filters=None, bypass_document_validation=False, collation=None, hint=None, session=None): """Update one or more documents that match the filter. >>> for doc in db.test.find(): ... print(doc) ... {u'x': 1, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} >>> result = db.test.update_many({'x': 1}, {'$inc': {'x': 3}}) >>> result.matched_count 3 >>> result.modified_count 3 >>> for doc in db.test.find(): ... print(doc) ... {u'x': 4, u'_id': 0} {u'x': 4, u'_id': 1} {u'x': 4, u'_id': 2} :Parameters: - `filter`: A query that matches the documents to update. - `update`: The modifications to apply. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `bypass_document_validation` (optional): If ``True``, allows the write to opt-out of document level validation. Default is ``False``. This option is only supported on MongoDB 3.2 and above. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `array_filters` (optional): A list of filters specifying which array elements an update should apply. This option is only supported on MongoDB 3.6 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.UpdateResult`. .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.9 Added the ability to accept a pipeline as the `update`. .. versionchanged:: 3.6 Added ``array_filters`` and ``session`` parameters. .. versionchanged:: 3.4 Added the `collation` option. .. versionchanged:: 3.2 Added bypass_document_validation support. .. versionadded:: 3.0 """ common.validate_is_mapping("filter", filter) common.validate_ok_for_update(update) common.validate_list_or_none('array_filters', array_filters) write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( filter, update, upsert, check_keys=False, multi=True, write_concern=write_concern, bypass_doc_val=bypass_document_validation, collation=collation, array_filters=array_filters, hint=hint, session=session), write_concern.acknowledged) def drop(self, session=None): """Alias for :meth:`~pymongo.database.Database.drop_collection`. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. The following two calls are equivalent: >>> db.foo.drop() >>> db.drop_collection("foo") .. versionchanged:: 3.7 :meth:`drop` now respects this :class:`Collection`'s :attr:`write_concern`. .. versionchanged:: 3.6 Added ``session`` parameter. """ dbo = self.__database.client.get_database( self.__database.name, self.codec_options, self.read_preference, self.write_concern, self.read_concern) dbo.drop_collection(self.__name, session=session) def _delete( self, sock_info, criteria, multi, write_concern=None, op_id=None, ordered=True, collation=None, hint=None, session=None, retryable_write=False): """Internal delete helper.""" common.validate_is_mapping("filter", criteria) write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged delete_doc = SON([('q', criteria), ('limit', int(not multi))]) collation = validate_collation_or_none(collation) if collation is not None: if sock_info.max_wire_version < 5: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use collations.') elif not acknowledged: raise ConfigurationError( 'Collation is unsupported for unacknowledged writes.') else: delete_doc['collation'] = collation if hint is not None: if sock_info.max_wire_version < 5: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use hint.') elif not acknowledged: raise ConfigurationError( 'hint is unsupported for unacknowledged writes.') if not isinstance(hint, string_type): hint = helpers._index_document(hint) delete_doc['hint'] = hint command = SON([('delete', self.name), ('ordered', ordered), ('deletes', [delete_doc])]) if not write_concern.is_server_default: command['writeConcern'] = write_concern.document if not sock_info.op_msg_enabled and not acknowledged: # Legacy OP_DELETE. return self._legacy_write( sock_info, 'delete', command, op_id, False, message.delete, self.__full_name, criteria, False, write_concern.document, self.__write_response_codec_options, int(not multi)) # Delete command. result = sock_info.command( self.__database.name, command, write_concern=write_concern, codec_options=self.__write_response_codec_options, session=session, client=self.__database.client, retryable_write=retryable_write) _check_write_command_response(result) return result def _delete_retryable( self, criteria, multi, write_concern=None, op_id=None, ordered=True, collation=None, hint=None, session=None): """Internal delete helper.""" def _delete(session, sock_info, retryable_write): return self._delete( sock_info, criteria, multi, write_concern=write_concern, op_id=op_id, ordered=ordered, collation=collation, hint=hint, session=session, retryable_write=retryable_write) return self.__database.client._retryable_write( (write_concern or self.write_concern).acknowledged and not multi, _delete, session) def delete_one(self, filter, collation=None, hint=None, session=None): """Delete a single document matching the filter. >>> db.test.count_documents({'x': 1}) 3 >>> result = db.test.delete_one({'x': 1}) >>> result.deleted_count 1 >>> db.test.count_documents({'x': 1}) 2 :Parameters: - `filter`: A query that matches the document to delete. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.DeleteResult`. .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the `collation` option. .. versionadded:: 3.0 """ write_concern = self._write_concern_for(session) return DeleteResult( self._delete_retryable( filter, False, write_concern=write_concern, collation=collation, hint=hint, session=session), write_concern.acknowledged) def delete_many(self, filter, collation=None, hint=None, session=None): """Delete one or more documents matching the filter. >>> db.test.count_documents({'x': 1}) 3 >>> result = db.test.delete_many({'x': 1}) >>> result.deleted_count 3 >>> db.test.count_documents({'x': 1}) 0 :Parameters: - `filter`: A query that matches the documents to delete. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: - An instance of :class:`~pymongo.results.DeleteResult`. .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the `collation` option. .. versionadded:: 3.0 """ write_concern = self._write_concern_for(session) return DeleteResult( self._delete_retryable( filter, True, write_concern=write_concern, collation=collation, hint=hint, session=session), write_concern.acknowledged) def find_one(self, filter=None, *args, **kwargs): """Get a single document from the database. All arguments to :meth:`find` are also valid arguments for :meth:`find_one`, although any `limit` argument will be ignored. Returns a single document, or ``None`` if no matching document is found. The :meth:`find_one` method obeys the :attr:`read_preference` of this :class:`Collection`. :Parameters: - `filter` (optional): a dictionary specifying the query to be performed OR any other type to be used as the value for a query for ``"_id"``. - `*args` (optional): any additional positional arguments are the same as the arguments to :meth:`find`. - `**kwargs` (optional): any additional keyword arguments are the same as the arguments to :meth:`find`. >>> collection.find_one(max_time_ms=100) """ if (filter is not None and not isinstance(filter, abc.Mapping)): filter = {"_id": filter} cursor = self.find(filter, *args, **kwargs) for result in cursor.limit(-1): return result return None def find(self, *args, **kwargs): """Query the database. The `filter` argument is a prototype document that all results must match. For example: >>> db.test.find({"hello": "world"}) only matches documents that have a key "hello" with value "world". Matches can have other keys *in addition* to "hello". The `projection` argument is used to specify a subset of fields that should be included in the result documents. By limiting results to a certain subset of fields you can cut down on network traffic and decoding time. Raises :class:`TypeError` if any of the arguments are of improper type. Returns an instance of :class:`~pymongo.cursor.Cursor` corresponding to this query. The :meth:`find` method obeys the :attr:`read_preference` of this :class:`Collection`. :Parameters: - `filter` (optional): a SON object specifying elements which must be present for a document to be included in the result set - `projection` (optional): a list of field names that should be returned in the result set or a dict specifying the fields to include or exclude. If `projection` is a list "_id" will always be returned. Use a dict to exclude fields from the result (e.g. projection={'_id': False}). - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `skip` (optional): the number of documents to omit (from the start of the result set) when returning the results - `limit` (optional): the maximum number of results to return. A limit of 0 (the default) is equivalent to setting no limit. - `no_cursor_timeout` (optional): if False (the default), any returned cursor is closed by the server after 10 minutes of inactivity. If set to True, the returned cursor will never time out on the server. Care should be taken to ensure that cursors with no_cursor_timeout turned on are properly closed. - `cursor_type` (optional): the type of cursor to return. The valid options are defined by :class:`~pymongo.cursor.CursorType`: - :attr:`~pymongo.cursor.CursorType.NON_TAILABLE` - the result of this find call will return a standard cursor over the result set. - :attr:`~pymongo.cursor.CursorType.TAILABLE` - the result of this find call will be a tailable cursor - tailable cursors are only for use with capped collections. They are not closed when the last data is retrieved but are kept open and the cursor location marks the final document position. If more data is received iteration of the cursor will continue from the last document received. For details, see the `tailable cursor documentation `_. - :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` - the result of this find call will be a tailable cursor with the await flag set. The server will wait for a few seconds after returning the full result set so that it can capture and return additional data added during the query. - :attr:`~pymongo.cursor.CursorType.EXHAUST` - the result of this find call will be an exhaust cursor. MongoDB will stream batched results to the client without waiting for the client to request each batch, reducing latency. See notes on compatibility below. - `sort` (optional): a list of (key, direction) pairs specifying the sort order for this query. See :meth:`~pymongo.cursor.Cursor.sort` for details. - `allow_partial_results` (optional): if True, mongos will return partial results if some shards are down instead of returning an error. - `oplog_replay` (optional): **DEPRECATED** - if True, set the oplogReplay query flag. Default: False. - `batch_size` (optional): Limits the number of documents returned in a single batch. - `manipulate` (optional): **DEPRECATED** - If True, apply any outgoing SON manipulators before returning. Default: True. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `return_key` (optional): If True, return only the index keys in each document. - `show_record_id` (optional): If True, adds a field ``$recordId`` in each document with the storage engine's internal record identifier. - `snapshot` (optional): **DEPRECATED** - If True, prevents the cursor from returning a document more than once because of an intervening write operation. - `hint` (optional): An index, in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.hint` on the cursor to tell Mongo the proper index to use for the query. - `max_time_ms` (optional): Specifies a time limit for a query operation. If the specified time is exceeded, the operation will be aborted and :exc:`~pymongo.errors.ExecutionTimeout` is raised. Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.max_time_ms` on the cursor. - `max_scan` (optional): **DEPRECATED** - The maximum number of documents to scan. Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.max_scan` on the cursor. - `min` (optional): A list of field, limit pairs specifying the inclusive lower bound for all keys of a specific index in order. Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.min` on the cursor. ``hint`` must also be passed to ensure the query utilizes the correct index. - `max` (optional): A list of field, limit pairs specifying the exclusive upper bound for all keys of a specific index in order. Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.max` on the cursor. ``hint`` must also be passed to ensure the query utilizes the correct index. - `comment` (optional): A string to attach to the query to help interpret and trace the operation in the server logs and in profile data. Pass this as an alternative to calling :meth:`~pymongo.cursor.Cursor.comment` on the cursor. - `modifiers` (optional): **DEPRECATED** - A dict specifying additional MongoDB query modifiers. Use the keyword arguments listed above instead. - `allow_disk_use` (optional): if True, MongoDB may use temporary disk files to store data exceeding the system memory limit while processing a blocking sort operation. The option has no effect if MongoDB can satisfy the specified sort using an index, or if the blocking sort requires less memory than the 100 MiB limit. This option is only supported on MongoDB 4.4 and above. .. note:: There are a number of caveats to using :attr:`~pymongo.cursor.CursorType.EXHAUST` as cursor_type: - The `limit` option can not be used with an exhaust cursor. - Exhaust cursors are not supported by mongos and can not be used with a sharded cluster. - A :class:`~pymongo.cursor.Cursor` instance created with the :attr:`~pymongo.cursor.CursorType.EXHAUST` cursor_type requires an exclusive :class:`~socket.socket` connection to MongoDB. If the :class:`~pymongo.cursor.Cursor` is discarded without being completely iterated the underlying :class:`~socket.socket` connection will be closed and discarded without being returned to the connection pool. .. versionchanged:: 3.11 Added the ``allow_disk_use`` option. Deprecated the ``oplog_replay`` option. Support for this option is deprecated in MongoDB 4.4. The query engine now automatically optimizes queries against the oplog without requiring this option to be set. .. versionchanged:: 3.7 Deprecated the ``snapshot`` option, which is deprecated in MongoDB 3.6 and removed in MongoDB 4.0. Deprecated the ``max_scan`` option. Support for this option is deprecated in MongoDB 4.0. Use ``max_time_ms`` instead to limit server-side execution time. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.5 Added the options ``return_key``, ``show_record_id``, ``snapshot``, ``hint``, ``max_time_ms``, ``max_scan``, ``min``, ``max``, and ``comment``. Deprecated the ``modifiers`` option. .. versionchanged:: 3.4 Added support for the ``collation`` option. .. versionchanged:: 3.0 Changed the parameter names ``spec``, ``fields``, ``timeout``, and ``partial`` to ``filter``, ``projection``, ``no_cursor_timeout``, and ``allow_partial_results`` respectively. Added the ``cursor_type``, ``oplog_replay``, and ``modifiers`` options. Removed the ``network_timeout``, ``read_preference``, ``tag_sets``, ``secondary_acceptable_latency_ms``, ``max_scan``, ``snapshot``, ``tailable``, ``await_data``, ``exhaust``, ``as_class``, and slave_okay parameters. Removed ``compile_re`` option: PyMongo now always represents BSON regular expressions as :class:`~bson.regex.Regex` objects. Use :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a BSON regular expression to a Python regular expression object. Soft deprecated the ``manipulate`` option. .. versionchanged:: 2.7 Added ``compile_re`` option. If set to False, PyMongo represented BSON regular expressions as :class:`~bson.regex.Regex` objects instead of attempting to compile BSON regular expressions as Python native regular expressions, thus preventing errors for some incompatible patterns, see `PYTHON-500`_. .. versionchanged:: 2.3 Added the ``tag_sets`` and ``secondary_acceptable_latency_ms`` parameters. .. _PYTHON-500: https://jira.mongodb.org/browse/PYTHON-500 .. mongodoc:: find """ return Cursor(self, *args, **kwargs) def find_raw_batches(self, *args, **kwargs): """Query the database and retrieve batches of raw BSON. Similar to the :meth:`find` method but returns a :class:`~pymongo.cursor.RawBatchCursor`. This example demonstrates how to work with raw batches, but in practice raw batches should be passed to an external library that can decode BSON into another data type, rather than used with PyMongo's :mod:`bson` module. >>> import bson >>> cursor = db.test.find_raw_batches() >>> for batch in cursor: ... print(bson.decode_all(batch)) .. note:: find_raw_batches does not support sessions or auto encryption. .. versionadded:: 3.6 """ # OP_MSG with document stream returns is required to support # sessions. if "session" in kwargs: raise ConfigurationError( "find_raw_batches does not support sessions") # OP_MSG is required to support encryption. if self.__database.client._encrypter: raise InvalidOperation( "find_raw_batches does not support auto encryption") return RawBatchCursor(self, *args, **kwargs) def parallel_scan(self, num_cursors, session=None, **kwargs): """**DEPRECATED**: Scan this entire collection in parallel. Returns a list of up to ``num_cursors`` cursors that can be iterated concurrently. As long as the collection is not modified during scanning, each document appears once in one of the cursors result sets. For example, to process each document in a collection using some thread-safe ``process_document()`` function: >>> def process_cursor(cursor): ... for document in cursor: ... # Some thread-safe processing function: ... process_document(document) >>> >>> # Get up to 4 cursors. ... >>> cursors = collection.parallel_scan(4) >>> threads = [ ... threading.Thread(target=process_cursor, args=(cursor,)) ... for cursor in cursors] >>> >>> for thread in threads: ... thread.start() >>> >>> for thread in threads: ... thread.join() >>> >>> # All documents have now been processed. The :meth:`parallel_scan` method obeys the :attr:`read_preference` of this :class:`Collection`. :Parameters: - `num_cursors`: the number of cursors to return - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs`: additional options for the parallelCollectionScan command can be passed as keyword arguments. .. note:: Requires server version **>= 2.5.5**. .. versionchanged:: 3.7 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added back support for arbitrary keyword arguments. MongoDB 3.4 adds support for maxTimeMS as an option to the parallelCollectionScan command. .. versionchanged:: 3.0 Removed support for arbitrary keyword arguments, since the parallelCollectionScan command has no optional arguments. """ warnings.warn("parallel_scan is deprecated. MongoDB 4.2 will remove " "the parallelCollectionScan command.", DeprecationWarning, stacklevel=2) cmd = SON([('parallelCollectionScan', self.__name), ('numCursors', num_cursors)]) cmd.update(kwargs) with self._socket_for_reads(session) as (sock_info, slave_ok): # We call sock_info.command here directly, instead of # calling self._command to avoid using an implicit session. result = sock_info.command( self.__database.name, cmd, slave_ok, self._read_preference_for(session), self.codec_options, read_concern=self.read_concern, parse_write_concern_error=True, session=session, client=self.__database.client) cursors = [] for cursor in result['cursors']: cursors.append(CommandCursor( self, cursor['cursor'], sock_info.address, session=session, explicit_session=session is not None)) return cursors def _count(self, cmd, collation=None, session=None): """Internal count helper.""" # XXX: "ns missing" checks can be removed when we drop support for # MongoDB 3.0, see SERVER-17051. def _cmd(session, server, sock_info, slave_ok): res = self._command( sock_info, cmd, slave_ok, allowable_errors=["ns missing"], codec_options=self.__write_response_codec_options, read_concern=self.read_concern, collation=collation, session=session) if res.get("errmsg", "") == "ns missing": return 0 return int(res["n"]) return self.__database.client._retryable_read( _cmd, self._read_preference_for(session), session) def _aggregate_one_result( self, sock_info, slave_ok, cmd, collation=None, session=None): """Internal helper to run an aggregate that returns a single result.""" result = self._command( sock_info, cmd, slave_ok, codec_options=self.__write_response_codec_options, read_concern=self.read_concern, collation=collation, session=session) batch = result['cursor']['firstBatch'] return batch[0] if batch else None def estimated_document_count(self, **kwargs): """Get an estimate of the number of documents in this collection using collection metadata. The :meth:`estimated_document_count` method is **not** supported in a transaction. All optional parameters should be passed as keyword arguments to this method. Valid options include: - `maxTimeMS` (int): The maximum amount of time to allow this operation to run, in milliseconds. :Parameters: - `**kwargs` (optional): See list of options above. .. versionadded:: 3.7 """ if 'session' in kwargs: raise ConfigurationError( 'estimated_document_count does not support sessions') cmd = SON([('count', self.__name)]) cmd.update(kwargs) return self._count(cmd) def count_documents(self, filter, session=None, **kwargs): """Count the number of documents in this collection. .. note:: For a fast count of the total documents in a collection see :meth:`estimated_document_count`. The :meth:`count_documents` method is supported in a transaction. All optional parameters should be passed as keyword arguments to this method. Valid options include: - `skip` (int): The number of matching documents to skip before returning results. - `limit` (int): The maximum number of documents to count. Must be a positive integer. If not provided, no limit is imposed. - `maxTimeMS` (int): The maximum amount of time to allow this operation to run, in milliseconds. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (string or list of tuples): The index to use. Specify either the index name as a string or the index specification as a list of tuples (e.g. [('a', pymongo.ASCENDING), ('b', pymongo.ASCENDING)]). This option is only supported on MongoDB 3.6 and above. The :meth:`count_documents` method obeys the :attr:`read_preference` of this :class:`Collection`. .. note:: When migrating from :meth:`count` to :meth:`count_documents` the following query operators must be replaced: +-------------+-------------------------------------+ | Operator | Replacement | +=============+=====================================+ | $where | `$expr`_ | +-------------+-------------------------------------+ | $near | `$geoWithin`_ with `$center`_ | +-------------+-------------------------------------+ | $nearSphere | `$geoWithin`_ with `$centerSphere`_ | +-------------+-------------------------------------+ $expr requires MongoDB 3.6+ :Parameters: - `filter` (required): A query document that selects which documents to count in the collection. Can be an empty document to count all documents. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): See list of options above. .. versionadded:: 3.7 .. _$expr: https://docs.mongodb.com/manual/reference/operator/query/expr/ .. _$geoWithin: https://docs.mongodb.com/manual/reference/operator/query/geoWithin/ .. _$center: https://docs.mongodb.com/manual/reference/operator/query/center/#op._S_center .. _$centerSphere: https://docs.mongodb.com/manual/reference/operator/query/centerSphere/#op._S_centerSphere """ pipeline = [{'$match': filter}] if 'skip' in kwargs: pipeline.append({'$skip': kwargs.pop('skip')}) if 'limit' in kwargs: pipeline.append({'$limit': kwargs.pop('limit')}) pipeline.append({'$group': {'_id': 1, 'n': {'$sum': 1}}}) cmd = SON([('aggregate', self.__name), ('pipeline', pipeline), ('cursor', {})]) if "hint" in kwargs and not isinstance(kwargs["hint"], string_type): kwargs["hint"] = helpers._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) def _cmd(session, server, sock_info, slave_ok): result = self._aggregate_one_result( sock_info, slave_ok, cmd, collation, session) if not result: return 0 return result['n'] return self.__database.client._retryable_read( _cmd, self._read_preference_for(session), session) def count(self, filter=None, session=None, **kwargs): """**DEPRECATED** - Get the number of documents in this collection. The :meth:`count` method is deprecated and **not** supported in a transaction. Please use :meth:`count_documents` or :meth:`estimated_document_count` instead. All optional count parameters should be passed as keyword arguments to this method. Valid options include: - `skip` (int): The number of matching documents to skip before returning results. - `limit` (int): The maximum number of documents to count. A limit of 0 (the default) is equivalent to setting no limit. - `maxTimeMS` (int): The maximum amount of time to allow the count command to run, in milliseconds. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (string or list of tuples): The index to use. Specify either the index name as a string or the index specification as a list of tuples (e.g. [('a', pymongo.ASCENDING), ('b', pymongo.ASCENDING)]). The :meth:`count` method obeys the :attr:`read_preference` of this :class:`Collection`. .. note:: When migrating from :meth:`count` to :meth:`count_documents` the following query operators must be replaced: +-------------+-------------------------------------+ | Operator | Replacement | +=============+=====================================+ | $where | `$expr`_ | +-------------+-------------------------------------+ | $near | `$geoWithin`_ with `$center`_ | +-------------+-------------------------------------+ | $nearSphere | `$geoWithin`_ with `$centerSphere`_ | +-------------+-------------------------------------+ $expr requires MongoDB 3.6+ :Parameters: - `filter` (optional): A query document that selects which documents to count in the collection. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): See list of options above. .. versionchanged:: 3.7 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Support the `collation` option. .. _$expr: https://docs.mongodb.com/manual/reference/operator/query/expr/ .. _$geoWithin: https://docs.mongodb.com/manual/reference/operator/query/geoWithin/ .. _$center: https://docs.mongodb.com/manual/reference/operator/query/center/#op._S_center .. _$centerSphere: https://docs.mongodb.com/manual/reference/operator/query/centerSphere/#op._S_centerSphere """ warnings.warn("count is deprecated. Use estimated_document_count or " "count_documents instead. Please note that $where must " "be replaced by $expr, $near must be replaced by " "$geoWithin with $center, and $nearSphere must be " "replaced by $geoWithin with $centerSphere", DeprecationWarning, stacklevel=2) cmd = SON([("count", self.__name)]) if filter is not None: if "query" in kwargs: raise ConfigurationError("can't pass both filter and query") kwargs["query"] = filter if "hint" in kwargs and not isinstance(kwargs["hint"], string_type): kwargs["hint"] = helpers._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) return self._count(cmd, collation, session) def create_indexes(self, indexes, session=None, **kwargs): """Create one or more indexes on this collection. >>> from pymongo import IndexModel, ASCENDING, DESCENDING >>> index1 = IndexModel([("hello", DESCENDING), ... ("world", ASCENDING)], name="hello_world") >>> index2 = IndexModel([("goodbye", DESCENDING)]) >>> db.test.create_indexes([index1, index2]) ["hello_world", "goodbye_-1"] :Parameters: - `indexes`: A list of :class:`~pymongo.operations.IndexModel` instances. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. .. note:: `create_indexes` uses the `createIndexes`_ command introduced in MongoDB **2.6** and cannot be used with earlier versions. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. Added support for arbitrary keyword arguments. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. .. versionadded:: 3.0 .. _createIndexes: https://docs.mongodb.com/manual/reference/command/createIndexes/ """ common.validate_list('indexes', indexes) return self.__create_indexes(indexes, session, **kwargs) def __create_indexes(self, indexes, session, **kwargs): """Internal createIndexes helper. :Parameters: - `indexes`: A list of :class:`~pymongo.operations.IndexModel` instances. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. """ names = [] with self._socket_for_writes(session) as sock_info: supports_collations = sock_info.max_wire_version >= 5 supports_quorum = sock_info.max_wire_version >= 9 def gen_indexes(): for index in indexes: if not isinstance(index, IndexModel): raise TypeError( "%r is not an instance of " "pymongo.operations.IndexModel" % (index,)) document = index.document if "collation" in document and not supports_collations: raise ConfigurationError( "Must be connected to MongoDB " "3.4+ to use collations.") if 'bucketSize' in document: # The bucketSize option is required by geoHaystack. warnings.warn( _HAYSTACK_MSG, DeprecationWarning, stacklevel=4) names.append(document["name"]) yield document cmd = SON([('createIndexes', self.name), ('indexes', list(gen_indexes()))]) cmd.update(kwargs) if 'commitQuorum' in kwargs and not supports_quorum: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use the " "commitQuorum option for createIndexes") self._command( sock_info, cmd, read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, write_concern=self._write_concern_for(session), session=session) return names def create_index(self, keys, session=None, **kwargs): """Creates an index on this collection. Takes either a single key or a list of (key, direction) pairs. The key(s) must be an instance of :class:`basestring` (:class:`str` in python 3), and the direction(s) must be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`, :data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). To create a single key ascending index on the key ``'mike'`` we just use a string argument:: >>> my_collection.create_index("mike") For a compound index on ``'mike'`` descending and ``'eliot'`` ascending we need to use a list of tuples:: >>> my_collection.create_index([("mike", pymongo.DESCENDING), ... ("eliot", pymongo.ASCENDING)]) All optional index creation parameters should be passed as keyword arguments to this method. For example:: >>> my_collection.create_index([("mike", pymongo.DESCENDING)], ... background=True) Valid options include, but are not limited to: - `name`: custom name to use for this index - if none is given, a name will be generated. - `unique`: if ``True``, creates a uniqueness constraint on the index. - `background`: if ``True``, this index should be created in the background. - `sparse`: if ``True``, omit from the index any documents that lack the indexed field. - `bucketSize`: for use with geoHaystack indexes. Number of documents to group together within a certain proximity to a given longitude and latitude. - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` index. - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` index. - `expireAfterSeconds`: Used to create an expiring (TTL) collection. MongoDB will automatically delete documents from this collection after seconds. The indexed field must be a UTC datetime or the data will not expire. - `partialFilterExpression`: A document that specifies a filter for a partial index. Requires MongoDB >=3.2. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. Requires MongoDB >= 3.4. - `wildcardProjection`: Allows users to include or exclude specific field paths from a `wildcard index`_ using the {"$**" : 1} key pattern. Requires MongoDB >= 4.2. - `hidden`: if ``True``, this index will be hidden from the query planner and will not be evaluated as part of query plan selection. Requires MongoDB >= 4.4. See the MongoDB documentation for a full list of supported options by server version. .. warning:: `dropDups` is not supported by MongoDB 3.0 or newer. The option is silently ignored by the server and unique index builds using the option will fail if a duplicate value is detected. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. :Parameters: - `keys`: a single key or a list of (key, direction) pairs specifying the index to create - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): any additional index creation options (see the above list) should be passed as keyword arguments .. versionchanged:: 3.11 Added the ``hidden`` option. .. versionchanged:: 3.6 Added ``session`` parameter. Added support for passing maxTimeMS in kwargs. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. Support the `collation` option. .. versionchanged:: 3.2 Added partialFilterExpression to support partial indexes. .. versionchanged:: 3.0 Renamed `key_or_list` to `keys`. Removed the `cache_for` option. :meth:`create_index` no longer caches index names. Removed support for the drop_dups and bucket_size aliases. .. mongodoc:: indexes .. _wildcard index: https://docs.mongodb.com/master/core/index-wildcard/#wildcard-index-core """ cmd_options = {} if "maxTimeMS" in kwargs: cmd_options["maxTimeMS"] = kwargs.pop("maxTimeMS") index = IndexModel(keys, **kwargs) return self.__create_indexes([index], session, **cmd_options)[0] def ensure_index(self, key_or_list, cache_for=300, **kwargs): """**DEPRECATED** - Ensures that an index exists on this collection. .. versionchanged:: 3.0 **DEPRECATED** """ warnings.warn("ensure_index is deprecated. Use create_index instead.", DeprecationWarning, stacklevel=2) # The types supported by datetime.timedelta. if not (isinstance(cache_for, integer_types) or isinstance(cache_for, float)): raise TypeError("cache_for must be an integer or float.") if "drop_dups" in kwargs: kwargs["dropDups"] = kwargs.pop("drop_dups") if "bucket_size" in kwargs: kwargs["bucketSize"] = kwargs.pop("bucket_size") index = IndexModel(key_or_list, **kwargs) name = index.document["name"] # Note that there is a race condition here. One thread could # check if the index is cached and be preempted before creating # and caching the index. This means multiple threads attempting # to create the same index concurrently could send the index # to the server two or more times. This has no practical impact # other than wasted round trips. if not self.__database.client._cached(self.__database.name, self.__name, name): self.__create_indexes([index], session=None) self.__database.client._cache_index(self.__database.name, self.__name, name, cache_for) return name return None def drop_indexes(self, session=None, **kwargs): """Drops all indexes on this collection. Can be used on non-existant collections or collections with no indexes. Raises OperationFailure on an error. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. Added support for arbitrary keyword arguments. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. """ self.__database.client._purge_index(self.__database.name, self.__name) self.drop_index("*", session=session, **kwargs) def drop_index(self, index_or_name, session=None, **kwargs): """Drops the specified index on this collection. Can be used on non-existant collections or collections with no indexes. Raises OperationFailure on an error (e.g. trying to drop an index that does not exist). `index_or_name` can be either an index name (as returned by `create_index`), or an index specifier (as passed to `create_index`). An index specifier should be a list of (key, direction) pairs. Raises TypeError if index is not an instance of (str, unicode, list). .. warning:: if a custom name was used on index creation (by passing the `name` parameter to :meth:`create_index` or :meth:`ensure_index`) the index **must** be dropped by name. :Parameters: - `index_or_name`: index (or name of index) to drop - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. Added support for arbitrary keyword arguments. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. """ name = index_or_name if isinstance(index_or_name, list): name = helpers._gen_index_name(index_or_name) if not isinstance(name, string_type): raise TypeError("index_or_name must be an index name or list") self.__database.client._purge_index( self.__database.name, self.__name, name) cmd = SON([("dropIndexes", self.__name), ("index", name)]) cmd.update(kwargs) with self._socket_for_writes(session) as sock_info: self._command(sock_info, cmd, read_preference=ReadPreference.PRIMARY, allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), session=session) def reindex(self, session=None, **kwargs): """Rebuilds all indexes on this collection. **DEPRECATED** - The :meth:`~reindex` method is deprecated and will be removed in PyMongo 4.0. Use :meth:`~pymongo.database.Database.command` to run the ``reIndex`` command directly instead:: db.command({"reIndex": ""}) .. note:: Starting in MongoDB 4.6, the `reIndex` command can only be run when connected to a standalone mongod. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): optional arguments to the reIndex command (like maxTimeMS) can be passed as keyword arguments. .. warning:: reindex blocks all other operations (indexes are built in the foreground) and will be slow for large collections. .. versionchanged:: 3.11 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. Added support for arbitrary keyword arguments. .. versionchanged:: 3.5 We no longer apply this collection's write concern to this operation. MongoDB 3.4 silently ignored the write concern. MongoDB 3.6+ returns an error if we include the write concern. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. """ warnings.warn("The reindex method is deprecated and will be removed in " "PyMongo 4.0. Use the Database.command method to run the " "reIndex command instead.", DeprecationWarning, stacklevel=2) cmd = SON([("reIndex", self.__name)]) cmd.update(kwargs) with self._socket_for_writes(session) as sock_info: return self._command( sock_info, cmd, read_preference=ReadPreference.PRIMARY, session=session) def list_indexes(self, session=None): """Get a cursor over the index documents for this collection. >>> for index in db.test.list_indexes(): ... print(index) ... SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')]) :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. :Returns: An instance of :class:`~pymongo.command_cursor.CommandCursor`. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionadded:: 3.0 """ codec_options = CodecOptions(SON) coll = self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY) read_pref = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) def _cmd(session, server, sock_info, slave_ok): cmd = SON([("listIndexes", self.__name), ("cursor", {})]) if sock_info.max_wire_version > 2: with self.__database.client._tmp_session(session, False) as s: try: cursor = self._command(sock_info, cmd, slave_ok, read_pref, codec_options, session=s)["cursor"] except OperationFailure as exc: # Ignore NamespaceNotFound errors to match the behavior # of reading from *.system.indexes. if exc.code != 26: raise cursor = {'id': 0, 'firstBatch': []} return CommandCursor(coll, cursor, sock_info.address, session=s, explicit_session=session is not None) else: res = message._first_batch( sock_info, self.__database.name, "system.indexes", {"ns": self.__full_name}, 0, slave_ok, codec_options, read_pref, cmd, self.database.client._event_listeners) cursor = res["cursor"] # Note that a collection can only have 64 indexes, so there # will never be a getMore call. return CommandCursor(coll, cursor, sock_info.address) return self.__database.client._retryable_read( _cmd, read_pref, session) def index_information(self, session=None): """Get information on this collection's indexes. Returns a dictionary where the keys are index names (as returned by create_index()) and the values are dictionaries containing information about each index. The dictionary is guaranteed to contain at least a single key, ``"key"`` which is a list of (key, direction) pairs specifying the index (as passed to create_index()). It will also contain any other metadata about the indexes, except for the ``"ns"`` and ``"name"`` keys, which are cleaned. Example output might look like this: >>> db.test.create_index("x", unique=True) u'x_1' >>> db.test.index_information() {u'_id_': {u'key': [(u'_id', 1)]}, u'x_1': {u'unique': True, u'key': [(u'x', 1)]}} :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. """ cursor = self.list_indexes(session=session) info = {} for index in cursor: index["key"] = index["key"].items() index = dict(index) info[index.pop("name")] = index return info def options(self, session=None): """Get the options set on this collection. Returns a dictionary of options and their values - see :meth:`~pymongo.database.Database.create_collection` for more information on the possible options. Returns an empty dictionary if the collection has not been created yet. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. """ dbo = self.__database.client.get_database( self.__database.name, self.codec_options, self.read_preference, self.write_concern, self.read_concern) cursor = dbo.list_collections( session=session, filter={"name": self.__name}) result = None for doc in cursor: result = doc break if not result: return {} options = result.get("options", {}) if "create" in options: del options["create"] return options def _aggregate(self, aggregation_command, pipeline, cursor_class, session, explicit_session, **kwargs): # Remove things that are not command options. use_cursor = True if "useCursor" in kwargs: warnings.warn( "The useCursor option is deprecated " "and will be removed in PyMongo 4.0", DeprecationWarning, stacklevel=2) use_cursor = common.validate_boolean( "useCursor", kwargs.pop("useCursor", True)) cmd = aggregation_command( self, cursor_class, pipeline, kwargs, explicit_session, user_fields={'cursor': {'firstBatch': 1}}, use_cursor=use_cursor) return self.__database.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(session), session, retryable=not cmd._performs_write) def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this collection. All optional `aggregate command`_ parameters should be passed as keyword arguments to this method. Valid options include, but are not limited to: - `allowDiskUse` (bool): Enables writing to temporary files. When set to True, aggregation stages can write data to the _tmp subdirectory of the --dbpath directory. The default is False. - `maxTimeMS` (int): The maximum amount of time to allow the operation to run in milliseconds. - `batchSize` (int): The maximum number of documents to return per batch. Ignored if the connected mongod or mongos does not support returning aggregate results using a cursor, or `useCursor` is ``False``. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `useCursor` (bool): Deprecated. Will be removed in PyMongo 4.0. The :meth:`aggregate` method obeys the :attr:`read_preference` of this :class:`Collection`, except when ``$out`` or ``$merge`` are used, in which case :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY` is used. .. note:: This method does not support the 'explain' option. Please use :meth:`~pymongo.database.Database.command` instead. An example is included in the :ref:`aggregate-examples` documentation. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. :Parameters: - `pipeline`: a list of aggregation pipeline stages - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): See list of options above. :Returns: A :class:`~pymongo.command_cursor.CommandCursor` over the result set. .. versionchanged:: 3.9 Apply this collection's read concern to pipelines containing the `$out` stage when connected to MongoDB >= 4.2. Added support for the ``$merge`` pipeline stage. Aggregations that write always use read preference :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. .. versionchanged:: 3.6 Added the `session` parameter. Added the `maxAwaitTimeMS` option. Deprecated the `useCursor` option. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. Support the `collation` option. .. versionchanged:: 3.0 The :meth:`aggregate` method always returns a CommandCursor. The pipeline argument must be a list. .. versionchanged:: 2.7 When the cursor option is used, return :class:`~pymongo.command_cursor.CommandCursor` instead of :class:`~pymongo.cursor.Cursor`. .. versionchanged:: 2.6 Added cursor support. .. versionadded:: 2.3 .. seealso:: :doc:`/examples/aggregation` .. _aggregate command: https://docs.mongodb.com/manual/reference/command/aggregate """ with self.__database.client._tmp_session(session, close=False) as s: return self._aggregate(_CollectionAggregationCommand, pipeline, CommandCursor, session=s, explicit_session=session is not None, **kwargs) def aggregate_raw_batches(self, pipeline, **kwargs): """Perform an aggregation and retrieve batches of raw BSON. Similar to the :meth:`aggregate` method but returns a :class:`~pymongo.cursor.RawBatchCursor`. This example demonstrates how to work with raw batches, but in practice raw batches should be passed to an external library that can decode BSON into another data type, rather than used with PyMongo's :mod:`bson` module. >>> import bson >>> cursor = db.test.aggregate_raw_batches([ ... {'$project': {'x': {'$multiply': [2, '$x']}}}]) >>> for batch in cursor: ... print(bson.decode_all(batch)) .. note:: aggregate_raw_batches does not support sessions or auto encryption. .. versionadded:: 3.6 """ # OP_MSG with document stream returns is required to support # sessions. if "session" in kwargs: raise ConfigurationError( "aggregate_raw_batches does not support sessions") # OP_MSG is required to support encryption. if self.__database.client._encrypter: raise InvalidOperation( "aggregate_raw_batches does not support auto encryption") return self._aggregate(_CollectionRawAggregationCommand, pipeline, RawBatchCommandCursor, session=None, explicit_session=False, **kwargs) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, start_at_operation_time=None, session=None, start_after=None): """Watch changes on this collection. Performs an aggregation with an implicit initial ``$changeStream`` stage and returns a :class:`~pymongo.change_stream.CollectionChangeStream` cursor which iterates over changes on this collection. Introduced in MongoDB 3.6. .. code-block:: python with db.collection.watch() as stream: for change in stream: print(change) The :class:`~pymongo.change_stream.CollectionChangeStream` iterable blocks until the next change document is returned or an error is raised. If the :meth:`~pymongo.change_stream.CollectionChangeStream.next` method encounters a network error when retrieving a batch from the server, it will automatically attempt to recreate the cursor such that no change events are missed. Any error encountered during the resume attempt indicates there may be an outage and will be raised. .. code-block:: python try: with db.collection.watch( [{'$match': {'operationType': 'insert'}}]) as stream: for insert_change in stream: print(insert_change) except pymongo.errors.PyMongoError: # The ChangeStream encountered an unrecoverable error or the # resume attempt failed to recreate the cursor. logging.error('...') For a precise description of the resume process see the `change streams specification`_. .. note:: Using this helper method is preferred to directly calling :meth:`~pymongo.collection.Collection.aggregate` with a ``$changeStream`` stage, for the purpose of supporting resumability. .. warning:: This Collection's :attr:`read_concern` must be ``ReadConcern("majority")`` in order to use the ``$changeStream`` stage. :Parameters: - `pipeline` (optional): A list of aggregation pipeline stages to append to an initial ``$changeStream`` stage. Not all pipeline stages are valid after a ``$changeStream`` stage, see the MongoDB documentation on change streams for the supported stages. - `full_document` (optional): The fullDocument to pass as an option to the ``$changeStream`` stage. Allowed values: 'updateLookup'. When set to 'updateLookup', the change notification for partial updates will include both a delta describing the changes to the document, as well as a copy of the entire document that was changed from some time after the change occurred. - `resume_after` (optional): A resume token. If provided, the change stream will start returning changes that occur directly after the operation specified in the resume token. A resume token is the _id value of a change document. - `max_await_time_ms` (optional): The maximum time in milliseconds for the server to wait for changes before responding to a getMore operation. - `batch_size` (optional): The maximum number of documents to return per batch. - `collation` (optional): The :class:`~pymongo.collation.Collation` to use for the aggregation. - `start_at_operation_time` (optional): If provided, the resulting change stream will only return changes that occurred at or after the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `start_after` (optional): The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. :Returns: A :class:`~pymongo.change_stream.CollectionChangeStream` cursor. .. versionchanged:: 3.9 Added the ``start_after`` parameter. .. versionchanged:: 3.7 Added the ``start_at_operation_time`` parameter. .. versionadded:: 3.6 .. mongodoc:: changeStreams .. _change streams specification: https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return CollectionChangeStream( self, pipeline, full_document, resume_after, max_await_time_ms, batch_size, collation, start_at_operation_time, session, start_after) def group(self, key, condition, initial, reduce, finalize=None, **kwargs): """Perform a query similar to an SQL *group by* operation. **DEPRECATED** - The group command was deprecated in MongoDB 3.4. The :meth:`~group` method is deprecated and will be removed in PyMongo 4.0. Use :meth:`~aggregate` with the `$group` stage or :meth:`~map_reduce` instead. .. versionchanged:: 3.5 Deprecated the group method. .. versionchanged:: 3.4 Added the `collation` option. .. versionchanged:: 2.2 Removed deprecated argument: command """ warnings.warn("The group method is deprecated and will be removed in " "PyMongo 4.0. Use the aggregate method with the $group " "stage or the map_reduce method instead.", DeprecationWarning, stacklevel=2) group = {} if isinstance(key, string_type): group["$keyf"] = Code(key) elif key is not None: group = {"key": helpers._fields_list_to_dict(key, "key")} group["ns"] = self.__name group["$reduce"] = Code(reduce) group["cond"] = condition group["initial"] = initial if finalize is not None: group["finalize"] = Code(finalize) cmd = SON([("group", group)]) collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) with self._socket_for_reads(session=None) as (sock_info, slave_ok): return self._command(sock_info, cmd, slave_ok, collation=collation, user_fields={'retval': 1})["retval"] def rename(self, new_name, session=None, **kwargs): """Rename this collection. If operating in auth mode, client must be authorized as an admin to perform this operation. Raises :class:`TypeError` if `new_name` is not an instance of :class:`basestring` (:class:`str` in python 3). Raises :class:`~pymongo.errors.InvalidName` if `new_name` is not a valid collection name. :Parameters: - `new_name`: new name for this collection - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional arguments to the rename command may be passed as keyword arguments to this helper method (i.e. ``dropTarget=True``) .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. """ if not isinstance(new_name, string_type): raise TypeError("new_name must be an " "instance of %s" % (string_type.__name__,)) if not new_name or ".." in new_name: raise InvalidName("collection names cannot be empty") if new_name[0] == "." or new_name[-1] == ".": raise InvalidName("collecion names must not start or end with '.'") if "$" in new_name and not new_name.startswith("oplog.$main"): raise InvalidName("collection names must not contain '$'") new_name = "%s.%s" % (self.__database.name, new_name) cmd = SON([("renameCollection", self.__full_name), ("to", new_name)]) cmd.update(kwargs) write_concern = self._write_concern_for_cmd(cmd, session) with self._socket_for_writes(session) as sock_info: with self.__database.client._tmp_session(session) as s: return sock_info.command( 'admin', cmd, write_concern=write_concern, parse_write_concern_error=True, session=s, client=self.__database.client) def distinct(self, key, filter=None, session=None, **kwargs): """Get a list of distinct values for `key` among all documents in this collection. Raises :class:`TypeError` if `key` is not an instance of :class:`basestring` (:class:`str` in python 3). All optional distinct parameters should be passed as keyword arguments to this method. Valid options include: - `maxTimeMS` (int): The maximum amount of time to allow the count command to run, in milliseconds. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. The :meth:`distinct` method obeys the :attr:`read_preference` of this :class:`Collection`. :Parameters: - `key`: name of the field for which we want to get the distinct values - `filter` (optional): A query document that specifies the documents from which to retrieve the distinct values. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): See list of options above. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Support the `collation` option. """ if not isinstance(key, string_type): raise TypeError("key must be an " "instance of %s" % (string_type.__name__,)) cmd = SON([("distinct", self.__name), ("key", key)]) if filter is not None: if "query" in kwargs: raise ConfigurationError("can't pass both filter and query") kwargs["query"] = filter collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) def _cmd(session, server, sock_info, slave_ok): return self._command( sock_info, cmd, slave_ok, read_concern=self.read_concern, collation=collation, session=session, user_fields={"values": 1})["values"] return self.__database.client._retryable_read( _cmd, self._read_preference_for(session), session) def _map_reduce(self, map, reduce, out, session, read_pref, **kwargs): """Internal mapReduce helper.""" cmd = SON([("mapReduce", self.__name), ("map", map), ("reduce", reduce), ("out", out)]) collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd.update(kwargs) inline = 'inline' in out if inline: user_fields = {'results': 1} else: user_fields = None read_pref = ((session and session._txn_read_preference()) or read_pref) with self.__database.client._socket_for_reads(read_pref, session) as ( sock_info, slave_ok): if (sock_info.max_wire_version >= 4 and ('readConcern' not in cmd) and inline): read_concern = self.read_concern else: read_concern = None if 'writeConcern' not in cmd and not inline: write_concern = self._write_concern_for(session) else: write_concern = None return self._command( sock_info, cmd, slave_ok, read_pref, read_concern=read_concern, write_concern=write_concern, collation=collation, session=session, user_fields=user_fields) def map_reduce(self, map, reduce, out, full_response=False, session=None, **kwargs): """Perform a map/reduce operation on this collection. If `full_response` is ``False`` (default) returns a :class:`~pymongo.collection.Collection` instance containing the results of the operation. Otherwise, returns the full response from the server to the `map reduce command`_. :Parameters: - `map`: map function (as a JavaScript string) - `reduce`: reduce function (as a JavaScript string) - `out`: output collection name or `out object` (dict). See the `map reduce command`_ documentation for available options. Note: `out` options are order sensitive. :class:`~bson.son.SON` can be used to specify multiple options. e.g. SON([('replace', ), ('db', )]) - `full_response` (optional): if ``True``, return full response to this command - otherwise just return the result collection - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional arguments to the `map reduce command`_ may be passed as keyword arguments to this helper method, e.g.:: >>> db.test.map_reduce(map, reduce, "myresults", limit=2) .. note:: The :meth:`map_reduce` method does **not** obey the :attr:`read_preference` of this :class:`Collection`. To run mapReduce on a secondary use the :meth:`inline_map_reduce` method instead. .. note:: The :attr:`~pymongo.collection.Collection.write_concern` of this collection is automatically applied to this operation (if the output is not inline) when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Apply this collection's write concern automatically to this operation when connected to MongoDB >= 3.4. .. seealso:: :doc:`/examples/aggregation` .. versionchanged:: 3.4 Added the `collation` option. .. versionchanged:: 2.2 Removed deprecated arguments: merge_output and reduce_output .. _map reduce command: http://docs.mongodb.org/manual/reference/command/mapReduce/ .. mongodoc:: mapreduce """ if not isinstance(out, (string_type, abc.Mapping)): raise TypeError("'out' must be an instance of " "%s or a mapping" % (string_type.__name__,)) response = self._map_reduce(map, reduce, out, session, ReadPreference.PRIMARY, **kwargs) if full_response or not response.get('result'): return response elif isinstance(response['result'], dict): dbase = response['result']['db'] coll = response['result']['collection'] return self.__database.client[dbase][coll] else: return self.__database[response["result"]] def inline_map_reduce(self, map, reduce, full_response=False, session=None, **kwargs): """Perform an inline map/reduce operation on this collection. Perform the map/reduce operation on the server in RAM. A result collection is not created. The result set is returned as a list of documents. If `full_response` is ``False`` (default) returns the result documents in a list. Otherwise, returns the full response from the server to the `map reduce command`_. The :meth:`inline_map_reduce` method obeys the :attr:`read_preference` of this :class:`Collection`. :Parameters: - `map`: map function (as a JavaScript string) - `reduce`: reduce function (as a JavaScript string) - `full_response` (optional): if ``True``, return full response to this command - otherwise just return the result collection - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional arguments to the `map reduce command`_ may be passed as keyword arguments to this helper method, e.g.:: >>> db.test.inline_map_reduce(map, reduce, limit=2) .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the `collation` option. """ res = self._map_reduce(map, reduce, {"inline": 1}, session, self.read_preference, **kwargs) if full_response: return res else: return res.get("results") def _write_concern_for_cmd(self, cmd, session): raw_wc = cmd.get('writeConcern') if raw_wc is not None: return WriteConcern(**raw_wc) else: return self._write_concern_for(session) def __find_and_modify(self, filter, projection, sort, upsert=None, return_document=ReturnDocument.BEFORE, array_filters=None, hint=None, session=None, **kwargs): """Internal findAndModify helper.""" common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): raise ValueError("return_document must be " "ReturnDocument.BEFORE or ReturnDocument.AFTER") collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd = SON([("findAndModify", self.__name), ("query", filter), ("new", return_document)]) cmd.update(kwargs) if projection is not None: cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") if sort is not None: cmd["sort"] = helpers._index_document(sort) if upsert is not None: common.validate_boolean("upsert", upsert) cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, string_type): hint = helpers._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) def _find_and_modify(session, sock_info, retryable_write): if array_filters is not None: if sock_info.max_wire_version < 6: raise ConfigurationError( 'Must be connected to MongoDB 3.6+ to use ' 'arrayFilters.') if not write_concern.acknowledged: raise ConfigurationError( 'arrayFilters is unsupported for unacknowledged ' 'writes.') cmd["arrayFilters"] = array_filters if hint is not None: if sock_info.max_wire_version < 8: raise ConfigurationError( 'Must be connected to MongoDB 4.2+ to use hint.') if not write_concern.acknowledged: raise ConfigurationError( 'hint is unsupported for unacknowledged writes.') cmd['hint'] = hint if (sock_info.max_wire_version >= 4 and not write_concern.is_server_default): cmd['writeConcern'] = write_concern.document out = self._command(sock_info, cmd, read_preference=ReadPreference.PRIMARY, write_concern=write_concern, collation=collation, session=session, retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS) _check_write_command_response(out) return out.get("value") return self.__database.client._retryable_write( write_concern.acknowledged, _find_and_modify, session) def find_one_and_delete(self, filter, projection=None, sort=None, hint=None, session=None, **kwargs): """Finds a single document and deletes it, returning the document. >>> db.test.count_documents({'x': 1}) 2 >>> db.test.find_one_and_delete({'x': 1}) {u'x': 1, u'_id': ObjectId('54f4e12bfba5220aa4d6dee8')} >>> db.test.count_documents({'x': 1}) 1 If multiple documents match *filter*, a *sort* can be applied. >>> for doc in db.test.find({'x': 1}): ... print(doc) ... {u'x': 1, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} >>> db.test.find_one_and_delete( ... {'x': 1}, sort=[('_id', pymongo.DESCENDING)]) {u'x': 1, u'_id': 2} The *projection* option can be used to limit the fields returned. >>> db.test.find_one_and_delete({'x': 1}, projection={'_id': False}) {u'x': 1} :Parameters: - `filter`: A query that matches the document to delete. - `projection` (optional): a list of field names that should be returned in the result document or a mapping specifying the fields to include or exclude. If `projection` is a list "_id" will always be returned. Use a mapping to exclude fields from the result (e.g. projection={'_id': False}). - `sort` (optional): a list of (key, direction) pairs specifying the sort order for the query. If multiple documents match the query, they are sorted and the first is deleted. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional command arguments can be passed as keyword arguments (for example maxTimeMS can be used with recent server versions). .. versionchanged:: 3.11 Added ``hint`` parameter. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.2 Respects write concern. .. warning:: Starting in PyMongo 3.2, this command uses the :class:`~pymongo.write_concern.WriteConcern` of this :class:`~pymongo.collection.Collection` when connected to MongoDB >= 3.2. Note that using an elevated write concern with this command may be slower compared to using the default write concern. .. versionchanged:: 3.4 Added the `collation` option. .. versionadded:: 3.0 """ kwargs['remove'] = True return self.__find_and_modify(filter, projection, sort, hint=hint, session=session, **kwargs) def find_one_and_replace(self, filter, replacement, projection=None, sort=None, upsert=False, return_document=ReturnDocument.BEFORE, hint=None, session=None, **kwargs): """Finds a single document and replaces it, returning either the original or the replaced document. The :meth:`find_one_and_replace` method differs from :meth:`find_one_and_update` by replacing the document matched by *filter*, rather than modifying the existing document. >>> for doc in db.test.find({}): ... print(doc) ... {u'x': 1, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} >>> db.test.find_one_and_replace({'x': 1}, {'y': 1}) {u'x': 1, u'_id': 0} >>> for doc in db.test.find({}): ... print(doc) ... {u'y': 1, u'_id': 0} {u'x': 1, u'_id': 1} {u'x': 1, u'_id': 2} :Parameters: - `filter`: A query that matches the document to replace. - `replacement`: The replacement document. - `projection` (optional): A list of field names that should be returned in the result document or a mapping specifying the fields to include or exclude. If `projection` is a list "_id" will always be returned. Use a mapping to exclude fields from the result (e.g. projection={'_id': False}). - `sort` (optional): a list of (key, direction) pairs specifying the sort order for the query. If multiple documents match the query, they are sorted and the first is replaced. - `upsert` (optional): When ``True``, inserts a new document if no document matches the query. Defaults to ``False``. - `return_document`: If :attr:`ReturnDocument.BEFORE` (the default), returns the original document before it was replaced, or ``None`` if no document matches. If :attr:`ReturnDocument.AFTER`, returns the replaced or inserted document. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional command arguments can be passed as keyword arguments (for example maxTimeMS can be used with recent server versions). .. versionchanged:: 3.11 Added the ``hint`` option. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the ``collation`` option. .. versionchanged:: 3.2 Respects write concern. .. warning:: Starting in PyMongo 3.2, this command uses the :class:`~pymongo.write_concern.WriteConcern` of this :class:`~pymongo.collection.Collection` when connected to MongoDB >= 3.2. Note that using an elevated write concern with this command may be slower compared to using the default write concern. .. versionadded:: 3.0 """ common.validate_ok_for_replace(replacement) kwargs['update'] = replacement return self.__find_and_modify(filter, projection, sort, upsert, return_document, hint=hint, session=session, **kwargs) def find_one_and_update(self, filter, update, projection=None, sort=None, upsert=False, return_document=ReturnDocument.BEFORE, array_filters=None, hint=None, session=None, **kwargs): """Finds a single document and updates it, returning either the original or the updated document. >>> db.test.find_one_and_update( ... {'_id': 665}, {'$inc': {'count': 1}, '$set': {'done': True}}) {u'_id': 665, u'done': False, u'count': 25}} Returns ``None`` if no document matches the filter. >>> db.test.find_one_and_update( ... {'_exists': False}, {'$inc': {'count': 1}}) When the filter matches, by default :meth:`find_one_and_update` returns the original version of the document before the update was applied. To return the updated (or inserted in the case of *upsert*) version of the document instead, use the *return_document* option. >>> from pymongo import ReturnDocument >>> db.example.find_one_and_update( ... {'_id': 'userid'}, ... {'$inc': {'seq': 1}}, ... return_document=ReturnDocument.AFTER) {u'_id': u'userid', u'seq': 1} You can limit the fields returned with the *projection* option. >>> db.example.find_one_and_update( ... {'_id': 'userid'}, ... {'$inc': {'seq': 1}}, ... projection={'seq': True, '_id': False}, ... return_document=ReturnDocument.AFTER) {u'seq': 2} The *upsert* option can be used to create the document if it doesn't already exist. >>> db.example.delete_many({}).deleted_count 1 >>> db.example.find_one_and_update( ... {'_id': 'userid'}, ... {'$inc': {'seq': 1}}, ... projection={'seq': True, '_id': False}, ... upsert=True, ... return_document=ReturnDocument.AFTER) {u'seq': 1} If multiple documents match *filter*, a *sort* can be applied. >>> for doc in db.test.find({'done': True}): ... print(doc) ... {u'_id': 665, u'done': True, u'result': {u'count': 26}} {u'_id': 701, u'done': True, u'result': {u'count': 17}} >>> db.test.find_one_and_update( ... {'done': True}, ... {'$set': {'final': True}}, ... sort=[('_id', pymongo.DESCENDING)]) {u'_id': 701, u'done': True, u'result': {u'count': 17}} :Parameters: - `filter`: A query that matches the document to update. - `update`: The update operations to apply. - `projection` (optional): A list of field names that should be returned in the result document or a mapping specifying the fields to include or exclude. If `projection` is a list "_id" will always be returned. Use a dict to exclude fields from the result (e.g. projection={'_id': False}). - `sort` (optional): a list of (key, direction) pairs specifying the sort order for the query. If multiple documents match the query, they are sorted and the first is updated. - `upsert` (optional): When ``True``, inserts a new document if no document matches the query. Defaults to ``False``. - `return_document`: If :attr:`ReturnDocument.BEFORE` (the default), returns the original document before it was updated. If :attr:`ReturnDocument.AFTER`, returns the updated or inserted document. - `array_filters` (optional): A list of filters specifying which array elements an update should apply. This option is only supported on MongoDB 3.6 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional command arguments can be passed as keyword arguments (for example maxTimeMS can be used with recent server versions). .. versionchanged:: 3.11 Added the ``hint`` option. .. versionchanged:: 3.9 Added the ability to accept a pipeline as the ``update``. .. versionchanged:: 3.6 Added the ``array_filters`` and ``session`` options. .. versionchanged:: 3.4 Added the ``collation`` option. .. versionchanged:: 3.2 Respects write concern. .. warning:: Starting in PyMongo 3.2, this command uses the :class:`~pymongo.write_concern.WriteConcern` of this :class:`~pymongo.collection.Collection` when connected to MongoDB >= 3.2. Note that using an elevated write concern with this command may be slower compared to using the default write concern. .. versionadded:: 3.0 """ common.validate_ok_for_update(update) common.validate_list_or_none('array_filters', array_filters) kwargs['update'] = update return self.__find_and_modify(filter, projection, sort, upsert, return_document, array_filters, hint=hint, session=session, **kwargs) def save(self, to_save, manipulate=True, check_keys=True, **kwargs): """Save a document in this collection. **DEPRECATED** - Use :meth:`insert_one` or :meth:`replace_one` instead. .. versionchanged:: 3.0 Removed the `safe` parameter. Pass ``w=0`` for unacknowledged write operations. """ warnings.warn("save is deprecated. Use insert_one or replace_one " "instead", DeprecationWarning, stacklevel=2) common.validate_is_document_type("to_save", to_save) write_concern = None collation = validate_collation_or_none(kwargs.pop('collation', None)) if kwargs: write_concern = WriteConcern(**kwargs) if not (isinstance(to_save, RawBSONDocument) or "_id" in to_save): return self._insert( to_save, True, check_keys, manipulate, write_concern) else: self._update_retryable( {"_id": to_save["_id"]}, to_save, True, check_keys, False, manipulate, write_concern, collation=collation) return to_save.get("_id") def insert(self, doc_or_docs, manipulate=True, check_keys=True, continue_on_error=False, **kwargs): """Insert a document(s) into this collection. **DEPRECATED** - Use :meth:`insert_one` or :meth:`insert_many` instead. .. versionchanged:: 3.0 Removed the `safe` parameter. Pass ``w=0`` for unacknowledged write operations. """ warnings.warn("insert is deprecated. Use insert_one or insert_many " "instead.", DeprecationWarning, stacklevel=2) write_concern = None if kwargs: write_concern = WriteConcern(**kwargs) return self._insert(doc_or_docs, not continue_on_error, check_keys, manipulate, write_concern) def update(self, spec, document, upsert=False, manipulate=False, multi=False, check_keys=True, **kwargs): """Update a document(s) in this collection. **DEPRECATED** - Use :meth:`replace_one`, :meth:`update_one`, or :meth:`update_many` instead. .. versionchanged:: 3.0 Removed the `safe` parameter. Pass ``w=0`` for unacknowledged write operations. """ warnings.warn("update is deprecated. Use replace_one, update_one or " "update_many instead.", DeprecationWarning, stacklevel=2) common.validate_is_mapping("spec", spec) common.validate_is_mapping("document", document) if document: # If a top level key begins with '$' this is a modify operation # and we should skip key validation. It doesn't matter which key # we check here. Passing a document with a mix of top level keys # starting with and without a '$' is invalid and the server will # raise an appropriate exception. first = next(iter(document)) if first.startswith('$'): check_keys = False write_concern = None collation = validate_collation_or_none(kwargs.pop('collation', None)) if kwargs: write_concern = WriteConcern(**kwargs) return self._update_retryable( spec, document, upsert, check_keys, multi, manipulate, write_concern, collation=collation) def remove(self, spec_or_id=None, multi=True, **kwargs): """Remove a document(s) from this collection. **DEPRECATED** - Use :meth:`delete_one` or :meth:`delete_many` instead. .. versionchanged:: 3.0 Removed the `safe` parameter. Pass ``w=0`` for unacknowledged write operations. """ warnings.warn("remove is deprecated. Use delete_one or delete_many " "instead.", DeprecationWarning, stacklevel=2) if spec_or_id is None: spec_or_id = {} if not isinstance(spec_or_id, abc.Mapping): spec_or_id = {"_id": spec_or_id} write_concern = None collation = validate_collation_or_none(kwargs.pop('collation', None)) if kwargs: write_concern = WriteConcern(**kwargs) return self._delete_retryable( spec_or_id, multi, write_concern, collation=collation) def find_and_modify(self, query={}, update=None, upsert=False, sort=None, full_response=False, manipulate=False, **kwargs): """Update and return an object. **DEPRECATED** - Use :meth:`find_one_and_delete`, :meth:`find_one_and_replace`, or :meth:`find_one_and_update` instead. """ warnings.warn("find_and_modify is deprecated, use find_one_and_delete" ", find_one_and_replace, or find_one_and_update instead", DeprecationWarning, stacklevel=2) if not update and not kwargs.get('remove', None): raise ValueError("Must either update or remove") if update and kwargs.get('remove', None): raise ValueError("Can't do both update and remove") # No need to include empty args if query: kwargs['query'] = query if update: kwargs['update'] = update if upsert: kwargs['upsert'] = upsert if sort: # Accept a list of tuples to match Cursor's sort parameter. if isinstance(sort, list): kwargs['sort'] = helpers._index_document(sort) # Accept OrderedDict, SON, and dict with len == 1 so we # don't break existing code already using find_and_modify. elif (isinstance(sort, ORDERED_TYPES) or isinstance(sort, dict) and len(sort) == 1): warnings.warn("Passing mapping types for `sort` is deprecated," " use a list of (key, direction) pairs instead", DeprecationWarning, stacklevel=2) kwargs['sort'] = sort else: raise TypeError("sort must be a list of (key, direction) " "pairs, a dict of len 1, or an instance of " "SON or OrderedDict") fields = kwargs.pop("fields", None) if fields is not None: kwargs["fields"] = helpers._fields_list_to_dict(fields, "fields") collation = validate_collation_or_none(kwargs.pop('collation', None)) cmd = SON([("findAndModify", self.__name)]) cmd.update(kwargs) write_concern = self._write_concern_for_cmd(cmd, None) def _find_and_modify(session, sock_info, retryable_write): if (sock_info.max_wire_version >= 4 and not write_concern.is_server_default): cmd['writeConcern'] = write_concern.document result = self._command( sock_info, cmd, read_preference=ReadPreference.PRIMARY, collation=collation, session=session, retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS) _check_write_command_response(result) return result out = self.__database.client._retryable_write( write_concern.acknowledged, _find_and_modify, None) if full_response: return out else: document = out.get('value') if manipulate: document = self.__database._fix_outgoing(document, self) return document def __iter__(self): return self def __next__(self): raise TypeError("'Collection' object is not iterable") next = __next__ def __call__(self, *args, **kwargs): """This is only here so that some API misusages are easier to debug. """ if "." not in self.__name: raise TypeError("'Collection' object is not callable. If you " "meant to call the '%s' method on a 'Database' " "object it is failing because no such method " "exists." % self.__name) raise TypeError("'Collection' object is not callable. If you meant to " "call the '%s' method on a 'Collection' object it is " "failing because no such method exists." % self.__name.split(".")[-1]) pymongo-3.11.0/pymongo/command_cursor.py000066400000000000000000000253511374256237000203600ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """CommandCursor class to iterate over command results.""" from collections import deque from bson.py3compat import integer_types from pymongo.errors import (ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure) from pymongo.message import (_CursorAddress, _GetMore, _RawBatchGetMore) class CommandCursor(object): """A cursor / iterator over command cursors.""" _getmore_class = _GetMore def __init__(self, collection, cursor_info, address, retrieved=0, batch_size=0, max_await_time_ms=None, session=None, explicit_session=False): """Create a new command cursor. The parameter 'retrieved' is unused. """ self.__collection = collection self.__id = cursor_info['id'] self.__data = deque(cursor_info['firstBatch']) self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken') self.__address = address self.__batch_size = batch_size self.__max_await_time_ms = max_await_time_ms self.__session = session self.__explicit_session = explicit_session self.__killed = (self.__id == 0) if self.__killed: self.__end_session(True) if "ns" in cursor_info: self.__ns = cursor_info["ns"] else: self.__ns = collection.full_name self.batch_size(batch_size) if (not isinstance(max_await_time_ms, integer_types) and max_await_time_ms is not None): raise TypeError("max_await_time_ms must be an integer or None") def __del__(self): if self.__id and not self.__killed: self.__die() def __die(self, synchronous=False): """Closes this cursor. """ already_killed = self.__killed self.__killed = True if self.__id and not already_killed: address = _CursorAddress( self.__address, self.__collection.full_name) if synchronous: self.__collection.database.client._close_cursor_now( self.__id, address, session=self.__session) else: # The cursor will be closed later in a different session. self.__collection.database.client._close_cursor( self.__id, address) self.__end_session(synchronous) def __end_session(self, synchronous): if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None def close(self): """Explicitly close / kill this cursor. """ self.__die(True) def batch_size(self, batch_size): """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. .. note:: batch_size can not override MongoDB's internal limits on the amount of data it will return to the client in a single batch (i.e if you set batch size to 1,000,000,000, MongoDB will currently only return 4-16MB of results per batch). Raises :exc:`TypeError` if `batch_size` is not an integer. Raises :exc:`ValueError` if `batch_size` is less than ``0``. :Parameters: - `batch_size`: The size of each batch of results requested. """ if not isinstance(batch_size, integer_types): raise TypeError("batch_size must be an integer") if batch_size < 0: raise ValueError("batch_size must be >= 0") self.__batch_size = batch_size == 1 and 2 or batch_size return self def _has_next(self): """Returns `True` if the cursor has documents remaining from the previous batch.""" return len(self.__data) > 0 @property def _post_batch_resume_token(self): """Retrieve the postBatchResumeToken from the response to a changeStream aggregate or getMore.""" return self.__postbatchresumetoken def __send_message(self, operation): """Send a getmore message and handle the response. """ def kill(): self.__killed = True self.__end_session(True) client = self.__collection.database.client try: response = client._run_operation_with_response( operation, self._unpack_response, address=self.__address) except OperationFailure: kill() raise except NotMasterError: # Don't send kill cursors to another server after a "not master" # error. It's completely pointless. kill() raise except ConnectionFailure: # Don't try to send kill cursors on another socket # or to another server. It can cause a _pinValue # assertion on some server releases if we get here # due to a socket timeout. kill() raise except Exception: # Close the cursor self.__die() raise from_command = response.from_command reply = response.data docs = response.docs if from_command: cursor = docs[0]['cursor'] documents = cursor['nextBatch'] self.__postbatchresumetoken = cursor.get('postBatchResumeToken') self.__id = cursor['id'] else: documents = docs self.__id = reply.cursor_id if self.__id == 0: kill() self.__data = deque(documents) def _unpack_response(self, response, cursor_id, codec_options, user_fields=None, legacy_response=False): return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _refresh(self): """Refreshes the cursor with more data from the server. Returns the length of self.__data after refresh. Will exit early if self.__data is already non-empty. Raises OperationFailure when the cursor cannot be refreshed due to an error on the query. """ if len(self.__data) or self.__killed: return len(self.__data) if self.__id: # Get More dbname, collname = self.__ns.split('.', 1) read_pref = self.__collection._read_preference_for(self.session) self.__send_message( self._getmore_class(dbname, collname, self.__batch_size, self.__id, self.__collection.codec_options, read_pref, self.__session, self.__collection.database.client, self.__max_await_time_ms, False)) else: # Cursor id is zero nothing else to return self.__killed = True self.__end_session(True) return len(self.__data) @property def alive(self): """Does this cursor have the potential to return more data? Even if :attr:`alive` is ``True``, :meth:`next` can raise :exc:`StopIteration`. Best to use a for loop:: for doc in collection.aggregate(pipeline): print(doc) .. note:: :attr:`alive` can be True while iterating a cursor from a failed server. In this case :attr:`alive` will return False after :meth:`next` fails to retrieve the next batch of results from the server. """ return bool(len(self.__data) or (not self.__killed)) @property def cursor_id(self): """Returns the id of the cursor.""" return self.__id @property def address(self): """The (host, port) of the server used, or None. .. versionadded:: 3.0 """ return self.__address @property def session(self): """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. .. versionadded:: 3.6 """ if self.__explicit_session: return self.__session def __iter__(self): return self def next(self): """Advance the cursor.""" # Block until a document is returnable. while self.alive: doc = self._try_next(True) if doc is not None: return doc raise StopIteration __next__ = next def _try_next(self, get_more_allowed): """Advance the cursor blocking for at most one getMore command.""" if not len(self.__data) and not self.__killed and get_more_allowed: self._refresh() if len(self.__data): coll = self.__collection return coll.database._fix_outgoing(self.__data.popleft(), coll) else: return None def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() class RawBatchCommandCursor(CommandCursor): _getmore_class = _RawBatchGetMore def __init__(self, collection, cursor_info, address, retrieved=0, batch_size=0, max_await_time_ms=None, session=None, explicit_session=False): """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - see :meth:`~pymongo.collection.Collection.aggregate_raw_batches` instead. .. mongodoc:: cursors """ assert not cursor_info.get('firstBatch') super(RawBatchCommandCursor, self).__init__( collection, cursor_info, address, retrieved, batch_size, max_await_time_ms, session, explicit_session) def _unpack_response(self, response, cursor_id, codec_options, user_fields=None, legacy_response=False): return response.raw_response(cursor_id) def __getitem__(self, index): raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") pymongo-3.11.0/pymongo/common.py000066400000000000000000001004671374256237000166370ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Functions and classes common to multiple pymongo modules.""" import datetime import warnings from bson import SON from bson.binary import UuidRepresentation from bson.codec_options import CodecOptions, TypeRegistry from bson.py3compat import abc, integer_types, iteritems, string_type, PY3 from bson.raw_bson import RawBSONDocument from pymongo.auth import MECHANISMS from pymongo.compression_support import (validate_compressors, validate_zlib_compression_level) from pymongo.driver_info import DriverInfo from pymongo.encryption_options import validate_auto_encryption_opts_or_none from pymongo.errors import ConfigurationError from pymongo.monitoring import _validate_event_listeners from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _MONGOS_MODES, _ServerMode from pymongo.ssl_support import (validate_cert_reqs, validate_allow_invalid_certs) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern try: from collections import OrderedDict ORDERED_TYPES = (SON, OrderedDict) except ImportError: ORDERED_TYPES = (SON,) if PY3: from urllib.parse import unquote_plus else: from urllib import unquote_plus # Defaults until we connect to a server and get updated limits. MAX_BSON_SIZE = 16 * (1024 ** 2) MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE MIN_WIRE_VERSION = 0 MAX_WIRE_VERSION = 0 MAX_WRITE_BATCH_SIZE = 1000 # What this version of PyMongo supports. MIN_SUPPORTED_SERVER_VERSION = "2.6" MIN_SUPPORTED_WIRE_VERSION = 2 MAX_SUPPORTED_WIRE_VERSION = 9 # Frequency to call ismaster on servers, in seconds. HEARTBEAT_FREQUENCY = 10 # Frequency to process kill-cursors, in seconds. See MongoClient.close_cursor. KILL_CURSOR_FREQUENCY = 1 # Frequency to process events queue, in seconds. EVENTS_QUEUE_FREQUENCY = 1 # How long to wait, in seconds, for a suitable server to be found before # aborting an operation. For example, if the client attempts an insert # during a replica set election, SERVER_SELECTION_TIMEOUT governs the # longest it is willing to wait for a new primary to be found. SERVER_SELECTION_TIMEOUT = 30 # Spec requires at least 500ms between ismaster calls. MIN_HEARTBEAT_INTERVAL = 0.5 # Spec requires at least 60s between SRV rescans. MIN_SRV_RESCAN_INTERVAL = 60 # Default connectTimeout in seconds. CONNECT_TIMEOUT = 20.0 # Default value for maxPoolSize. MAX_POOL_SIZE = 100 # Default value for minPoolSize. MIN_POOL_SIZE = 0 # Default value for maxIdleTimeMS. MAX_IDLE_TIME_MS = None # Default value for maxIdleTimeMS in seconds. MAX_IDLE_TIME_SEC = None # Default value for waitQueueTimeoutMS in seconds. WAIT_QUEUE_TIMEOUT = None # Default value for localThresholdMS. LOCAL_THRESHOLD_MS = 15 # Default value for retryWrites. RETRY_WRITES = True # Default value for retryReads. RETRY_READS = True # mongod/s 2.6 and above return code 59 when a command doesn't exist. COMMAND_NOT_FOUND_CODES = (59,) # Error codes to ignore if GridFS calls createIndex on a secondary UNAUTHORIZED_CODES = (13, 16547, 16548) # Maximum number of sessions to send in a single endSessions command. # From the driver sessions spec. _MAX_END_SESSIONS = 10000 def partition_node(node): """Split a host:port string into (host, int(port)) pair.""" host = node port = 27017 idx = node.rfind(':') if idx != -1: host, port = node[:idx], int(node[idx + 1:]) if host.startswith('['): host = host[1:-1] return host, port def clean_node(node): """Split and normalize a node name from an ismaster response.""" host, port = partition_node(node) # Normalize hostname to lowercase, since DNS is case-insensitive: # http://tools.ietf.org/html/rfc4343 # This prevents useless rediscovery if "foo.com" is in the seed list but # "FOO.com" is in the ismaster response. return host.lower(), port def raise_config_error(key, dummy): """Raise ConfigurationError with the given key name.""" raise ConfigurationError("Unknown option %s" % (key,)) # Mapping of URI uuid representation options to valid subtypes. _UUID_REPRESENTATIONS = { 'unspecified': UuidRepresentation.UNSPECIFIED, 'standard': UuidRepresentation.STANDARD, 'pythonLegacy': UuidRepresentation.PYTHON_LEGACY, 'javaLegacy': UuidRepresentation.JAVA_LEGACY, 'csharpLegacy': UuidRepresentation.CSHARP_LEGACY } def validate_boolean(option, value): """Validates that 'value' is True or False.""" if isinstance(value, bool): return value raise TypeError("%s must be True or False" % (option,)) def validate_boolean_or_string(option, value): """Validates that value is True, False, 'true', or 'false'.""" if isinstance(value, string_type): if value not in ('true', 'false'): raise ValueError("The value of %s must be " "'true' or 'false'" % (option,)) return value == 'true' return validate_boolean(option, value) def validate_integer(option, value): """Validates that 'value' is an integer (or basestring representation). """ if isinstance(value, integer_types): return value elif isinstance(value, string_type): try: return int(value) except ValueError: raise ValueError("The value of %s must be " "an integer" % (option,)) raise TypeError("Wrong type for %s, value must be an integer" % (option,)) def validate_positive_integer(option, value): """Validate that 'value' is a positive integer, which does not include 0. """ val = validate_integer(option, value) if val <= 0: raise ValueError("The value of %s must be " "a positive integer" % (option,)) return val def validate_non_negative_integer(option, value): """Validate that 'value' is a positive integer or 0. """ val = validate_integer(option, value) if val < 0: raise ValueError("The value of %s must be " "a non negative integer" % (option,)) return val def validate_readable(option, value): """Validates that 'value' is file-like and readable. """ if value is None: return value # First make sure its a string py3.3 open(True, 'r') succeeds # Used in ssl cert checking due to poor ssl module error reporting value = validate_string(option, value) open(value, 'r').close() return value def validate_positive_integer_or_none(option, value): """Validate that 'value' is a positive integer or None. """ if value is None: return value return validate_positive_integer(option, value) def validate_non_negative_integer_or_none(option, value): """Validate that 'value' is a positive integer or 0 or None. """ if value is None: return value return validate_non_negative_integer(option, value) def validate_string(option, value): """Validates that 'value' is an instance of `basestring` for Python 2 or `str` for Python 3. """ if isinstance(value, string_type): return value raise TypeError("Wrong type for %s, value must be " "an instance of %s" % (option, string_type.__name__)) def validate_string_or_none(option, value): """Validates that 'value' is an instance of `basestring` or `None`. """ if value is None: return value return validate_string(option, value) def validate_int_or_basestring(option, value): """Validates that 'value' is an integer or string. """ if isinstance(value, integer_types): return value elif isinstance(value, string_type): try: return int(value) except ValueError: return value raise TypeError("Wrong type for %s, value must be an " "integer or a string" % (option,)) def validate_non_negative_int_or_basestring(option, value): """Validates that 'value' is an integer or string. """ if isinstance(value, integer_types): return value elif isinstance(value, string_type): try: val = int(value) except ValueError: return value return validate_non_negative_integer(option, val) raise TypeError("Wrong type for %s, value must be an " "non negative integer or a string" % (option,)) def validate_positive_float(option, value): """Validates that 'value' is a float, or can be converted to one, and is positive. """ errmsg = "%s must be an integer or float" % (option,) try: value = float(value) except ValueError: raise ValueError(errmsg) except TypeError: raise TypeError(errmsg) # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at # one billion - this is a reasonable approximation for infinity if not 0 < value < 1e9: raise ValueError("%s must be greater than 0 and " "less than one billion" % (option,)) return value def validate_positive_float_or_zero(option, value): """Validates that 'value' is 0 or a positive float, or can be converted to 0 or a positive float. """ if value == 0 or value == "0": return 0 return validate_positive_float(option, value) def validate_timeout_or_none(option, value): """Validates a timeout specified in milliseconds returning a value in floating point seconds. """ if value is None: return value return validate_positive_float(option, value) / 1000.0 def validate_timeout_or_zero(option, value): """Validates a timeout specified in milliseconds returning a value in floating point seconds for the case where None is an error and 0 is valid. Setting the timeout to nothing in the URI string is a config error. """ if value is None: raise ConfigurationError("%s cannot be None" % (option, )) if value == 0 or value == "0": return 0 return validate_positive_float(option, value) / 1000.0 def validate_timeout_or_none_or_zero(option, value): """Validates a timeout specified in milliseconds returning a value in floating point seconds. value=0 and value="0" are treated the same as value=None which means unlimited timeout. """ if value is None or value == 0 or value == "0": return None return validate_positive_float(option, value) / 1000.0 def validate_max_staleness(option, value): """Validates maxStalenessSeconds according to the Max Staleness Spec.""" if value == -1 or value == "-1": # Default: No maximum staleness. return -1 return validate_positive_integer(option, value) def validate_read_preference(dummy, value): """Validate a read preference. """ if not isinstance(value, _ServerMode): raise TypeError("%r is not a read preference." % (value,)) return value def validate_read_preference_mode(dummy, value): """Validate read preference mode for a MongoReplicaSetClient. .. versionchanged:: 3.5 Returns the original ``value`` instead of the validated read preference mode. """ if value not in _MONGOS_MODES: raise ValueError("%s is not a valid read preference" % (value,)) return value def validate_auth_mechanism(option, value): """Validate the authMechanism URI option. """ # CRAM-MD5 is for server testing only. Undocumented, # unsupported, may be removed at any time. You have # been warned. if value not in MECHANISMS and value != 'CRAM-MD5': raise ValueError("%s must be in %s" % (option, tuple(MECHANISMS))) return value def validate_uuid_representation(dummy, value): """Validate the uuid representation option selected in the URI. """ try: return _UUID_REPRESENTATIONS[value] except KeyError: raise ValueError("%s is an invalid UUID representation. " "Must be one of " "%s" % (value, tuple(_UUID_REPRESENTATIONS))) def validate_read_preference_tags(name, value): """Parse readPreferenceTags if passed as a client kwarg. """ if not isinstance(value, list): value = [value] tag_sets = [] for tag_set in value: if tag_set == '': tag_sets.append({}) continue try: tags = {} for tag in tag_set.split(","): key, val = tag.split(":") tags[unquote_plus(key)] = unquote_plus(val) tag_sets.append(tags) except Exception: raise ValueError("%r not a valid " "value for %s" % (tag_set, name)) return tag_sets _MECHANISM_PROPS = frozenset(['SERVICE_NAME', 'CANONICALIZE_HOST_NAME', 'SERVICE_REALM', 'AWS_SESSION_TOKEN']) def validate_auth_mechanism_properties(option, value): """Validate authMechanismProperties.""" value = validate_string(option, value) props = {} for opt in value.split(','): try: key, val = opt.split(':') except ValueError: # Try not to leak the token. if 'AWS_SESSION_TOKEN' in opt: opt = ('AWS_SESSION_TOKEN:, did you forget ' 'to percent-escape the token with quote_plus?') raise ValueError("auth mechanism properties must be " "key:value pairs like SERVICE_NAME:" "mongodb, not %s." % (opt,)) if key not in _MECHANISM_PROPS: raise ValueError("%s is not a supported auth " "mechanism property. Must be one of " "%s." % (key, tuple(_MECHANISM_PROPS))) if key == 'CANONICALIZE_HOST_NAME': props[key] = validate_boolean_or_string(key, val) else: props[key] = unquote_plus(val) return props def validate_document_class(option, value): """Validate the document_class option.""" if not issubclass(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError("%s must be dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or a " "sublass of collections.MutableMapping" % (option,)) return value def validate_type_registry(option, value): """Validate the type_registry option.""" if value is not None and not isinstance(value, TypeRegistry): raise TypeError("%s must be an instance of %s" % ( option, TypeRegistry)) return value def validate_list(option, value): """Validates that 'value' is a list.""" if not isinstance(value, list): raise TypeError("%s must be a list" % (option,)) return value def validate_list_or_none(option, value): """Validates that 'value' is a list or None.""" if value is None: return value return validate_list(option, value) def validate_list_or_mapping(option, value): """Validates that 'value' is a list or a document.""" if not isinstance(value, (abc.Mapping, list)): raise TypeError("%s must either be a list or an instance of dict, " "bson.son.SON, or any other type that inherits from " "collections.Mapping" % (option,)) def validate_is_mapping(option, value): """Validate the type of method arguments that expect a document.""" if not isinstance(value, abc.Mapping): raise TypeError("%s must be an instance of dict, bson.son.SON, or " "any other type that inherits from " "collections.Mapping" % (option,)) def validate_is_document_type(option, value): """Validate the type of method arguments that expect a MongoDB document.""" if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): raise TypeError("%s must be an instance of dict, bson.son.SON, " "bson.raw_bson.RawBSONDocument, or " "a type that inherits from " "collections.MutableMapping" % (option,)) def validate_appname_or_none(option, value): """Validate the appname option.""" if value is None: return value validate_string(option, value) # We need length in bytes, so encode utf8 first. if len(value.encode('utf-8')) > 128: raise ValueError("%s must be <= 128 bytes" % (option,)) return value def validate_driver_or_none(option, value): """Validate the driver keyword arg.""" if value is None: return value if not isinstance(value, DriverInfo): raise TypeError("%s must be an instance of DriverInfo" % (option,)) return value def validate_is_callable_or_none(option, value): """Validates that 'value' is a callable.""" if value is None: return value if not callable(value): raise ValueError("%s must be a callable" % (option,)) return value def validate_ok_for_replace(replacement): """Validate a replacement document.""" validate_is_mapping("replacement", replacement) # Replacement can be {} if replacement and not isinstance(replacement, RawBSONDocument): first = next(iter(replacement)) if first.startswith('$'): raise ValueError('replacement can not include $ operators') def validate_ok_for_update(update): """Validate an update document.""" validate_list_or_mapping("update", update) # Update cannot be {}. if not update: raise ValueError('update cannot be empty') is_document = not isinstance(update, list) first = next(iter(update)) if is_document and not first.startswith('$'): raise ValueError('update only works with $ operators') _UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore']) def validate_unicode_decode_error_handler(dummy, value): """Validate the Unicode decode error handler option of CodecOptions. """ if value not in _UNICODE_DECODE_ERROR_HANDLERS: raise ValueError("%s is an invalid Unicode decode error handler. " "Must be one of " "%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS))) return value def validate_tzinfo(dummy, value): """Validate the tzinfo option """ if value is not None and not isinstance(value, datetime.tzinfo): raise TypeError("%s must be an instance of datetime.tzinfo" % value) return value # Dictionary where keys are the names of public URI options, and values # are lists of aliases for that option. Aliases of option names are assumed # to have been deprecated. URI_OPTIONS_ALIAS_MAP = { 'journal': ['j'], 'wtimeoutms': ['wtimeout'], 'tls': ['ssl'], 'tlsallowinvalidcertificates': ['ssl_cert_reqs'], 'tlsallowinvalidhostnames': ['ssl_match_hostname'], 'tlscrlfile': ['ssl_crlfile'], 'tlscafile': ['ssl_ca_certs'], 'tlscertificatekeyfile': ['ssl_certfile'], 'tlscertificatekeyfilepassword': ['ssl_pem_passphrase'], } # Dictionary where keys are the names of URI options, and values # are functions that validate user-input values for that option. If an option # alias uses a different validator than its public counterpart, it should be # included here as a key, value pair. URI_OPTIONS_VALIDATOR_MAP = { 'appname': validate_appname_or_none, 'authmechanism': validate_auth_mechanism, 'authmechanismproperties': validate_auth_mechanism_properties, 'authsource': validate_string, 'compressors': validate_compressors, 'connecttimeoutms': validate_timeout_or_none_or_zero, 'directconnection': validate_boolean_or_string, 'heartbeatfrequencyms': validate_timeout_or_none, 'journal': validate_boolean_or_string, 'localthresholdms': validate_positive_float_or_zero, 'maxidletimems': validate_timeout_or_none, 'maxpoolsize': validate_positive_integer_or_none, 'maxstalenessseconds': validate_max_staleness, 'readconcernlevel': validate_string_or_none, 'readpreference': validate_read_preference_mode, 'readpreferencetags': validate_read_preference_tags, 'replicaset': validate_string_or_none, 'retryreads': validate_boolean_or_string, 'retrywrites': validate_boolean_or_string, 'serverselectiontimeoutms': validate_timeout_or_zero, 'sockettimeoutms': validate_timeout_or_none_or_zero, 'ssl_keyfile': validate_readable, 'tls': validate_boolean_or_string, 'tlsallowinvalidcertificates': validate_allow_invalid_certs, 'ssl_cert_reqs': validate_cert_reqs, 'tlsallowinvalidhostnames': lambda *x: not validate_boolean_or_string(*x), 'ssl_match_hostname': validate_boolean_or_string, 'tlscafile': validate_readable, 'tlscertificatekeyfile': validate_readable, 'tlscertificatekeyfilepassword': validate_string_or_none, 'tlsdisableocspendpointcheck': validate_boolean_or_string, 'tlsinsecure': validate_boolean_or_string, 'w': validate_non_negative_int_or_basestring, 'wtimeoutms': validate_non_negative_integer, 'zlibcompressionlevel': validate_zlib_compression_level, } # Dictionary where keys are the names of URI options specific to pymongo, # and values are functions that validate user-input values for those options. NONSPEC_OPTIONS_VALIDATOR_MAP = { 'connect': validate_boolean_or_string, 'driver': validate_driver_or_none, 'fsync': validate_boolean_or_string, 'minpoolsize': validate_non_negative_integer, 'socketkeepalive': validate_boolean_or_string, 'tlscrlfile': validate_readable, 'tz_aware': validate_boolean_or_string, 'unicode_decode_error_handler': validate_unicode_decode_error_handler, 'uuidrepresentation': validate_uuid_representation, 'waitqueuemultiple': validate_non_negative_integer_or_none, 'waitqueuetimeoutms': validate_timeout_or_none, } # Dictionary where keys are the names of keyword-only options for the # MongoClient constructor, and values are functions that validate user-input # values for those options. KW_VALIDATORS = { 'document_class': validate_document_class, 'type_registry': validate_type_registry, 'read_preference': validate_read_preference, 'event_listeners': _validate_event_listeners, 'tzinfo': validate_tzinfo, 'username': validate_string_or_none, 'password': validate_string_or_none, 'server_selector': validate_is_callable_or_none, 'auto_encryption_opts': validate_auto_encryption_opts_or_none, } # Dictionary where keys are any URI option name, and values are the # internally-used names of that URI option. Options with only one name # variant need not be included here. Options whose public and internal # names are the same need not be included here. INTERNAL_URI_OPTION_NAME_MAP = { 'j': 'journal', 'wtimeout': 'wtimeoutms', 'tls': 'ssl', 'tlsallowinvalidcertificates': 'ssl_cert_reqs', 'tlsallowinvalidhostnames': 'ssl_match_hostname', 'tlscrlfile': 'ssl_crlfile', 'tlscafile': 'ssl_ca_certs', 'tlscertificatekeyfile': 'ssl_certfile', 'tlscertificatekeyfilepassword': 'ssl_pem_passphrase', 'tlsdisableocspendpointcheck': 'ssl_check_ocsp_endpoint', } # Map from deprecated URI option names to a tuple indicating the method of # their deprecation and any additional information that may be needed to # construct the warning message. URI_OPTIONS_DEPRECATION_MAP = { # format: : (, ), # Supported values: # - 'renamed': should be the new option name. Note that case is # preserved for renamed options as they are part of user warnings. # - 'removed': may suggest the rationale for deprecating the # option and/or recommend remedial action. 'j': ('renamed', 'journal'), 'wtimeout': ('renamed', 'wTimeoutMS'), 'ssl_cert_reqs': ('renamed', 'tlsAllowInvalidCertificates'), 'ssl_match_hostname': ('renamed', 'tlsAllowInvalidHostnames'), 'ssl_crlfile': ('renamed', 'tlsCRLFile'), 'ssl_ca_certs': ('renamed', 'tlsCAFile'), 'ssl_pem_passphrase': ('renamed', 'tlsCertificateKeyFilePassword'), 'waitqueuemultiple': ('removed', ( 'Instead of using waitQueueMultiple to bound queuing, limit the size ' 'of the thread pool in your application server')) } # Augment the option validator map with pymongo-specific option information. URI_OPTIONS_VALIDATOR_MAP.update(NONSPEC_OPTIONS_VALIDATOR_MAP) for optname, aliases in iteritems(URI_OPTIONS_ALIAS_MAP): for alias in aliases: if alias not in URI_OPTIONS_VALIDATOR_MAP: URI_OPTIONS_VALIDATOR_MAP[alias] = ( URI_OPTIONS_VALIDATOR_MAP[optname]) # Map containing all URI option and keyword argument validators. VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy() VALIDATORS.update(KW_VALIDATORS) # List of timeout-related options. TIMEOUT_OPTIONS = [ 'connecttimeoutms', 'heartbeatfrequencyms', 'maxidletimems', 'maxstalenessseconds', 'serverselectiontimeoutms', 'sockettimeoutms', 'waitqueuetimeoutms', ] _AUTH_OPTIONS = frozenset(['authmechanismproperties']) def validate_auth_option(option, value): """Validate optional authentication parameters. """ lower, value = validate(option, value) if lower not in _AUTH_OPTIONS: raise ConfigurationError('Unknown ' 'authentication option: %s' % (option,)) return option, value def validate(option, value): """Generic validation function. """ lower = option.lower() validator = VALIDATORS.get(lower, raise_config_error) value = validator(option, value) return option, value def get_validated_options(options, warn=True): """Validate each entry in options and raise a warning if it is not valid. Returns a copy of options with invalid entries removed. :Parameters: - `opts`: A dict containing MongoDB URI options. - `warn` (optional): If ``True`` then warnings will be logged and invalid options will be ignored. Otherwise, invalid options will cause errors. """ if isinstance(options, _CaseInsensitiveDictionary): validated_options = _CaseInsensitiveDictionary() get_normed_key = lambda x: x get_setter_key = lambda x: options.cased_key(x) else: validated_options = {} get_normed_key = lambda x: x.lower() get_setter_key = lambda x: x for opt, value in iteritems(options): normed_key = get_normed_key(opt) try: validator = URI_OPTIONS_VALIDATOR_MAP.get( normed_key, raise_config_error) value = validator(opt, value) except (ValueError, TypeError, ConfigurationError) as exc: if warn: warnings.warn(str(exc)) else: raise else: validated_options[get_setter_key(normed_key)] = value return validated_options # List of write-concern-related options. WRITE_CONCERN_OPTIONS = frozenset([ 'w', 'wtimeout', 'wtimeoutms', 'fsync', 'j', 'journal' ]) class BaseObject(object): """A base class that provides attributes and methods common to multiple pymongo classes. SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. """ def __init__(self, codec_options, read_preference, write_concern, read_concern): if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of " "bson.codec_options.CodecOptions") self.__codec_options = codec_options if not isinstance(read_preference, _ServerMode): raise TypeError("%r is not valid for read_preference. See " "pymongo.read_preferences for valid " "options." % (read_preference,)) self.__read_preference = read_preference if not isinstance(write_concern, WriteConcern): raise TypeError("write_concern must be an instance of " "pymongo.write_concern.WriteConcern") self.__write_concern = write_concern if not isinstance(read_concern, ReadConcern): raise TypeError("read_concern must be an instance of " "pymongo.read_concern.ReadConcern") self.__read_concern = read_concern @property def codec_options(self): """Read only access to the :class:`~bson.codec_options.CodecOptions` of this instance. """ return self.__codec_options @property def write_concern(self): """Read only access to the :class:`~pymongo.write_concern.WriteConcern` of this instance. .. versionchanged:: 3.0 The :attr:`write_concern` attribute is now read only. """ return self.__write_concern def _write_concern_for(self, session): """Read only access to the write concern of this instance or session. """ # Override this operation's write concern with the transaction's. if session and session.in_transaction: return DEFAULT_WRITE_CONCERN return self.write_concern @property def read_preference(self): """Read only access to the read preference of this instance. .. versionchanged:: 3.0 The :attr:`read_preference` attribute is now read only. """ return self.__read_preference def _read_preference_for(self, session): """Read only access to the read preference of this instance or session. """ # Override this operation's read preference with the transaction's. if session: return session._txn_read_preference() or self.__read_preference return self.__read_preference @property def read_concern(self): """Read only access to the :class:`~pymongo.read_concern.ReadConcern` of this instance. .. versionadded:: 3.2 """ return self.__read_concern class _CaseInsensitiveDictionary(abc.MutableMapping): def __init__(self, *args, **kwargs): self.__casedkeys = {} self.__data = {} self.update(dict(*args, **kwargs)) def __contains__(self, key): return key.lower() in self.__data def __len__(self): return len(self.__data) def __iter__(self): return (key for key in self.__casedkeys) def __repr__(self): return str({self.__casedkeys[k]: self.__data[k] for k in self}) def __setitem__(self, key, value): lc_key = key.lower() self.__casedkeys[lc_key] = key self.__data[lc_key] = value def __getitem__(self, key): return self.__data[key.lower()] def __delitem__(self, key): lc_key = key.lower() del self.__casedkeys[lc_key] del self.__data[lc_key] def __eq__(self, other): if not isinstance(other, abc.Mapping): return NotImplemented if len(self) != len(other): return False for key in other: if self[key] != other[key]: return False return True def get(self, key, default=None): return self.__data.get(key.lower(), default) def pop(self, key, *args, **kwargs): lc_key = key.lower() self.__casedkeys.pop(lc_key, None) return self.__data.pop(lc_key, *args, **kwargs) def popitem(self): lc_key, cased_key = self.__casedkeys.popitem() value = self.__data.pop(lc_key) return cased_key, value def clear(self): self.__casedkeys.clear() self.__data.clear() def setdefault(self, key, default=None): lc_key = key.lower() if key in self: return self.__data[lc_key] else: self.__casedkeys[lc_key] = key self.__data[lc_key] = default return default def update(self, other): if isinstance(other, _CaseInsensitiveDictionary): for key in other: self[other.cased_key(key)] = other[key] else: for key in other: self[key] = other[key] def cased_key(self, key): return self.__casedkeys[key.lower()]pymongo-3.11.0/pymongo/compression_support.py000066400000000000000000000120621374256237000214750ustar00rootroot00000000000000# Copyright 2018 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings try: import snappy _HAVE_SNAPPY = True except ImportError: # python-snappy isn't available. _HAVE_SNAPPY = False try: import zlib _HAVE_ZLIB = True except ImportError: # Python built without zlib support. _HAVE_ZLIB = False try: from zstandard import ZstdCompressor, ZstdDecompressor _HAVE_ZSTD = True except ImportError: _HAVE_ZSTD = False from pymongo.monitoring import _SENSITIVE_COMMANDS _SUPPORTED_COMPRESSORS = set(["snappy", "zlib", "zstd"]) _NO_COMPRESSION = set(['ismaster']) _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) def validate_compressors(dummy, value): try: # `value` is string. compressors = value.split(",") except AttributeError: # `value` is an iterable. compressors = list(value) for compressor in compressors[:]: if compressor not in _SUPPORTED_COMPRESSORS: compressors.remove(compressor) warnings.warn("Unsupported compressor: %s" % (compressor,)) elif compressor == "snappy" and not _HAVE_SNAPPY: compressors.remove(compressor) warnings.warn( "Wire protocol compression with snappy is not available. " "You must install the python-snappy module for snappy support.") elif compressor == "zlib" and not _HAVE_ZLIB: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zlib is not available. " "The zlib module is not available.") elif compressor == "zstd" and not _HAVE_ZSTD: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zstandard is not available. " "You must install the zstandard module for zstandard support.") return compressors def validate_zlib_compression_level(option, value): try: level = int(value) except: raise TypeError("%s must be an integer, not %r." % (option, value)) if level < -1 or level > 9: raise ValueError( "%s must be between -1 and 9, not %d." % (option, level)) return level class CompressionSettings(object): def __init__(self, compressors, zlib_compression_level): self.compressors = compressors self.zlib_compression_level = zlib_compression_level def get_compression_context(self, compressors): if compressors: chosen = compressors[0] if chosen == "snappy": return SnappyContext() elif chosen == "zlib": return ZlibContext(self.zlib_compression_level) elif chosen == "zstd": return ZstdContext() def _zlib_no_compress(data): """Compress data with zlib level 0.""" cobj = zlib.compressobj(0) return b"".join([cobj.compress(data), cobj.flush()]) class SnappyContext(object): compressor_id = 1 @staticmethod def compress(data): return snappy.compress(data) class ZlibContext(object): compressor_id = 2 def __init__(self, level): # Jython zlib.compress doesn't support -1 if level == -1: self.compress = zlib.compress # Jython zlib.compress also doesn't support 0 elif level == 0: self.compress = _zlib_no_compress else: self.compress = lambda data: zlib.compress(data, level) class ZstdContext(object): compressor_id = 3 @staticmethod def compress(data): # ZstdCompressor is not thread safe. # TODO: Use a pool? return ZstdCompressor().compress(data) def decompress(data, compressor_id): if compressor_id == SnappyContext.compressor_id: # python-snappy doesn't support the buffer interface. # https://github.com/andrix/python-snappy/issues/65 # This only matters when data is a memoryview since # id(bytes(data)) == id(data) when data is a bytes. # NOTE: bytes(memoryview) returns the memoryview repr # in Python 2.7. The right thing to do in 2.7 is call # memoryview.tobytes(), but we currently only use # memoryview in Python 3.x. return snappy.uncompress(bytes(data)) elif compressor_id == ZlibContext.compressor_id: return zlib.decompress(data) elif compressor_id == ZstdContext.compressor_id: # ZstdDecompressor is not thread safe. # TODO: Use a pool? return ZstdDecompressor().decompress(data) else: raise ValueError("Unknown compressorId %d" % (compressor_id,)) pymongo-3.11.0/pymongo/cursor.py000066400000000000000000001400301374256237000166520ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Cursor class to iterate over Mongo query results.""" import copy import warnings from collections import deque from bson import RE_TYPE from bson.code import Code from bson.py3compat import (iteritems, integer_types, string_type) from bson.son import SON from pymongo import helpers from pymongo.common import validate_boolean, validate_is_mapping from pymongo.collation import validate_collation_or_none from pymongo.errors import (ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure) from pymongo.message import (_CursorAddress, _GetMore, _RawBatchGetMore, _Query, _RawBatchQuery) from pymongo.monitoring import ConnectionClosedReason _QUERY_OPTIONS = { "tailable_cursor": 2, "slave_okay": 4, "oplog_replay": 8, "no_timeout": 16, "await_data": 32, "exhaust": 64, "partial": 128} class CursorType(object): NON_TAILABLE = 0 """The standard cursor type.""" TAILABLE = _QUERY_OPTIONS["tailable_cursor"] """The tailable cursor type. Tailable cursors are only for use with capped collections. They are not closed when the last data is retrieved but are kept open and the cursor location marks the final document position. If more data is received iteration of the cursor will continue from the last document received. """ TAILABLE_AWAIT = TAILABLE | _QUERY_OPTIONS["await_data"] """A tailable cursor with the await option set. Creates a tailable cursor that will wait for a few seconds after returning the full result set so that it can capture and return additional data added during the query. """ EXHAUST = _QUERY_OPTIONS["exhaust"] """An exhaust cursor. MongoDB will stream batched results to the client without waiting for the client to request each batch, reducing latency. """ # This has to be an old style class due to # http://bugs.jython.org/issue1057 class _SocketManager: """Used with exhaust cursors to ensure the socket is returned. """ def __init__(self, sock, pool): self.sock = sock self.pool = pool self.__closed = False def __del__(self): self.close() def close(self): """Return this instance's socket to the connection pool. """ if not self.__closed: self.__closed = True self.pool.return_socket(self.sock) self.sock, self.pool = None, None class Cursor(object): """A cursor / iterator over Mongo query results. """ _query_class = _Query _getmore_class = _GetMore def __init__(self, collection, filter=None, projection=None, skip=0, limit=0, no_cursor_timeout=False, cursor_type=CursorType.NON_TAILABLE, sort=None, allow_partial_results=False, oplog_replay=False, modifiers=None, batch_size=0, manipulate=True, collation=None, hint=None, max_scan=None, max_time_ms=None, max=None, min=None, return_key=False, show_record_id=False, snapshot=False, comment=None, session=None, allow_disk_use=None): """Create a new cursor. Should not be called directly by application developers - see :meth:`~pymongo.collection.Collection.find` instead. .. mongodoc:: cursors """ # Initialize all attributes used in __del__ before possibly raising # an error to avoid attribute errors during garbage collection. self.__id = None self.__exhaust = False self.__exhaust_mgr = None self.__killed = False if session: self.__session = session self.__explicit_session = True else: self.__session = None self.__explicit_session = False spec = filter if spec is None: spec = {} validate_is_mapping("filter", spec) if not isinstance(skip, int): raise TypeError("skip must be an instance of int") if not isinstance(limit, int): raise TypeError("limit must be an instance of int") validate_boolean("no_cursor_timeout", no_cursor_timeout) if cursor_type not in (CursorType.NON_TAILABLE, CursorType.TAILABLE, CursorType.TAILABLE_AWAIT, CursorType.EXHAUST): raise ValueError("not a valid value for cursor_type") validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) if modifiers is not None: warnings.warn("the 'modifiers' parameter is deprecated", DeprecationWarning, stacklevel=2) validate_is_mapping("modifiers", modifiers) if not isinstance(batch_size, integer_types): raise TypeError("batch_size must be an integer") if batch_size < 0: raise ValueError("batch_size must be >= 0") # Only set if allow_disk_use is provided by the user, else None. if allow_disk_use is not None: allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) if projection is not None: if not projection: projection = {"_id": 1} projection = helpers._fields_list_to_dict(projection, "projection") self.__collection = collection self.__spec = spec self.__projection = projection self.__skip = skip self.__limit = limit self.__batch_size = batch_size self.__modifiers = modifiers and modifiers.copy() or {} self.__ordering = sort and helpers._index_document(sort) or None self.__max_scan = max_scan self.__explain = False self.__comment = comment self.__max_time_ms = max_time_ms self.__max_await_time_ms = None self.__max = max self.__min = min self.__manipulate = manipulate self.__collation = validate_collation_or_none(collation) self.__return_key = return_key self.__show_record_id = show_record_id self.__allow_disk_use = allow_disk_use self.__snapshot = snapshot self.__set_hint(hint) # Exhaust cursor support if cursor_type == CursorType.EXHAUST: if self.__collection.database.client.is_mongos: raise InvalidOperation('Exhaust cursors are ' 'not supported by mongos') if limit: raise InvalidOperation("Can't use limit and exhaust together.") self.__exhaust = True # This is ugly. People want to be able to do cursor[5:5] and # get an empty result set (old behavior was an # exception). It's hard to do that right, though, because the # server uses limit(0) to mean 'no limit'. So we set __empty # in that case and check for it when iterating. We also unset # it anytime we change __limit. self.__empty = False self.__data = deque() self.__address = None self.__retrieved = 0 self.__codec_options = collection.codec_options # Read preference is set when the initial find is sent. self.__read_preference = None self.__read_concern = collection.read_concern self.__query_flags = cursor_type if no_cursor_timeout: self.__query_flags |= _QUERY_OPTIONS["no_timeout"] if allow_partial_results: self.__query_flags |= _QUERY_OPTIONS["partial"] if oplog_replay: self.__query_flags |= _QUERY_OPTIONS["oplog_replay"] # The namespace to use for find/getMore commands. self.__dbname = collection.database.name self.__collname = collection.name @property def collection(self): """The :class:`~pymongo.collection.Collection` that this :class:`Cursor` is iterating. """ return self.__collection @property def retrieved(self): """The number of documents retrieved so far. """ return self.__retrieved def __del__(self): self.__die() def rewind(self): """Rewind this cursor to its unevaluated state. Reset this cursor if it has been partially or completely evaluated. Any options that are present on the cursor will remain in effect. Future iterating performed on this cursor will cause new queries to be sent to the server, even if the resultant data has already been retrieved by this cursor. """ self.__data = deque() self.__id = None self.__address = None self.__retrieved = 0 self.__killed = False return self def clone(self): """Get a clone of this cursor. Returns a new Cursor instance with options matching those that have been set on the current instance. The clone will be completely unevaluated, even if the current instance has been partially or completely evaluated. """ return self._clone(True) def _clone(self, deepcopy=True, base=None): """Internal clone helper.""" if not base: if self.__explicit_session: base = self._clone_base(self.__session) else: base = self._clone_base(None) values_to_clone = ("spec", "projection", "skip", "limit", "max_time_ms", "max_await_time_ms", "comment", "max", "min", "ordering", "explain", "hint", "batch_size", "max_scan", "manipulate", "query_flags", "modifiers", "collation", "empty", "show_record_id", "return_key", "allow_disk_use", "snapshot", "exhaust") data = dict((k, v) for k, v in iteritems(self.__dict__) if k.startswith('_Cursor__') and k[9:] in values_to_clone) if deepcopy: data = self._deepcopy(data) base.__dict__.update(data) return base def _clone_base(self, session): """Creates an empty Cursor object for information to be copied into. """ return self.__class__(self.__collection, session=session) def __die(self, synchronous=False): """Closes this cursor. """ try: already_killed = self.__killed except AttributeError: # __init__ did not run to completion (or at all). return self.__killed = True if self.__id and not already_killed: if self.__exhaust and self.__exhaust_mgr: # If this is an exhaust cursor and we haven't completely # exhausted the result set we *must* close the socket # to stop the server from sending more data. self.__exhaust_mgr.sock.close_socket( ConnectionClosedReason.ERROR) else: address = _CursorAddress( self.__address, self.__collection.full_name) if synchronous: self.__collection.database.client._close_cursor_now( self.__id, address, session=self.__session) else: # The cursor will be closed later in a different session. self.__collection.database.client._close_cursor( self.__id, address) if self.__exhaust and self.__exhaust_mgr: self.__exhaust_mgr.close() if self.__session and not self.__explicit_session: self.__session._end_session(lock=synchronous) self.__session = None def close(self): """Explicitly close / kill this cursor. """ self.__die(True) def __query_spec(self): """Get the spec to use for a query. """ operators = self.__modifiers.copy() if self.__ordering: operators["$orderby"] = self.__ordering if self.__explain: operators["$explain"] = True if self.__hint: operators["$hint"] = self.__hint if self.__comment: operators["$comment"] = self.__comment if self.__max_scan: operators["$maxScan"] = self.__max_scan if self.__max_time_ms is not None: operators["$maxTimeMS"] = self.__max_time_ms if self.__max: operators["$max"] = self.__max if self.__min: operators["$min"] = self.__min if self.__return_key: operators["$returnKey"] = self.__return_key if self.__show_record_id: # This is upgraded to showRecordId for MongoDB 3.2+ "find" command. operators["$showDiskLoc"] = self.__show_record_id if self.__snapshot: operators["$snapshot"] = self.__snapshot if operators: # Make a shallow copy so we can cleanly rewind or clone. spec = self.__spec.copy() # White-listed commands must be wrapped in $query. if "$query" not in spec: # $query has to come first spec = SON([("$query", spec)]) if not isinstance(spec, SON): # Ensure the spec is SON. As order is important this will # ensure its set before merging in any extra operators. spec = SON(spec) spec.update(operators) return spec # Have to wrap with $query if "query" is the first key. # We can't just use $query anytime "query" is a key as # that breaks commands like count and find_and_modify. # Checking spec.keys()[0] covers the case that the spec # was passed as an instance of SON or OrderedDict. elif ("query" in self.__spec and (len(self.__spec) == 1 or next(iter(self.__spec)) == "query")): return SON({"$query": self.__spec}) return self.__spec def __check_okay_to_chain(self): """Check if it is okay to chain more options onto this cursor. """ if self.__retrieved or self.__id is not None: raise InvalidOperation("cannot set options after executing query") def add_option(self, mask): """Set arbitrary query flags using a bitmask. To set the tailable flag: cursor.add_option(2) """ if not isinstance(mask, int): raise TypeError("mask must be an int") self.__check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: if self.__limit: raise InvalidOperation("Can't use limit and exhaust together.") if self.__collection.database.client.is_mongos: raise InvalidOperation('Exhaust cursors are ' 'not supported by mongos') self.__exhaust = True self.__query_flags |= mask return self def remove_option(self, mask): """Unset arbitrary query flags using a bitmask. To unset the tailable flag: cursor.remove_option(2) """ if not isinstance(mask, int): raise TypeError("mask must be an int") self.__check_okay_to_chain() if mask & _QUERY_OPTIONS["exhaust"]: self.__exhaust = False self.__query_flags &= ~mask return self def allow_disk_use(self, allow_disk_use): """Specifies whether MongoDB can use temporary disk files while processing a blocking sort operation. Raises :exc:`TypeError` if `allow_disk_use` is not a boolean. .. note:: `allow_disk_use` requires server version **>= 4.4** :Parameters: - `allow_disk_use`: if True, MongoDB may use temporary disk files to store data exceeding the system memory limit while processing a blocking sort operation. .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): raise TypeError('allow_disk_use must be a bool') self.__check_okay_to_chain() self.__allow_disk_use = allow_disk_use return self def limit(self, limit): """Limits the number of results to be returned by this cursor. Raises :exc:`TypeError` if `limit` is not an integer. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. The last `limit` applied to this cursor takes precedence. A limit of ``0`` is equivalent to no limit. :Parameters: - `limit`: the number of results to return .. mongodoc:: limit """ if not isinstance(limit, integer_types): raise TypeError("limit must be an integer") if self.__exhaust: raise InvalidOperation("Can't use limit and exhaust together.") self.__check_okay_to_chain() self.__empty = False self.__limit = limit return self def batch_size(self, batch_size): """Limits the number of documents returned in one batch. Each batch requires a round trip to the server. It can be adjusted to optimize performance and limit data transfer. .. note:: batch_size can not override MongoDB's internal limits on the amount of data it will return to the client in a single batch (i.e if you set batch size to 1,000,000,000, MongoDB will currently only return 4-16MB of results per batch). Raises :exc:`TypeError` if `batch_size` is not an integer. Raises :exc:`ValueError` if `batch_size` is less than ``0``. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. The last `batch_size` applied to this cursor takes precedence. :Parameters: - `batch_size`: The size of each batch of results requested. """ if not isinstance(batch_size, integer_types): raise TypeError("batch_size must be an integer") if batch_size < 0: raise ValueError("batch_size must be >= 0") self.__check_okay_to_chain() self.__batch_size = batch_size return self def skip(self, skip): """Skips the first `skip` results of this cursor. Raises :exc:`TypeError` if `skip` is not an integer. Raises :exc:`ValueError` if `skip` is less than ``0``. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. The last `skip` applied to this cursor takes precedence. :Parameters: - `skip`: the number of results to skip """ if not isinstance(skip, integer_types): raise TypeError("skip must be an integer") if skip < 0: raise ValueError("skip must be >= 0") self.__check_okay_to_chain() self.__skip = skip return self def max_time_ms(self, max_time_ms): """Specifies a time limit for a query operation. If the specified time is exceeded, the operation will be aborted and :exc:`~pymongo.errors.ExecutionTimeout` is raised. If `max_time_ms` is ``None`` no limit is applied. Raises :exc:`TypeError` if `max_time_ms` is not an integer or ``None``. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. :Parameters: - `max_time_ms`: the time limit after which the operation is aborted """ if (not isinstance(max_time_ms, integer_types) and max_time_ms is not None): raise TypeError("max_time_ms must be an integer or None") self.__check_okay_to_chain() self.__max_time_ms = max_time_ms return self def max_await_time_ms(self, max_await_time_ms): """Specifies a time limit for a getMore operation on a :attr:`~pymongo.cursor.CursorType.TAILABLE_AWAIT` cursor. For all other types of cursor max_await_time_ms is ignored. Raises :exc:`TypeError` if `max_await_time_ms` is not an integer or ``None``. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. .. note:: `max_await_time_ms` requires server version **>= 3.2** :Parameters: - `max_await_time_ms`: the time limit after which the operation is aborted .. versionadded:: 3.2 """ if (not isinstance(max_await_time_ms, integer_types) and max_await_time_ms is not None): raise TypeError("max_await_time_ms must be an integer or None") self.__check_okay_to_chain() # Ignore max_await_time_ms if not tailable or await_data is False. if self.__query_flags & CursorType.TAILABLE_AWAIT: self.__max_await_time_ms = max_await_time_ms return self def __getitem__(self, index): """Get a single document or a slice of documents from this cursor. Raises :class:`~pymongo.errors.InvalidOperation` if this cursor has already been used. To get a single document use an integral index, e.g.:: >>> db.test.find()[50] An :class:`IndexError` will be raised if the index is negative or greater than the amount of documents in this cursor. Any limit previously applied to this cursor will be ignored. To get a slice of documents use a slice index, e.g.:: >>> db.test.find()[20:25] This will return this cursor with a limit of ``5`` and skip of ``20`` applied. Using a slice index will override any prior limits or skips applied to this cursor (including those applied through previous calls to this method). Raises :class:`IndexError` when the slice has a step, a negative start value, or a stop value less than or equal to the start value. :Parameters: - `index`: An integer or slice index to be applied to this cursor """ self.__check_okay_to_chain() self.__empty = False if isinstance(index, slice): if index.step is not None: raise IndexError("Cursor instances do not support slice steps") skip = 0 if index.start is not None: if index.start < 0: raise IndexError("Cursor instances do not support " "negative indices") skip = index.start if index.stop is not None: limit = index.stop - skip if limit < 0: raise IndexError("stop index must be greater than start " "index for slice %r" % index) if limit == 0: self.__empty = True else: limit = 0 self.__skip = skip self.__limit = limit return self if isinstance(index, integer_types): if index < 0: raise IndexError("Cursor instances do not support negative " "indices") clone = self.clone() clone.skip(index + self.__skip) clone.limit(-1) # use a hard limit clone.__query_flags &= ~CursorType.TAILABLE_AWAIT # PYTHON-1371 for doc in clone: return doc raise IndexError("no such item for Cursor instance") raise TypeError("index %r cannot be applied to Cursor " "instances" % index) def max_scan(self, max_scan): """**DEPRECATED** - Limit the number of documents to scan when performing the query. Raises :class:`~pymongo.errors.InvalidOperation` if this cursor has already been used. Only the last :meth:`max_scan` applied to this cursor has any effect. :Parameters: - `max_scan`: the maximum number of documents to scan .. versionchanged:: 3.7 Deprecated :meth:`max_scan`. Support for this option is deprecated in MongoDB 4.0. Use :meth:`max_time_ms` instead to limit server side execution time. """ self.__check_okay_to_chain() self.__max_scan = max_scan return self def max(self, spec): """Adds ``max`` operator that specifies upper bound for specific index. When using ``max``, :meth:`~hint` should also be configured to ensure the query uses the expected index and starting in MongoDB 4.2 :meth:`~hint` will be required. :Parameters: - `spec`: a list of field, limit pairs specifying the exclusive upper bound for all keys of a specific index in order. .. versionchanged:: 3.8 Deprecated cursors that use ``max`` without a :meth:`~hint`. .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): raise TypeError("spec must be an instance of list or tuple") self.__check_okay_to_chain() self.__max = SON(spec) return self def min(self, spec): """Adds ``min`` operator that specifies lower bound for specific index. When using ``min``, :meth:`~hint` should also be configured to ensure the query uses the expected index and starting in MongoDB 4.2 :meth:`~hint` will be required. :Parameters: - `spec`: a list of field, limit pairs specifying the inclusive lower bound for all keys of a specific index in order. .. versionchanged:: 3.8 Deprecated cursors that use ``min`` without a :meth:`~hint`. .. versionadded:: 2.7 """ if not isinstance(spec, (list, tuple)): raise TypeError("spec must be an instance of list or tuple") self.__check_okay_to_chain() self.__min = SON(spec) return self def sort(self, key_or_list, direction=None): """Sorts this cursor's results. Pass a field name and a direction, either :data:`~pymongo.ASCENDING` or :data:`~pymongo.DESCENDING`:: for doc in collection.find().sort('field', pymongo.ASCENDING): print(doc) To sort by multiple fields, pass a list of (key, direction) pairs:: for doc in collection.find().sort([ ('field1', pymongo.ASCENDING), ('field2', pymongo.DESCENDING)]): print(doc) Beginning with MongoDB version 2.6, text search results can be sorted by relevance:: cursor = db.test.find( {'$text': {'$search': 'some words'}}, {'score': {'$meta': 'textScore'}}) # Sort by 'score' field. cursor.sort([('score', {'$meta': 'textScore'})]) for doc in cursor: print(doc) For more advanced text search functionality, see MongoDB's `Atlas Search `_. Raises :class:`~pymongo.errors.InvalidOperation` if this cursor has already been used. Only the last :meth:`sort` applied to this cursor has any effect. :Parameters: - `key_or_list`: a single key or a list of (key, direction) pairs specifying the keys to sort on - `direction` (optional): only used if `key_or_list` is a single key, if not given :data:`~pymongo.ASCENDING` is assumed """ self.__check_okay_to_chain() keys = helpers._index_list(key_or_list, direction) self.__ordering = helpers._index_document(keys) return self def count(self, with_limit_and_skip=False): """**DEPRECATED** - Get the size of the results set for this query. The :meth:`count` method is deprecated and **not** supported in a transaction. Please use :meth:`~pymongo.collection.Collection.count_documents` instead. Returns the number of documents in the results set for this query. Does not take :meth:`limit` and :meth:`skip` into account by default - set `with_limit_and_skip` to ``True`` if that is the desired behavior. Raises :class:`~pymongo.errors.OperationFailure` on a database error. When used with MongoDB >= 2.6, :meth:`~count` uses any :meth:`~hint` applied to the query. In the following example the hint is passed to the count command: collection.find({'field': 'value'}).hint('field_1').count() The :meth:`count` method obeys the :attr:`~pymongo.collection.Collection.read_preference` of the :class:`~pymongo.collection.Collection` instance on which :meth:`~pymongo.collection.Collection.find` was called. :Parameters: - `with_limit_and_skip` (optional): take any :meth:`limit` or :meth:`skip` that has been applied to this cursor into account when getting the count .. note:: The `with_limit_and_skip` parameter requires server version **>= 1.1.4-** .. versionchanged:: 3.7 Deprecated. .. versionchanged:: 2.8 The :meth:`~count` method now supports :meth:`~hint`. """ warnings.warn("count is deprecated. Use Collection.count_documents " "instead.", DeprecationWarning, stacklevel=2) validate_boolean("with_limit_and_skip", with_limit_and_skip) cmd = SON([("count", self.__collection.name), ("query", self.__spec)]) if self.__max_time_ms is not None: cmd["maxTimeMS"] = self.__max_time_ms if self.__comment: cmd["comment"] = self.__comment if self.__hint is not None: cmd["hint"] = self.__hint if with_limit_and_skip: if self.__limit: cmd["limit"] = self.__limit if self.__skip: cmd["skip"] = self.__skip return self.__collection._count( cmd, self.__collation, session=self.__session) def distinct(self, key): """Get a list of distinct values for `key` among all documents in the result set of this query. Raises :class:`TypeError` if `key` is not an instance of :class:`basestring` (:class:`str` in python 3). The :meth:`distinct` method obeys the :attr:`~pymongo.collection.Collection.read_preference` of the :class:`~pymongo.collection.Collection` instance on which :meth:`~pymongo.collection.Collection.find` was called. :Parameters: - `key`: name of key for which we want to get the distinct values .. seealso:: :meth:`pymongo.collection.Collection.distinct` """ options = {} if self.__spec: options["query"] = self.__spec if self.__max_time_ms is not None: options['maxTimeMS'] = self.__max_time_ms if self.__comment: options['comment'] = self.__comment if self.__collation is not None: options['collation'] = self.__collation return self.__collection.distinct( key, session=self.__session, **options) def explain(self): """Returns an explain plan record for this cursor. .. note:: Starting with MongoDB 3.2 :meth:`explain` uses the default verbosity mode of the `explain command `_, ``allPlansExecution``. To use a different verbosity use :meth:`~pymongo.database.Database.command` to run the explain command directly. .. mongodoc:: explain """ c = self.clone() c.__explain = True # always use a hard limit for explains if c.__limit: c.__limit = -abs(c.__limit) return next(c) def __set_hint(self, index): if index is None: self.__hint = None return if isinstance(index, string_type): self.__hint = index else: self.__hint = helpers._index_document(index) def hint(self, index): """Adds a 'hint', telling Mongo the proper index to use for the query. Judicious use of hints can greatly improve query performance. When doing a query on multiple fields (at least one of which is indexed) pass the indexed field as a hint to the query. Raises :class:`~pymongo.errors.OperationFailure` if the provided hint requires an index that does not exist on this collection, and raises :class:`~pymongo.errors.InvalidOperation` if this cursor has already been used. `index` should be an index as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``) or the name of the index. If `index` is ``None`` any existing hint for this query is cleared. The last hint applied to this cursor takes precedence over all others. :Parameters: - `index`: index to hint on (as an index specifier) .. versionchanged:: 2.8 The :meth:`~hint` method accepts the name of the index. """ self.__check_okay_to_chain() self.__set_hint(index) return self def comment(self, comment): """Adds a 'comment' to the cursor. http://docs.mongodb.org/manual/reference/operator/comment/ :Parameters: - `comment`: A string to attach to the query to help interpret and trace the operation in the server logs and in profile data. .. versionadded:: 2.7 """ self.__check_okay_to_chain() self.__comment = comment return self def where(self, code): """Adds a `$where`_ clause to this query. The `code` argument must be an instance of :class:`basestring` (:class:`str` in python 3) or :class:`~bson.code.Code` containing a JavaScript expression. This expression will be evaluated for each document scanned. Only those documents for which the expression evaluates to *true* will be returned as results. The keyword *this* refers to the object currently being scanned. For example:: # Find all documents where field "a" is less than "b" plus "c". for doc in db.test.find().where('this.a < (this.b + this.c)'): print(doc) Raises :class:`TypeError` if `code` is not an instance of :class:`basestring` (:class:`str` in python 3). Raises :class:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. Only the last call to :meth:`where` applied to a :class:`Cursor` has any effect. .. note:: MongoDB 4.4 drops support for :class:`~bson.code.Code` with scope variables. Consider using `$expr`_ instead. :Parameters: - `code`: JavaScript expression to use as a filter .. _$expr: https://docs.mongodb.com/manual/reference/operator/query/expr/ .. _$where: https://docs.mongodb.com/manual/reference/operator/query/where/ """ self.__check_okay_to_chain() if not isinstance(code, Code): code = Code(code) self.__spec["$where"] = code return self def collation(self, collation): """Adds a :class:`~pymongo.collation.Collation` to this query. This option is only supported on MongoDB 3.4 and above. Raises :exc:`TypeError` if `collation` is not an instance of :class:`~pymongo.collation.Collation` or a ``dict``. Raises :exc:`~pymongo.errors.InvalidOperation` if this :class:`Cursor` has already been used. Only the last collation applied to this cursor has any effect. :Parameters: - `collation`: An instance of :class:`~pymongo.collation.Collation`. """ self.__check_okay_to_chain() self.__collation = validate_collation_or_none(collation) return self def __send_message(self, operation): """Send a query or getmore operation and handles the response. If operation is ``None`` this is an exhaust cursor, which reads the next result batch off the exhaust socket instead of sending getMore messages to the server. Can raise ConnectionFailure. """ client = self.__collection.database.client # OP_MSG is required to support exhaust cursors with encryption. if client._encrypter and self.__exhaust: raise InvalidOperation( "exhaust cursors do not support auto encryption") try: response = client._run_operation_with_response( operation, self._unpack_response, exhaust=self.__exhaust, address=self.__address) except OperationFailure: self.__killed = True # Make sure exhaust socket is returned immediately, if necessary. self.__die() # If this is a tailable cursor the error is likely # due to capped collection roll over. Setting # self.__killed to True ensures Cursor.alive will be # False. No need to re-raise. if self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]: return raise except NotMasterError: # Don't send kill cursors to another server after a "not master" # error. It's completely pointless. self.__killed = True # Make sure exhaust socket is returned immediately, if necessary. self.__die() raise except ConnectionFailure: # Don't try to send kill cursors on another socket # or to another server. It can cause a _pinValue # assertion on some server releases if we get here # due to a socket timeout. self.__killed = True self.__die() raise except Exception: # Close the cursor self.__die() raise self.__address = response.address if self.__exhaust and not self.__exhaust_mgr: # 'response' is an ExhaustResponse. self.__exhaust_mgr = _SocketManager(response.socket_info, response.pool) cmd_name = operation.name docs = response.docs if response.from_command: if cmd_name != "explain": cursor = docs[0]['cursor'] self.__id = cursor['id'] if cmd_name == 'find': documents = cursor['firstBatch'] # Update the namespace used for future getMore commands. ns = cursor.get('ns') if ns: self.__dbname, self.__collname = ns.split('.', 1) else: documents = cursor['nextBatch'] self.__data = deque(documents) self.__retrieved += len(documents) else: self.__id = 0 self.__data = deque(docs) self.__retrieved += len(docs) else: self.__id = response.data.cursor_id self.__data = deque(docs) self.__retrieved += response.data.number_returned if self.__id == 0: self.__killed = True # Don't wait for garbage collection to call __del__, return the # socket and the session to the pool now. self.__die() if self.__limit and self.__id and self.__limit <= self.__retrieved: self.__die() def _unpack_response(self, response, cursor_id, codec_options, user_fields=None, legacy_response=False): return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _read_preference(self): if self.__read_preference is None: # Save the read preference for getMore commands. self.__read_preference = self.__collection._read_preference_for( self.session) return self.__read_preference def _refresh(self): """Refreshes the cursor with more data from Mongo. Returns the length of self.__data after refresh. Will exit early if self.__data is already non-empty. Raises OperationFailure when the cursor cannot be refreshed due to an error on the query. """ if len(self.__data) or self.__killed: return len(self.__data) if not self.__session: self.__session = self.__collection.database.client._ensure_session() if self.__id is None: # Query if (self.__min or self.__max) and not self.__hint: warnings.warn("using a min/max query operator without " "specifying a Cursor.hint is deprecated. A " "hint will be required when using min/max in " "PyMongo 4.0", DeprecationWarning, stacklevel=3) q = self._query_class(self.__query_flags, self.__collection.database.name, self.__collection.name, self.__skip, self.__query_spec(), self.__projection, self.__codec_options, self._read_preference(), self.__limit, self.__batch_size, self.__read_concern, self.__collation, self.__session, self.__collection.database.client, self.__allow_disk_use) self.__send_message(q) elif self.__id: # Get More if self.__limit: limit = self.__limit - self.__retrieved if self.__batch_size: limit = min(limit, self.__batch_size) else: limit = self.__batch_size # Exhaust cursors don't send getMore messages. g = self._getmore_class(self.__dbname, self.__collname, limit, self.__id, self.__codec_options, self._read_preference(), self.__session, self.__collection.database.client, self.__max_await_time_ms, self.__exhaust_mgr) self.__send_message(g) return len(self.__data) @property def alive(self): """Does this cursor have the potential to return more data? This is mostly useful with `tailable cursors `_ since they will stop iterating even though they *may* return more results in the future. With regular cursors, simply use a for loop instead of :attr:`alive`:: for doc in collection.find(): print(doc) .. note:: Even if :attr:`alive` is True, :meth:`next` can raise :exc:`StopIteration`. :attr:`alive` can also be True while iterating a cursor from a failed server. In this case :attr:`alive` will return False after :meth:`next` fails to retrieve the next batch of results from the server. """ return bool(len(self.__data) or (not self.__killed)) @property def cursor_id(self): """Returns the id of the cursor Useful if you need to manage cursor ids and want to handle killing cursors manually using :meth:`~pymongo.mongo_client.MongoClient.kill_cursors` .. versionadded:: 2.2 """ return self.__id @property def address(self): """The (host, port) of the server used, or None. .. versionchanged:: 3.0 Renamed from "conn_id". """ return self.__address @property def session(self): """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. .. versionadded:: 3.6 """ if self.__explicit_session: return self.__session def __iter__(self): return self def next(self): """Advance the cursor.""" if self.__empty: raise StopIteration if len(self.__data) or self._refresh(): if self.__manipulate: _db = self.__collection.database return _db._fix_outgoing(self.__data.popleft(), self.__collection) else: return self.__data.popleft() else: raise StopIteration __next__ = next def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __copy__(self): """Support function for `copy.copy()`. .. versionadded:: 2.4 """ return self._clone(deepcopy=False) def __deepcopy__(self, memo): """Support function for `copy.deepcopy()`. .. versionadded:: 2.4 """ return self._clone(deepcopy=True) def _deepcopy(self, x, memo=None): """Deepcopy helper for the data dictionary or list. Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ if not hasattr(x, 'items'): y, is_list, iterator = [], True, enumerate(x) else: y, is_list, iterator = {}, False, iteritems(x) if memo is None: memo = {} val_id = id(x) if val_id in memo: return memo.get(val_id) memo[val_id] = y for key, value in iterator: if isinstance(value, (dict, list)) and not isinstance(value, SON): value = self._deepcopy(value, memo) elif not isinstance(value, RE_TYPE): value = copy.deepcopy(value, memo) if is_list: y.append(value) else: if not isinstance(key, RE_TYPE): key = copy.deepcopy(key, memo) y[key] = value return y class RawBatchCursor(Cursor): """A cursor / iterator over raw batches of BSON data from a query result.""" _query_class = _RawBatchQuery _getmore_class = _RawBatchGetMore def __init__(self, *args, **kwargs): """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - see :meth:`~pymongo.collection.Collection.find_raw_batches` instead. .. mongodoc:: cursors """ manipulate = kwargs.get('manipulate') kwargs['manipulate'] = False super(RawBatchCursor, self).__init__(*args, **kwargs) # Throw only after cursor's initialized, to prevent errors in __del__. if manipulate: raise InvalidOperation( "Cannot use RawBatchCursor with manipulate=True") def _unpack_response(self, response, cursor_id, codec_options, user_fields=None, legacy_response=False): return response.raw_response(cursor_id) def explain(self): """Returns an explain plan record for this cursor. .. mongodoc:: explain """ clone = self._clone(deepcopy=True, base=Cursor(self.collection)) return clone.explain() def __getitem__(self, index): raise InvalidOperation("Cannot call __getitem__ on RawBatchCursor") pymongo-3.11.0/pymongo/cursor_manager.py000066400000000000000000000040501374256237000203450ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """DEPRECATED - A manager to handle when cursors are killed after they are closed. New cursor managers should be defined as subclasses of CursorManager and can be installed on a client by calling :meth:`~pymongo.mongo_client.MongoClient.set_cursor_manager`. .. versionchanged:: 3.3 Deprecated, for real this time. .. versionchanged:: 3.0 Undeprecated. :meth:`~pymongo.cursor_manager.CursorManager.close` now requires an `address` argument. The ``BatchCursorManager`` class is removed. """ import warnings import weakref from bson.py3compat import integer_types class CursorManager(object): """DEPRECATED - The cursor manager base class.""" def __init__(self, client): """Instantiate the manager. :Parameters: - `client`: a MongoClient """ warnings.warn( "Cursor managers are deprecated.", DeprecationWarning, stacklevel=2) self.__client = weakref.ref(client) def close(self, cursor_id, address): """Kill a cursor. Raises TypeError if cursor_id is not an instance of (int, long). :Parameters: - `cursor_id`: cursor id to close - `address`: the cursor's server's (host, port) pair .. versionchanged:: 3.0 Now requires an `address` argument. """ if not isinstance(cursor_id, integer_types): raise TypeError("cursor_id must be an integer") self.__client().kill_cursors([cursor_id], address) pymongo-3.11.0/pymongo/daemon.py000066400000000000000000000125301374256237000166030ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Support for spawning a daemon process. PyMongo only attempts to spawn the mongocryptd daemon process when automatic client-side field level encryption is enabled. See :ref:`automatic-client-side-encryption` for more info. """ import os import subprocess import sys import time # The maximum amount of time to wait for the intermediate subprocess. _WAIT_TIMEOUT = 10 _THIS_FILE = os.path.realpath(__file__) if sys.version_info[0] < 3: def _popen_wait(popen, timeout): """Implement wait timeout support for Python 2.""" from pymongo.monotonic import time as _time deadline = _time() + timeout # Initial delay of 1ms delay = .0005 while True: returncode = popen.poll() if returncode is not None: return returncode remaining = deadline - _time() if remaining <= 0: # Just return None instead of raising an error. return None delay = min(delay * 2, remaining, .5) time.sleep(delay) else: def _popen_wait(popen, timeout): """Implement wait timeout support for Python 3.""" try: return popen.wait(timeout=timeout) except subprocess.TimeoutExpired: # Silence TimeoutExpired errors. return None def _silence_resource_warning(popen): """Silence Popen's ResourceWarning. Note this should only be used if the process was created as a daemon. """ # Set the returncode to avoid this warning when popen is garbage collected: # "ResourceWarning: subprocess XXX is still running". # See https://bugs.python.org/issue38890 and # https://bugs.python.org/issue26741. popen.returncode = 0 if sys.platform == 'win32': # On Windows we spawn the daemon process simply by using DETACHED_PROCESS. _DETACHED_PROCESS = getattr(subprocess, 'DETACHED_PROCESS', 0x00000008) def _spawn_daemon(args): """Spawn a daemon process (Windows).""" with open(os.devnull, 'r+b') as devnull: popen = subprocess.Popen( args, creationflags=_DETACHED_PROCESS, stdin=devnull, stderr=devnull, stdout=devnull) _silence_resource_warning(popen) else: # On Unix we spawn the daemon process with a double Popen. # 1) The first Popen runs this file as a Python script using the current # interpreter. # 2) The script then decouples itself and performs the second Popen to # spawn the daemon process. # 3) The original process waits up to 10 seconds for the script to exit. # # Note that we do not call fork() directly because we want this procedure # to be safe to call from any thread. Using Popen instead of fork also # avoids triggering the application's os.register_at_fork() callbacks when # we spawn the mongocryptd daemon process. def _spawn(args): """Spawn the process and silence stdout/stderr.""" with open(os.devnull, 'r+b') as devnull: return subprocess.Popen( args, close_fds=True, stdin=devnull, stderr=devnull, stdout=devnull) def _spawn_daemon_double_popen(args): """Spawn a daemon process using a double subprocess.Popen.""" spawner_args = [sys.executable, _THIS_FILE] spawner_args.extend(args) temp_proc = subprocess.Popen(spawner_args, close_fds=True) # Reap the intermediate child process to avoid creating zombie # processes. _popen_wait(temp_proc, _WAIT_TIMEOUT) def _spawn_daemon(args): """Spawn a daemon process (Unix).""" # "If Python is unable to retrieve the real path to its executable, # sys.executable will be an empty string or None". if sys.executable: _spawn_daemon_double_popen(args) else: # Fallback to spawn a non-daemon process without silencing the # resource warning. We do not use fork here because it is not # safe to call from a thread on all systems. # Unfortunately, this means that: # 1) If the parent application is killed via Ctrl-C, the # non-daemon process will also be killed. # 2) Each non-daemon process will hang around as a zombie process # until the main application exits. _spawn(args) if __name__ == '__main__': # Attempt to start a new session to decouple from the parent. if hasattr(os, 'setsid'): try: os.setsid() except OSError: pass # We are performing a double fork (Popen) to spawn the process as a # daemon so it is safe to ignore the resource warning. _silence_resource_warning(_spawn(sys.argv[1:])) os._exit(0) pymongo-3.11.0/pymongo/database.py000066400000000000000000002026041374256237000171070ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Database level operations.""" import warnings from bson.code import Code from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.dbref import DBRef from bson.py3compat import iteritems, string_type, _unicode from bson.son import SON from pymongo import auth, common from pymongo.aggregation import _DatabaseAggregationCommand from pymongo.change_stream import DatabaseChangeStream from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor from pymongo.errors import (CollectionInvalid, ConfigurationError, InvalidName, OperationFailure) from pymongo.message import _first_batch from pymongo.read_preferences import ReadPreference from pymongo.son_manipulator import SONManipulator from pymongo.write_concern import DEFAULT_WRITE_CONCERN _INDEX_REGEX = {"name": {"$regex": r"^(?!.*\$)"}} _SYSTEM_FILTER = {"filter": {"name": {"$regex": r"^(?!system\.)"}}} def _check_name(name): """Check if a database name is valid. """ if not name: raise InvalidName("database name cannot be the empty string") for invalid_char in [' ', '.', '$', '/', '\\', '\x00', '"']: if invalid_char in name: raise InvalidName("database names cannot contain the " "character %r" % invalid_char) class Database(common.BaseObject): """A Mongo database. """ def __init__(self, client, name, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get a database by client and name. Raises :class:`TypeError` if `name` is not an instance of :class:`basestring` (:class:`str` in python 3). Raises :class:`~pymongo.errors.InvalidName` if `name` is not a valid database name. :Parameters: - `client`: A :class:`~pymongo.mongo_client.MongoClient` instance. - `name`: The database name. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) client.codec_options is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) client.read_preference is used. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) client.write_concern is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) client.read_concern is used. .. mongodoc:: databases .. versionchanged:: 3.2 Added the read_concern option. .. versionchanged:: 3.0 Added the codec_options, read_preference, and write_concern options. :class:`~pymongo.database.Database` no longer returns an instance of :class:`~pymongo.collection.Collection` for attribute names with leading underscores. You must use dict-style lookups instead:: db['__my_collection__'] Not: db.__my_collection__ """ super(Database, self).__init__( codec_options or client.codec_options, read_preference or client.read_preference, write_concern or client.write_concern, read_concern or client.read_concern) if not isinstance(name, string_type): raise TypeError("name must be an instance " "of %s" % (string_type.__name__,)) if name != '$external': _check_name(name) self.__name = _unicode(name) self.__client = client self.__incoming_manipulators = [] self.__incoming_copying_manipulators = [] self.__outgoing_manipulators = [] self.__outgoing_copying_manipulators = [] def add_son_manipulator(self, manipulator): """Add a new son manipulator to this database. **DEPRECATED** - `add_son_manipulator` is deprecated. .. versionchanged:: 3.0 Deprecated add_son_manipulator. """ warnings.warn("add_son_manipulator is deprecated", DeprecationWarning, stacklevel=2) base = SONManipulator() def method_overwritten(instance, method): """Test if this method has been overridden.""" return (getattr( instance, method).__func__ != getattr(base, method).__func__) if manipulator.will_copy(): if method_overwritten(manipulator, "transform_incoming"): self.__incoming_copying_manipulators.insert(0, manipulator) if method_overwritten(manipulator, "transform_outgoing"): self.__outgoing_copying_manipulators.insert(0, manipulator) else: if method_overwritten(manipulator, "transform_incoming"): self.__incoming_manipulators.insert(0, manipulator) if method_overwritten(manipulator, "transform_outgoing"): self.__outgoing_manipulators.insert(0, manipulator) @property def system_js(self): """**DEPRECATED**: :class:`SystemJS` helper for this :class:`Database`. See the documentation for :class:`SystemJS` for more details. """ return SystemJS(self) @property def client(self): """The client instance for this :class:`Database`.""" return self.__client @property def name(self): """The name of this :class:`Database`.""" return self.__name @property def incoming_manipulators(self): """**DEPRECATED**: All incoming SON manipulators. .. versionchanged:: 3.5 Deprecated. .. versionadded:: 2.0 """ warnings.warn("Database.incoming_manipulators() is deprecated", DeprecationWarning, stacklevel=2) return [manipulator.__class__.__name__ for manipulator in self.__incoming_manipulators] @property def incoming_copying_manipulators(self): """**DEPRECATED**: All incoming SON copying manipulators. .. versionchanged:: 3.5 Deprecated. .. versionadded:: 2.0 """ warnings.warn("Database.incoming_copying_manipulators() is deprecated", DeprecationWarning, stacklevel=2) return [manipulator.__class__.__name__ for manipulator in self.__incoming_copying_manipulators] @property def outgoing_manipulators(self): """**DEPRECATED**: All outgoing SON manipulators. .. versionchanged:: 3.5 Deprecated. .. versionadded:: 2.0 """ warnings.warn("Database.outgoing_manipulators() is deprecated", DeprecationWarning, stacklevel=2) return [manipulator.__class__.__name__ for manipulator in self.__outgoing_manipulators] @property def outgoing_copying_manipulators(self): """**DEPRECATED**: All outgoing SON copying manipulators. .. versionchanged:: 3.5 Deprecated. .. versionadded:: 2.0 """ warnings.warn("Database.outgoing_copying_manipulators() is deprecated", DeprecationWarning, stacklevel=2) return [manipulator.__class__.__name__ for manipulator in self.__outgoing_copying_manipulators] def with_options(self, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get a clone of this database changing the specified settings. >>> db1.read_preference Primary() >>> from pymongo import ReadPreference >>> db2 = db1.with_options(read_preference=ReadPreference.SECONDARY) >>> db1.read_preference Primary() >>> db2.read_preference Secondary(tag_sets=None) :Parameters: - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`Collection` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`Collection` is used. See :mod:`~pymongo.read_preferences` for options. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`Collection` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`Collection` is used. .. versionadded:: 3.8 """ return Database(self.client, self.__name, codec_options or self.codec_options, read_preference or self.read_preference, write_concern or self.write_concern, read_concern or self.read_concern) def __eq__(self, other): if isinstance(other, Database): return (self.__client == other.client and self.__name == other.name) return NotImplemented def __ne__(self, other): return not self == other def __repr__(self): return "Database(%r, %r)" % (self.__client, self.__name) def __getattr__(self, name): """Get a collection of this database by name. Raises InvalidName if an invalid collection name is used. :Parameters: - `name`: the name of the collection to get """ if name.startswith('_'): raise AttributeError( "Database has no attribute %r. To access the %s" " collection, use database[%r]." % (name, name, name)) return self.__getitem__(name) def __getitem__(self, name): """Get a collection of this database by name. Raises InvalidName if an invalid collection name is used. :Parameters: - `name`: the name of the collection to get """ return Collection(self, name) def get_collection(self, name, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get a :class:`~pymongo.collection.Collection` with the given name and options. Useful for creating a :class:`~pymongo.collection.Collection` with different codec options, read preference, and/or write concern from this :class:`Database`. >>> db.read_preference Primary() >>> coll1 = db.test >>> coll1.read_preference Primary() >>> from pymongo import ReadPreference >>> coll2 = db.get_collection( ... 'test', read_preference=ReadPreference.SECONDARY) >>> coll2.read_preference Secondary(tag_sets=None) :Parameters: - `name`: The name of the collection - a string. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`Database` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`Database` is used. See :mod:`~pymongo.read_preferences` for options. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`Database` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`Database` is used. """ return Collection( self, name, False, codec_options, read_preference, write_concern, read_concern) def create_collection(self, name, codec_options=None, read_preference=None, write_concern=None, read_concern=None, session=None, **kwargs): """Create a new :class:`~pymongo.collection.Collection` in this database. Normally collection creation is automatic. This method should only be used to specify options on creation. :class:`~pymongo.errors.CollectionInvalid` will be raised if the collection already exists. Options should be passed as keyword arguments to this method. Supported options vary with MongoDB release. Some examples include: - "size": desired initial size for the collection (in bytes). For capped collections this size is the max size of the collection. - "capped": if True, this is a capped collection - "max": maximum number of objects if capped (optional) See the MongoDB documentation for a full list of supported options by server version. :Parameters: - `name`: the name of the collection to create - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`Database` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`Database` is used. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`Database` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`Database` is used. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional keyword arguments will be passed as options for the create collection command .. versionchanged:: 3.11 This method is now supported inside multi-document transactions with MongoDB 4.4+. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Added the collation option. .. versionchanged:: 3.0 Added the codec_options, read_preference, and write_concern options. .. versionchanged:: 2.2 Removed deprecated argument: options """ with self.__client._tmp_session(session) as s: # Skip this check in a transaction where listCollections is not # supported. if ((not s or not s.in_transaction) and name in self.list_collection_names( filter={"name": name}, session=s)): raise CollectionInvalid("collection %s already exists" % name) return Collection(self, name, True, codec_options, read_preference, write_concern, read_concern, session=s, **kwargs) def _apply_incoming_manipulators(self, son, collection): """Apply incoming manipulators to `son`.""" for manipulator in self.__incoming_manipulators: son = manipulator.transform_incoming(son, collection) return son def _apply_incoming_copying_manipulators(self, son, collection): """Apply incoming copying manipulators to `son`.""" for manipulator in self.__incoming_copying_manipulators: son = manipulator.transform_incoming(son, collection) return son def _fix_incoming(self, son, collection): """Apply manipulators to an incoming SON object before it gets stored. :Parameters: - `son`: the son object going into the database - `collection`: the collection the son object is being saved in """ son = self._apply_incoming_manipulators(son, collection) son = self._apply_incoming_copying_manipulators(son, collection) return son def _fix_outgoing(self, son, collection): """Apply manipulators to a SON object as it comes out of the database. :Parameters: - `son`: the son object coming out of the database - `collection`: the collection the son object was saved in """ for manipulator in reversed(self.__outgoing_manipulators): son = manipulator.transform_outgoing(son, collection) for manipulator in reversed(self.__outgoing_copying_manipulators): son = manipulator.transform_outgoing(son, collection) return son def aggregate(self, pipeline, session=None, **kwargs): """Perform a database-level aggregation. See the `aggregation pipeline`_ documentation for a list of stages that are supported. Introduced in MongoDB 3.6. .. code-block:: python # Lists all operations currently running on the server. with client.admin.aggregate([{"$currentOp": {}}]) as cursor: for operation in cursor: print(operation) All optional `aggregate command`_ parameters should be passed as keyword arguments to this method. Valid options include, but are not limited to: - `allowDiskUse` (bool): Enables writing to temporary files. When set to True, aggregation stages can write data to the _tmp subdirectory of the --dbpath directory. The default is False. - `maxTimeMS` (int): The maximum amount of time to allow the operation to run in milliseconds. - `batchSize` (int): The maximum number of documents to return per batch. Ignored if the connected mongod or mongos does not support returning aggregate results using a cursor. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. The :meth:`aggregate` method obeys the :attr:`read_preference` of this :class:`Database`, except when ``$out`` or ``$merge`` are used, in which case :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY` is used. .. note:: This method does not support the 'explain' option. Please use :meth:`~pymongo.database.Database.command` instead. .. note:: The :attr:`~pymongo.database.Database.write_concern` of this collection is automatically applied to this operation. :Parameters: - `pipeline`: a list of aggregation pipeline stages - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): See list of options above. :Returns: A :class:`~pymongo.command_cursor.CommandCursor` over the result set. .. versionadded:: 3.9 .. _aggregation pipeline: https://docs.mongodb.com/manual/reference/operator/aggregation-pipeline .. _aggregate command: https://docs.mongodb.com/manual/reference/command/aggregate """ with self.client._tmp_session(session, close=False) as s: cmd = _DatabaseAggregationCommand( self, CommandCursor, pipeline, kwargs, session is not None, user_fields={'cursor': {'firstBatch': 1}}) return self.client._retryable_read( cmd.get_cursor, cmd.get_read_preference(s), s, retryable=not cmd._performs_write) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, start_at_operation_time=None, session=None, start_after=None): """Watch changes on this database. Performs an aggregation with an implicit initial ``$changeStream`` stage and returns a :class:`~pymongo.change_stream.DatabaseChangeStream` cursor which iterates over changes on all collections in this database. Introduced in MongoDB 4.0. .. code-block:: python with db.watch() as stream: for change in stream: print(change) The :class:`~pymongo.change_stream.DatabaseChangeStream` iterable blocks until the next change document is returned or an error is raised. If the :meth:`~pymongo.change_stream.DatabaseChangeStream.next` method encounters a network error when retrieving a batch from the server, it will automatically attempt to recreate the cursor such that no change events are missed. Any error encountered during the resume attempt indicates there may be an outage and will be raised. .. code-block:: python try: with db.watch( [{'$match': {'operationType': 'insert'}}]) as stream: for insert_change in stream: print(insert_change) except pymongo.errors.PyMongoError: # The ChangeStream encountered an unrecoverable error or the # resume attempt failed to recreate the cursor. logging.error('...') For a precise description of the resume process see the `change streams specification`_. :Parameters: - `pipeline` (optional): A list of aggregation pipeline stages to append to an initial ``$changeStream`` stage. Not all pipeline stages are valid after a ``$changeStream`` stage, see the MongoDB documentation on change streams for the supported stages. - `full_document` (optional): The fullDocument to pass as an option to the ``$changeStream`` stage. Allowed values: 'updateLookup'. When set to 'updateLookup', the change notification for partial updates will include both a delta describing the changes to the document, as well as a copy of the entire document that was changed from some time after the change occurred. - `resume_after` (optional): A resume token. If provided, the change stream will start returning changes that occur directly after the operation specified in the resume token. A resume token is the _id value of a change document. - `max_await_time_ms` (optional): The maximum time in milliseconds for the server to wait for changes before responding to a getMore operation. - `batch_size` (optional): The maximum number of documents to return per batch. - `collation` (optional): The :class:`~pymongo.collation.Collation` to use for the aggregation. - `start_at_operation_time` (optional): If provided, the resulting change stream will only return changes that occurred at or after the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `start_after` (optional): The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. :Returns: A :class:`~pymongo.change_stream.DatabaseChangeStream` cursor. .. versionchanged:: 3.9 Added the ``start_after`` parameter. .. versionadded:: 3.7 .. mongodoc:: changeStreams .. _change streams specification: https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return DatabaseChangeStream( self, pipeline, full_document, resume_after, max_await_time_ms, batch_size, collation, start_at_operation_time, session, start_after) def _command(self, sock_info, command, slave_ok=False, value=1, check=True, allowable_errors=None, read_preference=ReadPreference.PRIMARY, codec_options=DEFAULT_CODEC_OPTIONS, write_concern=None, parse_write_concern_error=False, session=None, **kwargs): """Internal command helper.""" if isinstance(command, string_type): command = SON([(command, value)]) command.update(kwargs) with self.__client._tmp_session(session) as s: return sock_info.command( self.__name, command, slave_ok, read_preference, codec_options, check, allowable_errors, write_concern=write_concern, parse_write_concern_error=parse_write_concern_error, session=s, client=self.__client) def command(self, command, value=1, check=True, allowable_errors=None, read_preference=None, codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): """Issue a MongoDB command. Send command `command` to the database and return the response. If `command` is an instance of :class:`basestring` (:class:`str` in python 3) then the command {`command`: `value`} will be sent. Otherwise, `command` must be an instance of :class:`dict` and will be sent as is. Any additional keyword arguments will be added to the final command document before it is sent. For example, a command like ``{buildinfo: 1}`` can be sent using: >>> db.command("buildinfo") For a command where the value matters, like ``{collstats: collection_name}`` we can do: >>> db.command("collstats", collection_name) For commands that take additional arguments we can use kwargs. So ``{filemd5: object_id, root: file_root}`` becomes: >>> db.command("filemd5", object_id, root=file_root) :Parameters: - `command`: document representing the command to be issued, or the name of the command (for simple commands only). .. note:: the order of keys in the `command` document is significant (the "verb" must come first), so commands which require multiple keys (e.g. `findandmodify`) should use an instance of :class:`~bson.son.SON` or a string and kwargs instead of a Python `dict`. - `value` (optional): value to use for the command verb when `command` is passed as a string - `check` (optional): check the response for errors, raising :class:`~pymongo.errors.OperationFailure` if there are any - `allowable_errors`: if `check` is ``True``, error messages in this list will be ignored by error-checking - `read_preference` (optional): The read preference for this operation. See :mod:`~pymongo.read_preferences` for options. If the provided `session` is in a transaction, defaults to the read preference configured for the transaction. Otherwise, defaults to :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - `codec_options`: A :class:`~bson.codec_options.CodecOptions` instance. - `session` (optional): A :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): additional keyword arguments will be added to the command document before it is sent .. note:: :meth:`command` does **not** obey this Database's :attr:`read_preference` or :attr:`codec_options`. You must use the `read_preference` and `codec_options` parameters instead. .. note:: :meth:`command` does **not** apply any custom TypeDecoders when decoding the command response. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.0 Removed the `as_class`, `fields`, `uuid_subtype`, `tag_sets`, and `secondary_acceptable_latency_ms` option. Removed `compile_re` option: PyMongo now always represents BSON regular expressions as :class:`~bson.regex.Regex` objects. Use :meth:`~bson.regex.Regex.try_compile` to attempt to convert from a BSON regular expression to a Python regular expression object. Added the `codec_options` parameter. .. versionchanged:: 2.7 Added `compile_re` option. If set to False, PyMongo represented BSON regular expressions as :class:`~bson.regex.Regex` objects instead of attempting to compile BSON regular expressions as Python native regular expressions, thus preventing errors for some incompatible patterns, see `PYTHON-500`_. .. versionchanged:: 2.3 Added `tag_sets` and `secondary_acceptable_latency_ms` options. .. versionchanged:: 2.2 Added support for `as_class` - the class you want to use for the resulting documents .. _PYTHON-500: https://jira.mongodb.org/browse/PYTHON-500 .. mongodoc:: commands """ if read_preference is None: read_preference = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) with self.__client._socket_for_reads( read_preference, session) as (sock_info, slave_ok): return self._command(sock_info, command, slave_ok, value, check, allowable_errors, read_preference, codec_options, session=session, **kwargs) def _retryable_read_command(self, command, value=1, check=True, allowable_errors=None, read_preference=None, codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): """Same as command but used for retryable read commands.""" if read_preference is None: read_preference = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) def _cmd(session, server, sock_info, slave_ok): return self._command(sock_info, command, slave_ok, value, check, allowable_errors, read_preference, codec_options, session=session, **kwargs) return self.__client._retryable_read( _cmd, read_preference, session) def _list_collections(self, sock_info, slave_okay, session, read_preference, **kwargs): """Internal listCollections helper.""" coll = self.get_collection( "$cmd", read_preference=read_preference) if sock_info.max_wire_version > 2: cmd = SON([("listCollections", 1), ("cursor", {})]) cmd.update(kwargs) with self.__client._tmp_session( session, close=False) as tmp_session: cursor = self._command( sock_info, cmd, slave_okay, read_preference=read_preference, session=tmp_session)["cursor"] return CommandCursor( coll, cursor, sock_info.address, session=tmp_session, explicit_session=session is not None) else: match = _INDEX_REGEX if "filter" in kwargs: match = {"$and": [_INDEX_REGEX, kwargs["filter"]]} dblen = len(self.name.encode("utf8") + b".") pipeline = [ {"$project": {"name": {"$substr": ["$name", dblen, -1]}, "options": 1}}, {"$match": match} ] cmd = SON([("aggregate", "system.namespaces"), ("pipeline", pipeline), ("cursor", kwargs.get("cursor", {}))]) cursor = self._command(sock_info, cmd, slave_okay)["cursor"] return CommandCursor(coll, cursor, sock_info.address) def list_collections(self, session=None, filter=None, **kwargs): """Get a cursor over the collectons of this database. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `filter` (optional): A query document to filter the list of collections returned from the listCollections command. - `**kwargs` (optional): Optional parameters of the `listCollections command `_ can be passed as keyword arguments to this method. The supported options differ by server version. :Returns: An instance of :class:`~pymongo.command_cursor.CommandCursor`. .. versionadded:: 3.6 """ if filter is not None: kwargs['filter'] = filter read_pref = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) def _cmd(session, server, sock_info, slave_okay): return self._list_collections( sock_info, slave_okay, session, read_preference=read_pref, **kwargs) return self.__client._retryable_read( _cmd, read_pref, session) def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. For example, to list all non-system collections:: filter = {"name": {"$regex": r"^(?!system\\.)"}} db.list_collection_names(filter=filter) :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `filter` (optional): A query document to filter the list of collections returned from the listCollections command. - `**kwargs` (optional): Optional parameters of the `listCollections command `_ can be passed as keyword arguments to this method. The supported options differ by server version. .. versionchanged:: 3.8 Added the ``filter`` and ``**kwargs`` parameters. .. versionadded:: 3.6 """ if filter is None: kwargs["nameOnly"] = True else: # The enumerate collections spec states that "drivers MUST NOT set # nameOnly if a filter specifies any keys other than name." common.validate_is_mapping("filter", filter) kwargs["filter"] = filter if not filter or (len(filter) == 1 and "name" in filter): kwargs["nameOnly"] = True return [result["name"] for result in self.list_collections(session=session, **kwargs)] def collection_names(self, include_system_collections=True, session=None): """**DEPRECATED**: Get a list of all the collection names in this database. :Parameters: - `include_system_collections` (optional): if ``False`` list will not include system collections (e.g ``system.indexes``) - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.7 Deprecated. Use :meth:`list_collection_names` instead. .. versionchanged:: 3.6 Added ``session`` parameter. """ warnings.warn("collection_names is deprecated. Use " "list_collection_names instead.", DeprecationWarning, stacklevel=2) kws = {} if include_system_collections else _SYSTEM_FILTER return [result["name"] for result in self.list_collections(session=session, nameOnly=True, **kws)] def drop_collection(self, name_or_collection, session=None): """Drop a collection. :Parameters: - `name_or_collection`: the name of a collection to drop or the collection object itself - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. note:: The :attr:`~pymongo.database.Database.write_concern` of this database is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.6 Added ``session`` parameter. .. versionchanged:: 3.4 Apply this database's write concern automatically to this operation when connected to MongoDB >= 3.4. """ name = name_or_collection if isinstance(name, Collection): name = name.name if not isinstance(name, string_type): raise TypeError("name_or_collection must be an " "instance of %s" % (string_type.__name__,)) self.__client._purge_index(self.__name, name) with self.__client._socket_for_writes(session) as sock_info: return self._command( sock_info, 'drop', value=_unicode(name), allowable_errors=['ns not found', 26], write_concern=self._write_concern_for(session), parse_write_concern_error=True, session=session) def validate_collection(self, name_or_collection, scandata=False, full=False, session=None, background=None): """Validate a collection. Returns a dict of validation info. Raises CollectionInvalid if validation fails. See also the MongoDB documentation on the `validate command`_. :Parameters: - `name_or_collection`: A Collection object or the name of a collection to validate. - `scandata`: Do extra checks beyond checking the overall structure of the collection. - `full`: Have the server do a more thorough scan of the collection. Use with `scandata` for a thorough scan of the structure of the collection and the individual documents. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `background` (optional): A boolean flag that determines whether the command runs in the background. Requires MongoDB 4.4+. .. versionchanged:: 3.11 Added ``background`` parameter. .. versionchanged:: 3.6 Added ``session`` parameter. .. _validate command: https://docs.mongodb.com/manual/reference/command/validate/ """ name = name_or_collection if isinstance(name, Collection): name = name.name if not isinstance(name, string_type): raise TypeError("name_or_collection must be an instance of " "%s or Collection" % (string_type.__name__,)) cmd = SON([("validate", _unicode(name)), ("scandata", scandata), ("full", full)]) if background is not None: cmd["background"] = background result = self.command(cmd, session=session) valid = True # Pre 1.9 results if "result" in result: info = result["result"] if info.find("exception") != -1 or info.find("corrupt") != -1: raise CollectionInvalid("%s invalid: %s" % (name, info)) # Sharded results elif "raw" in result: for _, res in iteritems(result["raw"]): if "result" in res: info = res["result"] if (info.find("exception") != -1 or info.find("corrupt") != -1): raise CollectionInvalid("%s invalid: " "%s" % (name, info)) elif not res.get("valid", False): valid = False break # Post 1.9 non-sharded results. elif not result.get("valid", False): valid = False if not valid: raise CollectionInvalid("%s invalid: %r" % (name, result)) return result def _current_op(self, include_all=False, session=None): """Helper for running $currentOp.""" cmd = SON([("currentOp", 1), ("$all", include_all)]) with self.__client._socket_for_writes(session) as sock_info: if sock_info.max_wire_version >= 4: return self.__client.admin._command( sock_info, cmd, codec_options=self.codec_options, session=session) else: spec = {"$all": True} if include_all else {} return _first_batch(sock_info, "admin", "$cmd.sys.inprog", spec, -1, True, self.codec_options, ReadPreference.PRIMARY, cmd, self.client._event_listeners) def current_op(self, include_all=False, session=None): """**DEPRECATED**: Get information on operations currently running. Starting with MongoDB 3.6 this helper is obsolete. The functionality provided by this helper is available in MongoDB 3.6+ using the `$currentOp aggregation pipeline stage`_, which can be used with :meth:`aggregate`. Note that, while this helper can only return a single document limited to a 16MB result, :meth:`aggregate` returns a cursor avoiding that limitation. Users of MongoDB versions older than 3.6 can use the `currentOp command`_ directly:: # MongoDB 3.2 and 3.4 client.admin.command("currentOp") Or query the "inprog" virtual collection:: # MongoDB 2.6 and 3.0 client.admin["$cmd.sys.inprog"].find_one() :Parameters: - `include_all` (optional): if ``True`` also list currently idle operations in the result - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.9 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. .. _$currentOp aggregation pipeline stage: https://docs.mongodb.com/manual/reference/operator/aggregation/currentOp/ .. _currentOp command: https://docs.mongodb.com/manual/reference/command/currentOp/ """ warnings.warn("current_op() is deprecated. See the documentation for " "more information", DeprecationWarning, stacklevel=2) return self._current_op(include_all, session) def profiling_level(self, session=None): """Get the database's current profiling level. Returns one of (:data:`~pymongo.OFF`, :data:`~pymongo.SLOW_ONLY`, :data:`~pymongo.ALL`). :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. .. mongodoc:: profiling """ result = self.command("profile", -1, session=session) assert result["was"] >= 0 and result["was"] <= 2 return result["was"] def set_profiling_level(self, level, slow_ms=None, session=None): """Set the database's profiling level. :Parameters: - `level`: Specifies a profiling level, see list of possible values below. - `slow_ms`: Optionally modify the threshold for the profile to consider a query or operation. Even if the profiler is off queries slower than the `slow_ms` level will get written to the logs. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. Possible `level` values: +----------------------------+------------------------------------+ | Level | Setting | +============================+====================================+ | :data:`~pymongo.OFF` | Off. No profiling. | +----------------------------+------------------------------------+ | :data:`~pymongo.SLOW_ONLY` | On. Only includes slow operations. | +----------------------------+------------------------------------+ | :data:`~pymongo.ALL` | On. Includes all operations. | +----------------------------+------------------------------------+ Raises :class:`ValueError` if level is not one of (:data:`~pymongo.OFF`, :data:`~pymongo.SLOW_ONLY`, :data:`~pymongo.ALL`). .. versionchanged:: 3.6 Added ``session`` parameter. .. mongodoc:: profiling """ if not isinstance(level, int) or level < 0 or level > 2: raise ValueError("level must be one of (OFF, SLOW_ONLY, ALL)") if slow_ms is not None and not isinstance(slow_ms, int): raise TypeError("slow_ms must be an integer") if slow_ms is not None: self.command("profile", level, slowms=slow_ms, session=session) else: self.command("profile", level, session=session) def profiling_info(self, session=None): """Returns a list containing current profiling information. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. .. mongodoc:: profiling """ return list(self["system.profile"].find(session=session)) def error(self): """**DEPRECATED**: Get the error if one occurred on the last operation. This method is obsolete: all MongoDB write operations (insert, update, remove, and so on) use the write concern ``w=1`` and report their errors by default. .. versionchanged:: 2.8 Deprecated. """ warnings.warn("Database.error() is deprecated", DeprecationWarning, stacklevel=2) error = self.command("getlasterror") error_msg = error.get("err", "") if error_msg is None: return None if error_msg.startswith("not master"): # Reset primary server and request check, if another thread isn't # doing so already. primary = self.__client.primary if primary: self.__client._handle_getlasterror(primary, error_msg) return error def last_status(self): """**DEPRECATED**: Get status information from the last operation. This method is obsolete: all MongoDB write operations (insert, update, remove, and so on) use the write concern ``w=1`` and report their errors by default. Returns a SON object with status information. .. versionchanged:: 2.8 Deprecated. """ warnings.warn("last_status() is deprecated", DeprecationWarning, stacklevel=2) return self.command("getlasterror") def previous_error(self): """**DEPRECATED**: Get the most recent error on this database. This method is obsolete: all MongoDB write operations (insert, update, remove, and so on) use the write concern ``w=1`` and report their errors by default. Only returns errors that have occurred since the last call to :meth:`reset_error_history`. Returns None if no such errors have occurred. .. versionchanged:: 2.8 Deprecated. """ warnings.warn("previous_error() is deprecated", DeprecationWarning, stacklevel=2) error = self.command("getpreverror") if error.get("err", 0) is None: return None return error def reset_error_history(self): """**DEPRECATED**: Reset the error history of this database. This method is obsolete: all MongoDB write operations (insert, update, remove, and so on) use the write concern ``w=1`` and report their errors by default. Calls to :meth:`previous_error` will only return errors that have occurred since the most recent call to this method. .. versionchanged:: 2.8 Deprecated. """ warnings.warn("reset_error_history() is deprecated", DeprecationWarning, stacklevel=2) self.command("reseterror") def __iter__(self): return self def __next__(self): raise TypeError("'Database' object is not iterable") next = __next__ def _default_role(self, read_only): """Return the default user role for this database.""" if self.name == "admin": if read_only: return "readAnyDatabase" else: return "root" else: if read_only: return "read" else: return "dbOwner" def _create_or_update_user( self, create, name, password, read_only, session=None, **kwargs): """Use a command to create (if create=True) or modify a user. """ opts = {} if read_only or (create and "roles" not in kwargs): warnings.warn("Creating a user with the read_only option " "or without roles is deprecated in MongoDB " ">= 2.6", DeprecationWarning) opts["roles"] = [self._default_role(read_only)] if read_only: warnings.warn("The read_only option is deprecated in MongoDB " ">= 2.6, use 'roles' instead", DeprecationWarning) if password is not None: if "digestPassword" in kwargs: raise ConfigurationError("The digestPassword option is not " "supported via add_user. Please use " "db.command('createUser', ...) " "instead for this option.") opts["pwd"] = password # Don't send {} as writeConcern. if self.write_concern.acknowledged and self.write_concern.document: opts["writeConcern"] = self.write_concern.document opts.update(kwargs) if create: command_name = "createUser" else: command_name = "updateUser" self.command(command_name, name, session=session, **opts) def add_user(self, name, password=None, read_only=None, session=None, **kwargs): """**DEPRECATED**: Create user `name` with password `password`. Add a new user with permissions for this :class:`Database`. .. note:: Will change the password if user `name` already exists. .. note:: add_user is deprecated and will be removed in PyMongo 4.0. Starting with MongoDB 2.6 user management is handled with four database commands, createUser_, usersInfo_, updateUser_, and dropUser_. To create a user:: db.command("createUser", "admin", pwd="password", roles=["root"]) To create a read-only user:: db.command("createUser", "user", pwd="password", roles=["read"]) To change a password:: db.command("updateUser", "user", pwd="newpassword") Or change roles:: db.command("updateUser", "user", roles=["readWrite"]) .. _createUser: https://docs.mongodb.com/manual/reference/command/createUser/ .. _usersInfo: https://docs.mongodb.com/manual/reference/command/usersInfo/ .. _updateUser: https://docs.mongodb.com/manual/reference/command/updateUser/ .. _dropUser: https://docs.mongodb.com/manual/reference/command/createUser/ .. warning:: Never create or modify users over an insecure network without the use of TLS. See :doc:`/examples/tls` for more information. :Parameters: - `name`: the name of the user to create - `password` (optional): the password of the user to create. Can not be used with the ``userSource`` argument. - `read_only` (optional): if ``True`` the user will be read only - `**kwargs` (optional): optional fields for the user document (e.g. ``userSource``, ``otherDBRoles``, or ``roles``). See ``_ for more information. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.7 Added support for SCRAM-SHA-256 users with MongoDB 4.0 and later. .. versionchanged:: 3.6 Added ``session`` parameter. Deprecated add_user. .. versionchanged:: 2.5 Added kwargs support for optional fields introduced in MongoDB 2.4 .. versionchanged:: 2.2 Added support for read only users """ warnings.warn("add_user is deprecated and will be removed in PyMongo " "4.0. Use db.command with createUser or updateUser " "instead", DeprecationWarning, stacklevel=2) if not isinstance(name, string_type): raise TypeError("name must be an " "instance of %s" % (string_type.__name__,)) if password is not None: if not isinstance(password, string_type): raise TypeError("password must be an " "instance of %s" % (string_type.__name__,)) if len(password) == 0: raise ValueError("password can't be empty") if read_only is not None: read_only = common.validate_boolean('read_only', read_only) if 'roles' in kwargs: raise ConfigurationError("Can not use " "read_only and roles together") try: uinfo = self.command("usersInfo", name, session=session) # Create the user if not found in uinfo, otherwise update one. self._create_or_update_user( (not uinfo["users"]), name, password, read_only, session=session, **kwargs) except OperationFailure as exc: # Unauthorized. Attempt to create the user in case of # localhost exception. if exc.code == 13: self._create_or_update_user( True, name, password, read_only, session=session, **kwargs) else: raise def remove_user(self, name, session=None): """**DEPRECATED**: Remove user `name` from this :class:`Database`. User `name` will no longer have permissions to access this :class:`Database`. .. note:: remove_user is deprecated and will be removed in PyMongo 4.0. Use the dropUser command instead:: db.command("dropUser", "user") :Parameters: - `name`: the name of the user to remove - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. Deprecated remove_user. """ warnings.warn("remove_user is deprecated and will be removed in " "PyMongo 4.0. Use db.command with dropUser " "instead", DeprecationWarning, stacklevel=2) cmd = SON([("dropUser", name)]) # Don't send {} as writeConcern. if self.write_concern.acknowledged and self.write_concern.document: cmd["writeConcern"] = self.write_concern.document self.command(cmd, session=session) def authenticate(self, name=None, password=None, source=None, mechanism='DEFAULT', **kwargs): """**DEPRECATED**: Authenticate to use this database. .. warning:: Starting in MongoDB 3.6, calling :meth:`authenticate` invalidates all existing cursors. It may also leave logical sessions open on the server for up to 30 minutes until they time out. Authentication lasts for the life of the underlying client instance, or until :meth:`logout` is called. Raises :class:`TypeError` if (required) `name`, (optional) `password`, or (optional) `source` is not an instance of :class:`basestring` (:class:`str` in python 3). .. note:: - This method authenticates the current connection, and will also cause all new :class:`~socket.socket` connections in the underlying client instance to be authenticated automatically. - Authenticating more than once on the same database with different credentials is not supported. You must call :meth:`logout` before authenticating with new credentials. - When sharing a client instance between multiple threads, all threads will share the authentication. If you need different authentication profiles for different purposes you must use distinct client instances. :Parameters: - `name`: the name of the user to authenticate. Optional when `mechanism` is MONGODB-X509 and the MongoDB server version is >= 3.4. - `password` (optional): the password of the user to authenticate. Not used with GSSAPI or MONGODB-X509 authentication. - `source` (optional): the database to authenticate on. If not specified the current database is used. - `mechanism` (optional): See :data:`~pymongo.auth.MECHANISMS` for options. If no mechanism is specified, PyMongo automatically uses MONGODB-CR when connected to a pre-3.0 version of MongoDB, SCRAM-SHA-1 when connected to MongoDB 3.0 through 3.6, and negotiates the mechanism to use (SCRAM-SHA-1 or SCRAM-SHA-256) when connected to MongoDB 4.0+. - `authMechanismProperties` (optional): Used to specify authentication mechanism specific options. To specify the service name for GSSAPI authentication pass ``authMechanismProperties='SERVICE_NAME:'``. To specify the session token for MONGODB-AWS authentication pass ``authMechanismProperties='AWS_SESSION_TOKEN:'``. .. versionchanged:: 3.7 Added support for SCRAM-SHA-256 with MongoDB 4.0 and later. .. versionchanged:: 3.5 Deprecated. Authenticating multiple users conflicts with support for logical sessions in MongoDB 3.6. To authenticate as multiple users, create multiple instances of MongoClient. .. versionadded:: 2.8 Use SCRAM-SHA-1 with MongoDB 3.0 and later. .. versionchanged:: 2.5 Added the `source` and `mechanism` parameters. :meth:`authenticate` now raises a subclass of :class:`~pymongo.errors.PyMongoError` if authentication fails due to invalid credentials or configuration issues. .. mongodoc:: authenticate """ if name is not None and not isinstance(name, string_type): raise TypeError("name must be an " "instance of %s" % (string_type.__name__,)) if password is not None and not isinstance(password, string_type): raise TypeError("password must be an " "instance of %s" % (string_type.__name__,)) if source is not None and not isinstance(source, string_type): raise TypeError("source must be an " "instance of %s" % (string_type.__name__,)) common.validate_auth_mechanism('mechanism', mechanism) validated_options = common._CaseInsensitiveDictionary() for option, value in iteritems(kwargs): normalized, val = common.validate_auth_option(option, value) validated_options[normalized] = val credentials = auth._build_credentials_tuple( mechanism, source, name, password, validated_options, self.name) self.client._cache_credentials( self.name, credentials, connect=True) return True def logout(self): """**DEPRECATED**: Deauthorize use of this database. .. warning:: Starting in MongoDB 3.6, calling :meth:`logout` invalidates all existing cursors. It may also leave logical sessions open on the server for up to 30 minutes until they time out. """ warnings.warn("Database.logout() is deprecated", DeprecationWarning, stacklevel=2) # Sockets will be deauthenticated as they are used. self.client._purge_credentials(self.name) def dereference(self, dbref, session=None, **kwargs): """Dereference a :class:`~bson.dbref.DBRef`, getting the document it points to. Raises :class:`TypeError` if `dbref` is not an instance of :class:`~bson.dbref.DBRef`. Returns a document, or ``None`` if the reference does not point to a valid document. Raises :class:`ValueError` if `dbref` has a database specified that is different from the current database. :Parameters: - `dbref`: the reference - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): any additional keyword arguments are the same as the arguments to :meth:`~pymongo.collection.Collection.find`. .. versionchanged:: 3.6 Added ``session`` parameter. """ if not isinstance(dbref, DBRef): raise TypeError("cannot dereference a %s" % type(dbref)) if dbref.database is not None and dbref.database != self.__name: raise ValueError("trying to dereference a DBRef that points to " "another database (%r not %r)" % (dbref.database, self.__name)) return self[dbref.collection].find_one( {"_id": dbref.id}, session=session, **kwargs) def eval(self, code, *args): """**DEPRECATED**: Evaluate a JavaScript expression in MongoDB. :Parameters: - `code`: string representation of JavaScript code to be evaluated - `args` (optional): additional positional arguments are passed to the `code` being evaluated .. warning:: the eval command is deprecated in MongoDB 3.0 and will be removed in a future server version. """ warnings.warn("Database.eval() is deprecated", DeprecationWarning, stacklevel=2) if not isinstance(code, Code): code = Code(code) result = self.command("$eval", code, args=args) return result.get("retval", None) def __call__(self, *args, **kwargs): """This is only here so that some API misusages are easier to debug. """ raise TypeError("'Database' object is not callable. If you meant to " "call the '%s' method on a '%s' object it is " "failing because no such method exists." % ( self.__name, self.__client.__class__.__name__)) class SystemJS(object): """**DEPRECATED**: Helper class for dealing with stored JavaScript. """ def __init__(self, database): """**DEPRECATED**: Get a system js helper for the database `database`. SystemJS will be removed in PyMongo 4.0. """ warnings.warn("SystemJS is deprecated", DeprecationWarning, stacklevel=2) if not database.write_concern.acknowledged: database = database.client.get_database( database.name, write_concern=DEFAULT_WRITE_CONCERN) # can't just assign it since we've overridden __setattr__ object.__setattr__(self, "_db", database) def __setattr__(self, name, code): self._db.system.js.replace_one( {"_id": name}, {"_id": name, "value": Code(code)}, True) def __setitem__(self, name, code): self.__setattr__(name, code) def __delattr__(self, name): self._db.system.js.delete_one({"_id": name}) def __delitem__(self, name): self.__delattr__(name) def __getattr__(self, name): return lambda *args: self._db.eval(Code("function() { " "return this[name].apply(" "this, arguments); }", scope={'name': name}), *args) def __getitem__(self, name): return self.__getattr__(name) def list(self): """Get a list of the names of the functions stored in this database.""" return [x["_id"] for x in self._db.system.js.find(projection=["_id"])] pymongo-3.11.0/pymongo/driver_info.py000066400000000000000000000032471374256237000176530ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Advanced options for MongoDB drivers implemented on top of PyMongo.""" from collections import namedtuple from bson.py3compat import string_type class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])): """Info about a driver wrapping PyMongo. The MongoDB server logs PyMongo's name, version, and platform whenever PyMongo establishes a connection. A driver implemented on top of PyMongo can add its own info to this log message. Initialize with three strings like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be None to accept PyMongo's default. """ def __new__(cls, name=None, version=None, platform=None): self = super(DriverInfo, cls).__new__(cls, name, version, platform) for name, value in self._asdict().items(): if value is not None and not isinstance(value, string_type): raise TypeError("Wrong type for DriverInfo %s option, value " "must be an instance of %s" % ( name, string_type.__name__)) return self pymongo-3.11.0/pymongo/encryption.py000066400000000000000000000470171374256237000175420ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Support for explicit client-side field level encryption.""" import contextlib import os import subprocess import uuid import weakref try: from pymongocrypt.auto_encrypter import AutoEncrypter from pymongocrypt.errors import MongoCryptError from pymongocrypt.explicit_encrypter import ExplicitEncrypter from pymongocrypt.mongocrypt import MongoCryptOptions from pymongocrypt.state_machine import MongoCryptCallback _HAVE_PYMONGOCRYPT = True except ImportError: _HAVE_PYMONGOCRYPT = False MongoCryptCallback = object from bson import _dict_to_bson, decode, encode from bson.codec_options import CodecOptions from bson.binary import (Binary, STANDARD, UUID_SUBTYPE) from bson.errors import BSONError from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson) from bson.son import SON from pymongo.errors import (ConfigurationError, EncryptionError, InvalidOperation, ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient from pymongo.pool import _configured_socket, PoolOptions from pymongo.read_concern import ReadConcern from pymongo.ssl_support import get_ssl_context from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern from pymongo.daemon import _spawn_daemon _HTTPS_PORT = 443 _KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value. _MONGOCRYPTD_TIMEOUT_MS = 1000 _DATA_KEY_OPTS = CodecOptions(document_class=SON, uuid_representation=STANDARD) # Use RawBSONDocument codec options to avoid needlessly decoding # documents from the key vault. _KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument, uuid_representation=STANDARD) @contextlib.contextmanager def _wrap_encryption_errors(): """Context manager to wrap encryption related errors.""" try: yield except BSONError: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise except Exception as exc: raise EncryptionError(exc) class _EncryptionIO(MongoCryptCallback): def __init__(self, client, key_vault_coll, mongocryptd_client, opts): """Internal class to perform I/O on behalf of pymongocrypt.""" # Use a weak ref to break reference cycle. if client is not None: self.client_ref = weakref.ref(client) else: self.client_ref = None self.key_vault_coll = key_vault_coll.with_options( codec_options=_KEY_VAULT_OPTS, read_concern=ReadConcern(level='majority'), write_concern=WriteConcern(w='majority')) self.mongocryptd_client = mongocryptd_client self.opts = opts self._spawned = False def kms_request(self, kms_context): """Complete a KMS request. :Parameters: - `kms_context`: A :class:`MongoCryptKmsContext`. :Returns: None """ endpoint = kms_context.endpoint message = kms_context.message host, port = parse_host(endpoint, _HTTPS_PORT) ctx = get_ssl_context(None, None, None, None, None, None, True, True) opts = PoolOptions(connect_timeout=_KMS_CONNECT_TIMEOUT, socket_timeout=_KMS_CONNECT_TIMEOUT, ssl_context=ctx) conn = _configured_socket((host, port), opts) try: conn.sendall(message) while kms_context.bytes_needed > 0: data = conn.recv(kms_context.bytes_needed) kms_context.feed(data) finally: conn.close() def collection_info(self, database, filter): """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads the JSON schema. :Parameters: - `database`: The database on which to run listCollections. - `filter`: The filter to pass to listCollections. :Returns: The first document from the listCollections command response as BSON. """ with self.client_ref()[database].list_collections( filter=RawBSONDocument(filter)) as cursor: for doc in cursor: return _dict_to_bson(doc, False, _DATA_KEY_OPTS) def spawn(self): """Spawn mongocryptd. Note this method is thread safe; at most one mongocryptd will start successfully. """ self._spawned = True args = [self.opts._mongocryptd_spawn_path or 'mongocryptd'] args.extend(self.opts._mongocryptd_spawn_args) _spawn_daemon(args) def mark_command(self, database, cmd): """Mark a command for encryption. :Parameters: - `database`: The database on which to run this command. - `cmd`: The BSON command to run. :Returns: The marked command response from mongocryptd. """ if not self._spawned and not self.opts._mongocryptd_bypass_spawn: self.spawn() # Database.command only supports mutable mappings so we need to decode # the raw BSON command first. inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS) try: res = self.mongocryptd_client[database].command( inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS) except ServerSelectionTimeoutError: if self.opts._mongocryptd_bypass_spawn: raise self.spawn() res = self.mongocryptd_client[database].command( inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS) return res.raw def fetch_keys(self, filter): """Yields one or more keys from the key vault. :Parameters: - `filter`: The filter to pass to find. :Returns: A generator which yields the requested keys from the key vault. """ with self.key_vault_coll.find(RawBSONDocument(filter)) as cursor: for key in cursor: yield key.raw def insert_data_key(self, data_key): """Insert a data key into the key vault. :Parameters: - `data_key`: The data key document to insert. :Returns: The _id of the inserted data key document. """ raw_doc = RawBSONDocument(data_key) data_key_id = raw_doc.get('_id') if not isinstance(data_key_id, uuid.UUID): raise TypeError('data_key _id must be a UUID') self.key_vault_coll.insert_one(raw_doc) return Binary(data_key_id.bytes, subtype=UUID_SUBTYPE) def bson_encode(self, doc): """Encode a document to BSON. A document can be any mapping type (like :class:`dict`). :Parameters: - `doc`: mapping type representing a document :Returns: The encoded BSON bytes. """ return encode(doc) def close(self): """Release resources. Note it is not safe to call this method from __del__ or any GC hooks. """ self.client_ref = None self.key_vault_coll = None if self.mongocryptd_client: self.mongocryptd_client.close() self.mongocryptd_client = None class _Encrypter(object): def __init__(self, io_callbacks, opts): """Encrypts and decrypts MongoDB commands. This class is used to support automatic encryption and decryption of MongoDB commands. :Parameters: - `io_callbacks`: A :class:`MongoCryptCallback`. - `opts`: The encrypted client's :class:`AutoEncryptionOpts`. """ if opts._schema_map is None: schema_map = None else: schema_map = _dict_to_bson(opts._schema_map, False, _DATA_KEY_OPTS) self._auto_encrypter = AutoEncrypter(io_callbacks, MongoCryptOptions( opts._kms_providers, schema_map)) self._bypass_auto_encryption = opts._bypass_auto_encryption self._closed = False def encrypt(self, database, cmd, check_keys, codec_options): """Encrypt a MongoDB command. :Parameters: - `database`: The database for this command. - `cmd`: A command document. - `check_keys`: If True, check `cmd` for invalid keys. - `codec_options`: The CodecOptions to use while encoding `cmd`. :Returns: The encrypted command to execute. """ self._check_closed() # Workaround for $clusterTime which is incompatible with # check_keys. cluster_time = check_keys and cmd.pop('$clusterTime', None) encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options) with _wrap_encryption_errors(): encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. encrypt_cmd = _inflate_bson( encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) if cluster_time: encrypt_cmd['$clusterTime'] = cluster_time return encrypt_cmd def decrypt(self, response): """Decrypt a MongoDB command response. :Parameters: - `response`: A MongoDB command response as BSON. :Returns: The decrypted command response. """ self._check_closed() with _wrap_encryption_errors(): return self._auto_encrypter.decrypt(response) def _check_closed(self): if self._closed: raise InvalidOperation("Cannot use MongoClient after close") def close(self): """Cleanup resources.""" self._closed = True self._auto_encrypter.close() @staticmethod def create(client, opts): """Create a _CommandEncyptor for a client. :Parameters: - `client`: The encrypted MongoClient. - `opts`: The encrypted client's :class:`AutoEncryptionOpts`. :Returns: A :class:`_CommandEncrypter` for this client. """ key_vault_client = opts._key_vault_client or client db, coll = opts._key_vault_namespace.split('.', 1) key_vault_coll = key_vault_client[db][coll] mongocryptd_client = MongoClient( opts._mongocryptd_uri, connect=False, serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS) io_callbacks = _EncryptionIO( client, key_vault_coll, mongocryptd_client, opts) return _Encrypter(io_callbacks, opts) class Algorithm(object): """An enum that defines the supported encryption algorithms.""" AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = ( "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic") AEAD_AES_256_CBC_HMAC_SHA_512_Random = ( "AEAD_AES_256_CBC_HMAC_SHA_512-Random") class ClientEncryption(object): """Explicit client-side field level encryption.""" def __init__(self, kms_providers, key_vault_namespace, key_vault_client, codec_options): """Explicit client-side field level encryption. The ClientEncryption class encapsulates explicit operations on a key vault collection that cannot be done directly on a MongoClient. Similar to configuring auto encryption on a MongoClient, it is constructed with a MongoClient (to a MongoDB cluster containing the key vault collection), KMS provider configuration, and keyVaultNamespace. It provides an API for explicitly encrypting and decrypting values, and creating data keys. It does not provide an API to query keys from the key vault collection, as this can be done directly on the MongoClient. See :ref:`explicit-client-side-encryption` for an example. :Parameters: - `kms_providers`: Map of KMS provider options. Two KMS providers are supported: "aws" and "local". The kmsProviders map values differ by provider: - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. These are the AWS access key ID and AWS secret access key used to generate KMS messages. - `local`: Map with "key" as a 96-byte array or string. "key" is the master key used to encrypt/decrypt data keys. This key should be generated and stored as securely as possible. - `key_vault_namespace`: The namespace for the key vault collection. The key vault collection contains all data keys used for encryption and decryption. Data keys are stored as documents in this MongoDB collection. Data keys are protected with encryption by a KMS provider. - `key_vault_client`: A MongoClient connected to a MongoDB cluster containing the `key_vault_namespace` collection. - `codec_options`: An instance of :class:`~bson.codec_options.CodecOptions` to use when encoding a value for encryption and decoding the decrypted BSON value. This should be the same CodecOptions instance configured on the MongoClient, Database, or Collection used to access application data. .. versionadded:: 3.9 """ if not _HAVE_PYMONGOCRYPT: raise ConfigurationError( "client-side field level encryption requires the pymongocrypt " "library: install a compatible version with: " "python -m pip install 'pymongo[encryption]'") if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of " "bson.codec_options.CodecOptions") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client self._codec_options = codec_options db, coll = key_vault_namespace.split('.', 1) key_vault_coll = key_vault_client[db][coll] self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, None) self._encryption = ExplicitEncrypter( self._io_callbacks, MongoCryptOptions(kms_providers, None)) def create_data_key(self, kms_provider, master_key=None, key_alt_names=None): """Create and insert a new data key into the key vault collection. :Parameters: - `kms_provider`: The KMS provider to use. Supported values are "aws" and "local". - `master_key`: Identifies a KMS-specific key used to encrypt the new data key. If the kmsProvider is "local" the `master_key` is not applicable and may be omitted. If the `kms_provider` is "aws" it is required and has the following fields:: - `region` (string): Required. The AWS region, e.g. "us-east-1". - `key` (string): Required. The Amazon Resource Name (ARN) to the AWS customer. - `endpoint` (string): Optional. An alternate host to send KMS requests to. May include port number, e.g. "kms.us-east-1.amazonaws.com:443". - `key_alt_names` (optional): An optional list of string alternate names used to reference a key. If a key is created with alternate names, then encryption may refer to the key by the unique alternate name instead of by ``key_id``. The following example shows creating and referring to a data key by alternate name:: client_encryption.create_data_key("local", keyAltNames=["name1"]) # reference the key with the alternate name client_encryption.encrypt("457-55-5462", keyAltName="name1", algorithm=Algorithm.Random) :Returns: The ``_id`` of the created data key document as a :class:`~bson.binary.Binary` with subtype :data:`~bson.binary.UUID_SUBTYPE`. """ self._check_closed() with _wrap_encryption_errors(): return self._encryption.create_data_key( kms_provider, master_key=master_key, key_alt_names=key_alt_names) def encrypt(self, value, algorithm, key_id=None, key_alt_name=None): """Encrypt a BSON value with a given key and algorithm. Note that exactly one of ``key_id`` or ``key_alt_name`` must be provided. :Parameters: - `value`: The BSON value to encrypt. - `algorithm` (string): The encryption algorithm to use. See :class:`Algorithm` for some valid options. - `key_id`: Identifies a data key by ``_id`` which must be a :class:`~bson.binary.Binary` with subtype 4 ( :attr:`~bson.binary.UUID_SUBTYPE`). - `key_alt_name`: Identifies a key vault document by 'keyAltName'. :Returns: The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. """ self._check_closed() if (key_id is not None and not ( isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE)): raise TypeError( 'key_id must be a bson.binary.Binary with subtype 4') doc = encode({'v': value}, codec_options=self._codec_options) with _wrap_encryption_errors(): encrypted_doc = self._encryption.encrypt( doc, algorithm, key_id=key_id, key_alt_name=key_alt_name) return decode(encrypted_doc)['v'] def decrypt(self, value): """Decrypt an encrypted value. :Parameters: - `value` (Binary): The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. :Returns: The decrypted BSON value. """ self._check_closed() if not (isinstance(value, Binary) and value.subtype == 6): raise TypeError( 'value to decrypt must be a bson.binary.Binary with subtype 6') with _wrap_encryption_errors(): doc = encode({'v': value}) decrypted_doc = self._encryption.decrypt(doc) return decode(decrypted_doc, codec_options=self._codec_options)['v'] def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def _check_closed(self): if self._encryption is None: raise InvalidOperation("Cannot use closed ClientEncryption") def close(self): """Release resources. Note that using this class in a with-statement will automatically call :meth:`close`:: with ClientEncryption(...) as client_encryption: encrypted = client_encryption.encrypt(value, ...) decrypted = client_encryption.decrypt(encrypted) """ if self._io_callbacks: self._io_callbacks.close() self._encryption.close() self._io_callbacks = None self._encryption = None pymongo-3.11.0/pymongo/encryption_options.py000066400000000000000000000153201374256237000213050ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Support for automatic client-side field level encryption.""" import copy try: import pymongocrypt _HAVE_PYMONGOCRYPT = True except ImportError: _HAVE_PYMONGOCRYPT = False from pymongo.errors import ConfigurationError class AutoEncryptionOpts(object): """Options to configure automatic client-side field level encryption.""" def __init__(self, kms_providers, key_vault_namespace, key_vault_client=None, schema_map=None, bypass_auto_encryption=False, mongocryptd_uri='mongodb://localhost:27020', mongocryptd_bypass_spawn=False, mongocryptd_spawn_path='mongocryptd', mongocryptd_spawn_args=None): """Options to configure automatic client-side field level encryption. Automatic client-side field level encryption requires MongoDB 4.2 enterprise or a MongoDB 4.2 Atlas cluster. Automatic encryption is not supported for operations on a database or view and will result in error. Although automatic encryption requires MongoDB 4.2 enterprise or a MongoDB 4.2 Atlas cluster, automatic *decryption* is supported for all users. To configure automatic *decryption* without automatic *encryption* set ``bypass_auto_encryption=True``. Explicit encryption and explicit decryption is also supported for all users with the :class:`~pymongo.encryption.ClientEncryption` class. See :ref:`automatic-client-side-encryption` for an example. :Parameters: - `kms_providers`: Map of KMS provider options. Two KMS providers are supported: "aws" and "local". The kmsProviders map values differ by provider: - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. These are the AWS access key ID and AWS secret access key used to generate KMS messages. - `local`: Map with "key" as a 96-byte array or string. "key" is the master key used to encrypt/decrypt data keys. This key should be generated and stored as securely as possible. - `key_vault_namespace`: The namespace for the key vault collection. The key vault collection contains all data keys used for encryption and decryption. Data keys are stored as documents in this MongoDB collection. Data keys are protected with encryption by a KMS provider. - `key_vault_client` (optional): By default the key vault collection is assumed to reside in the same MongoDB cluster as the encrypted MongoClient. Use this option to route data key queries to a separate MongoDB cluster. - `schema_map` (optional): Map of collection namespace ("db.coll") to JSON Schema. By default, a collection's JSONSchema is periodically polled with the listCollections command. But a JSONSchema may be specified locally with the schemaMap option. **Supplying a `schema_map` provides more security than relying on JSON Schemas obtained from the server. It protects against a malicious server advertising a false JSON Schema, which could trick the client into sending unencrypted data that should be encrypted.** Schemas supplied in the schemaMap only apply to configuring automatic encryption for client side encryption. Other validation rules in the JSON schema will not be enforced by the driver and will result in an error. - `bypass_auto_encryption` (optional): If ``True``, automatic encryption will be disabled but automatic decryption will still be enabled. Defaults to ``False``. - `mongocryptd_uri` (optional): The MongoDB URI used to connect to the *local* mongocryptd process. Defaults to ``'mongodb://localhost:27020'``. - `mongocryptd_bypass_spawn` (optional): If ``True``, the encrypted MongoClient will not attempt to spawn the mongocryptd process. Defaults to ``False``. - `mongocryptd_spawn_path` (optional): Used for spawning the mongocryptd process. Defaults to ``'mongocryptd'`` and spawns mongocryptd from the system path. - `mongocryptd_spawn_args` (optional): A list of string arguments to use when spawning the mongocryptd process. Defaults to ``['--idleShutdownTimeoutSecs=60']``. If the list does not include the ``idleShutdownTimeoutSecs`` option then ``'--idleShutdownTimeoutSecs=60'`` will be added. .. versionadded:: 3.9 """ if not _HAVE_PYMONGOCRYPT: raise ConfigurationError( "client side encryption requires the pymongocrypt library: " "install a compatible version with: " "python -m pip install 'pymongo[encryption]'") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client self._schema_map = schema_map self._bypass_auto_encryption = bypass_auto_encryption self._mongocryptd_uri = mongocryptd_uri self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn self._mongocryptd_spawn_path = mongocryptd_spawn_path self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or ['--idleShutdownTimeoutSecs=60']) if not isinstance(self._mongocryptd_spawn_args, list): raise TypeError('mongocryptd_spawn_args must be a list') if not any('idleShutdownTimeoutSecs' in s for s in self._mongocryptd_spawn_args): self._mongocryptd_spawn_args.append('--idleShutdownTimeoutSecs=60') def validate_auto_encryption_opts_or_none(option, value): """Validate the driver keyword arg.""" if value is None: return value if not isinstance(value, AutoEncryptionOpts): raise TypeError("%s must be an instance of AutoEncryptionOpts" % ( option,)) return value pymongo-3.11.0/pymongo/errors.py000066400000000000000000000217631374256237000166640ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Exceptions raised by PyMongo.""" import sys from bson.errors import * try: # CPython 3.7+ from ssl import SSLCertVerificationError as CertificateError except ImportError: try: from ssl import CertificateError except ImportError: class CertificateError(ValueError): pass class PyMongoError(Exception): """Base class for all PyMongo exceptions.""" def __init__(self, message='', error_labels=None): super(PyMongoError, self).__init__(message) self._message = message self._error_labels = set(error_labels or []) def has_error_label(self, label): """Return True if this error contains the given label. .. versionadded:: 3.7 """ return label in self._error_labels def _add_error_label(self, label): """Add the given label to this error.""" self._error_labels.add(label) def _remove_error_label(self, label): """Remove the given label from this error.""" self._error_labels.discard(label) if sys.version_info[0] == 2: def __str__(self): if isinstance(self._message, unicode): return self._message.encode('utf-8', errors='replace') return str(self._message) def __unicode__(self): if isinstance(self._message, unicode): return self._message return unicode(self._message, 'utf-8', errors='replace') class ProtocolError(PyMongoError): """Raised for failures related to the wire protocol.""" class ConnectionFailure(PyMongoError): """Raised when a connection to the database cannot be made or is lost.""" class AutoReconnect(ConnectionFailure): """Raised when a connection to the database is lost and an attempt to auto-reconnect will be made. In order to auto-reconnect you must handle this exception, recognizing that the operation which caused it has not necessarily succeeded. Future operations will attempt to open a new connection to the database (and will continue to raise this exception until the first successful connection is made). Subclass of :exc:`~pymongo.errors.ConnectionFailure`. """ def __init__(self, message='', errors=None): error_labels = None if errors is not None and isinstance(errors, dict): error_labels = errors.get('errorLabels') super(AutoReconnect, self).__init__(message, error_labels) self.errors = self.details = errors or [] class NetworkTimeout(AutoReconnect): """An operation on an open connection exceeded socketTimeoutMS. The remaining connections in the pool stay open. In the case of a write operation, you cannot know whether it succeeded or failed. Subclass of :exc:`~pymongo.errors.AutoReconnect`. """ def _format_detailed_error(message, details): if details is not None: message = "%s, full error: %s" % (message, details) if sys.version_info[0] == 2 and isinstance(message, unicode): message = message.encode('utf-8', errors='replace') return message class NotMasterError(AutoReconnect): """The server responded "not master" or "node is recovering". These errors result from a query, write, or command. The operation failed because the client thought it was using the primary but the primary has stepped down, or the client thought it was using a healthy secondary but the secondary is stale and trying to recover. The client launches a refresh operation on a background thread, to update its view of the server as soon as possible after throwing this exception. Subclass of :exc:`~pymongo.errors.AutoReconnect`. """ def __init__(self, message='', errors=None): super(NotMasterError, self).__init__( _format_detailed_error(message, errors), errors=errors) class ServerSelectionTimeoutError(AutoReconnect): """Thrown when no MongoDB server is available for an operation If there is no suitable server for an operation PyMongo tries for ``serverSelectionTimeoutMS`` (default 30 seconds) to find one, then throws this exception. For example, it is thrown after attempting an operation when PyMongo cannot connect to any server, or if you attempt an insert into a replica set that has no primary and does not elect one within the timeout window, or if you attempt to query with a Read Preference that the replica set cannot satisfy. """ class ConfigurationError(PyMongoError): """Raised when something is incorrectly configured. """ class OperationFailure(PyMongoError): """Raised when a database operation fails. .. versionadded:: 2.7 The :attr:`details` attribute. """ def __init__(self, error, code=None, details=None, max_wire_version=None): error_labels = None if details is not None: error_labels = details.get('errorLabels') super(OperationFailure, self).__init__( _format_detailed_error(error, details), error_labels=error_labels) self.__code = code self.__details = details self.__max_wire_version = max_wire_version @property def _max_wire_version(self): return self.__max_wire_version @property def code(self): """The error code returned by the server, if any. """ return self.__code @property def details(self): """The complete error document returned by the server. Depending on the error that occurred, the error document may include useful information beyond just the error message. When connected to a mongos the error document may contain one or more subdocuments if errors occurred on multiple shards. """ return self.__details class CursorNotFound(OperationFailure): """Raised while iterating query results if the cursor is invalidated on the server. .. versionadded:: 2.7 """ class ExecutionTimeout(OperationFailure): """Raised when a database operation times out, exceeding the $maxTimeMS set in the query or command option. .. note:: Requires server version **>= 2.6.0** .. versionadded:: 2.7 """ class WriteConcernError(OperationFailure): """Base exception type for errors raised due to write concern. .. versionadded:: 3.0 """ class WriteError(OperationFailure): """Base exception type for errors raised during write operations. .. versionadded:: 3.0 """ class WTimeoutError(WriteConcernError): """Raised when a database operation times out (i.e. wtimeout expires) before replication completes. With newer versions of MongoDB the `details` attribute may include write concern fields like 'n', 'updatedExisting', or 'writtenTo'. .. versionadded:: 2.7 """ class DuplicateKeyError(WriteError): """Raised when an insert or update fails due to a duplicate key error.""" class BulkWriteError(OperationFailure): """Exception class for bulk write errors. .. versionadded:: 2.7 """ def __init__(self, results): super(BulkWriteError, self).__init__( "batch op errors occurred", 65, results) class InvalidOperation(PyMongoError): """Raised when a client attempts to perform an invalid operation.""" class InvalidName(PyMongoError): """Raised when an invalid name is used.""" class CollectionInvalid(PyMongoError): """Raised when collection validation fails.""" class InvalidURI(ConfigurationError): """Raised when trying to parse an invalid mongodb URI.""" class ExceededMaxWaiters(PyMongoError): """Raised when a thread tries to get a connection from a pool and ``maxPoolSize * waitQueueMultiple`` threads are already waiting. .. versionadded:: 2.6 """ pass class DocumentTooLarge(InvalidDocument): """Raised when an encoded document is too large for the connected server. """ pass class EncryptionError(PyMongoError): """Raised when encryption or decryption fails. This error always wraps another exception which can be retrieved via the :attr:`cause` property. .. versionadded:: 3.9 """ def __init__(self, cause): super(EncryptionError, self).__init__(str(cause)) self.__cause = cause @property def cause(self): """The exception that caused this encryption or decryption error.""" return self.__cause class _OperationCancelled(AutoReconnect): """Internal error raised when a socket operation is cancelled. """ pass pymongo-3.11.0/pymongo/event_loggers.py000066400000000000000000000201751374256237000202070ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example event logger classes. .. versionadded:: 3.11 These loggers can be registered using :func:`register` or :class:`~pymongo.mongo_client.MongoClient`. ``monitoring.register(CommandLogger())`` or ``MongoClient(event_listeners=[CommandLogger()])`` """ import logging from pymongo import monitoring class CommandLogger(monitoring.CommandListener): """A simple listener that logs command events. Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, :class:`~pymongo.monitoring.CommandSucceededEvent` and :class:`~pymongo.monitoring.CommandFailedEvent` events and logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ def started(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} started on server " "{0.connection_id}".format(event)) def succeeded(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "succeeded in {0.duration_micros} " "microseconds".format(event)) def failed(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "failed in {0.duration_micros} " "microseconds".format(event)) class ServerLogger(monitoring.ServerListener): """A simple listener that logs server discovery events. Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, and :class:`~pymongo.monitoring.ServerClosedEvent` events and logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ def opened(self, event): logging.info("Server {0.server_address} added to topology " "{0.topology_id}".format(event)) def description_changed(self, event): previous_server_type = event.previous_description.server_type new_server_type = event.new_description.server_type if new_server_type != previous_server_type: # server_type_name was added in PyMongo 3.4 logging.info( "Server {0.server_address} changed type from " "{0.previous_description.server_type_name} to " "{0.new_description.server_type_name}".format(event)) def closed(self, event): logging.warning("Server {0.server_address} removed from topology " "{0.topology_id}".format(event)) class HeartbeatLogger(monitoring.ServerHeartbeatListener): """A simple listener that logs server heartbeat events. Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` events and logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ def started(self, event): logging.info("Heartbeat sent to server " "{0.connection_id}".format(event)) def succeeded(self, event): # The reply.document attribute was added in PyMongo 3.4. logging.info("Heartbeat to server {0.connection_id} " "succeeded with reply " "{0.reply.document}".format(event)) def failed(self, event): logging.warning("Heartbeat to server {0.connection_id} " "failed with error {0.reply}".format(event)) class TopologyLogger(monitoring.TopologyListener): """A simple listener that logs server topology events. Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, and :class:`~pymongo.monitoring.TopologyClosedEvent` events and logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ def opened(self, event): logging.info("Topology with id {0.topology_id} " "opened".format(event)) def description_changed(self, event): logging.info("Topology description updated for " "topology id {0.topology_id}".format(event)) previous_topology_type = event.previous_description.topology_type new_topology_type = event.new_description.topology_type if new_topology_type != previous_topology_type: # topology_type_name was added in PyMongo 3.4 logging.info( "Topology {0.topology_id} changed type from " "{0.previous_description.topology_type_name} to " "{0.new_description.topology_type_name}".format(event)) # The has_writable_server and has_readable_server methods # were added in PyMongo 3.4. if not event.new_description.has_writable_server(): logging.warning("No writable servers available.") if not event.new_description.has_readable_server(): logging.warning("No readable servers available.") def closed(self, event): logging.info("Topology with id {0.topology_id} " "closed".format(event)) class ConnectionPoolLogger(monitoring.ConnectionPoolListener): """A simple listener that logs server connection pool events. Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, :class:`~pymongo.monitoring.PoolClearedEvent`, :class:`~pymongo.monitoring.PoolClosedEvent`, :~pymongo.monitoring.class:`ConnectionCreatedEvent`, :class:`~pymongo.monitoring.ConnectionReadyEvent`, :class:`~pymongo.monitoring.ConnectionClosedEvent`, :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` events and logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ def pool_created(self, event): logging.info("[pool {0.address}] pool created".format(event)) def pool_cleared(self, event): logging.info("[pool {0.address}] pool cleared".format(event)) def pool_closed(self, event): logging.info("[pool {0.address}] pool closed".format(event)) def connection_created(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection created".format(event)) def connection_ready(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection setup succeeded".format(event)) def connection_closed(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection closed, reason: " "{0.reason}".format(event)) def connection_check_out_started(self, event): logging.info("[pool {0.address}] connection check out " "started".format(event)) def connection_check_out_failed(self, event): logging.info("[pool {0.address}] connection check out " "failed, reason: {0.reason}".format(event)) def connection_checked_out(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked out of pool".format(event)) def connection_checked_in(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked into pool".format(event)) pymongo-3.11.0/pymongo/helpers.py000066400000000000000000000244641374256237000170130ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Bits and pieces used by the driver that don't really fit elsewhere.""" import sys import traceback from bson.py3compat import abc, iteritems, itervalues, string_type from bson.son import SON from pymongo import ASCENDING from pymongo.errors import (CursorNotFound, DuplicateKeyError, ExecutionTimeout, NotMasterError, OperationFailure, WriteError, WriteConcernError, WTimeoutError) # From the SDAM spec, the "node is shutting down" codes. _SHUTDOWN_CODES = frozenset([ 11600, # InterruptedAtShutdown 91, # ShutdownInProgress ]) # From the SDAM spec, the "not master" error codes are combined with the # "node is recovering" error codes (of which the "node is shutting down" # errors are a subset). _NOT_MASTER_CODES = frozenset([ 10107, # NotMaster 13435, # NotMasterNoSlaveOk 11602, # InterruptedDueToReplStateChange 13436, # NotMasterOrSecondary 189, # PrimarySteppedDown ]) | _SHUTDOWN_CODES # From the retryable writes spec. _RETRYABLE_ERROR_CODES = _NOT_MASTER_CODES | frozenset([ 7, # HostNotFound 6, # HostUnreachable 89, # NetworkTimeout 9001, # SocketException 262, # ExceededTimeLimit ]) _UUNDER = u"_" def _gen_index_name(keys): """Generate an index name from the set of fields it is over.""" return _UUNDER.join(["%s_%s" % item for item in keys]) def _index_list(key_or_list, direction=None): """Helper to generate a list of (key, direction) pairs. Takes such a list, or a single key, or a single key and direction. """ if direction is not None: return [(key_or_list, direction)] else: if isinstance(key_or_list, string_type): return [(key_or_list, ASCENDING)] elif not isinstance(key_or_list, (list, tuple)): raise TypeError("if no direction is specified, " "key_or_list must be an instance of list") return key_or_list def _index_document(index_list): """Helper to generate an index specifying document. Takes a list of (key, direction) pairs. """ if isinstance(index_list, abc.Mapping): raise TypeError("passing a dict to sort/create_index/hint is not " "allowed - use a list of tuples instead. did you " "mean %r?" % list(iteritems(index_list))) elif not isinstance(index_list, (list, tuple)): raise TypeError("must use a list of (key, direction) pairs, " "not: " + repr(index_list)) if not len(index_list): raise ValueError("key_or_list must not be the empty list") index = SON() for (key, value) in index_list: if not isinstance(key, string_type): raise TypeError("first item in each key pair must be a string") if not isinstance(value, (string_type, int, abc.Mapping)): raise TypeError("second item in each key pair must be 1, -1, " "'2d', or another valid MongoDB index specifier.") index[key] = value return index def _check_command_response(response, max_wire_version, msg=None, allowable_errors=None, parse_write_concern_error=False): """Check the response to a command for errors. """ if "ok" not in response: # Server didn't recognize our message as a command. raise OperationFailure(response.get("$err"), response.get("code"), response, max_wire_version) if parse_write_concern_error and 'writeConcernError' in response: _raise_write_concern_error(response['writeConcernError']) if not response["ok"]: details = response # Mongos returns the error details in a 'raw' object # for some errors. if "raw" in response: for shard in itervalues(response["raw"]): # Grab the first non-empty raw error from a shard. if shard.get("errmsg") and not shard.get("ok"): details = shard break errmsg = details["errmsg"] if (allowable_errors is None or (errmsg not in allowable_errors and details.get("code") not in allowable_errors)): code = details.get("code") # Server is "not master" or "recovering" if code in _NOT_MASTER_CODES: raise NotMasterError(errmsg, response) elif ("not master" in errmsg or "node is recovering" in errmsg): raise NotMasterError(errmsg, response) # Server assertion failures if errmsg == "db assertion failure": errmsg = ("db assertion failure, assertion: '%s'" % details.get("assertion", "")) raise OperationFailure(errmsg, details.get("assertionCode"), response, max_wire_version) # Other errors # findAndModify with upsert can raise duplicate key error if code in (11000, 11001, 12582): raise DuplicateKeyError(errmsg, code, response, max_wire_version) elif code == 50: raise ExecutionTimeout(errmsg, code, response, max_wire_version) elif code == 43: raise CursorNotFound(errmsg, code, response, max_wire_version) msg = msg or "%s" raise OperationFailure(msg % errmsg, code, response, max_wire_version) def _check_gle_response(result, max_wire_version): """Return getlasterror response as a dict, or raise OperationFailure.""" # Did getlasterror itself fail? _check_command_response(result, max_wire_version) if result.get("wtimeout", False): # MongoDB versions before 1.8.0 return the error message in an "errmsg" # field. If "errmsg" exists "err" will also exist set to None, so we # have to check for "errmsg" first. raise WTimeoutError(result.get("errmsg", result.get("err")), result.get("code"), result) error_msg = result.get("err", "") if error_msg is None: return result if error_msg.startswith("not master"): raise NotMasterError(error_msg, result) details = result # mongos returns the error code in an error object for some errors. if "errObjects" in result: for errobj in result["errObjects"]: if errobj.get("err") == error_msg: details = errobj break code = details.get("code") if code in (11000, 11001, 12582): raise DuplicateKeyError(details["err"], code, result) raise OperationFailure(details["err"], code, result) def _raise_last_write_error(write_errors): # If the last batch had multiple errors only report # the last error to emulate continue_on_error. error = write_errors[-1] if error.get("code") == 11000: raise DuplicateKeyError(error.get("errmsg"), 11000, error) raise WriteError(error.get("errmsg"), error.get("code"), error) def _raise_write_concern_error(error): if "errInfo" in error and error["errInfo"].get('wtimeout'): # Make sure we raise WTimeoutError raise WTimeoutError( error.get("errmsg"), error.get("code"), error) raise WriteConcernError( error.get("errmsg"), error.get("code"), error) def _check_write_command_response(result): """Backward compatibility helper for write command error handling. """ # Prefer write errors over write concern errors write_errors = result.get("writeErrors") if write_errors: _raise_last_write_error(write_errors) error = result.get("writeConcernError") if error: _raise_write_concern_error(error) def _raise_last_error(bulk_write_result): """Backward compatibility helper for insert error handling. """ # Prefer write errors over write concern errors write_errors = bulk_write_result.get("writeErrors") if write_errors: _raise_last_write_error(write_errors) _raise_write_concern_error(bulk_write_result["writeConcernErrors"][-1]) def _fields_list_to_dict(fields, option_name): """Takes a sequence of field names and returns a matching dictionary. ["a", "b"] becomes {"a": 1, "b": 1} and ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} """ if isinstance(fields, abc.Mapping): return fields if isinstance(fields, (abc.Sequence, abc.Set)): if not all(isinstance(field, string_type) for field in fields): raise TypeError("%s must be a list of key names, each an " "instance of %s" % (option_name, string_type.__name__)) return dict.fromkeys(fields, 1) raise TypeError("%s must be a mapping or " "list of key names" % (option_name,)) def _handle_exception(): """Print exceptions raised by subscribers to stderr.""" # Heavily influenced by logging.Handler.handleError. # See note here: # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ if sys.stderr: einfo = sys.exc_info() try: traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) except IOError: pass finally: del einfo pymongo-3.11.0/pymongo/ismaster.py000066400000000000000000000121701374256237000171670ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Parse a response to the 'ismaster' command.""" import itertools from bson.py3compat import imap from pymongo import common from pymongo.server_type import SERVER_TYPE def _get_server_type(doc): """Determine the server type from an ismaster response.""" if not doc.get('ok'): return SERVER_TYPE.Unknown if doc.get('isreplicaset'): return SERVER_TYPE.RSGhost elif doc.get('setName'): if doc.get('hidden'): return SERVER_TYPE.RSOther elif doc.get('ismaster'): return SERVER_TYPE.RSPrimary elif doc.get('secondary'): return SERVER_TYPE.RSSecondary elif doc.get('arbiterOnly'): return SERVER_TYPE.RSArbiter else: return SERVER_TYPE.RSOther elif doc.get('msg') == 'isdbgrid': return SERVER_TYPE.Mongos else: return SERVER_TYPE.Standalone class IsMaster(object): __slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable', '_awaitable') def __init__(self, doc, awaitable=False): """Parse an ismaster response from the server.""" self._server_type = _get_server_type(doc) self._doc = doc self._is_writable = self._server_type in ( SERVER_TYPE.RSPrimary, SERVER_TYPE.Standalone, SERVER_TYPE.Mongos) self._is_readable = ( self.server_type == SERVER_TYPE.RSSecondary or self._is_writable) self._awaitable = awaitable @property def document(self): """The complete ismaster command response document. .. versionadded:: 3.4 """ return self._doc.copy() @property def server_type(self): return self._server_type @property def all_hosts(self): """List of hosts, passives, and arbiters known to this server.""" return set(imap(common.clean_node, itertools.chain( self._doc.get('hosts', []), self._doc.get('passives', []), self._doc.get('arbiters', [])))) @property def tags(self): """Replica set member tags or empty dict.""" return self._doc.get('tags', {}) @property def primary(self): """This server's opinion about who the primary is, or None.""" if self._doc.get('primary'): return common.partition_node(self._doc['primary']) else: return None @property def replica_set_name(self): """Replica set name or None.""" return self._doc.get('setName') @property def max_bson_size(self): return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE) @property def max_message_size(self): return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size) @property def max_write_batch_size(self): return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE) @property def min_wire_version(self): return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION) @property def max_wire_version(self): return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION) @property def set_version(self): return self._doc.get('setVersion') @property def election_id(self): return self._doc.get('electionId') @property def cluster_time(self): return self._doc.get('$clusterTime') @property def logical_session_timeout_minutes(self): return self._doc.get('logicalSessionTimeoutMinutes') @property def is_writable(self): return self._is_writable @property def is_readable(self): return self._is_readable @property def me(self): me = self._doc.get('me') if me: return common.clean_node(me) @property def last_write_date(self): return self._doc.get('lastWrite', {}).get('lastWriteDate') @property def compressors(self): return self._doc.get('compression') @property def sasl_supported_mechs(self): """Supported authentication mechanisms for the current user. For example:: >>> ismaster.sasl_supported_mechs ["SCRAM-SHA-1", "SCRAM-SHA-256"] """ return self._doc.get('saslSupportedMechs', []) @property def speculative_authenticate(self): """The speculativeAuthenticate field.""" return self._doc.get('speculativeAuthenticate') @property def topology_version(self): return self._doc.get('topologyVersion') @property def awaitable(self): return self._awaitable pymongo-3.11.0/pymongo/max_staleness_selectors.py000066400000000000000000000103461374256237000222740ustar00rootroot00000000000000# Copyright 2016 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Criteria to select ServerDescriptions based on maxStalenessSeconds. The Max Staleness Spec says: When there is a known primary P, a secondary S's staleness is estimated with this formula: (S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate) + heartbeatFrequencyMS When there is no known primary, a secondary S's staleness is estimated with: SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS where "SMax" is the secondary with the greatest lastWriteDate. """ from pymongo.errors import ConfigurationError from pymongo.server_type import SERVER_TYPE # Constant defined in Max Staleness Spec: An idle primary writes a no-op every # 10 seconds to refresh secondaries' lastWriteDate values. IDLE_WRITE_PERIOD = 10 SMALLEST_MAX_STALENESS = 90 def _validate_max_staleness(max_staleness, heartbeat_frequency): # We checked for max staleness -1 before this, it must be positive here. if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD: raise ConfigurationError( "maxStalenessSeconds must be at least heartbeatFrequencyMS +" " %d seconds. maxStalenessSeconds is set to %d," " heartbeatFrequencyMS is set to %d." % ( IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)) if max_staleness < SMALLEST_MAX_STALENESS: raise ConfigurationError( "maxStalenessSeconds must be at least %d. " "maxStalenessSeconds is set to %d." % ( SMALLEST_MAX_STALENESS, max_staleness)) def _with_primary(max_staleness, selection): """Apply max_staleness, in seconds, to a Selection with a known primary.""" primary = selection.primary sds = [] for s in selection.server_descriptions: if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. staleness = ( (s.last_update_time - s.last_write_date) - (primary.last_update_time - primary.last_write_date) + selection.heartbeat_frequency) if staleness <= max_staleness: sds.append(s) else: sds.append(s) return selection.with_server_descriptions(sds) def _no_primary(max_staleness, selection): """Apply max_staleness, in seconds, to a Selection with no known primary.""" # Secondary that's replicated the most recent writes. smax = selection.secondary_with_max_last_write_date() if not smax: # No secondaries and no primary, short-circuit out of here. return selection.with_server_descriptions([]) sds = [] for s in selection.server_descriptions: if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. staleness = (smax.last_write_date - s.last_write_date + selection.heartbeat_frequency) if staleness <= max_staleness: sds.append(s) else: sds.append(s) return selection.with_server_descriptions(sds) def select(max_staleness, selection): """Apply max_staleness, in seconds, to a Selection.""" if max_staleness == -1: return selection # Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or # ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness < # heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90. _validate_max_staleness(max_staleness, selection.heartbeat_frequency) if selection.primary: return _with_primary(max_staleness, selection) else: return _no_primary(max_staleness, selection) pymongo-3.11.0/pymongo/message.py000066400000000000000000001673001374256237000167720ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for creating `messages `_ to be sent to MongoDB. .. note:: This module is for internal use and is generally not needed by application developers. """ import datetime import random import struct import bson from bson import (CodecOptions, decode, encode, _dict_to_bson, _make_c_string) from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.raw_bson import _inflate_bson, DEFAULT_RAW_BSON_OPTIONS from bson.py3compat import b, StringIO from bson.son import SON try: from pymongo import _cmessage _use_c = True except ImportError: _use_c = False from pymongo.errors import (ConfigurationError, CursorNotFound, DocumentTooLarge, ExecutionTimeout, InvalidOperation, NotMasterError, OperationFailure, ProtocolError) from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 # Overhead allowed for encoded command documents. _COMMAND_OVERHEAD = 16382 _INSERT = 0 _UPDATE = 1 _DELETE = 2 _EMPTY = b'' _BSONOBJ = b'\x03' _ZERO_8 = b'\x00' _ZERO_16 = b'\x00\x00' _ZERO_32 = b'\x00\x00\x00\x00' _ZERO_64 = b'\x00\x00\x00\x00\x00\x00\x00\x00' _SKIPLIM = b'\x00\x00\x00\x00\xff\xff\xff\xff' _OP_MAP = { _INSERT: b'\x04documents\x00\x00\x00\x00\x00', _UPDATE: b'\x04updates\x00\x00\x00\x00\x00', _DELETE: b'\x04deletes\x00\x00\x00\x00\x00', } _FIELD_MAP = { 'insert': 'documents', 'update': 'updates', 'delete': 'deletes' } _UJOIN = u"%s.%s" _UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions( unicode_decode_error_handler='replace') def _randint(): """Generate a pseudo random 32 bit integer.""" return random.randint(MIN_INT32, MAX_INT32) def _maybe_add_read_preference(spec, read_preference): """Add $readPreference to spec when appropriate.""" mode = read_preference.mode document = read_preference.document # Only add $readPreference if it's something other than primary to avoid # problems with mongos versions that don't support read preferences. Also, # for maximum backwards compatibility, don't add $readPreference for # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting # the slaveOkay bit has the same effect). if mode and ( mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): if "$query" not in spec: spec = SON([("$query", spec)]) spec["$readPreference"] = document return spec def _convert_exception(exception): """Convert an Exception into a failure document for publishing.""" return {'errmsg': str(exception), 'errtype': exception.__class__.__name__} def _convert_write_result(operation, command, result): """Convert a legacy write result to write command format.""" # Based on _merge_legacy from bulk.py affected = result.get("n", 0) res = {"ok": 1, "n": affected} errmsg = result.get("errmsg", result.get("err", "")) if errmsg: # The write was successful on at least the primary so don't return. if result.get("wtimeout"): res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} else: # The write failed. error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} if "errInfo" in result: error["errInfo"] = result["errInfo"] res["writeErrors"] = [error] return res if operation == "insert": # GLE result for insert is always 0 in most MongoDB versions. res["n"] = len(command['documents']) elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] # Versions of MongoDB before 2.6 don't return the _id for an # upsert if _id is not an ObjectId. elif result.get("updatedExisting") is False and affected == 1: # If _id is in both the update document *and* the query spec # the update document _id takes precedence. update = command['updates'][0] _id = update["u"].get("_id", update["q"].get("_id")) res["upserted"] = [{"index": 0, "_id": _id}] return res _OPTIONS = SON([ ('tailable', 2), ('oplogReplay', 8), ('noCursorTimeout', 16), ('awaitData', 32), ('allowPartialResults', 128)]) _MODIFIERS = SON([ ('$query', 'filter'), ('$orderby', 'sort'), ('$hint', 'hint'), ('$comment', 'comment'), ('$maxScan', 'maxScan'), ('$maxTimeMS', 'maxTimeMS'), ('$max', 'max'), ('$min', 'min'), ('$returnKey', 'returnKey'), ('$showRecordId', 'showRecordId'), ('$showDiskLoc', 'showRecordId'), # <= MongoDb 3.0 ('$snapshot', 'snapshot')]) def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options, read_concern, collation=None, session=None, allow_disk_use=None): """Generate a find command document.""" cmd = SON([('find', coll)]) if '$query' in spec: cmd.update([(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) for key, val in spec.items()]) if '$explain' in cmd: cmd.pop('$explain') if '$readPreference' in cmd: cmd.pop('$readPreference') else: cmd['filter'] = spec if projection: cmd['projection'] = projection if skip: cmd['skip'] = skip if limit: cmd['limit'] = abs(limit) if limit < 0: cmd['singleBatch'] = True if batch_size: cmd['batchSize'] = batch_size if read_concern.level and not (session and session.in_transaction): cmd['readConcern'] = read_concern.document if collation: cmd['collation'] = collation if allow_disk_use is not None: cmd['allowDiskUse'] = allow_disk_use if options: cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) return cmd def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms): """Generate a getMore command document.""" cmd = SON([('getMore', cursor_id), ('collection', coll)]) if batch_size: cmd['batchSize'] = batch_size if max_await_time_ms is not None: cmd['maxTimeMS'] = max_await_time_ms return cmd class _Query(object): """A query operation.""" __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec', 'fields', 'codec_options', 'read_preference', 'limit', 'batch_size', 'name', 'read_concern', 'collation', 'session', 'client', 'allow_disk_use', '_as_command') # For compatibility with the _GetMore class. exhaust_mgr = None cursor_id = None def __init__(self, flags, db, coll, ntoskip, spec, fields, codec_options, read_preference, limit, batch_size, read_concern, collation, session, client, allow_disk_use): self.flags = flags self.db = db self.coll = coll self.ntoskip = ntoskip self.spec = spec self.fields = fields self.codec_options = codec_options self.read_preference = read_preference self.read_concern = read_concern self.limit = limit self.batch_size = batch_size self.collation = collation self.session = session self.client = client self.allow_disk_use = allow_disk_use self.name = 'find' self._as_command = None def namespace(self): return _UJOIN % (self.db, self.coll) def use_command(self, sock_info, exhaust): use_find_cmd = False if sock_info.max_wire_version >= 4: if not exhaust: use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( 'read concern level of %s is not valid ' 'with a max wire version of %d.' % (self.read_concern.level, sock_info.max_wire_version)) if sock_info.max_wire_version < 5 and self.collation is not None: raise ConfigurationError( 'Specifying a collation is unsupported with a max wire ' 'version of %d.' % (sock_info.max_wire_version,)) if sock_info.max_wire_version < 4 and self.allow_disk_use is not None: raise ConfigurationError( 'Specifying allowDiskUse is unsupported with a max wire ' 'version of %d.' % (sock_info.max_wire_version,)) sock_info.validate_session(self.client, self.session) return use_find_cmd def as_command(self, sock_info): """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. if self._as_command is not None: return self._as_command explain = '$explain' in self.spec cmd = _gen_find_command( self.coll, self.spec, self.fields, self.ntoskip, self.limit, self.batch_size, self.flags, self.read_concern, self.collation, self.session, self.allow_disk_use) if explain: self.name = 'explain' cmd = SON([('explain', cmd)]) session = self.session if session: session._apply_to(cmd, False, self.read_preference) # Explain does not support readConcern. if (not explain and session.options.causal_consistency and session.operation_time is not None and not session.in_transaction): cmd.setdefault( 'readConcern', {})[ 'afterClusterTime'] = session.operation_time sock_info.send_cluster_time(cmd, session, self.client) # Support auto encryption client = self.client if (client._encrypter and not client._encrypter._bypass_auto_encryption): cmd = client._encrypter.encrypt( self.db, cmd, False, self.codec_options) self._as_command = cmd, self.db return self._as_command def get_message(self, set_slave_ok, sock_info, use_cmd=False): """Get a query message, possibly setting the slaveOk bit.""" if set_slave_ok: # Set the slaveOk bit. flags = self.flags | 4 else: flags = self.flags ns = self.namespace() spec = self.spec if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: request_id, msg, size, _ = _op_msg( 0, spec, self.db, self.read_preference, set_slave_ok, False, self.codec_options, ctx=sock_info.compression_context) return request_id, msg, size ns = _UJOIN % (self.db, "$cmd") ntoreturn = -1 # All DB commands return 1 document else: # OP_QUERY treats ntoreturn of -1 and 1 the same, return # one document and close the cursor. We have to use 2 for # batch size if 1 is specified. ntoreturn = self.batch_size == 1 and 2 or self.batch_size if self.limit: if ntoreturn: ntoreturn = min(self.limit, ntoreturn) else: ntoreturn = self.limit if sock_info.is_mongos: spec = _maybe_add_read_preference(spec, self.read_preference) return query(flags, ns, self.ntoskip, ntoreturn, spec, None if use_cmd else self.fields, self.codec_options, ctx=sock_info.compression_context) class _GetMore(object): """A getmore operation.""" __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms', 'codec_options', 'read_preference', 'session', 'client', 'exhaust_mgr', '_as_command') name = 'getMore' def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, read_preference, session, client, max_await_time_ms, exhaust_mgr): self.db = db self.coll = coll self.ntoreturn = ntoreturn self.cursor_id = cursor_id self.codec_options = codec_options self.read_preference = read_preference self.session = session self.client = client self.max_await_time_ms = max_await_time_ms self.exhaust_mgr = exhaust_mgr self._as_command = None def namespace(self): return _UJOIN % (self.db, self.coll) def use_command(self, sock_info, exhaust): sock_info.validate_session(self.client, self.session) return sock_info.max_wire_version >= 4 and not exhaust def as_command(self, sock_info): """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: return self._as_command cmd = _gen_get_more_command(self.cursor_id, self.coll, self.ntoreturn, self.max_await_time_ms) if self.session: self.session._apply_to(cmd, False, self.read_preference) sock_info.send_cluster_time(cmd, self.session, self.client) # Support auto encryption client = self.client if (client._encrypter and not client._encrypter._bypass_auto_encryption): cmd = client._encrypter.encrypt( self.db, cmd, False, self.codec_options) self._as_command = cmd, self.db return self._as_command def get_message(self, dummy0, sock_info, use_cmd=False): """Get a getmore message.""" ns = self.namespace() ctx = sock_info.compression_context if use_cmd: spec = self.as_command(sock_info)[0] if sock_info.op_msg_enabled: request_id, msg, size, _ = _op_msg( 0, spec, self.db, None, False, False, self.codec_options, ctx=sock_info.compression_context) return request_id, msg, size ns = _UJOIN % (self.db, "$cmd") return query(0, ns, 0, -1, spec, None, self.codec_options, ctx=ctx) return get_more(ns, self.ntoreturn, self.cursor_id, ctx) # TODO: Use OP_MSG once the server is able to respond with document streams. class _RawBatchQuery(_Query): def use_command(self, socket_info, exhaust): # Compatibility checks. super(_RawBatchQuery, self).use_command(socket_info, exhaust) return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchQuery, self).get_message( set_slave_ok, sock_info, False) class _RawBatchGetMore(_GetMore): def use_command(self, socket_info, exhaust): return False def get_message(self, set_slave_ok, sock_info, use_cmd=False): # Always pass False for use_cmd. return super(_RawBatchGetMore, self).get_message( set_slave_ok, sock_info, False) class _CursorAddress(tuple): """The server address (host, port) of a cursor, with namespace property.""" def __new__(cls, address, namespace): self = tuple.__new__(cls, address) self.__namespace = namespace return self @property def namespace(self): """The namespace this cursor.""" return self.__namespace def __hash__(self): # Two _CursorAddress instances with different namespaces # must not hash the same. return (self + (self.__namespace,)).__hash__() def __eq__(self, other): if isinstance(other, _CursorAddress): return (tuple(self) == tuple(other) and self.namespace == other.namespace) return NotImplemented def __ne__(self, other): return not self == other _pack_compression_header = struct.Struct(" ctx.max_bson_size) message_length += encoded_length if message_length < ctx.max_message_size and not too_large: data.write(encoded) to_send.append(doc) has_docs = True continue if has_docs: # We have enough data, send this message. try: if compress: rid, msg = None, data.getvalue() else: rid, msg = _insert_message(data.getvalue(), send_safe) ctx.legacy_bulk_insert( rid, msg, 0, send_safe, to_send, compress) # Exception type could be OperationFailure or a subtype # (e.g. DuplicateKeyError) except OperationFailure as exc: # Like it says, continue on error... if continue_on_error: # Store exception details to re-raise after the final batch. last_error = exc # With unacknowledged writes just return at the first error. elif not safe: return # With acknowledged writes raise immediately. else: raise if too_large: _raise_document_too_large( "insert", encoded_length, ctx.max_bson_size) message_length = begin_loc + encoded_length data.seek(begin_loc) data.truncate() data.write(encoded) to_send = [doc] if not has_docs: raise InvalidOperation("cannot do an empty bulk insert") if compress: request_id, msg = None, data.getvalue() else: request_id, msg = _insert_message(data.getvalue(), safe) ctx.legacy_bulk_insert(request_id, msg, 0, safe, to_send, compress) # Re-raise any exception stored due to continue_on_error if last_error is not None: raise last_error if _use_c: _do_batched_insert = _cmessage._do_batched_insert # OP_MSG ------------------------------------------------------------- _OP_MSG_MAP = { _INSERT: b'documents\x00', _UPDATE: b'updates\x00', _DELETE: b'deletes\x00', } def _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf): """Create a batched OP_MSG write.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size max_message_size = ctx.max_message_size flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" # Flags buf.write(flags) # Type 0 Section buf.write(b"\x00") buf.write(_dict_to_bson(command, False, opts)) # Type 1 Section buf.write(b"\x01") size_location = buf.tell() # Save space for size buf.write(b"\x00\x00\x00\x00") try: buf.write(_OP_MSG_MAP[operation]) except KeyError: raise InvalidOperation('Unknown command') if operation in (_UPDATE, _DELETE): check_keys = False to_send = [] idx = 0 for doc in docs: # Encode the current operation value = _dict_to_bson(doc, check_keys, opts) doc_length = len(value) new_message_size = buf.tell() + doc_length # Does first document exceed max_message_size? doc_too_large = (idx == 0 and (new_message_size > max_message_size)) # When OP_MSG is used unacknowleged we have to check # document size client side or applications won't be notified. # Otherwise we let the server deal with documents that are too large # since ordered=False causes those documents to be skipped instead of # halting the bulk write operation. unacked_doc_too_large = (not ack and (doc_length > max_bson_size)) if doc_too_large or unacked_doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large( write_op, len(value), max_bson_size) # We have enough data, return this batch. if new_message_size > max_message_size: break buf.write(value) to_send.append(doc) idx += 1 # We have enough documents, return this batch. if idx == max_write_batch_size: break # Write type 1 section size length = buf.tell() buf.seek(size_location) buf.write(_pack_int(length - size_location)) return to_send, length def _encode_batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx): """Encode the next batched insert, update, or delete operation as OP_MSG. """ buf = StringIO() to_send, _ = _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_op_msg = _cmessage._encode_batched_op_msg def _batched_op_msg_compressed( operation, command, docs, check_keys, ack, opts, ctx): """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ data, to_send = _encode_batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx) request_id, msg = _compress( 2013, data, ctx.sock_info.compression_context) return request_id, msg, to_send def _batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx): """OP_MSG implementation entry point.""" buf = StringIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") to_send, length = _batched_op_msg_impl( operation, command, docs, check_keys, ack, opts, ctx, buf) # Header - request id and message length buf.seek(4) request_id = _randint() buf.write(_pack_int(request_id)) buf.seek(0) buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send if _use_c: _batched_op_msg = _cmessage._batched_op_msg def _do_batched_op_msg( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete operation using OP_MSG. """ command['$db'] = namespace.split('.', 1)[0] if 'writeConcern' in command: ack = bool(command['writeConcern'].get('w', 1)) else: ack = True if ctx.sock_info.compression_context: return _batched_op_msg_compressed( operation, command, docs, check_keys, ack, opts, ctx) return _batched_op_msg( operation, command, docs, check_keys, ack, opts, ctx) # End OP_MSG ----------------------------------------------------- def _batched_write_command_compressed( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command, compressed. """ data, to_send = _encode_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) request_id, msg = _compress( 2004, data, ctx.sock_info.compression_context) return request_id, msg, to_send def _encode_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Encode the next batched insert, update, or delete command. """ buf = StringIO() to_send, _ = _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf) return buf.getvalue(), to_send if _use_c: _encode_batched_write_command = _cmessage._encode_batched_write_command def _batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete command. """ buf = StringIO() # Save space for message length and request id buf.write(_ZERO_64) # responseTo, opCode buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00") # Write OP_QUERY write command to_send, length = _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf) # Header - request id and message length buf.seek(4) request_id = _randint() buf.write(_pack_int(request_id)) buf.seek(0) buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send if _use_c: _batched_write_command = _cmessage._batched_write_command def _do_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Batched write commands entry point.""" if ctx.sock_info.compression_context: return _batched_write_command_compressed( namespace, operation, command, docs, check_keys, opts, ctx) return _batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) def _do_bulk_write_command( namespace, operation, command, docs, check_keys, opts, ctx): """Bulk write commands entry point.""" if ctx.sock_info.max_wire_version > 5: return _do_batched_op_msg( namespace, operation, command, docs, check_keys, opts, ctx) return _do_batched_write_command( namespace, operation, command, docs, check_keys, opts, ctx) def _batched_write_command_impl( namespace, operation, command, docs, check_keys, opts, ctx, buf): """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size # Max BSON object size + 16k - 2 bytes for ending NUL bytes. # Server guarantees there is enough room: SERVER-10643. max_cmd_size = max_bson_size + _COMMAND_OVERHEAD max_split_size = ctx.max_split_size # No options buf.write(_ZERO_32) # Namespace as C string buf.write(b(namespace)) buf.write(_ZERO_8) # Skip: 0, Limit: -1 buf.write(_SKIPLIM) # Where to write command document length command_start = buf.tell() buf.write(encode(command)) # Start of payload buf.seek(-1, 2) # Work around some Jython weirdness. buf.truncate() try: buf.write(_OP_MAP[operation]) except KeyError: raise InvalidOperation('Unknown command') if operation in (_UPDATE, _DELETE): check_keys = False # Where to write list document length list_start = buf.tell() - 4 to_send = [] idx = 0 for doc in docs: # Encode the current operation key = b(str(idx)) value = encode(doc, check_keys, opts) # Is there enough room to add this document? max_cmd_size accounts for # the two trailing null bytes. doc_too_large = len(value) > max_cmd_size if doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] _raise_document_too_large( write_op, len(value), max_bson_size) enough_data = (idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size) enough_documents = (idx >= max_write_batch_size) if enough_data or enough_documents: break buf.write(_BSONOBJ) buf.write(key) buf.write(_ZERO_8) buf.write(value) to_send.append(doc) idx += 1 # Finalize the current OP_QUERY message. # Close list and command documents buf.write(_ZERO_16) # Write document lengths and request id length = buf.tell() buf.seek(list_start) buf.write(_pack_int(length - list_start - 1)) buf.seek(command_start) buf.write(_pack_int(length - command_start)) return to_send, length class _OpReply(object): """A MongoDB OP_REPLY response message.""" __slots__ = ("flags", "cursor_id", "number_returned", "documents") UNPACK_FROM = struct.Struct("1 section") # Convert Python 3 memoryview to bytes. Note we should call # memoryview.tobytes() if we start using memoryview in Python 2.7. payload_document = bytes(msg[5:]) return cls(flags, payload_document) _UNPACK_REPLY = { _OpReply.OP_CODE: _OpReply.unpack, _OpMsg.OP_CODE: _OpMsg.unpack, } def _first_batch(sock_info, db, coll, query, ntoreturn, slave_ok, codec_options, read_preference, cmd, listeners): """Simple query helper for retrieving a first (and possibly only) batch.""" query = _Query( 0, db, coll, 0, query, None, codec_options, read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None, None, None) name = next(iter(cmd)) publish = listeners.enabled_for_commands if publish: start = datetime.datetime.now() request_id, msg, max_doc_size = query.get_message(slave_ok, sock_info) if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start( cmd, db, request_id, sock_info.address) start = datetime.datetime.now() sock_info.send_message(msg, max_doc_size) reply = sock_info.receive_message(request_id) try: docs = reply.unpack_response(None, codec_options) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration if isinstance(exc, (NotMasterError, OperationFailure)): failure = exc.details else: failure = _convert_exception(exc) listeners.publish_command_failure( duration, failure, name, request_id, sock_info.address) raise # listIndexes if 'cursor' in cmd: result = { u'cursor': { u'firstBatch': docs, u'id': reply.cursor_id, u'ns': u'%s.%s' % (db, coll) }, u'ok': 1.0 } # fsyncUnlock, currentOp else: result = docs[0] if docs else {} result[u'ok'] = 1.0 if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( duration, result, name, request_id, sock_info.address) return result pymongo-3.11.0/pymongo/mongo_client.py000066400000000000000000003103321374256237000200160ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Tools for connecting to MongoDB. .. seealso:: :doc:`/examples/high_availability` for examples of connecting to replica sets or sets of mongos servers. To get a :class:`~pymongo.database.Database` instance from a :class:`MongoClient` use either dictionary-style or attribute-style access: .. doctest:: >>> from pymongo import MongoClient >>> c = MongoClient() >>> c.test_database Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True), u'test_database') >>> c['test-database'] Database(MongoClient(host=['localhost:27017'], document_class=dict, tz_aware=False, connect=True), u'test-database') """ import contextlib import datetime import threading import warnings import weakref from collections import defaultdict from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.py3compat import (integer_types, string_type) from bson.son import SON from pymongo import (common, database, helpers, message, periodic_executor, uri_parser, client_session) from pymongo.change_stream import ClusterChangeStream from pymongo.client_options import ClientOptions from pymongo.command_cursor import CommandCursor from pymongo.cursor_manager import CursorManager from pymongo.errors import (AutoReconnect, BulkWriteError, ConfigurationError, ConnectionFailure, InvalidOperation, NotMasterError, OperationFailure, PyMongoError, ServerSelectionTimeoutError) from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import (writable_preferred_server_selector, writable_server_selector) from pymongo.server_type import SERVER_TYPE from pymongo.topology import (Topology, _ErrorContext) from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.settings import TopologySettings from pymongo.uri_parser import (_handle_option_deprecations, _handle_security_options, _normalize_options) from pymongo.write_concern import DEFAULT_WRITE_CONCERN class MongoClient(common.BaseObject): """ A client-side representation of a MongoDB cluster. Instances can represent either a standalone MongoDB server, a replica set, or a sharded cluster. Instances of this class are responsible for maintaining up-to-date state of the cluster, and possibly cache resources related to this, including background threads for monitoring, and connection pools. """ HOST = "localhost" PORT = 27017 # Define order to retrieve options from ClientOptions for __repr__. # No host/port; these are retrieved from TopologySettings. _constructor_args = ('document_class', 'tz_aware', 'connect') def __init__( self, host=None, port=None, document_class=dict, tz_aware=None, connect=None, type_registry=None, **kwargs): """Client for a MongoDB instance, a replica set, or a set of mongoses. The client object is thread-safe and has connection-pooling built in. If an operation fails because of a network error, :class:`~pymongo.errors.ConnectionFailure` is raised and the client reconnects in the background. Application code should handle this exception (recognizing that the operation failed) and then continue to execute. The `host` parameter can be a full `mongodb URI `_, in addition to a simple hostname. It can also be a list of hostnames or URIs. Any port specified in the host string(s) will override the `port` parameter. If multiple mongodb URIs containing database or auth information are passed, the last database, username, and password present will be used. For username and passwords reserved characters like ':', '/', '+' and '@' must be percent encoded following RFC 2396:: try: # Python 3.x from urllib.parse import quote_plus except ImportError: # Python 2.x from urllib import quote_plus uri = "mongodb://%s:%s@%s" % ( quote_plus(user), quote_plus(password), host) client = MongoClient(uri) Unix domain sockets are also supported. The socket path must be percent encoded in the URI:: uri = "mongodb://%s:%s@%s" % ( quote_plus(user), quote_plus(password), quote_plus(socket_path)) client = MongoClient(uri) But not when passed as a simple hostname:: client = MongoClient('/tmp/mongodb-27017.sock') Starting with version 3.6, PyMongo supports mongodb+srv:// URIs. The URI must include one, and only one, hostname. The hostname will be resolved to one or more DNS `SRV records `_ which will be used as the seed list for connecting to the MongoDB deployment. When using SRV URIs, the `authSource` and `replicaSet` configuration options can be specified using `TXT records `_. See the `Initial DNS Seedlist Discovery spec `_ for more details. Note that the use of SRV URIs implicitly enables TLS support. Pass tls=false in the URI to override. .. note:: MongoClient creation will block waiting for answers from DNS when mongodb+srv:// URIs are used. .. note:: Starting with version 3.0 the :class:`MongoClient` constructor no longer blocks while connecting to the server or servers, and it no longer raises :class:`~pymongo.errors.ConnectionFailure` if they are unavailable, nor :class:`~pymongo.errors.ConfigurationError` if the user's credentials are wrong. Instead, the constructor returns immediately and launches the connection process on background threads. You can check if the server is available like this:: from pymongo.errors import ConnectionFailure client = MongoClient() try: # The ismaster command is cheap and does not require auth. client.admin.command('ismaster') except ConnectionFailure: print("Server not available") .. warning:: When using PyMongo in a multiprocessing context, please read :ref:`multiprocessing` first. .. note:: Many of the following options can be passed using a MongoDB URI or keyword parameters. If the same option is passed in a URI and as a keyword parameter the keyword parameter takes precedence. :Parameters: - `host` (optional): hostname or IP address or Unix domain socket path of a single mongod or mongos instance to connect to, or a mongodb URI, or a list of hostnames / mongodb URIs. If `host` is an IPv6 literal it must be enclosed in '[' and ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for localhost). Multihomed and round robin DNS addresses are **not** supported. - `port` (optional): port number on which to connect - `document_class` (optional): default class to use for documents returned from queries on this client - `type_registry` (optional): instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. - `tz_aware` (optional): if ``True``, :class:`~datetime.datetime` instances returned as values in a document by this :class:`MongoClient` will be timezone aware (otherwise they will be naive) - `connect` (optional): if ``True`` (the default), immediately begin connecting to MongoDB in the background. Otherwise connect on the first operation. - `directConnection` (optional): if ``True``, forces this client to connect directly to the specified MongoDB host as a standalone. If ``false``, the client connects to the entire replica set of which the given MongoDB host(s) is a part. If this is ``True`` and a mongodb+srv:// URI or a URI containing multiple seeds is provided, an exception will be raised. | **Other optional parameters can be passed as keyword arguments:** - `maxPoolSize` (optional): The maximum allowable number of concurrent connections to each connected server. Requests to a server will block if there are `maxPoolSize` outstanding connections to the requested server. Defaults to 100. Cannot be 0. - `minPoolSize` (optional): The minimum required number of concurrent connections that the pool will maintain to each connected server. Default is 0. - `maxIdleTimeMS` (optional): The maximum number of milliseconds that a connection can remain idle in the pool before being removed and replaced. Defaults to `None` (no limit). - `socketTimeoutMS`: (integer or None) Controls how long (in milliseconds) the driver will wait for a response after sending an ordinary (non-monitoring) database operation before concluding that a network error has occurred. ``0`` or ``None`` means no timeout. Defaults to ``None`` (no timeout). - `connectTimeoutMS`: (integer or None) Controls how long (in milliseconds) the driver will wait during server monitoring when connecting a new socket to a server before concluding the server is unavailable. ``0`` or ``None`` means no timeout. Defaults to ``20000`` (20 seconds). - `server_selector`: (callable or None) Optional, user-provided function that augments server selection rules. The function should accept as an argument a list of :class:`~pymongo.server_description.ServerDescription` objects and return a list of server descriptions that should be considered suitable for the desired operation. - `serverSelectionTimeoutMS`: (integer) Controls how long (in milliseconds) the driver will wait to find an available, appropriate server to carry out a database operation; while it is waiting, multiple server monitoring operations may be carried out, each controlled by `connectTimeoutMS`. Defaults to ``30000`` (30 seconds). - `waitQueueTimeoutMS`: (integer or None) How long (in milliseconds) a thread will wait for a socket from the pool if the pool has no free sockets. Defaults to ``None`` (no timeout). - `waitQueueMultiple`: (integer or None) Multiplied by maxPoolSize to give the number of threads allowed to wait for a socket at one time. Defaults to ``None`` (no limit). - `heartbeatFrequencyMS`: (optional) The number of milliseconds between periodic server checks, or None to accept the default frequency of 10 seconds. - `appname`: (string or None) The name of the application that created this MongoClient instance. MongoDB 3.4 and newer will print this value in the server log upon establishing each connection. It is also recorded in the slow query log and profile collections. - `driver`: (pair or None) A driver implemented on top of PyMongo can pass a :class:`~pymongo.driver_info.DriverInfo` to add its name, version, and platform to the message printed in the server log when establishing a connection. - `event_listeners`: a list or tuple of event listeners. See :mod:`~pymongo.monitoring` for details. - `retryWrites`: (boolean) Whether supported write operations executed within this MongoClient will be retried once after a network error on MongoDB 3.6+. Defaults to ``True``. The supported write operations are: - :meth:`~pymongo.collection.Collection.bulk_write`, as long as :class:`~pymongo.operations.UpdateMany` or :class:`~pymongo.operations.DeleteMany` are not included. - :meth:`~pymongo.collection.Collection.delete_one` - :meth:`~pymongo.collection.Collection.insert_one` - :meth:`~pymongo.collection.Collection.insert_many` - :meth:`~pymongo.collection.Collection.replace_one` - :meth:`~pymongo.collection.Collection.update_one` - :meth:`~pymongo.collection.Collection.find_one_and_delete` - :meth:`~pymongo.collection.Collection.find_one_and_replace` - :meth:`~pymongo.collection.Collection.find_one_and_update` Unsupported write operations include, but are not limited to, :meth:`~pymongo.collection.Collection.aggregate` using the ``$out`` pipeline operator and any operation with an unacknowledged write concern (e.g. {w: 0})). See https://github.com/mongodb/specifications/blob/master/source/retryable-writes/retryable-writes.rst - `retryReads`: (boolean) Whether supported read operations executed within this MongoClient will be retried once after a network error on MongoDB 3.6+. Defaults to ``True``. The supported read operations are: :meth:`~pymongo.collection.Collection.find`, :meth:`~pymongo.collection.Collection.find_one`, :meth:`~pymongo.collection.Collection.aggregate` without ``$out``, :meth:`~pymongo.collection.Collection.distinct`, :meth:`~pymongo.collection.Collection.count`, :meth:`~pymongo.collection.Collection.estimated_document_count`, :meth:`~pymongo.collection.Collection.count_documents`, :meth:`pymongo.collection.Collection.watch`, :meth:`~pymongo.collection.Collection.list_indexes`, :meth:`pymongo.database.Database.watch`, :meth:`~pymongo.database.Database.list_collections`, :meth:`pymongo.mongo_client.MongoClient.watch`, and :meth:`~pymongo.mongo_client.MongoClient.list_databases`. Unsupported read operations include, but are not limited to: :meth:`~pymongo.collection.Collection.map_reduce`, :meth:`~pymongo.collection.Collection.inline_map_reduce`, :meth:`~pymongo.database.Database.command`, and any getMore operation on a cursor. Enabling retryable reads makes applications more resilient to transient errors such as network failures, database upgrades, and replica set failovers. For an exact definition of which errors trigger a retry, see the `retryable reads specification `_. - `socketKeepAlive`: (boolean) **DEPRECATED** Whether to send periodic keep-alive packets on connected sockets. Defaults to ``True``. Disabling it is not recommended, see https://docs.mongodb.com/manual/faq/diagnostics/#does-tcp-keepalive-time-affect-mongodb-deployments", - `compressors`: Comma separated list of compressors for wire protocol compression. The list is used to negotiate a compressor with the server. Currently supported options are "snappy", "zlib" and "zstd". Support for snappy requires the `python-snappy `_ package. zlib support requires the Python standard library zlib module. zstd requires the `zstandard `_ package. By default no compression is used. Compression support must also be enabled on the server. MongoDB 3.4+ supports snappy compression. MongoDB 3.6 adds support for zlib. MongoDB 4.2 adds support for zstd. - `zlibCompressionLevel`: (int) The zlib compression level to use when zlib is used as the wire protocol compressor. Supported values are -1 through 9. -1 tells the zlib library to use its default compression level (usually 6). 0 means no compression. 1 is best speed. 9 is best compression. Defaults to -1. - `uuidRepresentation`: The BSON representation to use when encoding from and decoding to instances of :class:`~uuid.UUID`. Valid values are `pythonLegacy` (the default), `javaLegacy`, `csharpLegacy`, `standard` and `unspecified`. New applications should consider setting this to `standard` for cross language compatibility. See :ref:`handling-uuid-data-example` for details. | **Write Concern options:** | (Only set if passed. No default values.) - `w`: (integer or string) If this is a replica set, write operations will block until they have been replicated to the specified number or tagged set of servers. `w=` always includes the replica set primary (e.g. w=3 means write to the primary and wait until replicated to **two** secondaries). Passing w=0 **disables write acknowledgement** and all other write concern options. - `wTimeoutMS`: (integer) Used in conjunction with `w`. Specify a value in milliseconds to control how long to wait for write propagation to complete. If replication does not complete in the given timeframe, a timeout exception is raised. Passing wTimeoutMS=0 will cause **write operations to wait indefinitely**. - `journal`: If ``True`` block until write operations have been committed to the journal. Cannot be used in combination with `fsync`. Prior to MongoDB 2.6 this option was ignored if the server was running without journaling. Starting with MongoDB 2.6 write operations will fail with an exception if this option is used when the server is running without journaling. - `fsync`: If ``True`` and the server is running without journaling, blocks until the server has synced all data files to disk. If the server is running with journaling, this acts the same as the `j` option, blocking until write operations have been committed to the journal. Cannot be used in combination with `j`. | **Replica set keyword arguments for connecting with a replica set - either directly or via a mongos:** - `replicaSet`: (string or None) The name of the replica set to connect to. The driver will verify that all servers it connects to match this name. Implies that the hosts specified are a seed list and the driver should attempt to find all members of the set. Defaults to ``None``. | **Read Preference:** - `readPreference`: The replica set read preference for this client. One of ``primary``, ``primaryPreferred``, ``secondary``, ``secondaryPreferred``, or ``nearest``. Defaults to ``primary``. - `readPreferenceTags`: Specifies a tag set as a comma-separated list of colon-separated key-value pairs. For example ``dc:ny,rack:1``. Defaults to ``None``. - `maxStalenessSeconds`: (integer) The maximum estimated length of time a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations. Defaults to ``-1``, meaning no maximum. If maxStalenessSeconds is set, it must be a positive integer greater than or equal to 90 seconds. .. seealso:: :doc:`/examples/server_selection` | **Authentication:** - `username`: A string. - `password`: A string. Although username and password must be percent-escaped in a MongoDB URI, they must not be percent-escaped when passed as parameters. In this example, both the space and slash special characters are passed as-is:: MongoClient(username="user name", password="pass/word") - `authSource`: The database to authenticate on. Defaults to the database specified in the URI, if provided, or to "admin". - `authMechanism`: See :data:`~pymongo.auth.MECHANISMS` for options. If no mechanism is specified, PyMongo automatically uses MONGODB-CR when connected to a pre-3.0 version of MongoDB, SCRAM-SHA-1 when connected to MongoDB 3.0 through 3.6, and negotiates the mechanism to use (SCRAM-SHA-1 or SCRAM-SHA-256) when connected to MongoDB 4.0+. - `authMechanismProperties`: Used to specify authentication mechanism specific options. To specify the service name for GSSAPI authentication pass authMechanismProperties='SERVICE_NAME:'. To specify the session token for MONGODB-AWS authentication pass ``authMechanismProperties='AWS_SESSION_TOKEN:'``. .. seealso:: :doc:`/examples/authentication` | **TLS/SSL configuration:** - `tls`: (boolean) If ``True``, create the connection to the server using transport layer security. Defaults to ``False``. - `tlsInsecure`: (boolean) Specify whether TLS constraints should be relaxed as much as possible. Setting ``tlsInsecure=True`` implies ``tlsAllowInvalidCertificates=True`` and ``tlsAllowInvalidHostnames=True``. Defaults to ``False``. Think very carefully before setting this to ``True`` as it dramatically reduces the security of TLS. - `tlsAllowInvalidCertificates`: (boolean) If ``True``, continues the TLS handshake regardless of the outcome of the certificate verification process. If this is ``False``, and a value is not provided for ``tlsCAFile``, PyMongo will attempt to load system provided CA certificates. If the python version in use does not support loading system CA certificates then the ``tlsCAFile`` parameter must point to a file of CA certificates. ``tlsAllowInvalidCertificates=False`` implies ``tls=True``. Defaults to ``False``. Think very carefully before setting this to ``True`` as that could make your application vulnerable to man-in-the-middle attacks. - `tlsAllowInvalidHostnames`: (boolean) If ``True``, disables TLS hostname verification. ``tlsAllowInvalidHostnames=False`` implies ``tls=True``. Defaults to ``False``. Think very carefully before setting this to ``True`` as that could make your application vulnerable to man-in-the-middle attacks. - `tlsCAFile`: A file containing a single or a bundle of "certification authority" certificates, which are used to validate certificates passed from the other end of the connection. Implies ``tls=True``. Defaults to ``None``. - `tlsCertificateKeyFile`: A file containing the client certificate and private key. If you want to pass the certificate and private key as separate files, use the ``ssl_certfile`` and ``ssl_keyfile`` options instead. Implies ``tls=True``. Defaults to ``None``. - `tlsCRLFile`: A file containing a PEM or DER formatted certificate revocation list. Only supported by python 2.7.9+ (pypy 2.5.1+). Implies ``tls=True``. Defaults to ``None``. - `tlsCertificateKeyFilePassword`: The password or passphrase for decrypting the private key in ``tlsCertificateKeyFile`` or ``ssl_keyfile``. Only necessary if the private key is encrypted. Only supported by python 2.7.9+ (pypy 2.5.1+) and 3.3+. Defaults to ``None``. - `tlsDisableOCSPEndpointCheck`: (boolean) If ``True``, disables certificate revocation status checking via the OCSP responder specified on the server certificate. Defaults to ``False``. - `ssl`: (boolean) Alias for ``tls``. - `ssl_certfile`: The certificate file used to identify the local connection against mongod. Implies ``tls=True``. Defaults to ``None``. - `ssl_keyfile`: The private keyfile used to identify the local connection against mongod. Can be omitted if the keyfile is included with the ``tlsCertificateKeyFile``. Implies ``tls=True``. Defaults to ``None``. | **Read Concern options:** | (If not set explicitly, this will use the server default) - `readConcernLevel`: (string) The read concern level specifies the level of isolation for read operations. For example, a read operation using a read concern level of ``majority`` will only return data that has been written to a majority of nodes. If the level is left unspecified, the server default will be used. | **Client side encryption options:** | (If not set explicitly, client side encryption will not be enabled.) - `auto_encryption_opts`: A :class:`~pymongo.encryption_options.AutoEncryptionOpts` which configures this client to automatically encrypt collection commands and automatically decrypt results. See :ref:`automatic-client-side-encryption` for an example. .. mongodoc:: connections .. versionchanged:: 3.11 Added the following keyword arguments and URI options: - ``tlsDisableOCSPEndpointCheck`` - ``directConnection`` .. versionchanged:: 3.9 Added the ``retryReads`` keyword argument and URI option. Added the ``tlsInsecure`` keyword argument and URI option. The following keyword arguments and URI options were deprecated: - ``wTimeout`` was deprecated in favor of ``wTimeoutMS``. - ``j`` was deprecated in favor of ``journal``. - ``ssl_cert_reqs`` was deprecated in favor of ``tlsAllowInvalidCertificates``. - ``ssl_match_hostname`` was deprecated in favor of ``tlsAllowInvalidHostnames``. - ``ssl_ca_certs`` was deprecated in favor of ``tlsCAFile``. - ``ssl_certfile`` was deprecated in favor of ``tlsCertificateKeyFile``. - ``ssl_crlfile`` was deprecated in favor of ``tlsCRLFile``. - ``ssl_pem_passphrase`` was deprecated in favor of ``tlsCertificateKeyFilePassword``. .. versionchanged:: 3.9 ``retryWrites`` now defaults to ``True``. .. versionchanged:: 3.8 Added the ``server_selector`` keyword argument. Added the ``type_registry`` keyword argument. .. versionchanged:: 3.7 Added the ``driver`` keyword argument. .. versionchanged:: 3.6 Added support for mongodb+srv:// URIs. Added the ``retryWrites`` keyword argument and URI option. .. versionchanged:: 3.5 Add ``username`` and ``password`` options. Document the ``authSource``, ``authMechanism``, and ``authMechanismProperties`` options. Deprecated the ``socketKeepAlive`` keyword argument and URI option. ``socketKeepAlive`` now defaults to ``True``. .. versionchanged:: 3.0 :class:`~pymongo.mongo_client.MongoClient` is now the one and only client class for a standalone server, mongos, or replica set. It includes the functionality that had been split into :class:`~pymongo.mongo_client.MongoReplicaSetClient`: it can connect to a replica set, discover all its members, and monitor the set for stepdowns, elections, and reconfigs. The :class:`~pymongo.mongo_client.MongoClient` constructor no longer blocks while connecting to the server or servers, and it no longer raises :class:`~pymongo.errors.ConnectionFailure` if they are unavailable, nor :class:`~pymongo.errors.ConfigurationError` if the user's credentials are wrong. Instead, the constructor returns immediately and launches the connection process on background threads. Therefore the ``alive`` method is removed since it no longer provides meaningful information; even if the client is disconnected, it may discover a server in time to fulfill the next operation. In PyMongo 2.x, :class:`~pymongo.MongoClient` accepted a list of standalone MongoDB servers and used the first it could connect to:: MongoClient(['host1.com:27017', 'host2.com:27017']) A list of multiple standalones is no longer supported; if multiple servers are listed they must be members of the same replica set, or mongoses in the same sharded cluster. The behavior for a list of mongoses is changed from "high availability" to "load balancing". Before, the client connected to the lowest-latency mongos in the list, and used it until a network error prompted it to re-evaluate all mongoses' latencies and reconnect to one of them. In PyMongo 3, the client monitors its network latency to all the mongoses continuously, and distributes operations evenly among those with the lowest latency. See :ref:`mongos-load-balancing` for more information. The ``connect`` option is added. The ``start_request``, ``in_request``, and ``end_request`` methods are removed, as well as the ``auto_start_request`` option. The ``copy_database`` method is removed, see the :doc:`copy_database examples
` for alternatives. The :meth:`MongoClient.disconnect` method is removed; it was a synonym for :meth:`~pymongo.MongoClient.close`. :class:`~pymongo.mongo_client.MongoClient` no longer returns an instance of :class:`~pymongo.database.Database` for attribute names with leading underscores. You must use dict-style lookups instead:: client['__my_database__'] Not:: client.__my_database__ """ if host is None: host = self.HOST if isinstance(host, string_type): host = [host] if port is None: port = self.PORT if not isinstance(port, int): raise TypeError("port must be an instance of int") # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. pool_class = kwargs.pop('_pool_class', None) monitor_class = kwargs.pop('_monitor_class', None) condition_class = kwargs.pop('_condition_class', None) # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) keyword_opts['document_class'] = document_class seeds = set() username = None password = None dbase = None opts = {} fqdn = None for entity in host: if "://" in entity: # Determine connection timeout from kwargs. timeout = keyword_opts.get("connecttimeoutms") if timeout is not None: timeout = common.validate_timeout_or_none_or_zero( keyword_opts.cased_key("connecttimeoutms"), timeout) res = uri_parser.parse_uri( entity, port, validate=True, warn=True, normalize=False, connect_timeout=timeout) seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password dbase = res["database"] or dbase opts = res["options"] fqdn = res["fqdn"] else: seeds.update(uri_parser.split_hosts(entity, port)) if not seeds: raise ConfigurationError("need to specify at least one host") # Add options with named keyword arguments to the parsed kwarg options. if type_registry is not None: keyword_opts['type_registry'] = type_registry if tz_aware is None: tz_aware = opts.get('tz_aware', False) if connect is None: connect = opts.get('connect', True) keyword_opts['tz_aware'] = tz_aware keyword_opts['connect'] = connect # Handle deprecated options in kwarg options. keyword_opts = _handle_option_deprecations(keyword_opts) # Validate kwarg options. keyword_opts = common._CaseInsensitiveDictionary(dict(common.validate( keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())) # Override connection string options with kwarg options. opts.update(keyword_opts) # Handle security-option conflicts in combined options. opts = _handle_security_options(opts) # Normalize combined options. opts = _normalize_options(opts) # Ensure directConnection was not True if there are multiple seeds. if len(seeds) > 1 and opts.get('directconnection'): raise ConfigurationError( "Cannot specify multiple hosts with directConnection=true") # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) if 'socketkeepalive' in opts: warnings.warn( "The socketKeepAlive option is deprecated. It now" "defaults to true and disabling it is not recommended, see " "https://docs.mongodb.com/manual/faq/diagnostics/" "#does-tcp-keepalive-time-affect-mongodb-deployments", DeprecationWarning, stacklevel=2) self.__options = options = ClientOptions( username, password, dbase, opts) self.__default_database_name = dbase self.__lock = threading.Lock() self.__cursor_manager = None self.__kill_cursors_queue = [] self._event_listeners = options.pool_options.event_listeners # Cache of existing indexes used by ensure_index ops. self.__index_cache = {} self.__index_cache_lock = threading.Lock() super(MongoClient, self).__init__(options.codec_options, options.read_preference, options.write_concern, options.read_concern) self.__all_credentials = {} creds = options.credentials if creds: self._cache_credentials(creds.source, creds) self._topology_settings = TopologySettings( seeds=seeds, replica_set_name=options.replica_set_name, pool_class=pool_class, pool_options=options.pool_options, monitor_class=monitor_class, condition_class=condition_class, local_threshold_ms=options.local_threshold_ms, server_selection_timeout=options.server_selection_timeout, server_selector=options.server_selector, heartbeat_frequency=options.heartbeat_frequency, fqdn=fqdn, direct_connection=options.direct_connection) self._topology = Topology(self._topology_settings) def target(): client = self_ref() if client is None: return False # Stop the executor. MongoClient._process_periodic_tasks(client) return True executor = periodic_executor.PeriodicExecutor( interval=common.KILL_CURSOR_FREQUENCY, min_interval=0.5, target=target, name="pymongo_kill_cursors_thread") # We strongly reference the executor and it weakly references us via # this closure. When the client is freed, stop the executor soon. self_ref = weakref.ref(self, executor.close) self._kill_cursors_executor = executor if connect: self._get_topology() self._encrypter = None if self.__options.auto_encryption_opts: from pymongo.encryption import _Encrypter self._encrypter = _Encrypter.create( self, self.__options.auto_encryption_opts) def _cache_credentials(self, source, credentials, connect=False): """Save a set of authentication credentials. The credentials are used to login a socket whenever one is created. If `connect` is True, verify the credentials on the server first. """ # Don't let other threads affect this call's data. all_credentials = self.__all_credentials.copy() if source in all_credentials: # Nothing to do if we already have these credentials. if credentials == all_credentials[source]: return raise OperationFailure('Another user is already authenticated ' 'to this database. You must logout first.') if connect: server = self._get_topology().select_server( writable_preferred_server_selector) # get_socket() logs out of the database if logged in with old # credentials, and logs in with new ones. with server.get_socket(all_credentials) as sock_info: sock_info.authenticate(credentials) # If several threads run _cache_credentials at once, last one wins. self.__all_credentials[source] = credentials def _purge_credentials(self, source): """Purge credentials from the authentication cache.""" self.__all_credentials.pop(source, None) def _cached(self, dbname, coll, index): """Test if `index` is cached.""" cache = self.__index_cache now = datetime.datetime.utcnow() with self.__index_cache_lock: return (dbname in cache and coll in cache[dbname] and index in cache[dbname][coll] and now < cache[dbname][coll][index]) def _cache_index(self, dbname, collection, index, cache_for): """Add an index to the index cache for ensure_index operations.""" now = datetime.datetime.utcnow() expire = datetime.timedelta(seconds=cache_for) + now with self.__index_cache_lock: if dbname not in self.__index_cache: self.__index_cache[dbname] = {} self.__index_cache[dbname][collection] = {} self.__index_cache[dbname][collection][index] = expire elif collection not in self.__index_cache[dbname]: self.__index_cache[dbname][collection] = {} self.__index_cache[dbname][collection][index] = expire else: self.__index_cache[dbname][collection][index] = expire def _purge_index(self, database_name, collection_name=None, index_name=None): """Purge an index from the index cache. If `index_name` is None purge an entire collection. If `collection_name` is None purge an entire database. """ with self.__index_cache_lock: if not database_name in self.__index_cache: return if collection_name is None: del self.__index_cache[database_name] return if not collection_name in self.__index_cache[database_name]: return if index_name is None: del self.__index_cache[database_name][collection_name] return if index_name in self.__index_cache[database_name][collection_name]: del self.__index_cache[database_name][collection_name][index_name] def _server_property(self, attr_name): """An attribute of the current server's description. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. Not threadsafe if used multiple times in a single method, since the server may change. In such cases, store a local reference to a ServerDescription first, then use its properties. """ server = self._topology.select_server( writable_server_selector) return getattr(server.description, attr_name) def watch(self, pipeline=None, full_document=None, resume_after=None, max_await_time_ms=None, batch_size=None, collation=None, start_at_operation_time=None, session=None, start_after=None): """Watch changes on this cluster. Performs an aggregation with an implicit initial ``$changeStream`` stage and returns a :class:`~pymongo.change_stream.ClusterChangeStream` cursor which iterates over changes on all databases on this cluster. Introduced in MongoDB 4.0. .. code-block:: python with client.watch() as stream: for change in stream: print(change) The :class:`~pymongo.change_stream.ClusterChangeStream` iterable blocks until the next change document is returned or an error is raised. If the :meth:`~pymongo.change_stream.ClusterChangeStream.next` method encounters a network error when retrieving a batch from the server, it will automatically attempt to recreate the cursor such that no change events are missed. Any error encountered during the resume attempt indicates there may be an outage and will be raised. .. code-block:: python try: with client.watch( [{'$match': {'operationType': 'insert'}}]) as stream: for insert_change in stream: print(insert_change) except pymongo.errors.PyMongoError: # The ChangeStream encountered an unrecoverable error or the # resume attempt failed to recreate the cursor. logging.error('...') For a precise description of the resume process see the `change streams specification`_. :Parameters: - `pipeline` (optional): A list of aggregation pipeline stages to append to an initial ``$changeStream`` stage. Not all pipeline stages are valid after a ``$changeStream`` stage, see the MongoDB documentation on change streams for the supported stages. - `full_document` (optional): The fullDocument to pass as an option to the ``$changeStream`` stage. Allowed values: 'updateLookup'. When set to 'updateLookup', the change notification for partial updates will include both a delta describing the changes to the document, as well as a copy of the entire document that was changed from some time after the change occurred. - `resume_after` (optional): A resume token. If provided, the change stream will start returning changes that occur directly after the operation specified in the resume token. A resume token is the _id value of a change document. - `max_await_time_ms` (optional): The maximum time in milliseconds for the server to wait for changes before responding to a getMore operation. - `batch_size` (optional): The maximum number of documents to return per batch. - `collation` (optional): The :class:`~pymongo.collation.Collation` to use for the aggregation. - `start_at_operation_time` (optional): If provided, the resulting change stream will only return changes that occurred at or after the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `start_after` (optional): The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. :Returns: A :class:`~pymongo.change_stream.ClusterChangeStream` cursor. .. versionchanged:: 3.9 Added the ``start_after`` parameter. .. versionadded:: 3.7 .. mongodoc:: changeStreams .. _change streams specification: https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return ClusterChangeStream( self.admin, pipeline, full_document, resume_after, max_await_time_ms, batch_size, collation, start_at_operation_time, session, start_after) @property def event_listeners(self): """The event listeners registered for this client. See :mod:`~pymongo.monitoring` for details. """ return self._event_listeners.event_listeners @property def address(self): """(host, port) of the current standalone, primary, or mongos, or None. Accessing :attr:`address` raises :exc:`~.errors.InvalidOperation` if the client is load-balancing among mongoses, since there is no single address. Use :attr:`nodes` instead. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. .. versionadded:: 3.0 """ topology_type = self._topology._description.topology_type if topology_type == TOPOLOGY_TYPE.Sharded: raise InvalidOperation( 'Cannot use "address" property when load balancing among' ' mongoses, use "nodes" instead.') if topology_type not in (TOPOLOGY_TYPE.ReplicaSetWithPrimary, TOPOLOGY_TYPE.Single): return None return self._server_property('address') @property def primary(self): """The (host, port) of the current primary of the replica set. Returns ``None`` if this client is not connected to a replica set, there is no primary, or this client was created without the `replicaSet` option. .. versionadded:: 3.0 MongoClient gained this property in version 3.0 when MongoReplicaSetClient's functionality was merged in. """ return self._topology.get_primary() @property def secondaries(self): """The secondary members known to this client. A sequence of (host, port) pairs. Empty if this client is not connected to a replica set, there are no visible secondaries, or this client was created without the `replicaSet` option. .. versionadded:: 3.0 MongoClient gained this property in version 3.0 when MongoReplicaSetClient's functionality was merged in. """ return self._topology.get_secondaries() @property def arbiters(self): """Arbiters in the replica set. A sequence of (host, port) pairs. Empty if this client is not connected to a replica set, there are no arbiters, or this client was created without the `replicaSet` option. """ return self._topology.get_arbiters() @property def is_primary(self): """If this client is connected to a server that can accept writes. True if the current server is a standalone, mongos, or the primary of a replica set. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. """ return self._server_property('is_writable') @property def is_mongos(self): """If this client is connected to mongos. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available.. """ return self._server_property('server_type') == SERVER_TYPE.Mongos @property def max_pool_size(self): """The maximum allowable number of concurrent connections to each connected server. Requests to a server will block if there are `maxPoolSize` outstanding connections to the requested server. Defaults to 100. Cannot be 0. When a server's pool has reached `max_pool_size`, operations for that server block waiting for a socket to be returned to the pool. If ``waitQueueTimeoutMS`` is set, a blocked operation will raise :exc:`~pymongo.errors.ConnectionFailure` after a timeout. By default ``waitQueueTimeoutMS`` is not set. """ return self.__options.pool_options.max_pool_size @property def min_pool_size(self): """The minimum required number of concurrent connections that the pool will maintain to each connected server. Default is 0. """ return self.__options.pool_options.min_pool_size @property def max_idle_time_ms(self): """The maximum number of milliseconds that a connection can remain idle in the pool before being removed and replaced. Defaults to `None` (no limit). """ seconds = self.__options.pool_options.max_idle_time_seconds if seconds is None: return None return 1000 * seconds @property def nodes(self): """Set of all currently connected servers. .. warning:: When connected to a replica set the value of :attr:`nodes` can change over time as :class:`MongoClient`'s view of the replica set changes. :attr:`nodes` can also be an empty set when :class:`MongoClient` is first instantiated and hasn't yet connected to any servers, or a network partition causes it to lose connection to all servers. """ description = self._topology.description return frozenset(s.address for s in description.known_servers) @property def max_bson_size(self): """The largest BSON object the connected server accepts in bytes. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. """ return self._server_property('max_bson_size') @property def max_message_size(self): """The largest message the connected server accepts in bytes. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. """ return self._server_property('max_message_size') @property def max_write_batch_size(self): """The maxWriteBatchSize reported by the server. If the client is not connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available. Returns a default value when connected to server versions prior to MongoDB 2.6. """ return self._server_property('max_write_batch_size') @property def local_threshold_ms(self): """The local threshold for this instance.""" return self.__options.local_threshold_ms @property def server_selection_timeout(self): """The server selection timeout for this instance in seconds.""" return self.__options.server_selection_timeout @property def retry_writes(self): """If this instance should retry supported write operations.""" return self.__options.retry_writes @property def retry_reads(self): """If this instance should retry supported write operations.""" return self.__options.retry_reads def _is_writable(self): """Attempt to connect to a writable server, or return False. """ topology = self._get_topology() # Starts monitors if necessary. try: svr = topology.select_server(writable_server_selector) # When directly connected to a secondary, arbiter, etc., # select_server returns it, whatever the selector. Check # again if the server is writable. return svr.description.is_writable except ConnectionFailure: return False def _end_sessions(self, session_ids): """Send endSessions command(s) with the given session ids.""" try: # Use SocketInfo.command directly to avoid implicitly creating # another session. with self._socket_for_reads( ReadPreference.PRIMARY_PREFERRED, None) as (sock_info, slave_ok): if not sock_info.supports_sessions: return for i in range(0, len(session_ids), common._MAX_END_SESSIONS): spec = SON([('endSessions', session_ids[i:i + common._MAX_END_SESSIONS])]) sock_info.command( 'admin', spec, slave_ok=slave_ok, client=self) except PyMongoError: # Drivers MUST ignore any errors returned by the endSessions # command. pass def close(self): """Cleanup client resources and disconnect from MongoDB. On MongoDB >= 3.6, end all server sessions created by this client by sending one or more endSessions commands. Close all sockets in the connection pools and stop the monitor threads. If this instance is used again it will be automatically re-opened and the threads restarted unless auto encryption is enabled. A client enabled with auto encryption cannot be used again after being closed; any attempt will raise :exc:`~.errors.InvalidOperation`. .. versionchanged:: 3.6 End all server sessions created by this client. """ session_ids = self._topology.pop_all_sessions() if session_ids: self._end_sessions(session_ids) # Stop the periodic task thread and then send pending killCursor # requests before closing the topology. self._kill_cursors_executor.close() self._process_kill_cursors() self._topology.close() if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() def set_cursor_manager(self, manager_class): """DEPRECATED - Set this client's cursor manager. Raises :class:`TypeError` if `manager_class` is not a subclass of :class:`~pymongo.cursor_manager.CursorManager`. A cursor manager handles closing cursors. Different managers can implement different policies in terms of when to actually kill a cursor that has been closed. :Parameters: - `manager_class`: cursor manager to use .. versionchanged:: 3.3 Deprecated, for real this time. .. versionchanged:: 3.0 Undeprecated. """ warnings.warn( "set_cursor_manager is Deprecated", DeprecationWarning, stacklevel=2) manager = manager_class(self) if not isinstance(manager, CursorManager): raise TypeError("manager_class must be a subclass of " "CursorManager") self.__cursor_manager = manager def _get_topology(self): """Get the internal :class:`~pymongo.topology.Topology` object. If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ self._topology.open() with self.__lock: self._kill_cursors_executor.open() return self._topology @contextlib.contextmanager def _get_socket(self, server, session, exhaust=False): with _MongoClientErrorHandler(self, server, session) as err_handler: with server.get_socket( self.__all_credentials, checkout=exhaust) as sock_info: err_handler.contribute_socket(sock_info) if (self._encrypter and not self._encrypter._bypass_auto_encryption and sock_info.max_wire_version < 8): raise ConfigurationError( 'Auto-encryption requires a minimum MongoDB version ' 'of 4.2') yield sock_info def _select_server(self, server_selector, session, address=None): """Select a server to run an operation on this client. :Parameters: - `server_selector`: The server selector to use if the session is not pinned and no address is given. - `session`: The ClientSession for the next operation, or None. May be pinned to a mongos server address. - `address` (optional): Address when sending a message to a specific server, used for getMore. """ try: topology = self._get_topology() address = address or (session and session._pinned_address) if address: # We're running a getMore or this session is pinned to a mongos. server = topology.select_server_by_address(address) if not server: raise AutoReconnect('server %s:%d no longer available' % address) else: server = topology.select_server(server_selector) # Pin this session to the selected server if it's performing a # sharded transaction. if server.description.mongos and (session and session.in_transaction): session._pin_mongos(server) return server except PyMongoError as exc: # Server selection errors in a transaction are transient. if session and session.in_transaction: exc._add_error_label("TransientTransactionError") session._unpin_mongos() raise def _socket_for_writes(self, session): server = self._select_server(writable_server_selector, session) return self._get_socket(server, session) @contextlib.contextmanager def _slaveok_for_server(self, read_preference, server, session, exhaust=False): assert read_preference is not None, "read_preference must not be None" # Get a socket for a server matching the read preference, and yield # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to # mongods with topology type Single. If the server type is Mongos, # follow the rules for passing read preference to mongos, even for # topology type Single." # Thread safe: if the type is single it cannot change. topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single with self._get_socket(server, session, exhaust=exhaust) as sock_info: slave_ok = (single and not sock_info.is_mongos) or ( read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok @contextlib.contextmanager def _socket_for_reads(self, read_preference, session): assert read_preference is not None, "read_preference must not be None" # Get a socket for a server matching the read preference, and yield # sock_info, slave_ok. Server Selection Spec: "slaveOK must be sent to # mongods with topology type Single. If the server type is Mongos, # follow the rules for passing read preference to mongos, even for # topology type Single." # Thread safe: if the type is single it cannot change. topology = self._get_topology() single = topology.description.topology_type == TOPOLOGY_TYPE.Single server = self._select_server(read_preference, session) with self._get_socket(server, session) as sock_info: slave_ok = (single and not sock_info.is_mongos) or ( read_preference != ReadPreference.PRIMARY) yield sock_info, slave_ok def _run_operation_with_response(self, operation, unpack_res, exhaust=False, address=None): """Run a _Query/_GetMore operation and return a Response. :Parameters: - `operation`: a _Query or _GetMore object. - `unpack_res`: A callable that decodes the wire protocol response. - `exhaust` (optional): If True, the socket used stays checked out. It is returned along with its Pool in the Response. - `address` (optional): Optional address when sending a message to a specific server, used for getMore. """ if operation.exhaust_mgr: server = self._select_server( operation.read_preference, operation.session, address=address) with _MongoClientErrorHandler( self, server, operation.session) as err_handler: err_handler.contribute_socket(operation.exhaust_mgr.sock) return server.run_operation_with_response( operation.exhaust_mgr.sock, operation, True, self._event_listeners, exhaust, unpack_res) def _cmd(session, server, sock_info, slave_ok): return server.run_operation_with_response( sock_info, operation, slave_ok, self._event_listeners, exhaust, unpack_res) return self._retryable_read( _cmd, operation.read_preference, operation.session, address=address, retryable=isinstance(operation, message._Query), exhaust=exhaust) def _retry_with_session(self, retryable, func, session, bulk): """Execute an operation with at most one consecutive retries Returns func()'s return value on success. On error retries the same command once. Re-raises any exception thrown by func(). """ retryable = (retryable and self.retry_writes and session and not session.in_transaction) return self._retry_internal(retryable, func, session, bulk) def _retry_internal(self, retryable, func, session, bulk): """Internal retryable write helper.""" max_wire_version = 0 last_error = None retrying = False def is_retrying(): return bulk.retrying if bulk else retrying # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. if retryable and session and not session.in_transaction: session._start_retryable_write() if bulk: bulk.started_retryable_write = True while True: try: server = self._select_server(writable_server_selector, session) supports_session = ( session is not None and server.description.retryable_writes_supported) with self._get_socket(server, session) as sock_info: max_wire_version = sock_info.max_wire_version if retryable and not supports_session: if is_retrying(): # A retry is not possible because this server does # not support sessions raise the last error. raise last_error retryable = False return func(session, sock_info, retryable) except ServerSelectionTimeoutError: if is_retrying(): # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry # attempt. Raise the original exception instead. raise last_error # A ServerSelectionTimeoutError error indicates that there may # be a persistent outage. Attempting to retry in this case will # most likely be a waste of time. raise except PyMongoError as exc: if not retryable: raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version) retryable_error = exc.has_error_label("RetryableWriteError") if retryable_error: session._unpin_mongos() if is_retrying() or not retryable_error: raise if bulk: bulk.retrying = True else: retrying = True last_error = exc def _retryable_read(self, func, read_pref, session, address=None, retryable=True, exhaust=False): """Execute an operation with at most one consecutive retries Returns func()'s return value on success. On error retries the same command once. Re-raises any exception thrown by func(). """ retryable = (retryable and self.retry_reads and not (session and session.in_transaction)) last_error = None retrying = False while True: try: server = self._select_server( read_pref, session, address=address) if not server.description.retryable_reads_supported: retryable = False with self._slaveok_for_server(read_pref, server, session, exhaust=exhaust) as (sock_info, slave_ok): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. raise last_error return func(session, server, sock_info, slave_ok) except ServerSelectionTimeoutError: if retrying: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry # attempt. Raise the original exception instead. raise last_error # A ServerSelectionTimeoutError error indicates that there may # be a persistent outage. Attempting to retry in this case will # most likely be a waste of time. raise except ConnectionFailure as exc: if not retryable or retrying: raise retrying = True last_error = exc except OperationFailure as exc: if not retryable or retrying: raise if exc.code not in helpers._RETRYABLE_ERROR_CODES: raise retrying = True last_error = exc def _retryable_write(self, retryable, func, session): """Internal retryable write helper.""" with self._tmp_session(session) as s: return self._retry_with_session(retryable, func, s, None) def _handle_getlasterror(self, address, error_msg): """Clear our pool for a server, mark it Unknown, and check it soon.""" self._topology.handle_getlasterror(address, error_msg) def __eq__(self, other): if isinstance(other, self.__class__): return self.address == other.address return NotImplemented def __ne__(self, other): return not self == other def _repr_helper(self): def option_repr(option, value): """Fix options whose __repr__ isn't usable in a constructor.""" if option == 'document_class': if value is dict: return 'document_class=dict' else: return 'document_class=%s.%s' % (value.__module__, value.__name__) if option in common.TIMEOUT_OPTIONS and value is not None: return "%s=%s" % (option, int(value * 1000)) return '%s=%r' % (option, value) # Host first... options = ['host=%r' % [ '%s:%d' % (host, port) if port is not None else host for host, port in self._topology_settings.seeds]] # ... then everything in self._constructor_args... options.extend( option_repr(key, self.__options._options[key]) for key in self._constructor_args) # ... then everything else. options.extend( option_repr(key, self.__options._options[key]) for key in self.__options._options if key not in set(self._constructor_args) and key != 'username' and key != 'password') return ', '.join(options) def __repr__(self): return ("MongoClient(%s)" % (self._repr_helper(),)) def __getattr__(self, name): """Get a database by name. Raises :class:`~pymongo.errors.InvalidName` if an invalid database name is used. :Parameters: - `name`: the name of the database to get """ if name.startswith('_'): raise AttributeError( "MongoClient has no attribute %r. To access the %s" " database, use client[%r]." % (name, name, name)) return self.__getitem__(name) def __getitem__(self, name): """Get a database by name. Raises :class:`~pymongo.errors.InvalidName` if an invalid database name is used. :Parameters: - `name`: the name of the database to get """ return database.Database(self, name) def close_cursor(self, cursor_id, address=None): """DEPRECATED - Send a kill cursors message soon with the given id. Raises :class:`TypeError` if `cursor_id` is not an instance of ``(int, long)``. What closing the cursor actually means depends on this client's cursor manager. This method may be called from a :class:`~pymongo.cursor.Cursor` destructor during garbage collection, so it isn't safe to take a lock or do network I/O. Instead, we schedule the cursor to be closed soon on a background thread. :Parameters: - `cursor_id`: id of cursor to close - `address` (optional): (host, port) pair of the cursor's server. If it is not provided, the client attempts to close the cursor on the primary or standalone, or a mongos server. .. versionchanged:: 3.7 Deprecated. .. versionchanged:: 3.0 Added ``address`` parameter. """ warnings.warn( "close_cursor is deprecated.", DeprecationWarning, stacklevel=2) if not isinstance(cursor_id, integer_types): raise TypeError("cursor_id must be an instance of (int, long)") self._close_cursor(cursor_id, address) def _close_cursor(self, cursor_id, address): """Send a kill cursors message with the given id. What closing the cursor actually means depends on this client's cursor manager. If there is none, the cursor is closed asynchronously on a background thread. """ if self.__cursor_manager is not None: self.__cursor_manager.close(cursor_id, address) else: self.__kill_cursors_queue.append((address, [cursor_id])) def _close_cursor_now(self, cursor_id, address=None, session=None): """Send a kill cursors message with the given id. What closing the cursor actually means depends on this client's cursor manager. If there is none, the cursor is closed synchronously on the current thread. """ if not isinstance(cursor_id, integer_types): raise TypeError("cursor_id must be an instance of (int, long)") if self.__cursor_manager is not None: self.__cursor_manager.close(cursor_id, address) else: try: self._kill_cursors( [cursor_id], address, self._get_topology(), session) except PyMongoError: # Make another attempt to kill the cursor later. self.__kill_cursors_queue.append((address, [cursor_id])) def kill_cursors(self, cursor_ids, address=None): """DEPRECATED - Send a kill cursors message soon with the given ids. Raises :class:`TypeError` if `cursor_ids` is not an instance of ``list``. :Parameters: - `cursor_ids`: list of cursor ids to kill - `address` (optional): (host, port) pair of the cursor's server. If it is not provided, the client attempts to close the cursor on the primary or standalone, or a mongos server. .. versionchanged:: 3.3 Deprecated. .. versionchanged:: 3.0 Now accepts an `address` argument. Schedules the cursors to be closed on a background thread instead of sending the message immediately. """ warnings.warn( "kill_cursors is deprecated.", DeprecationWarning, stacklevel=2) if not isinstance(cursor_ids, list): raise TypeError("cursor_ids must be a list") # "Atomic", needs no lock. self.__kill_cursors_queue.append((address, cursor_ids)) def _kill_cursors(self, cursor_ids, address, topology, session): """Send a kill cursors message with the given ids.""" listeners = self._event_listeners publish = listeners.enabled_for_commands if address: # address could be a tuple or _CursorAddress, but # select_server_by_address needs (host, port). server = topology.select_server_by_address(tuple(address)) else: # Application called close_cursor() with no address. server = topology.select_server(writable_server_selector) try: namespace = address.namespace db, coll = namespace.split('.', 1) except AttributeError: namespace = None db = coll = "OP_KILL_CURSORS" spec = SON([('killCursors', coll), ('cursors', cursor_ids)]) with server.get_socket(self.__all_credentials) as sock_info: if sock_info.max_wire_version >= 4 and namespace is not None: sock_info.command(db, spec, session=session, client=self) else: if publish: start = datetime.datetime.now() request_id, msg = message.kill_cursors(cursor_ids) if publish: duration = datetime.datetime.now() - start # Here and below, address could be a tuple or # _CursorAddress. We always want to publish a # tuple to match the rest of the monitoring # API. listeners.publish_command_start( spec, db, request_id, tuple(address)) start = datetime.datetime.now() try: sock_info.send_message(msg, 0) except Exception as exc: if publish: dur = ((datetime.datetime.now() - start) + duration) listeners.publish_command_failure( dur, message._convert_exception(exc), 'killCursors', request_id, tuple(address)) raise if publish: duration = ((datetime.datetime.now() - start) + duration) # OP_KILL_CURSORS returns no reply, fake one. reply = {'cursorsUnknown': cursor_ids, 'ok': 1} listeners.publish_command_success( duration, reply, 'killCursors', request_id, tuple(address)) def _process_kill_cursors(self): """Process any pending kill cursors requests.""" address_to_cursor_ids = defaultdict(list) # Other threads or the GC may append to the queue concurrently. while True: try: address, cursor_ids = self.__kill_cursors_queue.pop() except IndexError: break address_to_cursor_ids[address].extend(cursor_ids) # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: topology = self._get_topology() for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors( cursor_ids, address, topology, session=None) except Exception: helpers._handle_exception() # This method is run periodically by a background thread. def _process_periodic_tasks(self): """Process any pending kill cursors requests and maintain connection pool parameters.""" self._process_kill_cursors() try: self._topology.update_pool(self.__all_credentials) except Exception: helpers._handle_exception() def __start_session(self, implicit, **kwargs): # Driver Sessions Spec: "If startSession is called when multiple users # are authenticated drivers MUST raise an error with the error message # 'Cannot call startSession when multiple users are authenticated.'" authset = set(self.__all_credentials.values()) if len(authset) > 1: raise InvalidOperation("Cannot call start_session when" " multiple users are authenticated") # Raises ConfigurationError if sessions are not supported. server_session = self._get_server_session() opts = client_session.SessionOptions(**kwargs) return client_session.ClientSession( self, server_session, opts, authset, implicit) def start_session(self, causal_consistency=True, default_transaction_options=None): """Start a logical session. This method takes the same parameters as :class:`~pymongo.client_session.SessionOptions`. See the :mod:`~pymongo.client_session` module for details and examples. Requires MongoDB 3.6. It is an error to call :meth:`start_session` if this client has been authenticated to multiple databases using the deprecated method :meth:`~pymongo.database.Database.authenticate`. A :class:`~pymongo.client_session.ClientSession` may only be used with the MongoClient that started it. :class:`ClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread or process at a time. A single :class:`ClientSession` cannot be used to run multiple operations concurrently. :Returns: An instance of :class:`~pymongo.client_session.ClientSession`. .. versionadded:: 3.6 """ return self.__start_session( False, causal_consistency=causal_consistency, default_transaction_options=default_transaction_options) def _get_server_session(self): """Internal: start or resume a _ServerSession.""" return self._topology.get_server_session() def _return_server_session(self, server_session, lock): """Internal: return a _ServerSession to the pool.""" return self._topology.return_server_session(server_session, lock) def _ensure_session(self, session=None): """If provided session is None, lend a temporary session.""" if session: return session try: # Don't make implicit sessions causally consistent. Applications # should always opt-in. return self.__start_session(True, causal_consistency=False) except (ConfigurationError, InvalidOperation): # Sessions not supported, or multiple users authenticated. return None @contextlib.contextmanager def _tmp_session(self, session, close=True): """If provided session is None, lend a temporary session.""" if session: # Don't call end_session. yield session return s = self._ensure_session(session) if s and close: with s: # Call end_session when we exit this scope. yield s elif s: try: # Only call end_session on error. yield s except Exception: s.end_session() raise else: yield None def _send_cluster_time(self, command, session): topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: if topology_time['clusterTime'] > session_time['clusterTime']: cluster_time = topology_time else: cluster_time = session_time else: cluster_time = topology_time or session_time if cluster_time: command['$clusterTime'] = cluster_time def _process_response(self, reply, session): self._topology.receive_cluster_time(reply.get('$clusterTime')) if session is not None: session._process_response(reply) def server_info(self, session=None): """Get information about the MongoDB server we're connected to. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. """ return self.admin.command("buildinfo", read_preference=ReadPreference.PRIMARY, session=session) def list_databases(self, session=None, **kwargs): """Get a cursor over the databases of the connected server. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. - `**kwargs` (optional): Optional parameters of the `listDatabases command `_ can be passed as keyword arguments to this method. The supported options differ by server version. :Returns: An instance of :class:`~pymongo.command_cursor.CommandCursor`. .. versionadded:: 3.6 """ cmd = SON([("listDatabases", 1)]) cmd.update(kwargs) admin = self._database_default_options("admin") res = admin._retryable_read_command(cmd, session=session) # listDatabases doesn't return a cursor (yet). Fake one. cursor = { "id": 0, "firstBatch": res["databases"], "ns": "admin.$cmd", } return CommandCursor(admin["$cmd"], cursor, None) def list_database_names(self, session=None): """Get a list of the names of all databases on the connected server. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionadded:: 3.6 """ return [doc["name"] for doc in self.list_databases(session, nameOnly=True)] def database_names(self, session=None): """**DEPRECATED**: Get a list of the names of all databases on the connected server. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.7 Deprecated. Use :meth:`list_database_names` instead. .. versionchanged:: 3.6 Added ``session`` parameter. """ warnings.warn("database_names is deprecated. Use list_database_names " "instead.", DeprecationWarning, stacklevel=2) return self.list_database_names(session) def drop_database(self, name_or_database, session=None): """Drop a database. Raises :class:`TypeError` if `name_or_database` is not an instance of :class:`basestring` (:class:`str` in python 3) or :class:`~pymongo.database.Database`. :Parameters: - `name_or_database`: the name of a database to drop, or a :class:`~pymongo.database.Database` instance representing the database to drop - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. .. note:: The :attr:`~pymongo.mongo_client.MongoClient.write_concern` of this client is automatically applied to this operation when using MongoDB >= 3.4. .. versionchanged:: 3.4 Apply this client's write concern automatically to this operation when connected to MongoDB >= 3.4. """ name = name_or_database if isinstance(name, database.Database): name = name.name if not isinstance(name, string_type): raise TypeError("name_or_database must be an instance " "of %s or a Database" % (string_type.__name__,)) self._purge_index(name) with self._socket_for_writes(session) as sock_info: self[name]._command( sock_info, "dropDatabase", read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), parse_write_concern_error=True, session=session) def get_default_database(self, default=None, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get the database named in the MongoDB connection URI. >>> uri = 'mongodb://host/my_database' >>> client = MongoClient(uri) >>> db = client.get_default_database() >>> assert db.name == 'my_database' >>> db = client.get_database() >>> assert db.name == 'my_database' Useful in scripts where you want to choose which database to use based only on the URI in a configuration file. :Parameters: - `default` (optional): the database name to use if no database name was provided in the URI. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`MongoClient` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` for options. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`MongoClient` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`MongoClient` is used. .. versionchanged:: 3.8 Undeprecated. Added the ``default``, ``codec_options``, ``read_preference``, ``write_concern`` and ``read_concern`` parameters. .. versionchanged:: 3.5 Deprecated, use :meth:`get_database` instead. """ if self.__default_database_name is None and default is None: raise ConfigurationError( 'No default database name defined or provided.') return database.Database( self, self.__default_database_name or default, codec_options, read_preference, write_concern, read_concern) def get_database(self, name=None, codec_options=None, read_preference=None, write_concern=None, read_concern=None): """Get a :class:`~pymongo.database.Database` with the given name and options. Useful for creating a :class:`~pymongo.database.Database` with different codec options, read preference, and/or write concern from this :class:`MongoClient`. >>> client.read_preference Primary() >>> db1 = client.test >>> db1.read_preference Primary() >>> from pymongo import ReadPreference >>> db2 = client.get_database( ... 'test', read_preference=ReadPreference.SECONDARY) >>> db2.read_preference Secondary(tag_sets=None) :Parameters: - `name` (optional): The name of the database - a string. If ``None`` (the default) the database named in the MongoDB connection URI is returned. - `codec_options` (optional): An instance of :class:`~bson.codec_options.CodecOptions`. If ``None`` (the default) the :attr:`codec_options` of this :class:`MongoClient` is used. - `read_preference` (optional): The read preference to use. If ``None`` (the default) the :attr:`read_preference` of this :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` for options. - `write_concern` (optional): An instance of :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the default) the :attr:`write_concern` of this :class:`MongoClient` is used. - `read_concern` (optional): An instance of :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the default) the :attr:`read_concern` of this :class:`MongoClient` is used. .. versionchanged:: 3.5 The `name` parameter is now optional, defaulting to the database named in the MongoDB connection URI. """ if name is None: if self.__default_database_name is None: raise ConfigurationError('No default database defined') name = self.__default_database_name return database.Database( self, name, codec_options, read_preference, write_concern, read_concern) def _database_default_options(self, name): """Get a Database instance with the default settings.""" return self.get_database( name, codec_options=DEFAULT_CODEC_OPTIONS, read_preference=ReadPreference.PRIMARY, write_concern=DEFAULT_WRITE_CONCERN) @property def is_locked(self): """**DEPRECATED**: Is this server locked? While locked, all write operations are blocked, although read operations may still be allowed. Use :meth:`unlock` to unlock. Deprecated. Users of MongoDB version 3.2 or newer can run the `currentOp command`_ directly with :meth:`~pymongo.database.Database.command`:: is_locked = client.admin.command('currentOp').get('fsyncLock') Users of MongoDB version 2.6 and 3.0 can query the "inprog" virtual collection:: is_locked = client.admin["$cmd.sys.inprog"].find_one().get('fsyncLock') .. versionchanged:: 3.11 Deprecated. .. _currentOp command: https://docs.mongodb.com/manual/reference/command/currentOp/ """ warnings.warn("is_locked is deprecated. See the documentation for " "more information.", DeprecationWarning, stacklevel=2) ops = self._database_default_options('admin')._current_op() return bool(ops.get('fsyncLock', 0)) def fsync(self, **kwargs): """**DEPRECATED**: Flush all pending writes to datafiles. Optional parameters can be passed as keyword arguments: - `lock`: If True lock the server to disallow writes. - `async`: If True don't block while synchronizing. - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. .. note:: Starting with Python 3.7 `async` is a reserved keyword. The async option to the fsync command can be passed using a dictionary instead:: options = {'async': True} client.fsync(**options) Deprecated. Run the `fsync command`_ directly with :meth:`~pymongo.database.Database.command` instead. For example:: client.admin.command('fsync', lock=True) .. versionchanged:: 3.11 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. .. warning:: `async` and `lock` can not be used together. .. warning:: MongoDB does not support the `async` option on Windows and will raise an exception on that platform. .. _fsync command: https://docs.mongodb.com/manual/reference/command/fsync/ """ warnings.warn("fsync is deprecated. Use " "client.admin.command('fsync') instead.", DeprecationWarning, stacklevel=2) self.admin.command("fsync", read_preference=ReadPreference.PRIMARY, **kwargs) def unlock(self, session=None): """**DEPRECATED**: Unlock a previously locked server. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. Deprecated. Users of MongoDB version 3.2 or newer can run the `fsyncUnlock command`_ directly with :meth:`~pymongo.database.Database.command`:: client.admin.command('fsyncUnlock') Users of MongoDB version 2.6 and 3.0 can query the "unlock" virtual collection:: client.admin["$cmd.sys.unlock"].find_one() .. versionchanged:: 3.11 Deprecated. .. versionchanged:: 3.6 Added ``session`` parameter. .. _fsyncUnlock command: https://docs.mongodb.com/manual/reference/command/fsyncUnlock/ """ warnings.warn("unlock is deprecated. Use " "client.admin.command('fsyncUnlock') instead. For " "MongoDB 2.6 and 3.0, see the documentation for " "more information.", DeprecationWarning, stacklevel=2) cmd = SON([("fsyncUnlock", 1)]) with self._socket_for_writes(session) as sock_info: if sock_info.max_wire_version >= 4: try: with self._tmp_session(session) as s: sock_info.command( "admin", cmd, session=s, client=self) except OperationFailure as exc: # Ignore "DB not locked" to replicate old behavior if exc.code != 125: raise else: message._first_batch(sock_info, "admin", "$cmd.sys.unlock", {}, -1, True, self.codec_options, ReadPreference.PRIMARY, cmd, self._event_listeners) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def __iter__(self): return self def __next__(self): raise TypeError("'MongoClient' object is not iterable") next = __next__ def _retryable_error_doc(exc): """Return the server response from PyMongo exception or None.""" if isinstance(exc, BulkWriteError): # Check the last writeConcernError to determine if this # BulkWriteError is retryable. wces = exc.details['writeConcernErrors'] wce = wces[-1] if wces else None return wce if isinstance(exc, (NotMasterError, OperationFailure)): return exc.details return None def _add_retryable_write_error(exc, max_wire_version): doc = _retryable_error_doc(exc) if doc: code = doc.get('code', 0) # retryWrites on MMAPv1 should raise an actionable error. if (code == 20 and str(exc).startswith("Transaction numbers")): errmsg = ( "This MongoDB deployment does not support " "retryable writes. Please add retryWrites=false " "to your connection string.") raise OperationFailure(errmsg, code, exc.details) if max_wire_version >= 9: # In MongoDB 4.4+, the server reports the error labels. for label in doc.get('errorLabels', []): exc._add_error_label(label) else: if code in helpers._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") # Connection errors are always retryable except NotMasterError which is # handled above. if (isinstance(exc, ConnectionFailure) and not isinstance(exc, NotMasterError)): exc._add_error_label("RetryableWriteError") class _MongoClientErrorHandler(object): """Handle errors raised when executing an operation.""" __slots__ = ('client', 'server_address', 'session', 'max_wire_version', 'sock_generation', 'completed_handshake') def __init__(self, client, server, session): self.client = client self.server_address = server.description.address self.session = session self.max_wire_version = common.MIN_WIRE_VERSION # XXX: When get_socket fails, this generation could be out of date: # "Note that when a network error occurs before the handshake # completes then the error's generation number is the generation # of the pool at the time the connection attempt was started." self.sock_generation = server.pool.generation self.completed_handshake = False def contribute_socket(self, sock_info): """Provide socket information to the error handler.""" self.max_wire_version = sock_info.max_wire_version self.sock_generation = sock_info.generation self.completed_handshake = True def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is None: return if self.session: if issubclass(exc_type, ConnectionFailure): if self.session.in_transaction: exc_val._add_error_label("TransientTransactionError") self.session._server_session.mark_dirty() if issubclass(exc_type, PyMongoError): if (exc_val.has_error_label("TransientTransactionError") or exc_val.has_error_label("RetryableWriteError")): self.session._unpin_mongos() err_ctx = _ErrorContext( exc_val, self.max_wire_version, self.sock_generation, self.completed_handshake) self.client._topology.handle_error(self.server_address, err_ctx) pymongo-3.11.0/pymongo/mongo_replica_set_client.py000066400000000000000000000036431374256237000223740ustar00rootroot00000000000000# Copyright 2011-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Deprecated. See :doc:`/examples/high_availability`.""" import warnings from pymongo import mongo_client class MongoReplicaSetClient(mongo_client.MongoClient): """Deprecated alias for :class:`~pymongo.mongo_client.MongoClient`. :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient` will be removed in a future version of PyMongo. .. versionchanged:: 3.0 :class:`~pymongo.mongo_client.MongoClient` is now the one and only client class for a standalone server, mongos, or replica set. It includes the functionality that had been split into :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`: it can connect to a replica set, discover all its members, and monitor the set for stepdowns, elections, and reconfigs. The ``refresh`` method is removed from :class:`~pymongo.mongo_replica_set_client.MongoReplicaSetClient`, as are the ``seeds`` and ``hosts`` properties. """ def __init__(self, *args, **kwargs): warnings.warn('MongoReplicaSetClient is deprecated, use MongoClient' ' to connect to a replica set', DeprecationWarning, stacklevel=2) super(MongoReplicaSetClient, self).__init__(*args, **kwargs) def __repr__(self): return "MongoReplicaSetClient(%s)" % (self._repr_helper(),) pymongo-3.11.0/pymongo/monitor.py000066400000000000000000000350761374256237000170410ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Class to monitor a MongoDB server on a background thread.""" import atexit import threading import weakref from pymongo import common, periodic_executor from pymongo.errors import (NotMasterError, OperationFailure, _OperationCancelled) from pymongo.ismaster import IsMaster from pymongo.monotonic import time as _time from pymongo.periodic_executor import _shutdown_executors from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription from pymongo.srv_resolver import _SrvResolver class MonitorBase(object): def __init__(self, topology, name, interval, min_interval): """Base class to do periodic work on a background thread. The the background thread is signaled to stop when the Topology or this instance is freed. """ # We strongly reference the executor and it weakly references us via # this closure. When the monitor is freed, stop the executor soon. def target(): monitor = self_ref() if monitor is None: return False # Stop the executor. monitor._run() return True executor = periodic_executor.PeriodicExecutor( interval=interval, min_interval=min_interval, target=target, name=name) self._executor = executor def _on_topology_gc(dummy=None): # This prevents GC from waiting 10 seconds for isMaster to complete # See test_cleanup_executors_on_client_del. monitor = self_ref() if monitor: monitor.gc_safe_close() # Avoid cycles. When self or topology is freed, stop executor soon. self_ref = weakref.ref(self, executor.close) self._topology = weakref.proxy(topology, _on_topology_gc) _register(self) def open(self): """Start monitoring, or restart after a fork. Multiple calls have no effect. """ self._executor.open() def gc_safe_close(self): """GC safe close.""" self._executor.close() def close(self): """Close and stop monitoring. open() restarts the monitor after closing. """ self.gc_safe_close() def join(self, timeout=None): """Wait for the monitor to stop.""" self._executor.join(timeout) def request_check(self): """If the monitor is sleeping, wake it soon.""" self._executor.wake() class Monitor(MonitorBase): def __init__( self, server_description, topology, pool, topology_settings): """Class to monitor a MongoDB server on a background thread. Pass an initial ServerDescription, a Topology, a Pool, and TopologySettings. The Topology is weakly referenced. The Pool must be exclusive to this Monitor. """ super(Monitor, self).__init__( topology, "pymongo_server_monitor_thread", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL) self._server_description = server_description self._pool = pool self._settings = topology_settings self._listeners = self._settings._pool_options.event_listeners pub = self._listeners is not None self._publish = pub and self._listeners.enabled_for_server_heartbeat self._cancel_context = None self._rtt_monitor = _RttMonitor( topology, topology_settings, topology._create_pool_for_monitor( server_description.address)) self.heartbeater = None def cancel_check(self): """Cancel any concurrent isMaster check. Note: this is called from a weakref.proxy callback and MUST NOT take any locks. """ context = self._cancel_context if context: # Note: we cannot close the socket because doing so may cause # concurrent reads/writes to hang until a timeout occurs # (depending on the platform). context.cancel() def _start_rtt_monitor(self): """Start an _RttMonitor that periodically runs ping.""" # If this monitor is closed directly before (or during) this open() # call, the _RttMonitor will not be closed. Checking if this monitor # was closed directly after resolves the race. self._rtt_monitor.open() if self._executor._stopped: self._rtt_monitor.close() def gc_safe_close(self): self._executor.close() self._rtt_monitor.gc_safe_close() self.cancel_check() def close(self): self.gc_safe_close() self._rtt_monitor.close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. self._reset_connection() def _reset_connection(self): # Clear our pooled connection. self._pool.reset() def _run(self): try: prev_sd = self._server_description try: self._server_description = self._check_server() except _OperationCancelled as exc: # Already closed the connection, wait for the next check. self._server_description = ServerDescription( self._server_description.address, error=exc) if prev_sd.is_server_type_known: # Immediately retry since we've already waited 500ms to # discover that we've been cancelled. self._executor.skip_sleep() return # Update the Topology and clear the server pool on error. self._topology.on_change(self._server_description, reset_pool=self._server_description.error) if (self._server_description.is_server_type_known and self._server_description.topology_version): self._start_rtt_monitor() # Immediately check for the next streaming response. self._executor.skip_sleep() if self._server_description.error and prev_sd.is_server_type_known: # Immediately retry on network errors. self._executor.skip_sleep() except ReferenceError: # Topology was garbage-collected. self.close() def _check_server(self): """Call isMaster or read the next streaming response. Returns a ServerDescription. """ start = _time() try: try: return self._check_once() except (OperationFailure, NotMasterError) as exc: # Update max cluster time even when isMaster fails. self._topology.receive_cluster_time( exc.details.get('$clusterTime')) raise except ReferenceError: raise except Exception as error: sd = self._server_description address = sd.address duration = _time() - start if self._publish: awaited = sd.is_server_type_known and sd.topology_version self._listeners.publish_server_heartbeat_failed( address, duration, error, awaited) self._reset_connection() if isinstance(error, _OperationCancelled): raise self._rtt_monitor.reset() # Server type defaults to Unknown. return ServerDescription(address, error=error) def _check_once(self): """A single attempt to call ismaster. Returns a ServerDescription, or raises an exception. """ address = self._server_description.address if self._publish: self._listeners.publish_server_heartbeat_started(address) if self._cancel_context and self._cancel_context.cancelled: self._reset_connection() with self._pool.get_socket({}) as sock_info: self._cancel_context = sock_info.cancel_context response, round_trip_time = self._check_with_socket(sock_info) if not response.awaitable: self._rtt_monitor.add_sample(round_trip_time) sd = ServerDescription(address, response, self._rtt_monitor.average()) if self._publish: self._listeners.publish_server_heartbeat_succeeded( address, round_trip_time, response, response.awaitable) return sd def _check_with_socket(self, conn): """Return (IsMaster, round_trip_time). Can raise ConnectionFailure or OperationFailure. """ cluster_time = self._topology.max_cluster_time() start = _time() if conn.more_to_come: # Read the next streaming isMaster (MongoDB 4.4+). response = IsMaster(conn._next_reply(), awaitable=True) elif (conn.performed_handshake and self._server_description.topology_version): # Initiate streaming isMaster (MongoDB 4.4+). response = conn._ismaster( cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, None) else: # New connection handshake or polling isMaster (MongoDB <4.4). response = conn._ismaster(cluster_time, None, None, None) return response, _time() - start class SrvMonitor(MonitorBase): def __init__(self, topology, topology_settings): """Class to poll SRV records on a background thread. Pass a Topology and a TopologySettings. The Topology is weakly referenced. """ super(SrvMonitor, self).__init__( topology, "pymongo_srv_polling_thread", common.MIN_SRV_RESCAN_INTERVAL, topology_settings.heartbeat_frequency) self._settings = topology_settings self._seedlist = self._settings._seeds self._fqdn = self._settings.fqdn def _run(self): seedlist = self._get_seedlist() if seedlist: self._seedlist = seedlist try: self._topology.on_srv_update(self._seedlist) except ReferenceError: # Topology was garbage-collected. self.close() def _get_seedlist(self): """Poll SRV records for a seedlist. Returns a list of ServerDescriptions. """ try: seedlist, ttl = _SrvResolver(self._fqdn).get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception except Exception: # As per the spec, upon encountering an error: # - An error must not be raised # - SRV records must be rescanned every heartbeatFrequencyMS # - Topology must be left unchanged self.request_check() return None else: self._executor.update_interval( max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) return seedlist class _RttMonitor(MonitorBase): def __init__(self, topology, topology_settings, pool): """Maintain round trip times for a server. The Topology is weakly referenced. """ super(_RttMonitor, self).__init__( topology, "pymongo_server_rtt_thread", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL) self._pool = pool self._moving_average = MovingAverage() self._lock = threading.Lock() def close(self): self.gc_safe_close() # Increment the generation and maybe close the socket. If the executor # thread has the socket checked out, it will be closed when checked in. self._pool.reset() def add_sample(self, sample): """Add a RTT sample.""" with self._lock: self._moving_average.add_sample(sample) def average(self): """Get the calculated average, or None if no samples yet.""" with self._lock: return self._moving_average.get() def reset(self): """Reset the average RTT.""" with self._lock: return self._moving_average.reset() def _run(self): try: # NOTE: This thread is only run when when using the streaming # heartbeat protocol (MongoDB 4.4+). # XXX: Skip check if the server is unknown? rtt = self._ping() self.add_sample(rtt) except ReferenceError: # Topology was garbage-collected. self.close() except Exception: self._pool.reset() def _ping(self): """Run an "isMaster" command and return the RTT.""" with self._pool.get_socket({}) as sock_info: if self._executor._stopped: raise Exception('_RttMonitor closed') start = _time() sock_info.ismaster() return _time() - start # Close monitors to cancel any in progress streaming checks before joining # executor threads. For an explanation of how this works see the comment # about _EXECUTORS in periodic_executor.py. _MONITORS = set() def _register(monitor): ref = weakref.ref(monitor, _unregister) _MONITORS.add(ref) def _unregister(monitor_ref): _MONITORS.remove(monitor_ref) def _shutdown_monitors(): if _MONITORS is None: return # Copy the set. Closing monitors removes them. monitors = list(_MONITORS) # Close all monitors. for ref in monitors: monitor = ref() if monitor: monitor.gc_safe_close() monitor = None def _shutdown_resources(): # _shutdown_monitors/_shutdown_executors may already be GC'd at shutdown. shutdown = _shutdown_monitors if shutdown: shutdown() shutdown = _shutdown_executors if shutdown: shutdown() atexit.register(_shutdown_resources) pymongo-3.11.0/pymongo/monitoring.py000066400000000000000000001507661374256237000175430ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Tools to monitor driver events. .. versionadded:: 3.1 .. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below are included in the PyMongo distribution under the :mod:`~pymongo.event_loggers` submodule. Use :func:`register` to register global listeners for specific events. Listeners must inherit from one of the abstract classes below and implement the correct functions for that class. For example, a simple command logger might be implemented like this:: import logging from pymongo import monitoring class CommandLogger(monitoring.CommandListener): def started(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} started on server " "{0.connection_id}".format(event)) def succeeded(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "succeeded in {0.duration_micros} " "microseconds".format(event)) def failed(self, event): logging.info("Command {0.command_name} with request id " "{0.request_id} on server {0.connection_id} " "failed in {0.duration_micros} " "microseconds".format(event)) monitoring.register(CommandLogger()) Server discovery and monitoring events are also available. For example:: class ServerLogger(monitoring.ServerListener): def opened(self, event): logging.info("Server {0.server_address} added to topology " "{0.topology_id}".format(event)) def description_changed(self, event): previous_server_type = event.previous_description.server_type new_server_type = event.new_description.server_type if new_server_type != previous_server_type: # server_type_name was added in PyMongo 3.4 logging.info( "Server {0.server_address} changed type from " "{0.previous_description.server_type_name} to " "{0.new_description.server_type_name}".format(event)) def closed(self, event): logging.warning("Server {0.server_address} removed from topology " "{0.topology_id}".format(event)) class HeartbeatLogger(monitoring.ServerHeartbeatListener): def started(self, event): logging.info("Heartbeat sent to server " "{0.connection_id}".format(event)) def succeeded(self, event): # The reply.document attribute was added in PyMongo 3.4. logging.info("Heartbeat to server {0.connection_id} " "succeeded with reply " "{0.reply.document}".format(event)) def failed(self, event): logging.warning("Heartbeat to server {0.connection_id} " "failed with error {0.reply}".format(event)) class TopologyLogger(monitoring.TopologyListener): def opened(self, event): logging.info("Topology with id {0.topology_id} " "opened".format(event)) def description_changed(self, event): logging.info("Topology description updated for " "topology id {0.topology_id}".format(event)) previous_topology_type = event.previous_description.topology_type new_topology_type = event.new_description.topology_type if new_topology_type != previous_topology_type: # topology_type_name was added in PyMongo 3.4 logging.info( "Topology {0.topology_id} changed type from " "{0.previous_description.topology_type_name} to " "{0.new_description.topology_type_name}".format(event)) # The has_writable_server and has_readable_server methods # were added in PyMongo 3.4. if not event.new_description.has_writable_server(): logging.warning("No writable servers available.") if not event.new_description.has_readable_server(): logging.warning("No readable servers available.") def closed(self, event): logging.info("Topology with id {0.topology_id} " "closed".format(event)) Connection monitoring and pooling events are also available. For example:: class ConnectionPoolLogger(ConnectionPoolListener): def pool_created(self, event): logging.info("[pool {0.address}] pool created".format(event)) def pool_cleared(self, event): logging.info("[pool {0.address}] pool cleared".format(event)) def pool_closed(self, event): logging.info("[pool {0.address}] pool closed".format(event)) def connection_created(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection created".format(event)) def connection_ready(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection setup succeeded".format(event)) def connection_closed(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection closed, reason: " "{0.reason}".format(event)) def connection_check_out_started(self, event): logging.info("[pool {0.address}] connection check out " "started".format(event)) def connection_check_out_failed(self, event): logging.info("[pool {0.address}] connection check out " "failed, reason: {0.reason}".format(event)) def connection_checked_out(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked out of pool".format(event)) def connection_checked_in(self, event): logging.info("[pool {0.address}][conn #{0.connection_id}] " "connection checked into pool".format(event)) Event listeners can also be registered per instance of :class:`~pymongo.mongo_client.MongoClient`:: client = MongoClient(event_listeners=[CommandLogger()]) Note that previously registered global listeners are automatically included when configuring per client event listeners. Registering a new global listener will not add that listener to existing client instances. .. note:: Events are delivered **synchronously**. Application threads block waiting for event handlers (e.g. :meth:`~CommandListener.started`) to return. Care must be taken to ensure that your event handlers are efficient enough to not adversely affect overall application performance. .. warning:: The command documents published through this API are *not* copies. If you intend to modify them in any way you must copy them in your event handler first. """ from collections import namedtuple from bson.py3compat import abc from pymongo.helpers import _handle_exception _Listeners = namedtuple('Listeners', ('command_listeners', 'server_listeners', 'server_heartbeat_listeners', 'topology_listeners', 'cmap_listeners')) _LISTENERS = _Listeners([], [], [], [], []) class _EventListener(object): """Abstract base class for all event listeners.""" class CommandListener(_EventListener): """Abstract base class for command listeners. Handles `CommandStartedEvent`, `CommandSucceededEvent`, and `CommandFailedEvent`. """ def started(self, event): """Abstract method to handle a `CommandStartedEvent`. :Parameters: - `event`: An instance of :class:`CommandStartedEvent`. """ raise NotImplementedError def succeeded(self, event): """Abstract method to handle a `CommandSucceededEvent`. :Parameters: - `event`: An instance of :class:`CommandSucceededEvent`. """ raise NotImplementedError def failed(self, event): """Abstract method to handle a `CommandFailedEvent`. :Parameters: - `event`: An instance of :class:`CommandFailedEvent`. """ raise NotImplementedError class ConnectionPoolListener(_EventListener): """Abstract base class for connection pool listeners. Handles all of the connection pool events defined in the Connection Monitoring and Pooling Specification: :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, :class:`ConnectionCheckOutStartedEvent`, :class:`ConnectionCheckOutFailedEvent`, :class:`ConnectionCheckedOutEvent`, and :class:`ConnectionCheckedInEvent`. .. versionadded:: 3.9 """ def pool_created(self, event): """Abstract method to handle a :class:`PoolCreatedEvent`. Emitted when a Connection Pool is created. :Parameters: - `event`: An instance of :class:`PoolCreatedEvent`. """ raise NotImplementedError def pool_cleared(self, event): """Abstract method to handle a `PoolClearedEvent`. Emitted when a Connection Pool is cleared. :Parameters: - `event`: An instance of :class:`PoolClearedEvent`. """ raise NotImplementedError def pool_closed(self, event): """Abstract method to handle a `PoolClosedEvent`. Emitted when a Connection Pool is closed. :Parameters: - `event`: An instance of :class:`PoolClosedEvent`. """ raise NotImplementedError def connection_created(self, event): """Abstract method to handle a :class:`ConnectionCreatedEvent`. Emitted when a Connection Pool creates a Connection object. :Parameters: - `event`: An instance of :class:`ConnectionCreatedEvent`. """ raise NotImplementedError def connection_ready(self, event): """Abstract method to handle a :class:`ConnectionReadyEvent`. Emitted when a Connection has finished its setup, and is now ready to use. :Parameters: - `event`: An instance of :class:`ConnectionReadyEvent`. """ raise NotImplementedError def connection_closed(self, event): """Abstract method to handle a :class:`ConnectionClosedEvent`. Emitted when a Connection Pool closes a Connection. :Parameters: - `event`: An instance of :class:`ConnectionClosedEvent`. """ raise NotImplementedError def connection_check_out_started(self, event): """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. Emitted when the driver starts attempting to check out a connection. :Parameters: - `event`: An instance of :class:`ConnectionCheckOutStartedEvent`. """ raise NotImplementedError def connection_check_out_failed(self, event): """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. Emitted when the driver's attempt to check out a connection fails. :Parameters: - `event`: An instance of :class:`ConnectionCheckOutFailedEvent`. """ raise NotImplementedError def connection_checked_out(self, event): """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. Emitted when the driver successfully checks out a Connection. :Parameters: - `event`: An instance of :class:`ConnectionCheckedOutEvent`. """ raise NotImplementedError def connection_checked_in(self, event): """Abstract method to handle a :class:`ConnectionCheckedInEvent`. Emitted when the driver checks in a Connection back to the Connection Pool. :Parameters: - `event`: An instance of :class:`ConnectionCheckedInEvent`. """ raise NotImplementedError class ServerHeartbeatListener(_EventListener): """Abstract base class for server heartbeat listeners. Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, and `ServerHeartbeatFailedEvent`. .. versionadded:: 3.3 """ def started(self, event): """Abstract method to handle a `ServerHeartbeatStartedEvent`. :Parameters: - `event`: An instance of :class:`ServerHeartbeatStartedEvent`. """ raise NotImplementedError def succeeded(self, event): """Abstract method to handle a `ServerHeartbeatSucceededEvent`. :Parameters: - `event`: An instance of :class:`ServerHeartbeatSucceededEvent`. """ raise NotImplementedError def failed(self, event): """Abstract method to handle a `ServerHeartbeatFailedEvent`. :Parameters: - `event`: An instance of :class:`ServerHeartbeatFailedEvent`. """ raise NotImplementedError class TopologyListener(_EventListener): """Abstract base class for topology monitoring listeners. Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and `TopologyClosedEvent`. .. versionadded:: 3.3 """ def opened(self, event): """Abstract method to handle a `TopologyOpenedEvent`. :Parameters: - `event`: An instance of :class:`TopologyOpenedEvent`. """ raise NotImplementedError def description_changed(self, event): """Abstract method to handle a `TopologyDescriptionChangedEvent`. :Parameters: - `event`: An instance of :class:`TopologyDescriptionChangedEvent`. """ raise NotImplementedError def closed(self, event): """Abstract method to handle a `TopologyClosedEvent`. :Parameters: - `event`: An instance of :class:`TopologyClosedEvent`. """ raise NotImplementedError class ServerListener(_EventListener): """Abstract base class for server listeners. Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and `ServerClosedEvent`. .. versionadded:: 3.3 """ def opened(self, event): """Abstract method to handle a `ServerOpeningEvent`. :Parameters: - `event`: An instance of :class:`ServerOpeningEvent`. """ raise NotImplementedError def description_changed(self, event): """Abstract method to handle a `ServerDescriptionChangedEvent`. :Parameters: - `event`: An instance of :class:`ServerDescriptionChangedEvent`. """ raise NotImplementedError def closed(self, event): """Abstract method to handle a `ServerClosedEvent`. :Parameters: - `event`: An instance of :class:`ServerClosedEvent`. """ raise NotImplementedError def _to_micros(dur): """Convert duration 'dur' to microseconds.""" return int(dur.total_seconds() * 10e5) def _validate_event_listeners(option, listeners): """Validate event listeners""" if not isinstance(listeners, abc.Sequence): raise TypeError("%s must be a list or tuple" % (option,)) for listener in listeners: if not isinstance(listener, _EventListener): raise TypeError("Listeners for %s must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " "ConnectionPoolListener." % (option,)) return listeners def register(listener): """Register a global event listener. :Parameters: - `listener`: A subclasses of :class:`CommandListener`, :class:`ServerHeartbeatListener`, :class:`ServerListener`, :class:`TopologyListener`, or :class:`ConnectionPoolListener`. """ if not isinstance(listener, _EventListener): raise TypeError("Listeners for %s must be either a " "CommandListener, ServerHeartbeatListener, " "ServerListener, TopologyListener, or " "ConnectionPoolListener." % (listener,)) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) if isinstance(listener, ServerHeartbeatListener): _LISTENERS.server_heartbeat_listeners.append(listener) if isinstance(listener, ServerListener): _LISTENERS.server_listeners.append(listener) if isinstance(listener, TopologyListener): _LISTENERS.topology_listeners.append(listener) if isinstance(listener, ConnectionPoolListener): _LISTENERS.cmap_listeners.append(listener) # Note - to avoid bugs from forgetting which if these is all lowercase and # which are camelCase, and at the same time avoid having to add a test for # every command, use all lowercase here and test against command_name.lower(). _SENSITIVE_COMMANDS = set( ["authenticate", "saslstart", "saslcontinue", "getnonce", "createuser", "updateuser", "copydbgetnonce", "copydbsaslstart", "copydb"]) class _CommandEvent(object): """Base class for command events.""" __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id") def __init__(self, command_name, request_id, connection_id, operation_id): self.__cmd_name = command_name self.__rqst_id = request_id self.__conn_id = connection_id self.__op_id = operation_id @property def command_name(self): """The command name.""" return self.__cmd_name @property def request_id(self): """The request id for this operation.""" return self.__rqst_id @property def connection_id(self): """The address (host, port) of the server this command was sent to.""" return self.__conn_id @property def operation_id(self): """An id for this series of events or None.""" return self.__op_id class CommandStartedEvent(_CommandEvent): """Event published when a command starts. :Parameters: - `command`: The command document. - `database_name`: The name of the database this command was run against. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. """ __slots__ = ("__cmd", "__db") def __init__(self, command, database_name, *args): if not command: raise ValueError("%r is not a valid command" % (command,)) # Command name must be first key. command_name = next(iter(command)) super(CommandStartedEvent, self).__init__(command_name, *args) if command_name.lower() in _SENSITIVE_COMMANDS: self.__cmd = {} else: self.__cmd = command self.__db = database_name @property def command(self): """The command document.""" return self.__cmd @property def database_name(self): """The name of the database this command was run against.""" return self.__db def __repr__(self): return "<%s %s db: %r, command: %r, operation_id: %s>" % ( self.__class__.__name__, self.connection_id, self.database_name, self.command_name, self.operation_id) class CommandSucceededEvent(_CommandEvent): """Event published when a command succeeds. :Parameters: - `duration`: The command duration as a datetime.timedelta. - `reply`: The server reply document. - `command_name`: The command name. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. """ __slots__ = ("__duration_micros", "__reply") def __init__(self, duration, reply, command_name, request_id, connection_id, operation_id): super(CommandSucceededEvent, self).__init__( command_name, request_id, connection_id, operation_id) self.__duration_micros = _to_micros(duration) if command_name.lower() in _SENSITIVE_COMMANDS: self.__reply = {} else: self.__reply = reply @property def duration_micros(self): """The duration of this operation in microseconds.""" return self.__duration_micros @property def reply(self): """The server failure document for this operation.""" return self.__reply def __repr__(self): return "<%s %s command: %r, operation_id: %s, duration_micros: %s>" % ( self.__class__.__name__, self.connection_id, self.command_name, self.operation_id, self.duration_micros) class CommandFailedEvent(_CommandEvent): """Event published when a command fails. :Parameters: - `duration`: The command duration as a datetime.timedelta. - `failure`: The server reply document. - `command_name`: The command name. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `operation_id`: An optional identifier for a series of related events. """ __slots__ = ("__duration_micros", "__failure") def __init__(self, duration, failure, *args): super(CommandFailedEvent, self).__init__(*args) self.__duration_micros = _to_micros(duration) self.__failure = failure @property def duration_micros(self): """The duration of this operation in microseconds.""" return self.__duration_micros @property def failure(self): """The server failure document for this operation.""" return self.__failure def __repr__(self): return ( "<%s %s command: %r, operation_id: %s, duration_micros: %s, " "failure: %r>" % ( self.__class__.__name__, self.connection_id, self.command_name, self.operation_id, self.duration_micros, self.failure)) class _PoolEvent(object): """Base class for pool events.""" __slots__ = ("__address",) def __init__(self, address): self.__address = address @property def address(self): """The address (host, port) pair of the server the pool is attempting to connect to. """ return self.__address def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.__address) class PoolCreatedEvent(_PoolEvent): """Published when a Connection Pool is created. :Parameters: - `address`: The address (host, port) pair of the server this Pool is attempting to connect to. .. versionadded:: 3.9 """ __slots__ = ("__options",) def __init__(self, address, options): super(PoolCreatedEvent, self).__init__(address) self.__options = options @property def options(self): """Any non-default pool options that were set on this Connection Pool. """ return self.__options def __repr__(self): return '%s(%r, %r)' % ( self.__class__.__name__, self.address, self.__options) class PoolClearedEvent(_PoolEvent): """Published when a Connection Pool is cleared. :Parameters: - `address`: The address (host, port) pair of the server this Pool is attempting to connect to. .. versionadded:: 3.9 """ __slots__ = () class PoolClosedEvent(_PoolEvent): """Published when a Connection Pool is closed. :Parameters: - `address`: The address (host, port) pair of the server this Pool is attempting to connect to. .. versionadded:: 3.9 """ __slots__ = () class ConnectionClosedReason(object): """An enum that defines values for `reason` on a :class:`ConnectionClosedEvent`. .. versionadded:: 3.9 """ STALE = 'stale' """The pool was cleared, making the connection no longer valid.""" IDLE = 'idle' """The connection became stale by being idle for too long (maxIdleTimeMS). """ ERROR = 'error' """The connection experienced an error, making it no longer valid.""" POOL_CLOSED = 'poolClosed' """The pool was closed, making the connection no longer valid.""" class ConnectionCheckOutFailedReason(object): """An enum that defines values for `reason` on a :class:`ConnectionCheckOutFailedEvent`. .. versionadded:: 3.9 """ TIMEOUT = 'timeout' """The connection check out attempt exceeded the specified timeout.""" POOL_CLOSED = 'poolClosed' """The pool was previously closed, and cannot provide new connections.""" CONN_ERROR = 'connectionError' """The connection check out attempt experienced an error while setting up a new connection. """ class _ConnectionEvent(object): """Private base class for some connection events.""" __slots__ = ("__address", "__connection_id") def __init__(self, address, connection_id): self.__address = address self.__connection_id = connection_id @property def address(self): """The address (host, port) pair of the server this connection is attempting to connect to. """ return self.__address @property def connection_id(self): """The ID of the Connection.""" return self.__connection_id def __repr__(self): return '%s(%r, %r)' % ( self.__class__.__name__, self.__address, self.__connection_id) class ConnectionCreatedEvent(_ConnectionEvent): """Published when a Connection Pool creates a Connection object. NOTE: This connection is not ready for use until the :class:`ConnectionReadyEvent` is published. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `connection_id`: The integer ID of the Connection in this Pool. .. versionadded:: 3.9 """ __slots__ = () class ConnectionReadyEvent(_ConnectionEvent): """Published when a Connection has finished its setup, and is ready to use. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `connection_id`: The integer ID of the Connection in this Pool. .. versionadded:: 3.9 """ __slots__ = () class ConnectionClosedEvent(_ConnectionEvent): """Published when a Connection is closed. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `connection_id`: The integer ID of the Connection in this Pool. - `reason`: A reason explaining why this connection was closed. .. versionadded:: 3.9 """ __slots__ = ("__reason",) def __init__(self, address, connection_id, reason): super(ConnectionClosedEvent, self).__init__(address, connection_id) self.__reason = reason @property def reason(self): """A reason explaining why this connection was closed. The reason must be one of the strings from the :class:`ConnectionClosedReason` enum. """ return self.__reason def __repr__(self): return '%s(%r, %r, %r)' % ( self.__class__.__name__, self.address, self.connection_id, self.__reason) class ConnectionCheckOutStartedEvent(object): """Published when the driver starts attempting to check out a connection. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. .. versionadded:: 3.9 """ __slots__ = ("__address",) def __init__(self, address): self.__address = address @property def address(self): """The address (host, port) pair of the server this connection is attempting to connect to. """ return self.__address def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.__address) class ConnectionCheckOutFailedEvent(object): """Published when the driver's attempt to check out a connection fails. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `reason`: A reason explaining why connection check out failed. .. versionadded:: 3.9 """ __slots__ = ("__address", "__reason") def __init__(self, address, reason): self.__address = address self.__reason = reason @property def address(self): """The address (host, port) pair of the server this connection is attempting to connect to. """ return self.__address @property def reason(self): """A reason explaining why connection check out failed. The reason must be one of the strings from the :class:`ConnectionCheckOutFailedReason` enum. """ return self.__reason def __repr__(self): return '%s(%r, %r)' % ( self.__class__.__name__, self.__address, self.__reason) class ConnectionCheckedOutEvent(_ConnectionEvent): """Published when the driver successfully checks out a Connection. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `connection_id`: The integer ID of the Connection in this Pool. .. versionadded:: 3.9 """ __slots__ = () class ConnectionCheckedInEvent(_ConnectionEvent): """Published when the driver checks in a Connection into the Pool. :Parameters: - `address`: The address (host, port) pair of the server this Connection is attempting to connect to. - `connection_id`: The integer ID of the Connection in this Pool. .. versionadded:: 3.9 """ __slots__ = () class _ServerEvent(object): """Base class for server events.""" __slots__ = ("__server_address", "__topology_id") def __init__(self, server_address, topology_id): self.__server_address = server_address self.__topology_id = topology_id @property def server_address(self): """The address (host, port) pair of the server""" return self.__server_address @property def topology_id(self): """A unique identifier for the topology this server is a part of.""" return self.__topology_id def __repr__(self): return "<%s %s topology_id: %s>" % ( self.__class__.__name__, self.server_address, self.topology_id) class ServerDescriptionChangedEvent(_ServerEvent): """Published when server description changes. .. versionadded:: 3.3 """ __slots__ = ('__previous_description', '__new_description') def __init__(self, previous_description, new_description, *args): super(ServerDescriptionChangedEvent, self).__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property def previous_description(self): """The previous :class:`~pymongo.server_description.ServerDescription`.""" return self.__previous_description @property def new_description(self): """The new :class:`~pymongo.server_description.ServerDescription`.""" return self.__new_description def __repr__(self): return "<%s %s changed from: %s, to: %s>" % ( self.__class__.__name__, self.server_address, self.previous_description, self.new_description) class ServerOpeningEvent(_ServerEvent): """Published when server is initialized. .. versionadded:: 3.3 """ __slots__ = () class ServerClosedEvent(_ServerEvent): """Published when server is closed. .. versionadded:: 3.3 """ __slots__ = () class TopologyEvent(object): """Base class for topology description events.""" __slots__ = ('__topology_id') def __init__(self, topology_id): self.__topology_id = topology_id @property def topology_id(self): """A unique identifier for the topology this server is a part of.""" return self.__topology_id def __repr__(self): return "<%s topology_id: %s>" % ( self.__class__.__name__, self.topology_id) class TopologyDescriptionChangedEvent(TopologyEvent): """Published when the topology description changes. .. versionadded:: 3.3 """ __slots__ = ('__previous_description', '__new_description') def __init__(self, previous_description, new_description, *args): super(TopologyDescriptionChangedEvent, self).__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @property def previous_description(self): """The previous :class:`~pymongo.topology_description.TopologyDescription`.""" return self.__previous_description @property def new_description(self): """The new :class:`~pymongo.topology_description.TopologyDescription`.""" return self.__new_description def __repr__(self): return "<%s topology_id: %s changed from: %s, to: %s>" % ( self.__class__.__name__, self.topology_id, self.previous_description, self.new_description) class TopologyOpenedEvent(TopologyEvent): """Published when the topology is initialized. .. versionadded:: 3.3 """ __slots__ = () class TopologyClosedEvent(TopologyEvent): """Published when the topology is closed. .. versionadded:: 3.3 """ __slots__ = () class _ServerHeartbeatEvent(object): """Base class for server heartbeat events.""" __slots__ = ('__connection_id') def __init__(self, connection_id): self.__connection_id = connection_id @property def connection_id(self): """The address (host, port) of the server this heartbeat was sent to.""" return self.__connection_id def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.connection_id) class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): """Published when a heartbeat is started. .. versionadded:: 3.3 """ __slots__ = () class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): """Fired when the server heartbeat succeeds. .. versionadded:: 3.3 """ __slots__ = ('__duration', '__reply', '__awaited') def __init__(self, duration, reply, connection_id, awaited=False): super(ServerHeartbeatSucceededEvent, self).__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @property def duration(self): """The duration of this heartbeat in microseconds.""" return self.__duration @property def reply(self): """An instance of :class:`~pymongo.ismaster.IsMaster`.""" return self.__reply @property def awaited(self): """Whether the heartbeat was awaited. If true, then :meth:`duration` reflects the sum of the round trip time to the server and the time that the server waited before sending a response. """ return self.__awaited def __repr__(self): return "<%s %s duration: %s, awaited: %s, reply: %s>" % ( self.__class__.__name__, self.connection_id, self.duration, self.awaited, self.reply) class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): """Fired when the server heartbeat fails, either with an "ok: 0" or a socket exception. .. versionadded:: 3.3 """ __slots__ = ('__duration', '__reply', '__awaited') def __init__(self, duration, reply, connection_id, awaited=False): super(ServerHeartbeatFailedEvent, self).__init__(connection_id) self.__duration = duration self.__reply = reply self.__awaited = awaited @property def duration(self): """The duration of this heartbeat in microseconds.""" return self.__duration @property def reply(self): """A subclass of :exc:`Exception`.""" return self.__reply @property def awaited(self): """Whether the heartbeat was awaited. If true, then :meth:`duration` reflects the sum of the round trip time to the server and the time that the server waited before sending a response. """ return self.__awaited def __repr__(self): return "<%s %s duration: %s, awaited: %s, reply: %r>" % ( self.__class__.__name__, self.connection_id, self.duration, self.awaited, self.reply) class _EventListeners(object): """Configure event listeners for a client instance. Any event listeners registered globally are included by default. :Parameters: - `listeners`: A list of event listeners. """ def __init__(self, listeners): self.__command_listeners = _LISTENERS.command_listeners[:] self.__server_listeners = _LISTENERS.server_listeners[:] lst = _LISTENERS.server_heartbeat_listeners self.__server_heartbeat_listeners = lst[:] self.__topology_listeners = _LISTENERS.topology_listeners[:] self.__cmap_listeners = _LISTENERS.cmap_listeners[:] if listeners is not None: for lst in listeners: if isinstance(lst, CommandListener): self.__command_listeners.append(lst) if isinstance(lst, ServerListener): self.__server_listeners.append(lst) if isinstance(lst, ServerHeartbeatListener): self.__server_heartbeat_listeners.append(lst) if isinstance(lst, TopologyListener): self.__topology_listeners.append(lst) if isinstance(lst, ConnectionPoolListener): self.__cmap_listeners.append(lst) self.__enabled_for_commands = bool(self.__command_listeners) self.__enabled_for_server = bool(self.__server_listeners) self.__enabled_for_server_heartbeat = bool( self.__server_heartbeat_listeners) self.__enabled_for_topology = bool(self.__topology_listeners) self.__enabled_for_cmap = bool(self.__cmap_listeners) @property def enabled_for_commands(self): """Are any CommandListener instances registered?""" return self.__enabled_for_commands @property def enabled_for_server(self): """Are any ServerListener instances registered?""" return self.__enabled_for_server @property def enabled_for_server_heartbeat(self): """Are any ServerHeartbeatListener instances registered?""" return self.__enabled_for_server_heartbeat @property def enabled_for_topology(self): """Are any TopologyListener instances registered?""" return self.__enabled_for_topology @property def enabled_for_cmap(self): """Are any ConnectionPoolListener instances registered?""" return self.__enabled_for_cmap def event_listeners(self): """List of registered event listeners.""" return (self.__command_listeners[:], self.__server_heartbeat_listeners[:], self.__server_listeners[:], self.__topology_listeners[:]) def publish_command_start(self, command, database_name, request_id, connection_id, op_id=None): """Publish a CommandStartedEvent to all command listeners. :Parameters: - `command`: The command document. - `database_name`: The name of the database this command was run against. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. """ if op_id is None: op_id = request_id event = CommandStartedEvent( command, database_name, request_id, connection_id, op_id) for subscriber in self.__command_listeners: try: subscriber.started(event) except Exception: _handle_exception() def publish_command_success(self, duration, reply, command_name, request_id, connection_id, op_id=None): """Publish a CommandSucceededEvent to all command listeners. :Parameters: - `duration`: The command duration as a datetime.timedelta. - `reply`: The server reply document. - `command_name`: The command name. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. """ if op_id is None: op_id = request_id event = CommandSucceededEvent( duration, reply, command_name, request_id, connection_id, op_id) for subscriber in self.__command_listeners: try: subscriber.succeeded(event) except Exception: _handle_exception() def publish_command_failure(self, duration, failure, command_name, request_id, connection_id, op_id=None): """Publish a CommandFailedEvent to all command listeners. :Parameters: - `duration`: The command duration as a datetime.timedelta. - `failure`: The server reply document or failure description document. - `command_name`: The command name. - `request_id`: The request id for this operation. - `connection_id`: The address (host, port) of the server this command was sent to. - `op_id`: The (optional) operation id for this operation. """ if op_id is None: op_id = request_id event = CommandFailedEvent( duration, failure, command_name, request_id, connection_id, op_id) for subscriber in self.__command_listeners: try: subscriber.failed(event) except Exception: _handle_exception() def publish_server_heartbeat_started(self, connection_id): """Publish a ServerHeartbeatStartedEvent to all server heartbeat listeners. :Parameters: - `connection_id`: The address (host, port) pair of the connection. """ event = ServerHeartbeatStartedEvent(connection_id) for subscriber in self.__server_heartbeat_listeners: try: subscriber.started(event) except Exception: _handle_exception() def publish_server_heartbeat_succeeded(self, connection_id, duration, reply, awaited): """Publish a ServerHeartbeatSucceededEvent to all server heartbeat listeners. :Parameters: - `connection_id`: The address (host, port) pair of the connection. - `duration`: The execution time of the event in the highest possible resolution for the platform. - `reply`: The command reply. - `awaited`: True if the response was awaited. """ event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) for subscriber in self.__server_heartbeat_listeners: try: subscriber.succeeded(event) except Exception: _handle_exception() def publish_server_heartbeat_failed(self, connection_id, duration, reply, awaited): """Publish a ServerHeartbeatFailedEvent to all server heartbeat listeners. :Parameters: - `connection_id`: The address (host, port) pair of the connection. - `duration`: The execution time of the event in the highest possible resolution for the platform. - `reply`: The command reply. - `awaited`: True if the response was awaited. """ event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) for subscriber in self.__server_heartbeat_listeners: try: subscriber.failed(event) except Exception: _handle_exception() def publish_server_opened(self, server_address, topology_id): """Publish a ServerOpeningEvent to all server listeners. :Parameters: - `server_address`: The address (host, port) pair of the server. - `topology_id`: A unique identifier for the topology this server is a part of. """ event = ServerOpeningEvent(server_address, topology_id) for subscriber in self.__server_listeners: try: subscriber.opened(event) except Exception: _handle_exception() def publish_server_closed(self, server_address, topology_id): """Publish a ServerClosedEvent to all server listeners. :Parameters: - `server_address`: The address (host, port) pair of the server. - `topology_id`: A unique identifier for the topology this server is a part of. """ event = ServerClosedEvent(server_address, topology_id) for subscriber in self.__server_listeners: try: subscriber.closed(event) except Exception: _handle_exception() def publish_server_description_changed(self, previous_description, new_description, server_address, topology_id): """Publish a ServerDescriptionChangedEvent to all server listeners. :Parameters: - `previous_description`: The previous server description. - `server_address`: The address (host, port) pair of the server. - `new_description`: The new server description. - `topology_id`: A unique identifier for the topology this server is a part of. """ event = ServerDescriptionChangedEvent(previous_description, new_description, server_address, topology_id) for subscriber in self.__server_listeners: try: subscriber.description_changed(event) except Exception: _handle_exception() def publish_topology_opened(self, topology_id): """Publish a TopologyOpenedEvent to all topology listeners. :Parameters: - `topology_id`: A unique identifier for the topology this server is a part of. """ event = TopologyOpenedEvent(topology_id) for subscriber in self.__topology_listeners: try: subscriber.opened(event) except Exception: _handle_exception() def publish_topology_closed(self, topology_id): """Publish a TopologyClosedEvent to all topology listeners. :Parameters: - `topology_id`: A unique identifier for the topology this server is a part of. """ event = TopologyClosedEvent(topology_id) for subscriber in self.__topology_listeners: try: subscriber.closed(event) except Exception: _handle_exception() def publish_topology_description_changed(self, previous_description, new_description, topology_id): """Publish a TopologyDescriptionChangedEvent to all topology listeners. :Parameters: - `previous_description`: The previous topology description. - `new_description`: The new topology description. - `topology_id`: A unique identifier for the topology this server is a part of. """ event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) for subscriber in self.__topology_listeners: try: subscriber.description_changed(event) except Exception: _handle_exception() def publish_pool_created(self, address, options): """Publish a :class:`PoolCreatedEvent` to all pool listeners. """ event = PoolCreatedEvent(address, options) for subscriber in self.__cmap_listeners: try: subscriber.pool_created(event) except Exception: _handle_exception() def publish_pool_cleared(self, address): """Publish a :class:`PoolClearedEvent` to all pool listeners. """ event = PoolClearedEvent(address) for subscriber in self.__cmap_listeners: try: subscriber.pool_cleared(event) except Exception: _handle_exception() def publish_pool_closed(self, address): """Publish a :class:`PoolClosedEvent` to all pool listeners. """ event = PoolClosedEvent(address) for subscriber in self.__cmap_listeners: try: subscriber.pool_closed(event) except Exception: _handle_exception() def publish_connection_created(self, address, connection_id): """Publish a :class:`ConnectionCreatedEvent` to all connection listeners. """ event = ConnectionCreatedEvent(address, connection_id) for subscriber in self.__cmap_listeners: try: subscriber.connection_created(event) except Exception: _handle_exception() def publish_connection_ready(self, address, connection_id): """Publish a :class:`ConnectionReadyEvent` to all connection listeners. """ event = ConnectionReadyEvent(address, connection_id) for subscriber in self.__cmap_listeners: try: subscriber.connection_ready(event) except Exception: _handle_exception() def publish_connection_closed(self, address, connection_id, reason): """Publish a :class:`ConnectionClosedEvent` to all connection listeners. """ event = ConnectionClosedEvent(address, connection_id, reason) for subscriber in self.__cmap_listeners: try: subscriber.connection_closed(event) except Exception: _handle_exception() def publish_connection_check_out_started(self, address): """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection listeners. """ event = ConnectionCheckOutStartedEvent(address) for subscriber in self.__cmap_listeners: try: subscriber.connection_check_out_started(event) except Exception: _handle_exception() def publish_connection_check_out_failed(self, address, reason): """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection listeners. """ event = ConnectionCheckOutFailedEvent(address, reason) for subscriber in self.__cmap_listeners: try: subscriber.connection_check_out_started(event) except Exception: _handle_exception() def publish_connection_checked_out(self, address, connection_id): """Publish a :class:`ConnectionCheckedOutEvent` to all connection listeners. """ event = ConnectionCheckedOutEvent(address, connection_id) for subscriber in self.__cmap_listeners: try: subscriber.connection_checked_out(event) except Exception: _handle_exception() def publish_connection_checked_in(self, address, connection_id): """Publish a :class:`ConnectionCheckedInEvent` to all connection listeners. """ event = ConnectionCheckedInEvent(address, connection_id) for subscriber in self.__cmap_listeners: try: subscriber.connection_checked_in(event) except Exception: _handle_exception() pymongo-3.11.0/pymongo/monotonic.py000066400000000000000000000021141374256237000173420ustar00rootroot00000000000000# Copyright 2014-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Time. Monotonic if possible. """ from __future__ import absolute_import __all__ = ['time'] try: # Patches standard time module. # From https://pypi.python.org/pypi/Monotime. import monotime except ImportError: pass try: # From https://pypi.python.org/pypi/monotonic. from monotonic import monotonic as time except ImportError: try: # Monotime or Python 3. from time import monotonic as time except ImportError: # Not monotonic. from time import time pymongo-3.11.0/pymongo/network.py000066400000000000000000000275261374256237000170440ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Internal network layer helper methods.""" import datetime import errno import socket import struct from bson import _decode_all_selective from bson.py3compat import PY3 from pymongo import helpers, message from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import decompress, _NO_COMPRESSION from pymongo.errors import (AutoReconnect, NotMasterError, OperationFailure, ProtocolError, NetworkTimeout, _OperationCancelled) from pymongo.message import _UNPACK_REPLY, _OpMsg from pymongo.monotonic import time from pymongo.socket_checker import _errno_from_exception _UNPACK_HEADER = struct.Struct(" max_bson_size): message._raise_document_too_large(name, size, max_bson_size) else: request_id, msg, size = message.query( flags, ns, 0, -1, spec, None, codec_options, check_keys, compression_ctx) if (max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD): message._raise_document_too_large( name, size, max_bson_size + message._COMMAND_OVERHEAD) if publish: encoding_duration = datetime.datetime.now() - start listeners.publish_command_start(orig, dbname, request_id, address) start = datetime.datetime.now() try: sock_info.sock.sendall(msg) if use_op_msg and unacknowledged: # Unacknowledged, fake a successful command response. reply = None response_doc = {"ok": 1} else: reply = receive_message(sock_info, request_id) sock_info.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( codec_options=codec_options, user_fields=user_fields) response_doc = unpacked_docs[0] if client: client._process_response(response_doc, session) if check: helpers._check_command_response( response_doc, sock_info.max_wire_version, None, allowable_errors, parse_write_concern_error=parse_write_concern_error) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration if isinstance(exc, (NotMasterError, OperationFailure)): failure = exc.details else: failure = message._convert_exception(exc) listeners.publish_command_failure( duration, failure, name, request_id, address) raise if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( duration, response_doc, name, request_id, address) if client and client._encrypter and reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0] return response_doc _UNPACK_COMPRESSION_HEADER = struct.Struct(" max_message_size: raise ProtocolError("Message length (%r) is larger than server max " "message size (%r)" % (length, max_message_size)) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( _receive_data_on_socket(sock_info, 9, deadline)) data = decompress( _receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) else: data = _receive_data_on_socket(sock_info, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: raise ProtocolError("Got opcode %r but expected " "%r" % (op_code, _UNPACK_REPLY.keys())) return unpack_reply(data) _POLL_TIMEOUT = 0.5 def wait_for_read(sock_info, deadline): """Block until at least one byte is read, or a timeout, or a cancel.""" context = sock_info.cancel_context # Only Monitor connections can be cancelled. if context: sock = sock_info.sock while True: # SSLSocket can have buffered data which won't be caught by select. if hasattr(sock, 'pending') and sock.pending() > 0: readable = True else: # Wait up to 500ms for the socket to become readable and then # check for cancellation. if deadline: timeout = max(min(deadline - time(), _POLL_TIMEOUT), 0.001) else: timeout = _POLL_TIMEOUT readable = sock_info.socket_checker.select( sock, read=True, timeout=timeout) if context.cancelled: raise _OperationCancelled('isMaster cancelled') if readable: return if deadline and time() > deadline: raise socket.timeout("timed out") # memoryview was introduced in Python 2.7 but we only use it on Python 3 # because before 2.7.4 the struct module did not support memoryview: # https://bugs.python.org/issue10212. # In Jython, using slice assignment on a memoryview results in a # NullPointerException. if not PY3: def _receive_data_on_socket(sock_info, length, deadline): buf = bytearray(length) i = 0 while length: try: wait_for_read(sock_info, deadline) chunk = sock_info.sock.recv(length) except (IOError, OSError) as exc: if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk == b"": raise AutoReconnect("connection closed") buf[i:i + len(chunk)] = chunk i += len(chunk) length -= len(chunk) return bytes(buf) else: def _receive_data_on_socket(sock_info, length, deadline): buf = bytearray(length) mv = memoryview(buf) bytes_read = 0 while bytes_read < length: try: wait_for_read(sock_info, deadline) chunk_length = sock_info.sock.recv_into(mv[bytes_read:]) except (IOError, OSError) as exc: if _errno_from_exception(exc) == errno.EINTR: continue raise if chunk_length == 0: raise AutoReconnect("connection closed") bytes_read += chunk_length return mv pymongo-3.11.0/pymongo/ocsp_cache.py000066400000000000000000000062121374256237000174270ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for caching OCSP responses.""" from collections import namedtuple from datetime import datetime as _datetime from threading import Lock class _OCSPCache(object): """A cache for OCSP responses.""" CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey', ['hash_algorithm', 'issuer_name_hash', 'issuer_key_hash', 'serial_number']) def __init__(self): self._data = {} # Hold this lock when accessing _data. self._lock = Lock() def _get_cache_key(self, ocsp_request): return self.CACHE_KEY_TYPE( hash_algorithm=ocsp_request.hash_algorithm.name.lower(), issuer_name_hash=ocsp_request.issuer_name_hash, issuer_key_hash=ocsp_request.issuer_key_hash, serial_number=ocsp_request.serial_number) def __setitem__(self, key, value): """Add/update a cache entry. 'key' is of type cryptography.x509.ocsp.OCSPRequest 'value' is of type cryptography.x509.ocsp.OCSPResponse Validity of the OCSP response must be checked by caller. """ with self._lock: cache_key = self._get_cache_key(key) # As per the OCSP protocol, if the response's nextUpdate field is # not set, the responder is indicating that newer revocation # information is available all the time. if value.next_update is None: self._data.pop(cache_key, None) return # Do nothing if the response is invalid. if not (value.this_update <= _datetime.utcnow() < value.next_update): return # Cache new response OR update cached response if new response # has longer validity. cached_value = self._data.get(cache_key, None) if (cached_value is None or cached_value.next_update < value.next_update): self._data[cache_key] = value def __getitem__(self, item): """Get a cache entry if it exists. 'item' is of type cryptography.x509.ocsp.OCSPRequest Raises KeyError if the item is not in the cache. """ with self._lock: cache_key = self._get_cache_key(item) value = self._data[cache_key] # Return cached response if it is still valid. if (value.this_update <= _datetime.utcnow() < value.next_update): return value self._data.pop(cache_key, None) raise KeyError(cache_key) pymongo-3.11.0/pymongo/ocsp_support.py000066400000000000000000000341671374256237000201120ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Support for requesting and verifying OCSP responses.""" import logging as _logging import re as _re from datetime import datetime as _datetime from cryptography.exceptions import InvalidSignature as _InvalidSignature from cryptography.hazmat.backends import default_backend as _default_backend from cryptography.hazmat.primitives.asymmetric.dsa import ( DSAPublicKey as _DSAPublicKey) from cryptography.hazmat.primitives.asymmetric.ec import ( ECDSA as _ECDSA, EllipticCurvePublicKey as _EllipticCurvePublicKey) from cryptography.hazmat.primitives.asymmetric.padding import ( PKCS1v15 as _PKCS1v15) from cryptography.hazmat.primitives.asymmetric.rsa import ( RSAPublicKey as _RSAPublicKey) from cryptography.hazmat.primitives.hashes import ( Hash as _Hash, SHA1 as _SHA1) from cryptography.hazmat.primitives.serialization import ( Encoding as _Encoding, PublicFormat as _PublicFormat) from cryptography.x509 import ( AuthorityInformationAccess as _AuthorityInformationAccess, ExtendedKeyUsage as _ExtendedKeyUsage, ExtensionNotFound as _ExtensionNotFound, load_pem_x509_certificate as _load_pem_x509_certificate, TLSFeature as _TLSFeature, TLSFeatureType as _TLSFeatureType) from cryptography.x509.oid import ( AuthorityInformationAccessOID as _AuthorityInformationAccessOID, ExtendedKeyUsageOID as _ExtendedKeyUsageOID) from cryptography.x509.ocsp import ( load_der_ocsp_response as _load_der_ocsp_response, OCSPCertStatus as _OCSPCertStatus, OCSPRequestBuilder as _OCSPRequestBuilder, OCSPResponseStatus as _OCSPResponseStatus) from requests import post as _post from requests.exceptions import RequestException as _RequestException # Note: the functions in this module generally return 1 or 0. The reason # is simple. The entry point, ocsp_callback, is registered as a callback # with OpenSSL through PyOpenSSL. The callback must return 1 (success) or # 0 (failure). _LOGGER = _logging.getLogger(__name__) _CERT_REGEX = _re.compile( b'-----BEGIN CERTIFICATE[^\r\n]+.+?-----END CERTIFICATE[^\r\n]+', _re.DOTALL) def _load_trusted_ca_certs(cafile): """Parse the tlsCAFile into a list of certificates.""" with open(cafile, 'rb') as f: data = f.read() # Load all the certs in the file. trusted_ca_certs = [] backend = _default_backend() for cert_data in _re.findall(_CERT_REGEX, data): trusted_ca_certs.append( _load_pem_x509_certificate(cert_data, backend)) return trusted_ca_certs def _get_issuer_cert(cert, chain, trusted_ca_certs): issuer_name = cert.issuer for candidate in chain: if candidate.subject == issuer_name: return candidate # Depending on the server's TLS library, the peer's cert chain may not # include the self signed root CA. In this case we check the user # provided tlsCAFile (ssl_ca_certs) for the issuer. # Remove once we use the verified peer cert chain in PYTHON-2147. if trusted_ca_certs: for candidate in trusted_ca_certs: if candidate.subject == issuer_name: return candidate return None def _verify_signature(key, signature, algorithm, data): # See cryptography.x509.Certificate.public_key # for the public key types. try: if isinstance(key, _RSAPublicKey): key.verify(signature, data, _PKCS1v15(), algorithm) elif isinstance(key, _DSAPublicKey): key.verify(signature, data, algorithm) elif isinstance(key, _EllipticCurvePublicKey): key.verify(signature, data, _ECDSA(algorithm)) else: key.verify(signature, data) except _InvalidSignature: return 0 return 1 def _get_extension(cert, klass): try: return cert.extensions.get_extension_for_class(klass) except _ExtensionNotFound: return None def _public_key_hash(cert): public_key = cert.public_key() # https://tools.ietf.org/html/rfc2560#section-4.2.1 # "KeyHash ::= OCTET STRING -- SHA-1 hash of responder's public key # (excluding the tag and length fields)" # https://stackoverflow.com/a/46309453/600498 if isinstance(public_key, _RSAPublicKey): pbytes = public_key.public_bytes( _Encoding.DER, _PublicFormat.PKCS1) elif isinstance(public_key, _EllipticCurvePublicKey): pbytes = public_key.public_bytes( _Encoding.X962, _PublicFormat.UncompressedPoint) else: pbytes = public_key.public_bytes( _Encoding.DER, _PublicFormat.SubjectPublicKeyInfo) digest = _Hash(_SHA1(), backend=_default_backend()) digest.update(pbytes) return digest.finalize() def _get_certs_by_key_hash(certificates, issuer, responder_key_hash): return [ cert for cert in certificates if _public_key_hash(cert) == responder_key_hash and cert.issuer == issuer.subject] def _get_certs_by_name(certificates, issuer, responder_name): return [ cert for cert in certificates if cert.subject == responder_name and cert.issuer == issuer.subject] def _verify_response_signature(issuer, response): # Response object will have a responder_name or responder_key_hash # not both. name = response.responder_name rkey_hash = response.responder_key_hash ikey_hash = response.issuer_key_hash if name is not None and name == issuer.subject or rkey_hash == ikey_hash: _LOGGER.debug("Responder is issuer") # Responder is the issuer responder_cert = issuer else: _LOGGER.debug("Responder is a delegate") # Responder is a delegate # https://tools.ietf.org/html/rfc6960#section-2.6 # RFC6960, Section 3.2, Number 3 certs = response.certificates if response.responder_name is not None: responder_certs = _get_certs_by_name(certs, issuer, name) _LOGGER.debug("Using responder name") else: responder_certs = _get_certs_by_key_hash(certs, issuer, rkey_hash) _LOGGER.debug("Using key hash") if not responder_certs: _LOGGER.debug("No matching or valid responder certs.") return 0 # XXX: Can there be more than one? If so, should we try each one # until we find one that passes signature verification? responder_cert = responder_certs[0] # RFC6960, Section 3.2, Number 4 ext = _get_extension(responder_cert, _ExtendedKeyUsage) if not ext or _ExtendedKeyUsageOID.OCSP_SIGNING not in ext.value: _LOGGER.debug("Delegate not authorized for OCSP signing") return 0 if not _verify_signature( issuer.public_key(), responder_cert.signature, responder_cert.signature_hash_algorithm, responder_cert.tbs_certificate_bytes): _LOGGER.debug("Delegate signature verification failed") return 0 # RFC6960, Section 3.2, Number 2 ret = _verify_signature( responder_cert.public_key(), response.signature, response.signature_hash_algorithm, response.tbs_response_bytes) if not ret: _LOGGER.debug("Response signature verification failed") return ret def _build_ocsp_request(cert, issuer): # https://cryptography.io/en/latest/x509/ocsp/#creating-requests builder = _OCSPRequestBuilder() builder = builder.add_certificate(cert, issuer, _SHA1()) return builder.build() def _verify_response(issuer, response): _LOGGER.debug("Verifying response") # RFC6960, Section 3.2, Number 2, 3 and 4 happen here. res = _verify_response_signature(issuer, response) if not res: return 0 # Note that we are not using a "tolerence period" as discussed in # https://tools.ietf.org/rfc/rfc5019.txt? now = _datetime.utcnow() # RFC6960, Section 3.2, Number 5 if response.this_update > now: _LOGGER.debug("thisUpdate is in the future") return 0 # RFC6960, Section 3.2, Number 6 if response.next_update and response.next_update < now: _LOGGER.debug("nextUpdate is in the past") return 0 return 1 def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache): ocsp_request = _build_ocsp_request(cert, issuer) try: ocsp_response = ocsp_response_cache[ocsp_request] _LOGGER.debug("Using cached OCSP response.") except KeyError: try: response = _post( uri, data=ocsp_request.public_bytes(_Encoding.DER), headers={'Content-Type': 'application/ocsp-request'}, timeout=5) except _RequestException as exc: _LOGGER.debug("HTTP request failed: %s", exc) return None if response.status_code != 200: _LOGGER.debug("HTTP request returned %d", response.status_code) return None ocsp_response = _load_der_ocsp_response(response.content) _LOGGER.debug( "OCSP response status: %r", ocsp_response.response_status) if ocsp_response.response_status != _OCSPResponseStatus.SUCCESSFUL: return None # RFC6960, Section 3.2, Number 1. Only relevant if we need to # talk to the responder directly. # Accessing response.serial_number raises if response status is not # SUCCESSFUL. if ocsp_response.serial_number != ocsp_request.serial_number: _LOGGER.debug("Response serial number does not match request") return None if not _verify_response(issuer, ocsp_response): # The response failed verification. return None _LOGGER.debug("Caching OCSP response.") ocsp_response_cache[ocsp_request] = ocsp_response return ocsp_response def _ocsp_callback(conn, ocsp_bytes, user_data): """Callback for use with OpenSSL.SSL.Context.set_ocsp_client_callback.""" cert = conn.get_peer_certificate() if cert is None: _LOGGER.debug("No peer cert?") return 0 cert = cert.to_cryptography() chain = conn.get_peer_cert_chain() if not chain: _LOGGER.debug("No peer cert chain?") return 0 chain = [cer.to_cryptography() for cer in chain] issuer = _get_issuer_cert(cert, chain, user_data.trusted_ca_certs) must_staple = False # https://tools.ietf.org/html/rfc7633#section-4.2.3.1 ext = _get_extension(cert, _TLSFeature) if ext is not None: for feature in ext.value: if feature == _TLSFeatureType.status_request: _LOGGER.debug("Peer presented a must-staple cert") must_staple = True break ocsp_response_cache = user_data.ocsp_response_cache # No stapled OCSP response if ocsp_bytes == b'': _LOGGER.debug("Peer did not staple an OCSP response") if must_staple: _LOGGER.debug("Must-staple cert with no stapled response, hard fail.") return 0 if not user_data.check_ocsp_endpoint: _LOGGER.debug("OCSP endpoint checking is disabled, soft fail.") # No stapled OCSP response, checking responder URI diabled, soft fail. return 1 # https://tools.ietf.org/html/rfc6960#section-3.1 ext = _get_extension(cert, _AuthorityInformationAccess) if ext is None: _LOGGER.debug("No authority access information, soft fail") # No stapled OCSP response, no responder URI, soft fail. return 1 uris = [desc.access_location.value for desc in ext.value if desc.access_method == _AuthorityInformationAccessOID.OCSP] if not uris: _LOGGER.debug("No OCSP URI, soft fail") # No responder URI, soft fail. return 1 if issuer is None: _LOGGER.debug("No issuer cert?") return 0 _LOGGER.debug("Requesting OCSP data") # When requesting data from an OCSP endpoint we only fail on # successful, valid responses with a certificate status of REVOKED. for uri in uris: _LOGGER.debug("Trying %s", uri) response = _get_ocsp_response( cert, issuer, uri, ocsp_response_cache) if response is None: # The endpoint didn't respond in time, or the response was # unsuccessful or didn't match the request, or the response # failed verification. continue _LOGGER.debug("OCSP cert status: %r", response.certificate_status) if response.certificate_status == _OCSPCertStatus.GOOD: return 1 if response.certificate_status == _OCSPCertStatus.REVOKED: return 0 # Soft fail if we couldn't get a definitive status. _LOGGER.debug("No definitive OCSP cert status, soft fail") return 1 _LOGGER.debug("Peer stapled an OCSP response") if issuer is None: _LOGGER.debug("No issuer cert?") return 0 response = _load_der_ocsp_response(ocsp_bytes) _LOGGER.debug( "OCSP response status: %r", response.response_status) # This happens in _request_ocsp when there is no stapled response so # we know if we can compare serial numbers for the request and response. if response.response_status != _OCSPResponseStatus.SUCCESSFUL: return 0 if not _verify_response(issuer, response): return 0 # Cache the verified, stapled response. ocsp_response_cache[_build_ocsp_request(cert, issuer)] = response _LOGGER.debug("OCSP cert status: %r", response.certificate_status) if response.certificate_status == _OCSPCertStatus.REVOKED: return 0 return 1 pymongo-3.11.0/pymongo/operations.py000066400000000000000000000421621374256237000175270ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Operation class definitions.""" from bson.py3compat import string_type from pymongo import helpers from pymongo.common import validate_boolean, validate_is_mapping, validate_list from pymongo.collation import validate_collation_or_none from pymongo.helpers import _gen_index_name, _index_document, _index_list class InsertOne(object): """Represents an insert_one operation.""" __slots__ = ("_doc",) def __init__(self, document): """Create an InsertOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `document`: The document to insert. If the document is missing an _id field one will be added. """ self._doc = document def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_insert(self._doc) def __repr__(self): return "InsertOne(%r)" % (self._doc,) def __eq__(self, other): if type(other) == type(self): return other._doc == self._doc return NotImplemented def __ne__(self, other): return not self == other class DeleteOne(object): """Represents a delete_one operation.""" __slots__ = ("_filter", "_collation", "_hint") def __init__(self, filter, collation=None, hint=None): """Create a DeleteOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `filter`: A query that matches the document to delete. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. .. versionchanged:: 3.11 Added the ``hint`` option. .. versionchanged:: 3.5 Added the `collation` option. """ if filter is not None: validate_is_mapping("filter", filter) if hint is not None: if not isinstance(hint, string_type): hint = helpers._index_document(hint) self._filter = filter self._collation = collation self._hint = hint def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_delete(self._filter, 1, collation=self._collation, hint=self._hint) def __repr__(self): return "DeleteOne(%r, %r)" % (self._filter, self._collation) def __eq__(self, other): if type(other) == type(self): return ((other._filter, other._collation) == (self._filter, self._collation)) return NotImplemented def __ne__(self, other): return not self == other class DeleteMany(object): """Represents a delete_many operation.""" __slots__ = ("_filter", "_collation", "_hint") def __init__(self, filter, collation=None, hint=None): """Create a DeleteMany instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `filter`: A query that matches the documents to delete. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. .. versionchanged:: 3.11 Added the ``hint`` option. .. versionchanged:: 3.5 Added the `collation` option. """ if filter is not None: validate_is_mapping("filter", filter) if hint is not None: if not isinstance(hint, string_type): hint = helpers._index_document(hint) self._filter = filter self._collation = collation self._hint = hint def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_delete(self._filter, 0, collation=self._collation, hint=self._hint) def __repr__(self): return "DeleteMany(%r, %r)" % (self._filter, self._collation) def __eq__(self, other): if type(other) == type(self): return ((other._filter, other._collation) == (self._filter, self._collation)) return NotImplemented def __ne__(self, other): return not self == other class ReplaceOne(object): """Represents a replace_one operation.""" __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") def __init__(self, filter, replacement, upsert=False, collation=None, hint=None): """Create a ReplaceOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `filter`: A query that matches the document to replace. - `replacement`: The new document. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. .. versionchanged:: 3.11 Added the ``hint`` option. .. versionchanged:: 3.5 Added the ``collation`` option. """ if filter is not None: validate_is_mapping("filter", filter) if upsert is not None: validate_boolean("upsert", upsert) if hint is not None: if not isinstance(hint, string_type): hint = helpers._index_document(hint) self._filter = filter self._doc = replacement self._upsert = upsert self._collation = collation self._hint = hint def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_replace(self._filter, self._doc, self._upsert, collation=self._collation, hint=self._hint) def __eq__(self, other): if type(other) == type(self): return ( (other._filter, other._doc, other._upsert, other._collation, other._hint) == (self._filter, self._doc, self._upsert, self._collation, other._hint)) return NotImplemented def __ne__(self, other): return not self == other def __repr__(self): return "%s(%r, %r, %r, %r, %r)" % ( self.__class__.__name__, self._filter, self._doc, self._upsert, self._collation, self._hint) class _UpdateOp(object): """Private base class for update operations.""" __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") def __init__(self, filter, doc, upsert, collation, array_filters, hint): if filter is not None: validate_is_mapping("filter", filter) if upsert is not None: validate_boolean("upsert", upsert) if array_filters is not None: validate_list("array_filters", array_filters) if hint is not None: if not isinstance(hint, string_type): hint = helpers._index_document(hint) self._filter = filter self._doc = doc self._upsert = upsert self._collation = collation self._array_filters = array_filters self._hint = hint def __eq__(self, other): if type(other) == type(self): return ( (other._filter, other._doc, other._upsert, other._collation, other._array_filters, other._hint) == (self._filter, self._doc, self._upsert, self._collation, self._array_filters, self._hint)) return NotImplemented def __ne__(self, other): return not self == other def __repr__(self): return "%s(%r, %r, %r, %r, %r, %r)" % ( self.__class__.__name__, self._filter, self._doc, self._upsert, self._collation, self._array_filters, self._hint) class UpdateOne(_UpdateOp): """Represents an update_one operation.""" __slots__ = () def __init__(self, filter, update, upsert=False, collation=None, array_filters=None, hint=None): """Represents an update_one operation. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `filter`: A query that matches the document to update. - `update`: The modifications to apply. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `array_filters` (optional): A list of filters specifying which array elements an update should apply. Requires MongoDB 3.6+. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. .. versionchanged:: 3.11 Added the `hint` option. .. versionchanged:: 3.9 Added the ability to accept a pipeline as the `update`. .. versionchanged:: 3.6 Added the `array_filters` option. .. versionchanged:: 3.5 Added the `collation` option. """ super(UpdateOne, self).__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_update(self._filter, self._doc, False, self._upsert, collation=self._collation, array_filters=self._array_filters, hint=self._hint) class UpdateMany(_UpdateOp): """Represents an update_many operation.""" __slots__ = () def __init__(self, filter, update, upsert=False, collation=None, array_filters=None, hint=None): """Create an UpdateMany instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. :Parameters: - `filter`: A query that matches the documents to update. - `update`: The modifications to apply. - `upsert` (optional): If ``True``, perform an insert if no documents match the filter. - `collation` (optional): An instance of :class:`~pymongo.collation.Collation`. This option is only supported on MongoDB 3.4 and above. - `array_filters` (optional): A list of filters specifying which array elements an update should apply. Requires MongoDB 3.6+. - `hint` (optional): An index to use to support the query predicate specified either by its string name, or in the same format as passed to :meth:`~pymongo.collection.Collection.create_index` (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. .. versionchanged:: 3.11 Added the `hint` option. .. versionchanged:: 3.9 Added the ability to accept a pipeline as the `update`. .. versionchanged:: 3.6 Added the `array_filters` option. .. versionchanged:: 3.5 Added the `collation` option. """ super(UpdateMany, self).__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" bulkobj.add_update(self._filter, self._doc, True, self._upsert, collation=self._collation, array_filters=self._array_filters, hint=self._hint) class IndexModel(object): """Represents an index to create.""" __slots__ = ("__document",) def __init__(self, keys, **kwargs): """Create an Index instance. For use with :meth:`~pymongo.collection.Collection.create_indexes`. Takes either a single key or a list of (key, direction) pairs. The key(s) must be an instance of :class:`basestring` (:class:`str` in python 3), and the direction(s) must be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOHAYSTACK`, :data:`~pymongo.GEOSPHERE`, :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). Valid options include, but are not limited to: - `name`: custom name to use for this index - if none is given, a name will be generated. - `unique`: if ``True``, creates a uniqueness constraint on the index. - `background`: if ``True``, this index should be created in the background. - `sparse`: if ``True``, omit from the index any documents that lack the indexed field. - `bucketSize`: for use with geoHaystack indexes. Number of documents to group together within a certain proximity to a given longitude and latitude. - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` index. - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` index. - `expireAfterSeconds`: Used to create an expiring (TTL) collection. MongoDB will automatically delete documents from this collection after seconds. The indexed field must be a UTC datetime or the data will not expire. - `partialFilterExpression`: A document that specifies a filter for a partial index. Requires MongoDB >= 3.2. - `collation`: An instance of :class:`~pymongo.collation.Collation` that specifies the collation to use in MongoDB >= 3.4. - `wildcardProjection`: Allows users to include or exclude specific field paths from a `wildcard index`_ using the { "$**" : 1} key pattern. Requires MongoDB >= 4.2. - `hidden`: if ``True``, this index will be hidden from the query planner and will not be evaluated as part of query plan selection. Requires MongoDB >= 4.4. See the MongoDB documentation for a full list of supported options by server version. :Parameters: - `keys`: a single key or a list of (key, direction) pairs specifying the index to create - `**kwargs` (optional): any additional index creation options (see the above list) should be passed as keyword arguments .. versionchanged:: 3.11 Added the ``hidden`` option. .. versionchanged:: 3.2 Added the ``partialFilterExpression`` option to support partial indexes. .. _wildcard index: https://docs.mongodb.com/master/core/index-wildcard/#wildcard-index-core """ keys = _index_list(keys) if "name" not in kwargs: kwargs["name"] = _gen_index_name(keys) kwargs["key"] = _index_document(keys) collation = validate_collation_or_none(kwargs.pop('collation', None)) self.__document = kwargs if collation is not None: self.__document['collation'] = collation @property def document(self): """An index document suitable for passing to the createIndexes command. """ return self.__document pymongo-3.11.0/pymongo/periodic_executor.py000066400000000000000000000134201374256237000210530ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Run a target function on a background thread.""" import threading import time import weakref from pymongo.monotonic import time as _time class PeriodicExecutor(object): def __init__(self, interval, min_interval, target, name=None): """"Run a target function periodically on a background thread. If the target's return value is false, the executor stops. :Parameters: - `interval`: Seconds between calls to `target`. - `min_interval`: Minimum seconds between calls if `wake` is called very often. - `target`: A function. - `name`: A name to give the underlying thread. """ # threading.Event and its internal condition variable are expensive # in Python 2, see PYTHON-983. Use a boolean to know when to wake. # The executor's design is constrained by several Python issues, see # "periodic_executor.rst" in this repository. self._event = False self._interval = interval self._min_interval = min_interval self._target = target self._stopped = False self._thread = None self._name = name self._skip_sleep = False self._thread_will_exit = False self._lock = threading.Lock() def __repr__(self): return '<%s(name=%s) object at 0x%x>' % ( self.__class__.__name__, self._name, id(self)) def open(self): """Start. Multiple calls have no effect. Not safe to call from multiple threads at once. """ with self._lock: if self._thread_will_exit: # If the background thread has read self._stopped as True # there is a chance that it has not yet exited. The call to # join should not block indefinitely because there is no # other work done outside the while loop in self._run. try: self._thread.join() except ReferenceError: # Thread terminated. pass self._thread_will_exit = False self._stopped = False started = False try: started = self._thread and self._thread.is_alive() except ReferenceError: # Thread terminated. pass if not started: thread = threading.Thread(target=self._run, name=self._name) thread.daemon = True self._thread = weakref.proxy(thread) _register_executor(self) thread.start() def close(self, dummy=None): """Stop. To restart, call open(). The dummy parameter allows an executor's close method to be a weakref callback; see monitor.py. """ self._stopped = True def join(self, timeout=None): if self._thread is not None: try: self._thread.join(timeout) except (ReferenceError, RuntimeError): # Thread already terminated, or not yet started. pass def wake(self): """Execute the target function soon.""" self._event = True def update_interval(self, new_interval): self._interval = new_interval def skip_sleep(self): self._skip_sleep = True def __should_stop(self): with self._lock: if self._stopped: self._thread_will_exit = True return True return False def _run(self): while not self.__should_stop(): try: if not self._target(): self._stopped = True break except: with self._lock: self._stopped = True self._thread_will_exit = True raise if self._skip_sleep: self._skip_sleep = False else: deadline = _time() + self._interval while not self._stopped and _time() < deadline: time.sleep(self._min_interval) if self._event: break # Early wake. self._event = False # _EXECUTORS has a weakref to each running PeriodicExecutor. Once started, # an executor is kept alive by a strong reference from its thread and perhaps # from other objects. When the thread dies and all other referrers are freed, # the executor is freed and removed from _EXECUTORS. If any threads are # running when the interpreter begins to shut down, we try to halt and join # them to avoid spurious errors. _EXECUTORS = set() def _register_executor(executor): ref = weakref.ref(executor, _on_executor_deleted) _EXECUTORS.add(ref) def _on_executor_deleted(ref): _EXECUTORS.remove(ref) def _shutdown_executors(): if _EXECUTORS is None: return # Copy the set. Stopping threads has the side effect of removing executors. executors = list(_EXECUTORS) # First signal all executors to close... for ref in executors: executor = ref() if executor: executor.close() # ...then try to join them. for ref in executors: executor = ref() if executor: executor.join(1) executor = None pymongo-3.11.0/pymongo/pool.py000066400000000000000000001520141374256237000163130ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import contextlib import copy import os import platform import socket import sys import threading import collections from pymongo.ssl_support import ( SSLError as _SSLError, HAS_SNI as _HAVE_SNI, IPADDR_SAFE as _IPADDR_SAFE) from bson import DEFAULT_CODEC_OPTIONS from bson.py3compat import imap, itervalues, _unicode from bson.son import SON from pymongo import auth, helpers, thread_util, __version__ from pymongo.client_session import _validate_session_write_concern from pymongo.common import (MAX_BSON_SIZE, MAX_IDLE_TIME_SEC, MAX_MESSAGE_SIZE, MAX_POOL_SIZE, MAX_WIRE_VERSION, MAX_WRITE_BATCH_SIZE, MIN_POOL_SIZE, ORDERED_TYPES, WAIT_QUEUE_TIMEOUT) from pymongo.errors import (AutoReconnect, CertificateError, ConnectionFailure, ConfigurationError, InvalidOperation, DocumentTooLarge, NetworkTimeout, NotMasterError, OperationFailure, PyMongoError) from pymongo.ismaster import IsMaster from pymongo.monotonic import time as _time from pymongo.monitoring import (ConnectionCheckOutFailedReason, ConnectionClosedReason) from pymongo.network import (command, receive_message) from pymongo.read_preferences import ReadPreference from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker # Always use our backport so we always have support for IP address matching from pymongo.ssl_match_hostname import match_hostname # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. try: from ipaddress import ip_address def is_ip_address(address): try: ip_address(_unicode(address)) return True except (ValueError, UnicodeError): return False except ImportError: if hasattr(socket, 'inet_pton') and socket.has_ipv6: # Most *nix, recent Windows def is_ip_address(address): try: # inet_pton rejects IPv4 literals with leading zeros # (e.g. 192.168.0.01), inet_aton does not, and we # can connect to them without issue. Use inet_aton. socket.inet_aton(address) return True except socket.error: try: socket.inet_pton(socket.AF_INET6, address) return True except socket.error: return False else: # No inet_pton def is_ip_address(address): try: socket.inet_aton(address) return True except socket.error: if ':' in address: # ':' is not a valid character for a hostname. If we get # here a few things have to be true: # - We're on a recent version of python 2.7 (2.7.9+). # Older 2.7 versions don't support SNI. # - We're on Windows XP or some unusual Unix that doesn't # have inet_pton. # - The application is using IPv6 literals with TLS, which # is pretty unusual. return True return False try: from fcntl import fcntl, F_GETFD, F_SETFD, FD_CLOEXEC def _set_non_inheritable_non_atomic(fd): """Set the close-on-exec flag on the given file descriptor.""" flags = fcntl(fd, F_GETFD) fcntl(fd, F_SETFD, flags | FD_CLOEXEC) except ImportError: # Windows, various platforms we don't claim to support # (Jython, IronPython, ...), systems that don't provide # everything we need from fcntl, etc. def _set_non_inheritable_non_atomic(dummy): """Dummy function for platforms that don't provide fcntl.""" pass _MAX_TCP_KEEPIDLE = 120 _MAX_TCP_KEEPINTVL = 10 _MAX_TCP_KEEPCNT = 9 if sys.platform == 'win32': try: import _winreg as winreg except ImportError: import winreg def _query(key, name, default): try: value, _ = winreg.QueryValueEx(key, name) # Ensure the value is a number or raise ValueError. return int(value) except (OSError, ValueError): # QueryValueEx raises OSError when the key does not exist (i.e. # the system is using the Windows default value). return default try: with winreg.OpenKey( winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters") as key: _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) except OSError: # We could not check the default values because winreg.OpenKey failed. # Assume the system is using the default values. _WINDOWS_TCP_IDLE_MS = 7200000 _WINDOWS_TCP_INTERVAL_MS = 1000 def _set_keepalive_times(sock): idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) if (idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS): sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) else: def _set_tcp_option(sock, tcp_option, max_value): if hasattr(socket, tcp_option): sockopt = getattr(socket, tcp_option) try: # PYTHON-1350 - NetBSD doesn't implement getsockopt for # TCP_KEEPIDLE and friends. Don't attempt to set the # values there. default = sock.getsockopt(socket.IPPROTO_TCP, sockopt) if default > max_value: sock.setsockopt(socket.IPPROTO_TCP, sockopt, max_value) except socket.error: pass def _set_keepalive_times(sock): _set_tcp_option(sock, 'TCP_KEEPIDLE', _MAX_TCP_KEEPIDLE) _set_tcp_option(sock, 'TCP_KEEPINTVL', _MAX_TCP_KEEPINTVL) _set_tcp_option(sock, 'TCP_KEEPCNT', _MAX_TCP_KEEPCNT) _METADATA = SON([ ('driver', SON([('name', 'PyMongo'), ('version', __version__)])), ]) if sys.platform.startswith('linux'): # platform.linux_distribution was deprecated in Python 3.5. if sys.version_info[:2] < (3, 5): # Distro name and version (e.g. Ubuntu 16.04 xenial) _name = ' '.join([part for part in platform.linux_distribution() if part]) else: _name = platform.system() _METADATA['os'] = SON([ ('type', platform.system()), ('name', _name), ('architecture', platform.machine()), # Kernel version (e.g. 4.4.0-17-generic). ('version', platform.release()) ]) elif sys.platform == 'darwin': _METADATA['os'] = SON([ ('type', platform.system()), ('name', platform.system()), ('architecture', platform.machine()), # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin # kernel version. ('version', platform.mac_ver()[0]) ]) elif sys.platform == 'win32': _METADATA['os'] = SON([ ('type', platform.system()), # "Windows XP", "Windows 7", "Windows 10", etc. ('name', ' '.join((platform.system(), platform.release()))), ('architecture', platform.machine()), # Windows patch level (e.g. 5.1.2600-SP3) ('version', '-'.join(platform.win32_ver()[1:3])) ]) elif sys.platform.startswith('java'): _name, _ver, _arch = platform.java_ver()[-1] _METADATA['os'] = SON([ # Linux, Windows 7, Mac OS X, etc. ('type', _name), ('name', _name), # x86, x86_64, AMD64, etc. ('architecture', _arch), # Linux kernel version, OSX version, etc. ('version', _ver) ]) else: # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) _aliased = platform.system_alias( platform.system(), platform.release(), platform.version()) _METADATA['os'] = SON([ ('type', platform.system()), ('name', ' '.join([part for part in _aliased[:2] if part])), ('architecture', platform.machine()), ('version', _aliased[2]) ]) if platform.python_implementation().startswith('PyPy'): _METADATA['platform'] = ' '.join( (platform.python_implementation(), '.'.join(imap(str, sys.pypy_version_info)), '(Python %s)' % '.'.join(imap(str, sys.version_info)))) elif sys.platform.startswith('java'): _METADATA['platform'] = ' '.join( (platform.python_implementation(), '.'.join(imap(str, sys.version_info)), '(%s)' % ' '.join((platform.system(), platform.release())))) else: _METADATA['platform'] = ' '.join( (platform.python_implementation(), '.'.join(imap(str, sys.version_info)))) # If the first getaddrinfo call of this interpreter's life is on a thread, # while the main thread holds the import lock, getaddrinfo deadlocks trying # to import the IDNA codec. Import it here, where presumably we're on the # main thread, to avoid the deadlock. See PYTHON-607. u'foo'.encode('idna') def _raise_connection_failure(address, error, msg_prefix=None): """Convert a socket.error to ConnectionFailure and raise it.""" host, port = address # If connecting to a Unix socket, port will be None. if port is not None: msg = '%s:%d: %s' % (host, port, error) else: msg = '%s: %s' % (host, error) if msg_prefix: msg = msg_prefix + msg if isinstance(error, socket.timeout): raise NetworkTimeout(msg) elif isinstance(error, _SSLError) and 'timed out' in str(error): # CPython 2.7 and PyPy 2.x do not distinguish network # timeouts from other SSLErrors (https://bugs.python.org/issue10272). # Luckily, we can work around this limitation because the phrase # 'timed out' appears in all the timeout related SSLErrors raised # on the above platforms. raise NetworkTimeout(msg) else: raise AutoReconnect(msg) class PoolOptions(object): __slots__ = ('__max_pool_size', '__min_pool_size', '__max_idle_time_seconds', '__connect_timeout', '__socket_timeout', '__wait_queue_timeout', '__wait_queue_multiple', '__ssl_context', '__ssl_match_hostname', '__socket_keepalive', '__event_listeners', '__appname', '__driver', '__metadata', '__compression_settings') def __init__(self, max_pool_size=MAX_POOL_SIZE, min_pool_size=MIN_POOL_SIZE, max_idle_time_seconds=MAX_IDLE_TIME_SEC, connect_timeout=None, socket_timeout=None, wait_queue_timeout=WAIT_QUEUE_TIMEOUT, wait_queue_multiple=None, ssl_context=None, ssl_match_hostname=True, socket_keepalive=True, event_listeners=None, appname=None, driver=None, compression_settings=None): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size self.__max_idle_time_seconds = max_idle_time_seconds self.__connect_timeout = connect_timeout self.__socket_timeout = socket_timeout self.__wait_queue_timeout = wait_queue_timeout self.__wait_queue_multiple = wait_queue_multiple self.__ssl_context = ssl_context self.__ssl_match_hostname = ssl_match_hostname self.__socket_keepalive = socket_keepalive self.__event_listeners = event_listeners self.__appname = appname self.__driver = driver self.__compression_settings = compression_settings self.__metadata = copy.deepcopy(_METADATA) if appname: self.__metadata['application'] = {'name': appname} # Combine the "driver" MongoClient option with PyMongo's info, like: # { # 'driver': { # 'name': 'PyMongo|MyDriver', # 'version': '3.7.0|1.2.3', # }, # 'platform': 'CPython 3.6.0|MyPlatform' # } if driver: if driver.name: self.__metadata['driver']['name'] = "%s|%s" % ( _METADATA['driver']['name'], driver.name) if driver.version: self.__metadata['driver']['version'] = "%s|%s" % ( _METADATA['driver']['version'], driver.version) if driver.platform: self.__metadata['platform'] = "%s|%s" % ( _METADATA['platform'], driver.platform) @property def non_default_options(self): """The non-default options this pool was created with. Added for CMAP's :class:`PoolCreatedEvent`. """ opts = {} if self.__max_pool_size != MAX_POOL_SIZE: opts['maxPoolSize'] = self.__max_pool_size if self.__min_pool_size != MIN_POOL_SIZE: opts['minPoolSize'] = self.__min_pool_size if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: opts['maxIdleTimeMS'] = self.__max_idle_time_seconds * 1000 if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: opts['waitQueueTimeoutMS'] = self.__wait_queue_timeout * 1000 return opts @property def max_pool_size(self): """The maximum allowable number of concurrent connections to each connected server. Requests to a server will block if there are `maxPoolSize` outstanding connections to the requested server. Defaults to 100. Cannot be 0. When a server's pool has reached `max_pool_size`, operations for that server block waiting for a socket to be returned to the pool. If ``waitQueueTimeoutMS`` is set, a blocked operation will raise :exc:`~pymongo.errors.ConnectionFailure` after a timeout. By default ``waitQueueTimeoutMS`` is not set. """ return self.__max_pool_size @property def min_pool_size(self): """The minimum required number of concurrent connections that the pool will maintain to each connected server. Default is 0. """ return self.__min_pool_size @property def max_idle_time_seconds(self): """The maximum number of seconds that a connection can remain idle in the pool before being removed and replaced. Defaults to `None` (no limit). """ return self.__max_idle_time_seconds @property def connect_timeout(self): """How long a connection can take to be opened before timing out. """ return self.__connect_timeout @property def socket_timeout(self): """How long a send or receive on a socket can take before timing out. """ return self.__socket_timeout @property def wait_queue_timeout(self): """How long a thread will wait for a socket from the pool if the pool has no free sockets. """ return self.__wait_queue_timeout @property def wait_queue_multiple(self): """Multiplied by max_pool_size to give the number of threads allowed to wait for a socket at one time. """ return self.__wait_queue_multiple @property def ssl_context(self): """An SSLContext instance or None. """ return self.__ssl_context @property def ssl_match_hostname(self): """Call ssl.match_hostname if cert_reqs is not ssl.CERT_NONE. """ return self.__ssl_match_hostname @property def socket_keepalive(self): """Whether to send periodic messages to determine if a connection is closed. """ return self.__socket_keepalive @property def event_listeners(self): """An instance of pymongo.monitoring._EventListeners. """ return self.__event_listeners @property def appname(self): """The application name, for sending with ismaster in server handshake. """ return self.__appname @property def driver(self): """Driver name and version, for sending with ismaster in handshake. """ return self.__driver @property def compression_settings(self): return self.__compression_settings @property def metadata(self): """A dict of metadata about the application, driver, os, and platform. """ return self.__metadata.copy() def _negotiate_creds(all_credentials): """Return one credential that needs mechanism negotiation, if any. """ if all_credentials: for creds in all_credentials.values(): if creds.mechanism == 'DEFAULT' and creds.username: return creds return None def _speculative_context(all_credentials): """Return the _AuthContext to use for speculative auth, if any. """ if all_credentials and len(all_credentials) == 1: creds = next(itervalues(all_credentials)) return auth._AuthContext.from_credentials(creds) return None class _CancellationContext(object): def __init__(self): self._cancelled = False def cancel(self): """Cancel this context.""" self._cancelled = True @property def cancelled(self): """Was cancel called?""" return self._cancelled class SocketInfo(object): """Store a socket with some metadata. :Parameters: - `sock`: a raw socket object - `pool`: a Pool instance - `address`: the server's (host, port) - `id`: the id of this socket in it's pool """ def __init__(self, sock, pool, address, id): self.sock = sock self.address = address self.id = id self.authset = set() self.closed = False self.last_checkin_time = _time() self.performed_handshake = False self.is_writable = False self.max_wire_version = MAX_WIRE_VERSION self.max_bson_size = MAX_BSON_SIZE self.max_message_size = MAX_MESSAGE_SIZE self.max_write_batch_size = MAX_WRITE_BATCH_SIZE self.supports_sessions = False self.is_mongos = False self.op_msg_enabled = False self.listeners = pool.opts.event_listeners self.enabled_for_cmap = pool.enabled_for_cmap self.compression_settings = pool.opts.compression_settings self.compression_context = None self.socket_checker = SocketChecker() # Support for mechanism negotiation on the initial handshake. # Maps credential to saslSupportedMechs. self.negotiated_mechanisms = {} self.auth_ctx = {} # The pool's generation changes with each reset() so we can close # sockets created before the last reset. self.generation = pool.generation self.ready = False self.cancel_context = None if not pool.handshake: # This is a Monitor connection. self.cancel_context = _CancellationContext() self.opts = pool.opts self.more_to_come = False def ismaster(self, all_credentials=None): return self._ismaster(None, None, None, all_credentials) def _ismaster(self, cluster_time, topology_version, heartbeat_frequency, all_credentials): cmd = SON([('ismaster', 1)]) performing_handshake = not self.performed_handshake awaitable = False if performing_handshake: self.performed_handshake = True cmd['client'] = self.opts.metadata if self.compression_settings: cmd['compression'] = self.compression_settings.compressors elif topology_version is not None: cmd['topologyVersion'] = topology_version cmd['maxAwaitTimeMS'] = int(heartbeat_frequency*1000) awaitable = True # If connect_timeout is None there is no timeout. if self.opts.connect_timeout: self.sock.settimeout( self.opts.connect_timeout + heartbeat_frequency) if self.max_wire_version >= 6 and cluster_time is not None: cmd['$clusterTime'] = cluster_time # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single # unchangeable value per MongoClient. creds = _negotiate_creds(all_credentials) if creds: cmd['saslSupportedMechs'] = creds.source + '.' + creds.username auth_ctx = _speculative_context(all_credentials) if auth_ctx: cmd['speculativeAuthenticate'] = auth_ctx.speculate_command() doc = self.command('admin', cmd, publish_events=False, exhaust_allowed=awaitable) ismaster = IsMaster(doc, awaitable=awaitable) self.is_writable = ismaster.is_writable self.max_wire_version = ismaster.max_wire_version self.max_bson_size = ismaster.max_bson_size self.max_message_size = ismaster.max_message_size self.max_write_batch_size = ismaster.max_write_batch_size self.supports_sessions = ( ismaster.logical_session_timeout_minutes is not None) self.is_mongos = ismaster.server_type == SERVER_TYPE.Mongos if performing_handshake and self.compression_settings: ctx = self.compression_settings.get_compression_context( ismaster.compressors) self.compression_context = ctx self.op_msg_enabled = ismaster.max_wire_version >= 6 if creds: self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs if auth_ctx: auth_ctx.parse_response(ismaster) if auth_ctx.speculate_succeeded(): self.auth_ctx[auth_ctx.credentials] = auth_ctx return ismaster def _next_reply(self): reply = self.receive_message(None) self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] helpers._check_command_response(response_doc, self.max_wire_version) return response_doc def command(self, dbname, spec, slave_ok=False, read_preference=ReadPreference.PRIMARY, codec_options=DEFAULT_CODEC_OPTIONS, check=True, allowable_errors=None, check_keys=False, read_concern=None, write_concern=None, parse_write_concern_error=False, collation=None, session=None, client=None, retryable_write=False, publish_events=True, user_fields=None, exhaust_allowed=False): """Execute a command or raise an error. :Parameters: - `dbname`: name of the database on which to run the command - `spec`: a command document as a dict, SON, or mapping object - `slave_ok`: whether to set the SlaveOkay wire protocol bit - `read_preference`: a read preference - `codec_options`: a CodecOptions instance - `check`: raise OperationFailure if there are errors - `allowable_errors`: errors to ignore if `check` is True - `check_keys`: if True, check `spec` for invalid keys - `read_concern`: The read concern for this command. - `write_concern`: The write concern for this command. - `parse_write_concern_error`: Whether to parse the ``writeConcernError`` field in the command response. - `collation`: The collation for this command. - `session`: optional ClientSession instance. - `client`: optional MongoClient for gossipping $clusterTime. - `retryable_write`: True if this command is a retryable write. - `publish_events`: Should we publish events for this command? - `user_fields` (optional): Response fields that should be decoded using the TypeDecoders from codec_options, passed to bson._decode_all_selective. """ self.validate_session(client, session) session = _validate_session_write_concern(session, write_concern) # Ensure command name remains in first place. if not isinstance(spec, ORDERED_TYPES): spec = SON(spec) if (read_concern and self.max_wire_version < 4 and not read_concern.ok_for_legacy): raise ConfigurationError( 'read concern level of %s is not valid ' 'with a max wire version of %d.' % (read_concern.level, self.max_wire_version)) if not (write_concern is None or write_concern.acknowledged or collation is None): raise ConfigurationError( 'Collation is unsupported for unacknowledged writes.') if (self.max_wire_version >= 5 and write_concern and not write_concern.is_server_default): spec['writeConcern'] = write_concern.document elif self.max_wire_version < 5 and collation is not None: raise ConfigurationError( 'Must be connected to MongoDB 3.4+ to use a collation.') if session: session._apply_to(spec, retryable_write, read_preference) self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = write_concern and not write_concern.acknowledged if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: return command(self, dbname, spec, slave_ok, self.is_mongos, read_preference, codec_options, session, client, check, allowable_errors, self.address, check_keys, listeners, self.max_bson_size, read_concern, parse_write_concern_error=parse_write_concern_error, collation=collation, compression_ctx=self.compression_context, use_op_msg=self.op_msg_enabled, unacknowledged=unacknowledged, user_fields=user_fields, exhaust_allowed=exhaust_allowed) except OperationFailure: raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. except BaseException as error: self._raise_connection_failure(error) def send_message(self, message, max_doc_size): """Send a raw BSON message or raise ConnectionFailure. If a network exception is raised, the socket is closed. """ if (self.max_bson_size is not None and max_doc_size > self.max_bson_size): raise DocumentTooLarge( "BSON document too large (%d bytes) - the connected server " "supports BSON document sizes up to %d bytes." % (max_doc_size, self.max_bson_size)) try: self.sock.sendall(message) except BaseException as error: self._raise_connection_failure(error) def receive_message(self, request_id): """Receive a raw BSON message or raise ConnectionFailure. If any exception is raised, the socket is closed. """ try: return receive_message(self, request_id, self.max_message_size) except BaseException as error: self._raise_connection_failure(error) def _raise_if_not_writable(self, unacknowledged): """Raise NotMasterError on unacknowledged write if this socket is not writable. """ if unacknowledged and not self.is_writable: # Write won't succeed, bail as if we'd received a not master error. raise NotMasterError("not master", { "ok": 0, "errmsg": "not master", "code": 10107}) def legacy_write(self, request_id, msg, max_doc_size, with_last_error): """Send OP_INSERT, etc., optionally returning response as a dict. Can raise ConnectionFailure or OperationFailure. :Parameters: - `request_id`: an int. - `msg`: bytes, an OP_INSERT, OP_UPDATE, or OP_DELETE message, perhaps with a getlasterror command appended. - `max_doc_size`: size in bytes of the largest document in `msg`. - `with_last_error`: True if a getlasterror command is appended. """ self._raise_if_not_writable(not with_last_error) self.send_message(msg, max_doc_size) if with_last_error: reply = self.receive_message(request_id) return helpers._check_gle_response(reply.command_response(), self.max_wire_version) def write_command(self, request_id, msg): """Send "insert" etc. command, returning response as a dict. Can raise ConnectionFailure or OperationFailure. :Parameters: - `request_id`: an int. - `msg`: bytes, the command message. """ self.send_message(msg, 0) reply = self.receive_message(request_id) result = reply.command_response() # Raises NotMasterError or OperationFailure. helpers._check_command_response(result, self.max_wire_version) return result def check_auth(self, all_credentials): """Update this socket's authentication. Log in or out to bring this socket's credentials up to date with those provided. Can raise ConnectionFailure or OperationFailure. :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. """ if all_credentials or self.authset: cached = set(itervalues(all_credentials)) authset = self.authset.copy() # Logout any credentials that no longer exist in the cache. for credentials in authset - cached: auth.logout(credentials.source, self) self.authset.discard(credentials) for credentials in cached - authset: self.authenticate(credentials) # CMAP spec says to publish the ready event only after authenticating # the connection. if not self.ready: self.ready = True if self.enabled_for_cmap: self.listeners.publish_connection_ready(self.address, self.id) def authenticate(self, credentials): """Log in to the server and store these credentials in `authset`. Can raise ConnectionFailure or OperationFailure. :Parameters: - `credentials`: A MongoCredential. """ auth.authenticate(credentials, self) self.authset.add(credentials) # negotiated_mechanisms are no longer needed. self.negotiated_mechanisms.pop(credentials, None) self.auth_ctx.pop(credentials, None) def validate_session(self, client, session): """Validate this session before use with client. Raises error if this session is logged in as a different user or the client is not the one that created the session. """ if session: if session._client is not client: raise InvalidOperation( 'Can only use session with the MongoClient that' ' started it') if session._authset != self.authset: raise InvalidOperation( 'Cannot use session after authenticating with different' ' credentials') def close_socket(self, reason): """Close this connection with a reason.""" if self.closed: return self._close_socket() if reason and self.enabled_for_cmap: self.listeners.publish_connection_closed( self.address, self.id, reason) def _close_socket(self): """Close this connection.""" if self.closed: return self.closed = True if self.cancel_context: self.cancel_context.cancel() # Note: We catch exceptions to avoid spurious errors on interpreter # shutdown. try: self.sock.close() except Exception: pass def socket_closed(self): """Return True if we know socket has been closed, False otherwise.""" return self.socket_checker.socket_closed(self.sock) def send_cluster_time(self, command, session, client): """Add cluster time for MongoDB >= 3.6.""" if self.max_wire_version >= 6 and client: client._send_cluster_time(command, session) def update_last_checkin_time(self): self.last_checkin_time = _time() def update_is_writable(self, is_writable): self.is_writable = is_writable def idle_time_seconds(self): """Seconds since this socket was last checked into its pool.""" return _time() - self.last_checkin_time def _raise_connection_failure(self, error): # Catch *all* exceptions from socket methods and close the socket. In # regular Python, socket operations only raise socket.error, even if # the underlying cause was a Ctrl-C: a signal raised during socket.recv # is expressed as an EINTR error from poll. See internal_select_ex() in # socketmodule.c. All error codes from poll become socket.error at # first. Eventually in PyEval_EvalFrameEx the interpreter checks for # signals and throws KeyboardInterrupt into the current frame on the # main thread. # # But in Gevent and Eventlet, the polling mechanism (epoll, kqueue, # ...) is called in Python code, which experiences the signal as a # KeyboardInterrupt from the start, rather than as an initial # socket.error, so we catch that, close the socket, and reraise it. self.close_socket(ConnectionClosedReason.ERROR) # SSLError from PyOpenSSL inherits directly from Exception. if isinstance(error, (IOError, OSError, _SSLError)): _raise_connection_failure(self.address, error) else: raise def __eq__(self, other): return self.sock == other.sock def __ne__(self, other): return not self == other def __hash__(self): return hash(self.sock) def __repr__(self): return "SocketInfo(%s)%s at %s" % ( repr(self.sock), self.closed and " CLOSED" or "", id(self) ) def _create_connection(address, options): """Given (host, port) and PoolOptions, connect and return a socket object. Can raise socket.error. This is a modified version of create_connection from CPython >= 2.7. """ host, port = address # Check if dealing with a unix domain socket if host.endswith('.sock'): if not hasattr(socket, "AF_UNIX"): raise ConnectionFailure("UNIX-sockets are not supported " "on this system") sock = socket.socket(socket.AF_UNIX) # SOCK_CLOEXEC not supported for Unix sockets. _set_non_inheritable_non_atomic(sock.fileno()) try: sock.connect(host) return sock except socket.error: sock.close() raise # Don't try IPv6 if we don't support it. Also skip it if host # is 'localhost' (::1 is fine). Avoids slow connect issues # like PYTHON-356. family = socket.AF_INET if socket.has_ipv6 and host != 'localhost': family = socket.AF_UNSPEC err = None for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM): af, socktype, proto, dummy, sa = res # SOCK_CLOEXEC was new in CPython 3.2, and only available on a limited # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 # all file descriptors are created non-inheritable. See PEP 446. try: sock = socket.socket( af, socktype | getattr(socket, 'SOCK_CLOEXEC', 0), proto) except socket.error: # Can SOCK_CLOEXEC be defined even if the kernel doesn't support # it? sock = socket.socket(af, socktype, proto) # Fallback when SOCK_CLOEXEC isn't available. _set_non_inheritable_non_atomic(sock.fileno()) try: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.settimeout(options.connect_timeout) sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, options.socket_keepalive) if options.socket_keepalive: _set_keepalive_times(sock) sock.connect(sa) return sock except socket.error as e: err = e sock.close() if err is not None: raise err else: # This likely means we tried to connect to an IPv6 only # host with an OS/kernel or Python interpreter that doesn't # support IPv6. The test case is Jython2.5.1 which doesn't # support IPv6 at all. raise socket.error('getaddrinfo failed') def _configured_socket(address, options): """Given (host, port) and PoolOptions, return a configured socket. Can raise socket.error, ConnectionFailure, or CertificateError. Sets socket's SSL and timeout options. """ sock = _create_connection(address, options) ssl_context = options.ssl_context if ssl_context is not None: host = address[0] try: # According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. # Previous to Python 3.7 wrap_socket would blindly pass # IP addresses as SNI hostname. # https://bugs.python.org/issue32185 # We have to pass hostname / ip address to wrap_socket # to use SSLContext.check_hostname. if _HAVE_SNI and (not is_ip_address(host) or _IPADDR_SAFE): sock = ssl_context.wrap_socket(sock, server_hostname=host) else: sock = ssl_context.wrap_socket(sock) except CertificateError: sock.close() # Raise CertificateError directly like we do after match_hostname # below. raise except (IOError, OSError, _SSLError) as exc: sock.close() # We raise AutoReconnect for transient and permanent SSL handshake # failures alike. Permanent handshake failures, like protocol # mismatch, will be turned into ServerSelectionTimeoutErrors later. _raise_connection_failure(address, exc, "SSL handshake failed: ") if (ssl_context.verify_mode and not getattr(ssl_context, "check_hostname", False) and options.ssl_match_hostname): try: match_hostname(sock.getpeercert(), hostname=host) except CertificateError: sock.close() raise sock.settimeout(options.socket_timeout) return sock class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. """ pass # Do *not* explicitly inherit from object or Jython won't call __del__ # http://bugs.jython.org/issue1057 class Pool: def __init__(self, address, options, handshake=True): """ :Parameters: - `address`: a (hostname, port) tuple - `options`: a PoolOptions instance - `handshake`: whether to call ismaster for each new SocketInfo """ # Check a socket's health with socket_closed() every once in a while. # Can override for testing: 0 to always check, None to never check. self._check_interval_seconds = 1 # LIFO pool. Sockets are ordered on idle time. Sockets claimed # and returned to pool from the left side. Stale sockets removed # from the right side. self.sockets = collections.deque() self.lock = threading.Lock() self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 self.closed = False # Track whether the sockets in this pool are writeable or not. self.is_writable = None # Keep track of resets, so we notice sockets created before the most # recent reset and close them. self.generation = 0 self.pid = os.getpid() self.address = address self.opts = options self.handshake = handshake # Don't publish events in Monitor pools. self.enabled_for_cmap = ( self.handshake and self.opts.event_listeners is not None and self.opts.event_listeners.enabled_for_cmap) if (self.opts.wait_queue_multiple is None or self.opts.max_pool_size is None): max_waiters = None else: max_waiters = ( self.opts.max_pool_size * self.opts.wait_queue_multiple) self._socket_semaphore = thread_util.create_semaphore( self.opts.max_pool_size, max_waiters) if self.enabled_for_cmap: self.opts.event_listeners.publish_pool_created( self.address, self.opts.non_default_options) def _reset(self, close): with self.lock: if self.closed: return self.generation += 1 self.pid = os.getpid() sockets, self.sockets = self.sockets, collections.deque() self.active_sockets = 0 if close: self.closed = True listeners = self.opts.event_listeners # CMAP spec says that close() MUST close sockets before publishing the # PoolClosedEvent but that reset() SHOULD close sockets *after* # publishing the PoolClearedEvent. if close: for sock_info in sockets: sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) if self.enabled_for_cmap: listeners.publish_pool_closed(self.address) else: if self.enabled_for_cmap: listeners.publish_pool_cleared(self.address) for sock_info in sockets: sock_info.close_socket(ConnectionClosedReason.STALE) def update_is_writable(self, is_writable): """Updates the is_writable attribute on all sockets currently in the Pool. """ self.is_writable = is_writable with self.lock: for socket in self.sockets: socket.update_is_writable(self.is_writable) def reset(self): self._reset(close=False) def close(self): self._reset(close=True) def remove_stale_sockets(self, reference_generation, all_credentials): """Removes stale sockets then adds new ones if pool is too small and has not been reset. The `reference_generation` argument specifies the `generation` at the point in time this operation was requested on the pool. """ if self.opts.max_idle_time_seconds is not None: with self.lock: while (self.sockets and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds): sock_info = self.sockets.pop() sock_info.close_socket(ConnectionClosedReason.IDLE) while True: with self.lock: if (len(self.sockets) + self.active_sockets >= self.opts.min_pool_size): # There are enough sockets in the pool. break # We must acquire the semaphore to respect max_pool_size. if not self._socket_semaphore.acquire(False): break try: sock_info = self.connect(all_credentials) with self.lock: # Close connection and return if the pool was reset during # socket creation or while acquiring the pool lock. if self.generation != reference_generation: sock_info.close_socket(ConnectionClosedReason.STALE) break self.sockets.appendleft(sock_info) finally: self._socket_semaphore.release() def connect(self, all_credentials=None): """Connect to Mongo and return a new SocketInfo. Can raise ConnectionFailure or CertificateError. Note that the pool does not keep a reference to the socket -- you must call return_socket() when you're done with it. """ with self.lock: conn_id = self.next_connection_id self.next_connection_id += 1 listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_created(self.address, conn_id) try: sock = _configured_socket(self.address, self.opts) except Exception as error: if self.enabled_for_cmap: listeners.publish_connection_closed( self.address, conn_id, ConnectionClosedReason.ERROR) if isinstance(error, (IOError, OSError, _SSLError)): _raise_connection_failure(self.address, error) raise sock_info = SocketInfo(sock, self, self.address, conn_id) if self.handshake: sock_info.ismaster(all_credentials) self.is_writable = sock_info.is_writable return sock_info @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): """Get a socket from the pool. Use with a "with" statement. Returns a :class:`SocketInfo` object wrapping a connected :class:`socket.socket`. This method should always be used in a with-statement:: with pool.get_socket(credentials, checkout) as socket_info: socket_info.send_message(msg) data = socket_info.receive_message(op_code, request_id) The socket is logged in or out as needed to match ``all_credentials`` using the correct authentication mechanism for the server's wire protocol version. Can raise ConnectionFailure or OperationFailure. :Parameters: - `all_credentials`: dict, maps auth source to MongoCredential. - `checkout` (optional): keep socket checked out. """ listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_started(self.address) sock_info = self._get_socket(all_credentials) if self.enabled_for_cmap: listeners.publish_connection_checked_out( self.address, sock_info.id) try: yield sock_info except: # Exception in caller. Decrement semaphore. self.return_socket(sock_info) raise else: if not checkout: self.return_socket(sock_info) def _get_socket(self, all_credentials): """Get or create a SocketInfo. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise if self.pid != os.getpid(): self.reset() if self.closed: if self.enabled_for_cmap: self.opts.event_listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.POOL_CLOSED) raise _PoolClosedError( 'Attempted to check out a connection from closed connection ' 'pool') # Get a free socket or create one. if not self._socket_semaphore.acquire( True, self.opts.wait_queue_timeout): self._raise_wait_queue_timeout() # We've now acquired the semaphore and must release it on error. sock_info = None incremented = False try: with self.lock: self.active_sockets += 1 incremented = True while sock_info is None: try: with self.lock: sock_info = self.sockets.popleft() except IndexError: # Can raise ConnectionFailure or CertificateError. sock_info = self.connect(all_credentials) else: if self._perished(sock_info): sock_info = None sock_info.check_auth(all_credentials) except Exception: if sock_info: # We checked out a socket but authentication failed. sock_info.close_socket(ConnectionClosedReason.ERROR) self._socket_semaphore.release() if incremented: with self.lock: self.active_sockets -= 1 if self.enabled_for_cmap: self.opts.event_listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.CONN_ERROR) raise return sock_info def return_socket(self, sock_info): """Return the socket to the pool, or if it's closed discard it. :Parameters: - `sock_info`: The socket to check into the pool. """ listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_checked_in(self.address, sock_info.id) if self.pid != os.getpid(): self.reset() else: if self.closed: sock_info.close_socket(ConnectionClosedReason.POOL_CLOSED) elif not sock_info.closed: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). if sock_info.generation != self.generation: sock_info.close_socket(ConnectionClosedReason.STALE) else: sock_info.update_last_checkin_time() sock_info.update_is_writable(self.is_writable) self.sockets.appendleft(sock_info) self._socket_semaphore.release() with self.lock: self.active_sockets -= 1 def _perished(self, sock_info): """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for for longer than the max idle time, or if the socket has been closed by some external network error, or if the socket's generation is outdated. Checking sockets lets us avoid seeing *some* :class:`~pymongo.errors.AutoReconnect` exceptions on server hiccups, etc. We only check if the socket was closed by an external error if it has been > 1 second since the socket was checked into the pool, to keep performance reasonable - we can't avoid AutoReconnects completely anyway. """ idle_time_seconds = sock_info.idle_time_seconds() # If socket is idle, open a new one. if (self.opts.max_idle_time_seconds is not None and idle_time_seconds > self.opts.max_idle_time_seconds): sock_info.close_socket(ConnectionClosedReason.IDLE) return True if (self._check_interval_seconds is not None and ( 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds)): if sock_info.socket_closed(): sock_info.close_socket(ConnectionClosedReason.ERROR) return True if sock_info.generation != self.generation: sock_info.close_socket(ConnectionClosedReason.STALE) return True return False def _raise_wait_queue_timeout(self): listeners = self.opts.event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_failed( self.address, ConnectionCheckOutFailedReason.TIMEOUT) raise ConnectionFailure( 'Timed out while checking out a connection from connection pool ' 'with max_size %r and wait_queue_timeout %r' % ( self.opts.max_pool_size, self.opts.wait_queue_timeout)) def __del__(self): # Avoid ResourceWarnings in Python 3 # Close all sockets without calling reset() or close() because it is # not safe to acquire a lock in __del__. for sock_info in self.sockets: sock_info.close_socket(None) pymongo-3.11.0/pymongo/pyopenssl_context.py000066400000000000000000000306541374256237000211470ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. """ import socket as _socket import ssl as _stdlibssl from errno import EINTR as _EINTR # service_identity requires this for py27, so it should always be available from ipaddress import ip_address as _ip_address from OpenSSL import SSL as _SSL from service_identity.pyopenssl import ( verify_hostname as _verify_hostname, verify_ip_address as _verify_ip_address) from service_identity import ( CertificateError as _SICertificateError, VerificationError as _SIVerificationError) from cryptography.hazmat.backends import default_backend as _default_backend from bson.py3compat import _unicode from pymongo.errors import CertificateError as _CertificateError from pymongo.monotonic import time as _time from pymongo.ocsp_support import ( _load_trusted_ca_certs, _ocsp_callback) from pymongo.ocsp_cache import _OCSPCache from pymongo.socket_checker import ( _errno_from_exception, SocketChecker as _SocketChecker) PROTOCOL_SSLv23 = _SSL.SSLv23_METHOD # Always available OP_NO_SSLv2 = _SSL.OP_NO_SSLv2 OP_NO_SSLv3 = _SSL.OP_NO_SSLv3 OP_NO_COMPRESSION = _SSL.OP_NO_COMPRESSION # This isn't currently documented for PyOpenSSL OP_NO_RENEGOTIATION = getattr(_SSL, "OP_NO_RENEGOTIATION", 0) # Always available HAS_SNI = True CHECK_HOSTNAME_SAFE = True IS_PYOPENSSL = True # Base Exception class SSLError = _SSL.Error # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L2995-L3002 _VERIFY_MAP = { _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE, _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER, _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT } _REVERSE_VERIFY_MAP = dict( (value, key) for key, value in _VERIFY_MAP.items()) def _is_ip_address(address): try: _ip_address(_unicode(address)) return True except (ValueError, UnicodeError): return False # According to the docs for Connection.send it can raise # WantX509LookupError and should be retried. _RETRY_ERRORS = ( _SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) def _ragged_eof(exc): """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" return exc.args == (-1, 'Unexpected EOF') # https://github.com/pyca/pyopenssl/issues/168 # https://github.com/pyca/pyopenssl/issues/176 # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__(self, ctx, sock, suppress_ragged_eofs): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super(_sslConn, self).__init__(ctx, sock) def _call(self, call, *args, **kwargs): timeout = self.gettimeout() if timeout: start = _time() while True: try: return call(*args, **kwargs) except _RETRY_ERRORS: self.socket_checker.select( self, True, True, timeout) if timeout and _time() - start > timeout: raise _socket.timeout("timed out") continue def do_handshake(self, *args, **kwargs): return self._call(super(_sslConn, self).do_handshake, *args, **kwargs) def recv(self, *args, **kwargs): try: return self._call(super(_sslConn, self).recv, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): return b"" raise def recv_into(self, *args, **kwargs): try: return self._call(super(_sslConn, self).recv_into, *args, **kwargs) except _SSL.SysCallError as exc: # Suppress ragged EOFs to match the stdlib. if self.suppress_ragged_eofs and _ragged_eof(exc): return 0 raise def sendall(self, buf, flags=0): view = memoryview(buf) total_length = len(buf) total_sent = 0 sent = 0 while total_sent < total_length: try: sent = self._call( super(_sslConn, self).send, view[total_sent:], flags) # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. except (IOError, OSError) as exc: if _errno_from_exception(exc) == _EINTR: continue raise # https://github.com/pyca/pyopenssl/blob/19.1.0/src/OpenSSL/SSL.py#L1756 # https://www.openssl.org/docs/man1.0.2/man3/SSL_write.html if sent <= 0: raise Exception("Connection closed") total_sent += sent class _CallbackData(object): """Data class which is passed to the OCSP callback.""" def __init__(self): self.trusted_ca_certs = None self.check_ocsp_endpoint = None self.ocsp_response_cache = _OCSPCache() class SSLContext(object): """A CPython compatible SSLContext implementation wrapping PyOpenSSL's context. """ __slots__ = ('_protocol', '_ctx', '_callback_data', '_check_hostname') def __init__(self, protocol): self._protocol = protocol self._ctx = _SSL.Context(self._protocol) self._callback_data = _CallbackData() self._check_hostname = True # OCSP # XXX: Find a better place to do this someday, since this is client # side configuration and wrap_socket tries to support both client and # server side sockets. self._callback_data.check_ocsp_endpoint = True self._ctx.set_ocsp_client_callback( callback=_ocsp_callback, data=self._callback_data) @property def protocol(self): """The protocol version chosen when constructing the context. This attribute is read-only. """ return self._protocol def __get_verify_mode(self): """Whether to try to verify other peers' certificates and how to behave if verification fails. This attribute must be one of ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. """ return _REVERSE_VERIFY_MAP[self._ctx.get_verify_mode()] def __set_verify_mode(self, value): """Setter for verify_mode.""" def _cb(connobj, x509obj, errnum, errdepth, retcode): # It seems we don't need to do anything here. Twisted doesn't, # and OpenSSL's SSL_CTX_set_verify let's you pass NULL # for the callback option. It's weird that PyOpenSSL requires # this. return retcode self._ctx.set_verify(_VERIFY_MAP[value], _cb) verify_mode = property(__get_verify_mode, __set_verify_mode) def __get_check_hostname(self): return self._check_hostname def __set_check_hostname(self, value): if not isinstance(value, bool): raise TypeError("check_hostname must be True or False") self._check_hostname = value check_hostname = property(__get_check_hostname, __set_check_hostname) def __get_check_ocsp_endpoint(self): return self._callback_data.check_ocsp_endpoint def __set_check_ocsp_endpoint(self, value): if not isinstance(value, bool): raise TypeError("check_ocsp must be True or False") self._callback_data.check_ocsp_endpoint = value check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) def __get_options(self): # Calling set_options adds the option to the existing bitmask and # returns the new bitmask. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_options return self._ctx.set_options(0) def __set_options(self, value): # Explcitly convert to int, since newer CPython versions # use enum.IntFlag for options. The values are the same # regardless of implementation. self._ctx.set_options(int(value)) options = property(__get_options, __set_options) def load_cert_chain(self, certfile, keyfile=None, password=None): """Load a private key and the corresponding certificate. The certfile string must be the path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. The keyfile string, if present, must point to a file containing the private key. Otherwise the private key will be taken from certfile as well. """ # Match CPython behavior # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971 # Password callback MUST be set first or it will be ignored. if password: def _pwcb(max_length, prompt_twice, user_data): # XXX:We could check the password length against what OpenSSL # tells us is the max, but we can't raise an exception, so... # warn? return password.encode('utf-8') self._ctx.set_passwd_cb(_pwcb) self._ctx.use_certificate_chain_file(certfile) self._ctx.use_privatekey_file(keyfile or certfile) self._ctx.check_privatekey() def load_verify_locations(self, cafile=None, capath=None): """Load a set of "certification authority"(CA) certificates used to validate other peers' certificates when `~verify_mode` is other than ssl.CERT_NONE. """ self._ctx.load_verify_locations(cafile, capath) self._callback_data.trusted_ca_certs = _load_trusted_ca_certs(cafile) def set_default_verify_paths(self): """Specify that the platform provided CA certificates are to be used for verification purposes.""" # Note: See PyOpenSSL's docs for limitations, which are similar # but not that same as CPython's. self._ctx.set_default_verify_paths() def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None, session=None): """Wrap an existing Python socket sock and return a TLS socket object. """ ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) if session: ssl_conn.set_session(session) if server_side is True: ssl_conn.set_accept_state() else: # SNI if server_hostname and not _is_ip_address(server_hostname): # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. ssl_conn.set_tlsext_host_name(server_hostname.encode('idna')) if self.verify_mode != _stdlibssl.CERT_NONE: # Request a stapled OCSP response. ssl_conn.request_ocsp() ssl_conn.set_connect_state() # If this wasn't true the caller of wrap_socket would call # do_handshake() if do_handshake_on_connect: # XXX: If we do hostname checking in a callback we can get rid # of this call to do_handshake() since the handshake # will happen automatically later. ssl_conn.do_handshake() # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. if self.check_hostname and server_hostname is not None: try: if _is_ip_address(server_hostname): _verify_ip_address(ssl_conn, _unicode(server_hostname)) else: _verify_hostname(ssl_conn, _unicode(server_hostname)) except (_SICertificateError, _SIVerificationError) as exc: raise _CertificateError(str(exc)) return ssl_conn pymongo-3.11.0/pymongo/read_concern.py000066400000000000000000000044411374256237000177640ustar00rootroot00000000000000# Copyright 2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License", # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for working with read concerns.""" from bson.py3compat import string_type class ReadConcern(object): """ReadConcern :Parameters: - `level`: (string) The read concern level specifies the level of isolation for read operations. For example, a read operation using a read concern level of ``majority`` will only return data that has been written to a majority of nodes. If the level is left unspecified, the server default will be used. .. versionadded:: 3.2 """ def __init__(self, level=None): if level is None or isinstance(level, string_type): self.__level = level else: raise TypeError( 'level must be a string or None.') @property def level(self): """The read concern level.""" return self.__level @property def ok_for_legacy(self): """Return ``True`` if this read concern is compatible with old wire protocol versions.""" return self.level is None or self.level == 'local' @property def document(self): """The document representation of this read concern. .. note:: :class:`ReadConcern` is immutable. Mutating the value of :attr:`document` does not mutate this :class:`ReadConcern`. """ doc = {} if self.__level: doc['level'] = self.level return doc def __eq__(self, other): if isinstance(other, ReadConcern): return self.document == other.document return NotImplemented def __repr__(self): if self.level: return 'ReadConcern(%s)' % self.level return 'ReadConcern()' DEFAULT_READ_CONCERN = ReadConcern() pymongo-3.11.0/pymongo/read_preferences.py000066400000000000000000000435261374256237000206450ustar00rootroot00000000000000# Copyright 2012-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License", # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for choosing which member of a replica set to read from.""" from bson.py3compat import abc, integer_types from pymongo import max_staleness_selectors from pymongo.errors import ConfigurationError from pymongo.server_selectors import (member_with_tags_server_selector, secondary_with_tags_server_selector) _PRIMARY = 0 _PRIMARY_PREFERRED = 1 _SECONDARY = 2 _SECONDARY_PREFERRED = 3 _NEAREST = 4 _MONGOS_MODES = ( 'primary', 'primaryPreferred', 'secondary', 'secondaryPreferred', 'nearest', ) def _validate_tag_sets(tag_sets): """Validate tag sets for a MongoReplicaSetClient. """ if tag_sets is None: return tag_sets if not isinstance(tag_sets, list): raise TypeError(( "Tag sets %r invalid, must be a list") % (tag_sets,)) if len(tag_sets) == 0: raise ValueError(( "Tag sets %r invalid, must be None or contain at least one set of" " tags") % (tag_sets,)) for tags in tag_sets: if not isinstance(tags, abc.Mapping): raise TypeError( "Tag set %r invalid, must be an instance of dict, " "bson.son.SON or other type that inherits from " "collection.Mapping" % (tags,)) return tag_sets def _invalid_max_staleness_msg(max_staleness): return ("maxStalenessSeconds must be a positive integer, not %s" % max_staleness) # Some duplication with common.py to avoid import cycle. def _validate_max_staleness(max_staleness): """Validate max_staleness.""" if max_staleness == -1: return -1 if not isinstance(max_staleness, integer_types): raise TypeError(_invalid_max_staleness_msg(max_staleness)) if max_staleness <= 0: raise ValueError(_invalid_max_staleness_msg(max_staleness)) return max_staleness def _validate_hedge(hedge): """Validate hedge.""" if hedge is None: return None if not isinstance(hedge, dict): raise TypeError("hedge must be a dictionary, not %r" % (hedge,)) return hedge class _ServerMode(object): """Base class for all read preferences. """ __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None): self.__mongos_mode = _MONGOS_MODES[mode] self.__mode = mode self.__tag_sets = _validate_tag_sets(tag_sets) self.__max_staleness = _validate_max_staleness(max_staleness) self.__hedge = _validate_hedge(hedge) @property def name(self): """The name of this read preference. """ return self.__class__.__name__ @property def mongos_mode(self): """The mongos mode of this read preference. """ return self.__mongos_mode @property def document(self): """Read preference as a document. """ doc = {'mode': self.__mongos_mode} if self.__tag_sets not in (None, [{}]): doc['tags'] = self.__tag_sets if self.__max_staleness != -1: doc['maxStalenessSeconds'] = self.__max_staleness if self.__hedge not in (None, {}): doc['hedge'] = self.__hedge return doc @property def mode(self): """The mode of this read preference instance. """ return self.__mode @property def tag_sets(self): """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to read only from members whose ``dc`` tag has the value ``"ny"``. To specify a priority-order for tag sets, provide a list of tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag set, ``{}``, means "read from any member that matches the mode, ignoring tags." MongoReplicaSetClient tries each set of tags in turn until it finds a set of tags with at least one matching member. .. seealso:: `Data-Center Awareness `_ """ return list(self.__tag_sets) if self.__tag_sets else [{}] @property def max_staleness(self): """The maximum estimated length of time (in seconds) a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations, or -1 for no maximum.""" return self.__max_staleness @property def hedge(self): """The read preference ``hedge`` parameter. A dictionary that configures how the server will perform hedged reads. It consists of the following keys: - ``enabled``: Enables or disables hedged reads in sharded clusters. Hedged reads are automatically enabled in MongoDB 4.4+ when using a ``nearest`` read preference. To explicitly enable hedged reads, set the ``enabled`` key to ``true``:: >>> Nearest(hedge={'enabled': True}) To explicitly disable hedged reads, set the ``enabled`` key to ``False``:: >>> Nearest(hedge={'enabled': False}) .. versionadded:: 3.11 """ return self.__hedge @property def min_wire_version(self): """The wire protocol version the server must support. Some read preferences impose version requirements on all servers (e.g. maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). All servers' maxWireVersion must be at least this read preference's `min_wire_version`, or the driver raises :exc:`~pymongo.errors.ConfigurationError`. """ return 0 if self.__max_staleness == -1 else 5 def __repr__(self): return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % ( self.name, self.__tag_sets, self.__max_staleness, self.__hedge) def __eq__(self, other): if isinstance(other, _ServerMode): return (self.mode == other.mode and self.tag_sets == other.tag_sets and self.max_staleness == other.max_staleness and self.hedge == other.hedge) return NotImplemented def __ne__(self, other): return not self == other def __getstate__(self): """Return value of object for pickling. Needed explicitly because __slots__() defined. """ return {'mode': self.__mode, 'tag_sets': self.__tag_sets, 'max_staleness': self.__max_staleness, 'hedge': self.__hedge} def __setstate__(self, value): """Restore from pickling.""" self.__mode = value['mode'] self.__mongos_mode = _MONGOS_MODES[self.__mode] self.__tag_sets = _validate_tag_sets(value['tag_sets']) self.__max_staleness = _validate_max_staleness(value['max_staleness']) self.__hedge = _validate_hedge(value['hedge']) class Primary(_ServerMode): """Primary read preference. * When directly connected to one mongod queries are allowed if the server is standalone or a replica set primary. * When connected to a mongos queries are sent to the primary of a shard. * When connected to a replica set queries are sent to the primary of the replica set. """ __slots__ = () def __init__(self): super(Primary, self).__init__(_PRIMARY) def __call__(self, selection): """Apply this read preference to a Selection.""" return selection.primary_selection def __repr__(self): return "Primary()" def __eq__(self, other): if isinstance(other, _ServerMode): return other.mode == _PRIMARY return NotImplemented class PrimaryPreferred(_ServerMode): """PrimaryPreferred read preference. * When directly connected to one mongod queries are allowed to standalone servers, to a replica set primary, or to replica set secondaries. * When connected to a mongos queries are sent to the primary of a shard if available, otherwise a shard secondary. * When connected to a replica set queries are sent to the primary if available, otherwise a secondary. :Parameters: - `tag_sets`: The :attr:`~tag_sets` to use if the primary is not available. - `max_staleness`: (integer, in seconds) The maximum estimated length of time a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - `hedge`: The :attr:`~hedge` to use if the primary is not available. .. versionchanged:: 3.11 Added ``hedge`` parameter. """ __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): super(PrimaryPreferred, self).__init__( _PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" if selection.primary: return selection.primary_selection else: return secondary_with_tags_server_selector( self.tag_sets, max_staleness_selectors.select( self.max_staleness, selection)) class Secondary(_ServerMode): """Secondary read preference. * When directly connected to one mongod queries are allowed to standalone servers, to a replica set primary, or to replica set secondaries. * When connected to a mongos queries are distributed among shard secondaries. An error is raised if no secondaries are available. * When connected to a replica set queries are distributed among secondaries. An error is raised if no secondaries are available. :Parameters: - `tag_sets`: The :attr:`~tag_sets` for this read preference. - `max_staleness`: (integer, in seconds) The maximum estimated length of time a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - `hedge`: The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. """ __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): super(Secondary, self).__init__( _SECONDARY, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" return secondary_with_tags_server_selector( self.tag_sets, max_staleness_selectors.select( self.max_staleness, selection)) class SecondaryPreferred(_ServerMode): """SecondaryPreferred read preference. * When directly connected to one mongod queries are allowed to standalone servers, to a replica set primary, or to replica set secondaries. * When connected to a mongos queries are distributed among shard secondaries, or the shard primary if no secondary is available. * When connected to a replica set queries are distributed among secondaries, or the primary if no secondary is available. :Parameters: - `tag_sets`: The :attr:`~tag_sets` for this read preference. - `max_staleness`: (integer, in seconds) The maximum estimated length of time a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - `hedge`: The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. """ __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): super(SecondaryPreferred, self).__init__( _SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" secondaries = secondary_with_tags_server_selector( self.tag_sets, max_staleness_selectors.select( self.max_staleness, selection)) if secondaries: return secondaries else: return selection.primary_selection class Nearest(_ServerMode): """Nearest read preference. * When directly connected to one mongod queries are allowed to standalone servers, to a replica set primary, or to replica set secondaries. * When connected to a mongos queries are distributed among all members of a shard. * When connected to a replica set queries are distributed among all members. :Parameters: - `tag_sets`: The :attr:`~tag_sets` for this read preference. - `max_staleness`: (integer, in seconds) The maximum estimated length of time a replica set secondary can fall behind the primary in replication before it will no longer be selected for operations. Default -1, meaning no maximum. If it is set, it must be at least 90 seconds. - `hedge`: The :attr:`~hedge` for this read preference. .. versionchanged:: 3.11 Added ``hedge`` parameter. """ __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): super(Nearest, self).__init__( _NEAREST, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" return member_with_tags_server_selector( self.tag_sets, max_staleness_selectors.select( self.max_staleness, selection)) _ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) def make_read_preference(mode, tag_sets, max_staleness=-1): if mode == _PRIMARY: if tag_sets not in (None, [{}]): raise ConfigurationError("Read preference primary " "cannot be combined with tags") if max_staleness != -1: raise ConfigurationError("Read preference primary cannot be " "combined with maxStalenessSeconds") return Primary() return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) _MODES = ( 'PRIMARY', 'PRIMARY_PREFERRED', 'SECONDARY', 'SECONDARY_PREFERRED', 'NEAREST', ) class ReadPreference(object): """An enum that defines the read preference modes supported by PyMongo. See :doc:`/examples/high_availability` for code examples. A read preference is used in three cases: :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: - ``PRIMARY``: Queries are allowed if the server is standalone or a replica set primary. - All other modes allow queries to standalone servers, to a replica set primary, or to replica set secondaries. :class:`~pymongo.mongo_client.MongoClient` initialized with the ``replicaSet`` option: - ``PRIMARY``: Read from the primary. This is the default, and provides the strongest consistency. If no primary is available, raise :class:`~pymongo.errors.AutoReconnect`. - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is none, read from a secondary. - ``SECONDARY``: Read from a secondary. If no secondary is available, raise :class:`~pymongo.errors.AutoReconnect`. - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise from the primary. - ``NEAREST``: Read from any member. :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a sharded cluster of replica sets: - ``PRIMARY``: Read from the primary of the shard, or raise :class:`~pymongo.errors.OperationFailure` if there is none. This is the default. - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is none, read from a secondary of the shard. - ``SECONDARY``: Read from a secondary of the shard, or raise :class:`~pymongo.errors.OperationFailure` if there is none. - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, otherwise from the shard primary. - ``NEAREST``: Read from any shard member. """ PRIMARY = Primary() PRIMARY_PREFERRED = PrimaryPreferred() SECONDARY = Secondary() SECONDARY_PREFERRED = SecondaryPreferred() NEAREST = Nearest() def read_pref_mode_from_name(name): """Get the read preference mode from mongos/uri name. """ return _MONGOS_MODES.index(name) class MovingAverage(object): """Tracks an exponentially-weighted moving average.""" def __init__(self): self.average = None def add_sample(self, sample): if sample < 0: # Likely system time change while waiting for ismaster response # and not using time.monotonic. Ignore it, the next one will # probably be valid. return if self.average is None: self.average = sample else: # The Server Selection Spec requires an exponentially weighted # average with alpha = 0.2. self.average = 0.8 * self.average + 0.2 * sample def get(self): """Get the calculated average, or None if no samples yet.""" return self.average def reset(self): self.average = None pymongo-3.11.0/pymongo/response.py000066400000000000000000000070731374256237000172040ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Represent a response from the server.""" class Response(object): __slots__ = ('_data', '_address', '_request_id', '_duration', '_from_command', '_docs') def __init__(self, data, address, request_id, duration, from_command, docs): """Represent a response from the server. :Parameters: - `data`: A network response message. - `address`: (host, port) of the source server. - `request_id`: The request id of this operation. - `duration`: The duration of the operation. - `from_command`: if the response is the result of a db command. """ self._data = data self._address = address self._request_id = request_id self._duration = duration self._from_command = from_command self._docs = docs @property def data(self): """Server response's raw BSON bytes.""" return self._data @property def address(self): """(host, port) of the source server.""" return self._address @property def request_id(self): """The request id of this operation.""" return self._request_id @property def duration(self): """The duration of the operation.""" return self._duration @property def from_command(self): """If the response is a result from a db command.""" return self._from_command @property def docs(self): """The decoded document(s).""" return self._docs class ExhaustResponse(Response): __slots__ = ('_socket_info', '_pool') def __init__(self, data, address, socket_info, pool, request_id, duration, from_command, docs): """Represent a response to an exhaust cursor's initial query. :Parameters: - `data`: A network response message. - `address`: (host, port) of the source server. - `socket_info`: The SocketInfo used for the initial query. - `pool`: The Pool from which the SocketInfo came. - `request_id`: The request id of this operation. - `duration`: The duration of the operation. - `from_command`: If the response is the result of a db command. """ super(ExhaustResponse, self).__init__(data, address, request_id, duration, from_command, docs) self._socket_info = socket_info self._pool = pool @property def socket_info(self): """The SocketInfo used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there is an error. """ return self._socket_info @property def pool(self): """The Pool from which the SocketInfo came.""" return self._pool pymongo-3.11.0/pymongo/results.py000066400000000000000000000172121374256237000170430ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Result class definitions.""" from pymongo.errors import InvalidOperation class _WriteResult(object): """Base class for write result classes.""" __slots__ = ("__acknowledged",) def __init__(self, acknowledged): self.__acknowledged = acknowledged def _raise_if_unacknowledged(self, property_name): """Raise an exception on property access if unacknowledged.""" if not self.__acknowledged: raise InvalidOperation("A value for %s is not available when " "the write is unacknowledged. Check the " "acknowledged attribute to avoid this " "error." % (property_name,)) @property def acknowledged(self): """Is this the result of an acknowledged write operation? The :attr:`acknowledged` attribute will be ``False`` when using ``WriteConcern(w=0)``, otherwise ``True``. .. note:: If the :attr:`acknowledged` attribute is ``False`` all other attibutes of this class will raise :class:`~pymongo.errors.InvalidOperation` when accessed. Values for other attributes cannot be determined if the write operation was unacknowledged. .. seealso:: :class:`~pymongo.write_concern.WriteConcern` """ return self.__acknowledged class InsertOneResult(_WriteResult): """The return type for :meth:`~pymongo.collection.Collection.insert_one`. """ __slots__ = ("__inserted_id", "__acknowledged") def __init__(self, inserted_id, acknowledged): self.__inserted_id = inserted_id super(InsertOneResult, self).__init__(acknowledged) @property def inserted_id(self): """The inserted document's _id.""" return self.__inserted_id class InsertManyResult(_WriteResult): """The return type for :meth:`~pymongo.collection.Collection.insert_many`. """ __slots__ = ("__inserted_ids", "__acknowledged") def __init__(self, inserted_ids, acknowledged): self.__inserted_ids = inserted_ids super(InsertManyResult, self).__init__(acknowledged) @property def inserted_ids(self): """A list of _ids of the inserted documents, in the order provided. .. note:: If ``False`` is passed for the `ordered` parameter to :meth:`~pymongo.collection.Collection.insert_many` the server may have inserted the documents in a different order than what is presented here. """ return self.__inserted_ids class UpdateResult(_WriteResult): """The return type for :meth:`~pymongo.collection.Collection.update_one`, :meth:`~pymongo.collection.Collection.update_many`, and :meth:`~pymongo.collection.Collection.replace_one`. """ __slots__ = ("__raw_result", "__acknowledged") def __init__(self, raw_result, acknowledged): self.__raw_result = raw_result super(UpdateResult, self).__init__(acknowledged) @property def raw_result(self): """The raw result document returned by the server.""" return self.__raw_result @property def matched_count(self): """The number of documents matched for this update.""" self._raise_if_unacknowledged("matched_count") if self.upserted_id is not None: return 0 return self.__raw_result.get("n", 0) @property def modified_count(self): """The number of documents modified. .. note:: modified_count is only reported by MongoDB 2.6 and later. When connected to an earlier server version, or in certain mixed version sharding configurations, this attribute will be set to ``None``. """ self._raise_if_unacknowledged("modified_count") return self.__raw_result.get("nModified") @property def upserted_id(self): """The _id of the inserted document if an upsert took place. Otherwise ``None``. """ self._raise_if_unacknowledged("upserted_id") return self.__raw_result.get("upserted") class DeleteResult(_WriteResult): """The return type for :meth:`~pymongo.collection.Collection.delete_one` and :meth:`~pymongo.collection.Collection.delete_many`""" __slots__ = ("__raw_result", "__acknowledged") def __init__(self, raw_result, acknowledged): self.__raw_result = raw_result super(DeleteResult, self).__init__(acknowledged) @property def raw_result(self): """The raw result document returned by the server.""" return self.__raw_result @property def deleted_count(self): """The number of documents deleted.""" self._raise_if_unacknowledged("deleted_count") return self.__raw_result.get("n", 0) class BulkWriteResult(_WriteResult): """An object wrapper for bulk API write results.""" __slots__ = ("__bulk_api_result", "__acknowledged") def __init__(self, bulk_api_result, acknowledged): """Create a BulkWriteResult instance. :Parameters: - `bulk_api_result`: A result dict from the bulk API - `acknowledged`: Was this write result acknowledged? If ``False`` then all properties of this object will raise :exc:`~pymongo.errors.InvalidOperation`. """ self.__bulk_api_result = bulk_api_result super(BulkWriteResult, self).__init__(acknowledged) @property def bulk_api_result(self): """The raw bulk API result.""" return self.__bulk_api_result @property def inserted_count(self): """The number of documents inserted.""" self._raise_if_unacknowledged("inserted_count") return self.__bulk_api_result.get("nInserted") @property def matched_count(self): """The number of documents matched for an update.""" self._raise_if_unacknowledged("matched_count") return self.__bulk_api_result.get("nMatched") @property def modified_count(self): """The number of documents modified. .. note:: modified_count is only reported by MongoDB 2.6 and later. When connected to an earlier server version, or in certain mixed version sharding configurations, this attribute will be set to ``None``. """ self._raise_if_unacknowledged("modified_count") return self.__bulk_api_result.get("nModified") @property def deleted_count(self): """The number of documents deleted.""" self._raise_if_unacknowledged("deleted_count") return self.__bulk_api_result.get("nRemoved") @property def upserted_count(self): """The number of documents upserted.""" self._raise_if_unacknowledged("upserted_count") return self.__bulk_api_result.get("nUpserted") @property def upserted_ids(self): """A map of operation index to the _id of the upserted document.""" self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: return dict((upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"]) pymongo-3.11.0/pymongo/saslprep.py000066400000000000000000000102631374256237000171720ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """An implementation of RFC4013 SASLprep.""" from bson.py3compat import text_type as _text_type try: import stringprep except ImportError: HAVE_STRINGPREP = False def saslprep(data): """SASLprep dummy""" if isinstance(data, _text_type): raise TypeError( "The stringprep module is not available. Usernames and " "passwords must be ASCII strings.") return data else: HAVE_STRINGPREP = True import unicodedata # RFC4013 section 2.3 prohibited output. _PROHIBITED = ( # A strict reading of RFC 4013 requires table c12 here, but # characters from it are mapped to SPACE in the Map step. Can # normalization reintroduce them somehow? stringprep.in_table_c12, stringprep.in_table_c21_c22, stringprep.in_table_c3, stringprep.in_table_c4, stringprep.in_table_c5, stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, stringprep.in_table_c9) def saslprep(data, prohibit_unassigned_code_points=True): """An implementation of RFC4013 SASLprep. :Parameters: - `data`: The string to SASLprep. Unicode strings (python 2.x unicode, 3.x str) are supported. Byte strings (python 2.x str, 3.x bytes) are ignored. - `prohibit_unassigned_code_points`: True / False. RFC 3454 and RFCs for various SASL mechanisms distinguish between `queries` (unassigned code points allowed) and `stored strings` (unassigned code points prohibited). Defaults to ``True`` (unassigned code points are prohibited). :Returns: The SASLprep'ed version of `data`. """ if not isinstance(data, _text_type): return data if prohibit_unassigned_code_points: prohibited = _PROHIBITED + (stringprep.in_table_a1,) else: prohibited = _PROHIBITED # RFC3454 section 2, step 1 - Map # RFC4013 section 2.1 mappings # Map Non-ASCII space characters to SPACE (U+0020). Map # commonly mapped to nothing characters to, well, nothing. in_table_c12 = stringprep.in_table_c12 in_table_b1 = stringprep.in_table_b1 data = u"".join( [u"\u0020" if in_table_c12(elt) else elt for elt in data if not in_table_b1(elt)]) # RFC3454 section 2, step 2 - Normalize # RFC4013 section 2.2 normalization data = unicodedata.ucd_3_2_0.normalize('NFKC', data) in_table_d1 = stringprep.in_table_d1 if in_table_d1(data[0]): if not in_table_d1(data[-1]): # RFC3454, Section 6, #3. If a string contains any # RandALCat character, the first and last characters # MUST be RandALCat characters. raise ValueError("SASLprep: failed bidirectional check") # RFC3454, Section 6, #2. If a string contains any RandALCat # character, it MUST NOT contain any LCat character. prohibited = prohibited + (stringprep.in_table_d2,) else: # RFC3454, Section 6, #3. Following the logic of #3, if # the first character is not a RandALCat, no other character # can be either. prohibited = prohibited + (in_table_d1,) # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi for char in data: if any(in_table(char) for in_table in prohibited): raise ValueError( "SASLprep: failed prohibited character check") return data pymongo-3.11.0/pymongo/server.py000066400000000000000000000176231374256237000166560ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Communicate with one MongoDB server in a topology.""" from datetime import datetime from bson import _decode_all_selective from pymongo.errors import NotMasterError, OperationFailure from pymongo.helpers import _check_command_response from pymongo.message import _convert_exception from pymongo.response import Response, ExhaustResponse from pymongo.server_type import SERVER_TYPE _CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}} class Server(object): def __init__(self, server_description, pool, monitor, topology_id=None, listeners=None, events=None): """Represent one MongoDB server.""" self._description = server_description self._pool = pool self._monitor = monitor self._topology_id = topology_id self._publish = listeners is not None and listeners.enabled_for_server self._listener = listeners self._events = None if self._publish: self._events = events() def open(self): """Start monitoring, or restart after a fork. Multiple calls have no effect. """ self._monitor.open() def reset(self): """Clear the connection pool.""" self.pool.reset() def close(self): """Clear the connection pool and stop the monitor. Reconnect with open(). """ if self._publish: self._events.put((self._listener.publish_server_closed, (self._description.address, self._topology_id))) self._monitor.close() self._pool.reset() def request_check(self): """Check the server's state soon.""" self._monitor.request_check() def run_operation_with_response( self, sock_info, operation, set_slave_okay, listeners, exhaust, unpack_res): """Run a _Query or _GetMore operation and return a Response object. This method is used only to run _Query/_GetMore operations from cursors. Can raise ConnectionFailure, OperationFailure, etc. :Parameters: - `operation`: A _Query or _GetMore object. - `set_slave_okay`: Pass to operation.get_message. - `all_credentials`: dict, maps auth source to MongoCredential. - `listeners`: Instance of _EventListeners or None. - `exhaust`: If True, then this is an exhaust cursor operation. - `unpack_res`: A callable that decodes the wire protocol response. """ duration = None publish = listeners.enabled_for_commands if publish: start = datetime.now() send_message = not operation.exhaust_mgr if send_message: use_cmd = operation.use_command(sock_info, exhaust) message = operation.get_message( set_slave_okay, sock_info, use_cmd) request_id, data, max_doc_size = self._split_message(message) else: use_cmd = False request_id = 0 if publish: cmd, dbn = operation.as_command(sock_info) listeners.publish_command_start( cmd, dbn, request_id, sock_info.address) start = datetime.now() try: if send_message: sock_info.send_message(data, max_doc_size) reply = sock_info.receive_message(request_id) else: reply = sock_info.receive_message(None) # Unpack and check for command errors. if use_cmd: user_fields = _CURSOR_DOC_FIELDS legacy_response = False else: user_fields = None legacy_response = True docs = unpack_res(reply, operation.cursor_id, operation.codec_options, legacy_response=legacy_response, user_fields=user_fields) if use_cmd: first = docs[0] operation.client._process_response( first, operation.session) _check_command_response( first, sock_info.max_wire_version) except Exception as exc: if publish: duration = datetime.now() - start if isinstance(exc, (NotMasterError, OperationFailure)): failure = exc.details else: failure = _convert_exception(exc) listeners.publish_command_failure( duration, failure, operation.name, request_id, sock_info.address) raise if publish: duration = datetime.now() - start # Must publish in find / getMore / explain command response # format. if use_cmd: res = docs[0] elif operation.name == "explain": res = docs[0] if docs else {} else: res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} if operation.name == "find": res["cursor"]["firstBatch"] = docs else: res["cursor"]["nextBatch"] = docs listeners.publish_command_success( duration, res, operation.name, request_id, sock_info.address) # Decrypt response. client = operation.client if client and client._encrypter: if use_cmd: decrypted = client._encrypter.decrypt( reply.raw_command_response()) docs = _decode_all_selective( decrypted, operation.codec_options, user_fields) if exhaust: response = ExhaustResponse( data=reply, address=self._description.address, socket_info=sock_info, pool=self._pool, duration=duration, request_id=request_id, from_command=use_cmd, docs=docs) else: response = Response( data=reply, address=self._description.address, duration=duration, request_id=request_id, from_command=use_cmd, docs=docs) return response def get_socket(self, all_credentials, checkout=False): return self.pool.get_socket(all_credentials, checkout) @property def description(self): return self._description @description.setter def description(self, server_description): assert server_description.address == self._description.address self._description = server_description @property def pool(self): return self._pool def _split_message(self, message): """Return request_id, data, max_doc_size. :Parameters: - `message`: (request_id, data, max_doc_size) or (request_id, data) """ if len(message) == 3: return message else: # get_more and kill_cursors messages don't include BSON documents. request_id, data = message return request_id, data, 0 def __repr__(self): return '<%s %r>' % (self.__class__.__name__, self._description) pymongo-3.11.0/pymongo/server_description.py000066400000000000000000000177701374256237000212640ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Represent one server the driver is connected to.""" from bson import EPOCH_NAIVE from pymongo.server_type import SERVER_TYPE from pymongo.ismaster import IsMaster from pymongo.monotonic import time as _time class ServerDescription(object): """Immutable representation of one server. :Parameters: - `address`: A (host, port) pair - `ismaster`: Optional IsMaster instance - `round_trip_time`: Optional float - `error`: Optional, the last error attempting to connect to the server """ __slots__ = ( '_address', '_server_type', '_all_hosts', '_tags', '_replica_set_name', '_primary', '_max_bson_size', '_max_message_size', '_max_write_batch_size', '_min_wire_version', '_max_wire_version', '_round_trip_time', '_me', '_is_writable', '_is_readable', '_ls_timeout_minutes', '_error', '_set_version', '_election_id', '_cluster_time', '_last_write_date', '_last_update_time', '_topology_version') def __init__( self, address, ismaster=None, round_trip_time=None, error=None): self._address = address if not ismaster: ismaster = IsMaster({}) self._server_type = ismaster.server_type self._all_hosts = ismaster.all_hosts self._tags = ismaster.tags self._replica_set_name = ismaster.replica_set_name self._primary = ismaster.primary self._max_bson_size = ismaster.max_bson_size self._max_message_size = ismaster.max_message_size self._max_write_batch_size = ismaster.max_write_batch_size self._min_wire_version = ismaster.min_wire_version self._max_wire_version = ismaster.max_wire_version self._set_version = ismaster.set_version self._election_id = ismaster.election_id self._cluster_time = ismaster.cluster_time self._is_writable = ismaster.is_writable self._is_readable = ismaster.is_readable self._ls_timeout_minutes = ismaster.logical_session_timeout_minutes self._round_trip_time = round_trip_time self._me = ismaster.me self._last_update_time = _time() self._error = error self._topology_version = ismaster.topology_version if error: if hasattr(error, 'details') and isinstance(error.details, dict): self._topology_version = error.details.get('topologyVersion') if ismaster.last_write_date: # Convert from datetime to seconds. delta = ismaster.last_write_date - EPOCH_NAIVE self._last_write_date = delta.total_seconds() else: self._last_write_date = None @property def address(self): """The address (host, port) of this server.""" return self._address @property def server_type(self): """The type of this server.""" return self._server_type @property def server_type_name(self): """The server type as a human readable string. .. versionadded:: 3.4 """ return SERVER_TYPE._fields[self._server_type] @property def all_hosts(self): """List of hosts, passives, and arbiters known to this server.""" return self._all_hosts @property def tags(self): return self._tags @property def replica_set_name(self): """Replica set name or None.""" return self._replica_set_name @property def primary(self): """This server's opinion about who the primary is, or None.""" return self._primary @property def max_bson_size(self): return self._max_bson_size @property def max_message_size(self): return self._max_message_size @property def max_write_batch_size(self): return self._max_write_batch_size @property def min_wire_version(self): return self._min_wire_version @property def max_wire_version(self): return self._max_wire_version @property def set_version(self): return self._set_version @property def election_id(self): return self._election_id @property def cluster_time(self): return self._cluster_time @property def election_tuple(self): return self._set_version, self._election_id @property def me(self): return self._me @property def logical_session_timeout_minutes(self): return self._ls_timeout_minutes @property def last_write_date(self): return self._last_write_date @property def last_update_time(self): return self._last_update_time @property def round_trip_time(self): """The current average latency or None.""" # This override is for unittesting only! if self._address in self._host_to_round_trip_time: return self._host_to_round_trip_time[self._address] return self._round_trip_time @property def error(self): """The last error attempting to connect to the server, or None.""" return self._error @property def is_writable(self): return self._is_writable @property def is_readable(self): return self._is_readable @property def mongos(self): return self._server_type == SERVER_TYPE.Mongos @property def is_server_type_known(self): return self.server_type != SERVER_TYPE.Unknown @property def retryable_writes_supported(self): """Checks if this server supports retryable writes.""" return ( self._ls_timeout_minutes is not None and self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)) @property def retryable_reads_supported(self): """Checks if this server supports retryable writes.""" return self._max_wire_version >= 6 @property def topology_version(self): return self._topology_version def to_unknown(self, error=None): unknown = ServerDescription(self.address, error=error) unknown._topology_version = self.topology_version return unknown def __eq__(self, other): if isinstance(other, ServerDescription): return ((self._address == other.address) and (self._server_type == other.server_type) and (self._min_wire_version == other.min_wire_version) and (self._max_wire_version == other.max_wire_version) and (self._me == other.me) and (self._all_hosts == other.all_hosts) and (self._tags == other.tags) and (self._replica_set_name == other.replica_set_name) and (self._set_version == other.set_version) and (self._election_id == other.election_id) and (self._primary == other.primary) and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) and (self._error == other.error)) return NotImplemented def __ne__(self, other): return not self == other def __repr__(self): errmsg = '' if self.error: errmsg = ', error=%r' % (self.error,) return "<%s %s server_type: %s, rtt: %s%s>" % ( self.__class__.__name__, self.address, self.server_type_name, self.round_trip_time, errmsg) # For unittesting only. Use under no circumstances! _host_to_round_trip_time = {} pymongo-3.11.0/pymongo/server_selectors.py000066400000000000000000000122731374256237000207350ustar00rootroot00000000000000# Copyright 2014-2016 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Criteria to select some ServerDescriptions from a TopologyDescription.""" from pymongo.server_type import SERVER_TYPE class Selection(object): """Input or output of a server selector function.""" @classmethod def from_topology_description(cls, topology_description): known_servers = topology_description.known_servers primary = None for sd in known_servers: if sd.server_type == SERVER_TYPE.RSPrimary: primary = sd break return Selection(topology_description, topology_description.known_servers, topology_description.common_wire_version, primary) def __init__(self, topology_description, server_descriptions, common_wire_version, primary): self.topology_description = topology_description self.server_descriptions = server_descriptions self.primary = primary self.common_wire_version = common_wire_version def with_server_descriptions(self, server_descriptions): return Selection(self.topology_description, server_descriptions, self.common_wire_version, self.primary) def secondary_with_max_last_write_date(self): secondaries = secondary_server_selector(self) if secondaries.server_descriptions: return max(secondaries.server_descriptions, key=lambda sd: sd.last_write_date) @property def primary_selection(self): primaries = [self.primary] if self.primary else [] return self.with_server_descriptions(primaries) @property def heartbeat_frequency(self): return self.topology_description.heartbeat_frequency @property def topology_type(self): return self.topology_description.topology_type def __bool__(self): return bool(self.server_descriptions) __nonzero__ = __bool__ # Python 2. def __getitem__(self, item): return self.server_descriptions[item] def any_server_selector(selection): return selection def readable_server_selector(selection): return selection.with_server_descriptions( [s for s in selection.server_descriptions if s.is_readable]) def writable_server_selector(selection): return selection.with_server_descriptions( [s for s in selection.server_descriptions if s.is_writable]) def secondary_server_selector(selection): return selection.with_server_descriptions( [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary]) def arbiter_server_selector(selection): return selection.with_server_descriptions( [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter]) def writable_preferred_server_selector(selection): """Like PrimaryPreferred but doesn't use tags or latency.""" return (writable_server_selector(selection) or secondary_server_selector(selection)) def apply_single_tag_set(tag_set, selection): """All servers matching one tag set. A tag set is a dict. A server matches if its tags are a superset: A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}. The empty tag set {} matches any server. """ def tags_match(server_tags): for key, value in tag_set.items(): if key not in server_tags or server_tags[key] != value: return False return True return selection.with_server_descriptions( [s for s in selection.server_descriptions if tags_match(s.tags)]) def apply_tag_sets(tag_sets, selection): """All servers match a list of tag sets. tag_sets is a list of dicts. The empty tag set {} matches any server, and may be provided at the end of the list as a fallback. So [{'a': 'value'}, {}] expresses a preference for servers tagged {'a': 'value'}, but accepts any server if none matches the first preference. """ for tag_set in tag_sets: with_tag_set = apply_single_tag_set(tag_set, selection) if with_tag_set: return with_tag_set return selection.with_server_descriptions([]) def secondary_with_tags_server_selector(tag_sets, selection): """All near-enough secondaries matching the tag sets.""" return apply_tag_sets(tag_sets, secondary_server_selector(selection)) def member_with_tags_server_selector(tag_sets, selection): """All near-enough members matching the tag sets.""" return apply_tag_sets(tag_sets, readable_server_selector(selection)) pymongo-3.11.0/pymongo/server_type.py000066400000000000000000000015621374256237000177120ustar00rootroot00000000000000# Copyright 2014-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Type codes for MongoDB servers.""" from collections import namedtuple SERVER_TYPE = namedtuple('ServerType', ['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary', 'RSArbiter', 'RSOther', 'RSGhost', 'Standalone'])(*range(8)) pymongo-3.11.0/pymongo/settings.py000066400000000000000000000106601374256237000172020ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Represent MongoClient's configuration.""" import threading import traceback from bson.objectid import ObjectId from pymongo import common, monitor, pool from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError from pymongo.pool import PoolOptions from pymongo.server_description import ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE class TopologySettings(object): def __init__(self, seeds=None, replica_set_name=None, pool_class=None, pool_options=None, monitor_class=None, condition_class=None, local_threshold_ms=LOCAL_THRESHOLD_MS, server_selection_timeout=SERVER_SELECTION_TIMEOUT, heartbeat_frequency=common.HEARTBEAT_FREQUENCY, server_selector=None, fqdn=None, direct_connection=None): """Represent MongoClient's configuration. Take a list of (host, port) pairs and optional replica set name. """ if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL: raise ConfigurationError( "heartbeatFrequencyMS cannot be less than %d" % ( common.MIN_HEARTBEAT_INTERVAL * 1000,)) self._seeds = seeds or [('localhost', 27017)] self._replica_set_name = replica_set_name self._pool_class = pool_class or pool.Pool self._pool_options = pool_options or PoolOptions() self._monitor_class = monitor_class or monitor.Monitor self._condition_class = condition_class or threading.Condition self._local_threshold_ms = local_threshold_ms self._server_selection_timeout = server_selection_timeout self._server_selector = server_selector self._fqdn = fqdn self._heartbeat_frequency = heartbeat_frequency if direct_connection is None: self._direct = (len(self._seeds) == 1 and not self.replica_set_name) else: self._direct = direct_connection self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. self._stack = ''.join(traceback.format_stack()) @property def seeds(self): """List of server addresses.""" return self._seeds @property def replica_set_name(self): return self._replica_set_name @property def pool_class(self): return self._pool_class @property def pool_options(self): return self._pool_options @property def monitor_class(self): return self._monitor_class @property def condition_class(self): return self._condition_class @property def local_threshold_ms(self): return self._local_threshold_ms @property def server_selection_timeout(self): return self._server_selection_timeout @property def server_selector(self): return self._server_selector @property def heartbeat_frequency(self): return self._heartbeat_frequency @property def fqdn(self): return self._fqdn @property def direct(self): """Connect directly to a single server, or use a set of servers? True if there is one seed and no replica_set_name. """ return self._direct def get_topology_type(self): if self.direct: return TOPOLOGY_TYPE.Single elif self.replica_set_name is not None: return TOPOLOGY_TYPE.ReplicaSetNoPrimary else: return TOPOLOGY_TYPE.Unknown def get_server_descriptions(self): """Initial dict of (address, ServerDescription) for all seeds.""" return dict([ (address, ServerDescription(address)) for address in self.seeds]) pymongo-3.11.0/pymongo/socket_checker.py000066400000000000000000000074371374256237000203260ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Select / poll helper""" import errno import select import sys # PYTHON-2320: Jython does not fully support poll on SSL sockets, # https://bugs.jython.org/issue2900 _HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith('java') _SelectError = getattr(select, "error", OSError) def _errno_from_exception(exc): if hasattr(exc, 'errno'): return exc.errno if exc.args: return exc.args[0] return None class SocketChecker(object): def __init__(self): if _HAVE_POLL: self._poller = select.poll() else: self._poller = None def select(self, sock, read=False, write=False, timeout=0): """Select for reads or writes with a timeout in seconds. Returns True if the socket is readable/writable, False on timeout. """ while True: try: if self._poller: mask = select.POLLERR | select.POLLHUP if read: mask = mask | select.POLLIN | select.POLLPRI if write: mask = mask | select.POLLOUT self._poller.register(sock, mask) try: # poll() timeout is in milliseconds. select() # timeout is in seconds. res = self._poller.poll(timeout * 1000) # poll returns a possibly-empty list containing # (fd, event) 2-tuples for the descriptors that have # events or errors to report. Return True if the list # is not empty. return bool(res) finally: self._poller.unregister(sock) else: rlist = [sock] if read else [] wlist = [sock] if write else [] res = select.select(rlist, wlist, [sock], timeout) # select returns a 3-tuple of lists of objects that are # ready: subsets of the first three arguments. Return # True if any of the lists are not empty. return any(res) except (_SelectError, IOError) as exc: if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN): continue raise def socket_closed(self, sock): """Return True if we know socket has been closed, False otherwise. """ try: return self.select(sock, read=True) except (RuntimeError, KeyError): # RuntimeError is raised during a concurrent poll. KeyError # is raised by unregister if the socket is not in the poller. # These errors should not be possible since we protect the # poller with a mutex. raise except ValueError: # ValueError is raised by register/unregister/select if the # socket file descriptor is negative or outside the range for # select (> 1023). return True except Exception: # Any other exceptions should be attributed to a closed # or invalid socket. return True pymongo-3.11.0/pymongo/son_manipulator.py000066400000000000000000000150531374256237000205550ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """**DEPRECATED**: Manipulators that can edit SON objects as they enter and exit a database. The :class:`~pymongo.son_manipulator.SONManipulator` API has limitations as a technique for transforming your data. Instead, it is more flexible and straightforward to transform outgoing documents in your own code before passing them to PyMongo, and transform incoming documents after receiving them from PyMongo. SON Manipulators will be removed from PyMongo in 4.0. PyMongo does **not** apply SON manipulators to documents passed to the modern methods :meth:`~pymongo.collection.Collection.bulk_write`, :meth:`~pymongo.collection.Collection.insert_one`, :meth:`~pymongo.collection.Collection.insert_many`, :meth:`~pymongo.collection.Collection.update_one`, or :meth:`~pymongo.collection.Collection.update_many`. SON manipulators are **not** applied to documents returned by the modern methods :meth:`~pymongo.collection.Collection.find_one_and_delete`, :meth:`~pymongo.collection.Collection.find_one_and_replace`, and :meth:`~pymongo.collection.Collection.find_one_and_update`. """ from bson.dbref import DBRef from bson.objectid import ObjectId from bson.py3compat import abc from bson.son import SON class SONManipulator(object): """A base son manipulator. This manipulator just saves and restores objects without changing them. """ def will_copy(self): """Will this SON manipulator make a copy of the incoming document? Derived classes that do need to make a copy should override this method, returning True instead of False. All non-copying manipulators will be applied first (so that the user's document will be updated appropriately), followed by copying manipulators. """ return False def transform_incoming(self, son, collection): """Manipulate an incoming SON object. :Parameters: - `son`: the SON object to be inserted into the database - `collection`: the collection the object is being inserted into """ if self.will_copy(): return SON(son) return son def transform_outgoing(self, son, collection): """Manipulate an outgoing SON object. :Parameters: - `son`: the SON object being retrieved from the database - `collection`: the collection this object was stored in """ if self.will_copy(): return SON(son) return son class ObjectIdInjector(SONManipulator): """A son manipulator that adds the _id field if it is missing. .. versionchanged:: 2.7 ObjectIdInjector is no longer used by PyMongo, but remains in this module for backwards compatibility. """ def transform_incoming(self, son, collection): """Add an _id field if it is missing. """ if not "_id" in son: son["_id"] = ObjectId() return son # This is now handled during BSON encoding (for performance reasons), # but I'm keeping this here as a reference for those implementing new # SONManipulators. class ObjectIdShuffler(SONManipulator): """A son manipulator that moves _id to the first position. """ def will_copy(self): """We need to copy to be sure that we are dealing with SON, not a dict. """ return True def transform_incoming(self, son, collection): """Move _id to the front if it's there. """ if not "_id" in son: return son transformed = SON({"_id": son["_id"]}) transformed.update(son) return transformed class NamespaceInjector(SONManipulator): """A son manipulator that adds the _ns field. """ def transform_incoming(self, son, collection): """Add the _ns field to the incoming object """ son["_ns"] = collection.name return son class AutoReference(SONManipulator): """Transparently reference and de-reference already saved embedded objects. This manipulator should probably only be used when the NamespaceInjector is also being used, otherwise it doesn't make too much sense - documents can only be auto-referenced if they have an *_ns* field. NOTE: this will behave poorly if you have a circular reference. TODO: this only works for documents that are in the same database. To fix this we'll need to add a DatabaseInjector that adds *_db* and then make use of the optional *database* support for DBRefs. """ def __init__(self, db): self.database = db def will_copy(self): """We need to copy so the user's document doesn't get transformed refs. """ return True def transform_incoming(self, son, collection): """Replace embedded documents with DBRefs. """ def transform_value(value): if isinstance(value, abc.MutableMapping): if "_id" in value and "_ns" in value: return DBRef(value["_ns"], transform_value(value["_id"])) else: return transform_dict(SON(value)) elif isinstance(value, list): return [transform_value(v) for v in value] return value def transform_dict(object): for (key, value) in object.items(): object[key] = transform_value(value) return object return transform_dict(SON(son)) def transform_outgoing(self, son, collection): """Replace DBRefs with embedded documents. """ def transform_value(value): if isinstance(value, DBRef): return self.database.dereference(value) elif isinstance(value, list): return [transform_value(v) for v in value] elif isinstance(value, abc.MutableMapping): return transform_dict(SON(value)) return value def transform_dict(object): for (key, value) in object.items(): object[key] = transform_value(value) return object return transform_dict(SON(son)) pymongo-3.11.0/pymongo/srv_resolver.py000066400000000000000000000071461374256237000201020ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Support for resolving hosts and options from mongodb+srv:// URIs.""" try: from dns import resolver _HAVE_DNSPYTHON = True except ImportError: _HAVE_DNSPYTHON = False from bson.py3compat import PY3 from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError if PY3: # dnspython can return bytes or str from various parts # of its API depending on version. We always want str. def maybe_decode(text): if isinstance(text, bytes): return text.decode() return text else: def maybe_decode(text): return text class _SrvResolver(object): def __init__(self, fqdn, connect_timeout=None): self.__fqdn = fqdn self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT # Validate the fully qualified domain name. try: self.__plist = self.__fqdn.split(".")[1:] except Exception: raise ConfigurationError("Invalid URI host: %s" % (fqdn,)) self.__slen = len(self.__plist) if self.__slen < 2: raise ConfigurationError("Invalid URI host: %s" % (fqdn,)) def get_options(self): try: results = resolver.query(self.__fqdn, 'TXT', lifetime=self.__connect_timeout) except (resolver.NoAnswer, resolver.NXDOMAIN): # No TXT records return None except Exception as exc: raise ConfigurationError(str(exc)) if len(results) > 1: raise ConfigurationError('Only one TXT record is supported') return ( b'&'.join([b''.join(res.strings) for res in results])).decode( 'utf-8') def _resolve_uri(self, encapsulate_errors): try: results = resolver.query('_mongodb._tcp.' + self.__fqdn, 'SRV', lifetime=self.__connect_timeout) except Exception as exc: if not encapsulate_errors: # Raise the original error. raise # Else, raise all errors as ConfigurationError. raise ConfigurationError(str(exc)) return results def _get_srv_response_and_hosts(self, encapsulate_errors): results = self._resolve_uri(encapsulate_errors) # Construct address tuples nodes = [ (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results] # Validate hosts for node in nodes: try: nlist = node[0].split(".")[1:][-self.__slen:] except Exception: raise ConfigurationError("Invalid SRV host: %s" % (node[0],)) if self.__plist != nlist: raise ConfigurationError("Invalid SRV host: %s" % (node[0],)) return results, nodes def get_hosts(self): _, nodes = self._get_srv_response_and_hosts(True) return nodes def get_hosts_and_min_ttl(self): results, nodes = self._get_srv_response_and_hosts(False) return nodes, results.rrset.ttl pymongo-3.11.0/pymongo/ssl_context.py000066400000000000000000000124251374256237000177100ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """A fake SSLContext implementation.""" import ssl as _ssl import sys as _sys # PROTOCOL_TLS_CLIENT is Python 3.6+ PROTOCOL_SSLv23 = getattr(_ssl, "PROTOCOL_TLS_CLIENT", _ssl.PROTOCOL_SSLv23) # Python 2.7.9+ OP_NO_SSLv2 = getattr(_ssl, "OP_NO_SSLv2", 0) # Python 2.7.9+ OP_NO_SSLv3 = getattr(_ssl, "OP_NO_SSLv3", 0) # Python 2.7.9+, OpenSSL 1.0.0+ OP_NO_COMPRESSION = getattr(_ssl, "OP_NO_COMPRESSION", 0) # Python 3.7+, OpenSSL 1.1.0h+ OP_NO_RENEGOTIATION = getattr(_ssl, "OP_NO_RENEGOTIATION", 0) # Python 2.7.9+ HAS_SNI = getattr(_ssl, "HAS_SNI", False) IS_PYOPENSSL = False # Base Exception class SSLError = _ssl.SSLError try: # CPython 2.7.9+ from ssl import SSLContext if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"): from ssl import VERIFY_CRL_CHECK_LEAF # Python 3.7 uses OpenSSL's hostname matching implementation # making it the obvious version to start using SSLConext.check_hostname. # Python 3.6 might have been a good version, but it suffers # from https://bugs.python.org/issue32185. # We'll use our bundled match_hostname for older Python # versions, which also supports IP address matching # with Python < 3.5. CHECK_HOSTNAME_SAFE = _sys.version_info[:2] >= (3, 7) except ImportError: from pymongo.errors import ConfigurationError class SSLContext(object): """A fake SSLContext. This implements an API similar to ssl.SSLContext from python 3.2 but does not implement methods or properties that would be incompatible with ssl.wrap_socket from python 2.7 < 2.7.9. You must pass protocol which must be one of the PROTOCOL_* constants defined in the ssl module. ssl.PROTOCOL_SSLv23 is recommended for maximum interoperability. """ __slots__ = ('_cafile', '_certfile', '_keyfile', '_protocol', '_verify_mode') def __init__(self, protocol): self._cafile = None self._certfile = None self._keyfile = None self._protocol = protocol self._verify_mode = _ssl.CERT_NONE @property def protocol(self): """The protocol version chosen when constructing the context. This attribute is read-only. """ return self._protocol def __get_verify_mode(self): """Whether to try to verify other peers' certificates and how to behave if verification fails. This attribute must be one of ssl.CERT_NONE, ssl.CERT_OPTIONAL or ssl.CERT_REQUIRED. """ return self._verify_mode def __set_verify_mode(self, value): """Setter for verify_mode.""" self._verify_mode = value verify_mode = property(__get_verify_mode, __set_verify_mode) def load_cert_chain(self, certfile, keyfile=None, password=None): """Load a private key and the corresponding certificate. The certfile string must be the path to a single file in PEM format containing the certificate as well as any number of CA certificates needed to establish the certificate's authenticity. The keyfile string, if present, must point to a file containing the private key. Otherwise the private key will be taken from certfile as well. """ if password is not None: raise ConfigurationError( "Support for ssl_pem_passphrase requires " "python 2.7.9+ (pypy 2.5.1+), python 3 or " "PyOpenSSL") self._certfile = certfile self._keyfile = keyfile def load_verify_locations(self, cafile=None, dummy=None): """Load a set of "certification authority"(CA) certificates used to validate other peers' certificates when `~verify_mode` is other than ssl.CERT_NONE. """ self._cafile = cafile def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, dummy=None): """Wrap an existing Python socket sock and return an ssl.SSLSocket object. """ return _ssl.wrap_socket(sock, keyfile=self._keyfile, certfile=self._certfile, server_side=server_side, cert_reqs=self._verify_mode, ssl_version=self._protocol, ca_certs=self._cafile, do_handshake_on_connect=do_handshake_on_connect, suppress_ragged_eofs=suppress_ragged_eofs) pymongo-3.11.0/pymongo/ssl_match_hostname.py000066400000000000000000000111021374256237000212050ustar00rootroot00000000000000# Backport of the match_hostname logic from python 3.5, with small # changes to support IP address matching on python 2.7 and 3.4. import re import sys try: # Python 3.4+, or the ipaddress module from pypi. from ipaddress import ip_address except ImportError: ip_address = lambda address: None # ipaddress.ip_address requires unicode if sys.version_info[0] < 3: _unicode = unicode else: _unicode = lambda value: value from pymongo.errors import CertificateError def _dnsname_match(dn, hostname, max_wildcards=1): """Matching according to RFC 6125, section 6.4.3 http://tools.ietf.org/html/rfc6125#section-6.4.3 """ pats = [] if not dn: return False parts = dn.split(r'.') leftmost = parts[0] remainder = parts[1:] wildcards = leftmost.count('*') if wildcards > max_wildcards: # Issue #17980: avoid denials of service by refusing more # than one wildcard per fragment. A survey of established # policy among SSL implementations showed it to be a # reasonable choice. raise CertificateError( "too many wildcards in certificate DNS name: " + repr(dn)) # speed up common case w/o wildcards if not wildcards: return dn.lower() == hostname.lower() # RFC 6125, section 6.4.3, subitem 1. # The client SHOULD NOT attempt to match a presented identifier in which # the wildcard character comprises a label other than the left-most label. if leftmost == '*': # When '*' is a fragment by itself, it matches a non-empty dotless # fragment. pats.append('[^.]+') elif leftmost.startswith('xn--') or hostname.startswith('xn--'): # RFC 6125, section 6.4.3, subitem 3. # The client SHOULD NOT attempt to match a presented identifier # where the wildcard character is embedded within an A-label or # U-label of an internationalized domain name. pats.append(re.escape(leftmost)) else: # Otherwise, '*' matches any dotless string, e.g. www* pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) # add the remaining fragments, ignore any wildcards for frag in remainder: pats.append(re.escape(frag)) pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) return pat.match(hostname) def _ipaddress_match(ipname, host_ip): """Exact matching of IP addresses. RFC 6125 explicitly doesn't define an algorithm for this (section 1.7.2 - "Out of Scope"). """ # OpenSSL may add a trailing newline to a subjectAltName's IP address ip = ip_address(_unicode(ipname).rstrip()) return ip == host_ip def match_hostname(cert, hostname): """Verify that *cert* (in decoded format as returned by SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 rules are followed. CertificateError is raised on failure. On success, the function returns nothing. """ if not cert: raise ValueError("empty or no certificate, match_hostname needs a " "SSL socket or SSL context with either " "CERT_OPTIONAL or CERT_REQUIRED") try: host_ip = ip_address(_unicode(hostname)) except (ValueError, UnicodeError): # Not an IP address (common case) host_ip = None dnsnames = [] san = cert.get('subjectAltName', ()) for key, value in san: if key == 'DNS': if host_ip is None and _dnsname_match(value, hostname): return dnsnames.append(value) elif key == 'IP Address': if host_ip is not None and _ipaddress_match(value, host_ip): return dnsnames.append(value) if not dnsnames: # The subject is only checked when there is no dNSName entry # in subjectAltName for sub in cert.get('subject', ()): for key, value in sub: # XXX according to RFC 2818, the most specific Common Name # must be used. if key == 'commonName': if _dnsname_match(value, hostname): return dnsnames.append(value) if len(dnsnames) > 1: raise CertificateError("hostname %r " "doesn't match either of %s" % (hostname, ', '.join(map(repr, dnsnames)))) elif len(dnsnames) == 1: raise CertificateError("hostname %r " "doesn't match %r" % (hostname, dnsnames[0])) else: raise CertificateError("no appropriate commonName or " "subjectAltName fields were found") pymongo-3.11.0/pymongo/ssl_support.py000066400000000000000000000155121374256237000177400ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Support for SSL in PyMongo.""" import atexit import sys import threading from bson.py3compat import string_type from pymongo.errors import ConfigurationError HAVE_SSL = True try: import pymongo.pyopenssl_context as _ssl except ImportError: try: import pymongo.ssl_context as _ssl except ImportError: HAVE_SSL = False HAVE_CERTIFI = False try: import certifi HAVE_CERTIFI = True except ImportError: pass HAVE_WINCERTSTORE = False try: from wincertstore import CertFile HAVE_WINCERTSTORE = True except ImportError: pass _WINCERTSLOCK = threading.Lock() _WINCERTS = None if HAVE_SSL: # Note: The validate* functions below deal with users passing # CPython ssl module constants to configure certificate verification # at a high level. This is legacy behavior, but requires us to # import the ssl module even if we're only using it for this purpose. import ssl as _stdlibssl from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED HAS_SNI = _ssl.HAS_SNI IPADDR_SAFE = _ssl.IS_PYOPENSSL or sys.version_info[:2] >= (3, 7) SSLError = _ssl.SSLError def validate_cert_reqs(option, value): """Validate the cert reqs are valid. It must be None or one of the three values ``ssl.CERT_NONE``, ``ssl.CERT_OPTIONAL`` or ``ssl.CERT_REQUIRED``. """ if value is None: return value if isinstance(value, string_type) and hasattr(_stdlibssl, value): value = getattr(_stdlibssl, value) if value in (CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED): return value raise ValueError("The value of %s must be one of: " "`ssl.CERT_NONE`, `ssl.CERT_OPTIONAL` or " "`ssl.CERT_REQUIRED`" % (option,)) def validate_allow_invalid_certs(option, value): """Validate the option to allow invalid certificates is valid.""" # Avoid circular import. from pymongo.common import validate_boolean_or_string boolean_cert_reqs = validate_boolean_or_string(option, value) if boolean_cert_reqs: return CERT_NONE return CERT_REQUIRED def _load_wincerts(): """Set _WINCERTS to an instance of wincertstore.Certfile.""" global _WINCERTS certfile = CertFile() certfile.addstore("CA") certfile.addstore("ROOT") atexit.register(certfile.close) _WINCERTS = certfile def get_ssl_context(*args): """Create and return an SSLContext object.""" (certfile, keyfile, passphrase, ca_certs, cert_reqs, crlfile, match_hostname, check_ocsp_endpoint) = args verify_mode = CERT_REQUIRED if cert_reqs is None else cert_reqs ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23) # SSLContext.check_hostname was added in CPython 2.7.9 and 3.4. if hasattr(ctx, "check_hostname"): if _ssl.CHECK_HOSTNAME_SAFE and verify_mode != CERT_NONE: ctx.check_hostname = match_hostname else: ctx.check_hostname = False if hasattr(ctx, "check_ocsp_endpoint"): ctx.check_ocsp_endpoint = check_ocsp_endpoint if hasattr(ctx, "options"): # Explicitly disable SSLv2, SSLv3 and TLS compression. Note that # up to date versions of MongoDB 2.4 and above already disable # SSLv2 and SSLv3, python disables SSLv2 by default in >= 2.7.7 # and >= 3.3.4 and SSLv3 in >= 3.4.3. ctx.options |= _ssl.OP_NO_SSLv2 ctx.options |= _ssl.OP_NO_SSLv3 ctx.options |= _ssl.OP_NO_COMPRESSION ctx.options |= _ssl.OP_NO_RENEGOTIATION if certfile is not None: try: ctx.load_cert_chain(certfile, keyfile, passphrase) except _ssl.SSLError as exc: raise ConfigurationError( "Private key doesn't match certificate: %s" % (exc,)) if crlfile is not None: if _ssl.IS_PYOPENSSL: raise ConfigurationError( "ssl_crlfile cannot be used with PyOpenSSL") if not hasattr(ctx, "verify_flags"): raise ConfigurationError( "Support for ssl_crlfile requires " "python 2.7.9+ (pypy 2.5.1+) or 3.4+") # Match the server's behavior. ctx.verify_flags = getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0) ctx.load_verify_locations(crlfile) if ca_certs is not None: ctx.load_verify_locations(ca_certs) elif cert_reqs != CERT_NONE: # CPython >= 2.7.9 or >= 3.4.0, pypy >= 2.5.1 if hasattr(ctx, "load_default_certs"): ctx.load_default_certs() # Python >= 3.2.0, useless on Windows. elif (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")): ctx.set_default_verify_paths() elif sys.platform == "win32" and HAVE_WINCERTSTORE: with _WINCERTSLOCK: if _WINCERTS is None: _load_wincerts() ctx.load_verify_locations(_WINCERTS.name) elif HAVE_CERTIFI: ctx.load_verify_locations(certifi.where()) else: raise ConfigurationError( "`ssl_cert_reqs` is not ssl.CERT_NONE and no system " "CA certificates could be loaded. `ssl_ca_certs` is " "required.") ctx.verify_mode = verify_mode return ctx else: class SSLError(Exception): pass HAS_SNI = False IPADDR_SAFE = False def validate_cert_reqs(option, dummy): """No ssl module, raise ConfigurationError.""" raise ConfigurationError("The value of %s is set but can't be " "validated. The ssl module is not available" % (option,)) def validate_allow_invalid_certs(option, dummy): """No ssl module, raise ConfigurationError.""" return validate_cert_reqs(option, dummy) def get_ssl_context(*dummy): """No ssl module, raise ConfigurationError.""" raise ConfigurationError("The ssl module is not available.") pymongo-3.11.0/pymongo/thread_util.py000066400000000000000000000075651374256237000176600ustar00rootroot00000000000000# Copyright 2012-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for multi-threading support.""" import threading try: from time import monotonic as _time except ImportError: from time import time as _time from pymongo.monotonic import time as _time from pymongo.errors import ExceededMaxWaiters ### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire class Semaphore: # After Tim Peters' semaphore class, but not quite the same (no maximum) def __init__(self, value=1): if value < 0: raise ValueError("semaphore initial value must be >= 0") self._cond = threading.Condition(threading.Lock()) self._value = value def acquire(self, blocking=True, timeout=None): if not blocking and timeout is not None: raise ValueError("can't specify timeout for non-blocking acquire") rc = False endtime = None with self._cond: while self._value == 0: if not blocking: break if timeout is not None: if endtime is None: endtime = _time() + timeout else: timeout = endtime - _time() if timeout <= 0: break self._cond.wait(timeout) else: self._value = self._value - 1 rc = True return rc __enter__ = acquire def release(self): with self._cond: self._value = self._value + 1 self._cond.notify() def __exit__(self, t, v, tb): self.release() @property def counter(self): return self._value class BoundedSemaphore(Semaphore): """Semaphore that checks that # releases is <= # acquires""" def __init__(self, value=1): Semaphore.__init__(self, value) self._initial_value = value def release(self): if self._value >= self._initial_value: raise ValueError("Semaphore released too many times") return Semaphore.release(self) ### End backport from CPython 3.2 class DummySemaphore(object): def __init__(self, value=None): pass def acquire(self, blocking=True, timeout=None): return True def release(self): pass class MaxWaitersBoundedSemaphore(object): def __init__(self, semaphore_class, value=1, max_waiters=1): self.waiter_semaphore = semaphore_class(max_waiters) self.semaphore = semaphore_class(value) def acquire(self, blocking=True, timeout=None): if not self.waiter_semaphore.acquire(False): raise ExceededMaxWaiters() try: return self.semaphore.acquire(blocking, timeout) finally: self.waiter_semaphore.release() def __getattr__(self, name): return getattr(self.semaphore, name) class MaxWaitersBoundedSemaphoreThread(MaxWaitersBoundedSemaphore): def __init__(self, value=1, max_waiters=1): MaxWaitersBoundedSemaphore.__init__( self, BoundedSemaphore, value, max_waiters) def create_semaphore(max_size, max_waiters): if max_size is None: return DummySemaphore() else: if max_waiters is None: return BoundedSemaphore(max_size) else: return MaxWaitersBoundedSemaphoreThread(max_size, max_waiters) pymongo-3.11.0/pymongo/topology.py000066400000000000000000000771631374256237000172310ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Internal class to monitor a topology of one or more servers.""" import os import random import threading import warnings import weakref from bson.py3compat import itervalues, PY3 if PY3: import queue as Queue else: import Queue from pymongo import (common, helpers, periodic_executor) from pymongo.pool import PoolOptions from pymongo.topology_description import (updated_topology_description, _updated_topology_description_srv_polling, TopologyDescription, SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE) from pymongo.errors import (ConnectionFailure, ConfigurationError, NetworkTimeout, NotMasterError, OperationFailure, ServerSelectionTimeoutError) from pymongo.monitor import SrvMonitor from pymongo.monotonic import time as _time from pymongo.server import Server from pymongo.server_description import ServerDescription from pymongo.server_selectors import (any_server_selector, arbiter_server_selector, secondary_server_selector, readable_server_selector, writable_server_selector, Selection) from pymongo.client_session import _ServerSessionPool def process_events_queue(queue_ref): q = queue_ref() if not q: return False # Cancel PeriodicExecutor. while True: try: event = q.get_nowait() except Queue.Empty: break else: fn, args = event fn(*args) return True # Continue PeriodicExecutor. class Topology(object): """Monitor a topology of one or more servers.""" def __init__(self, topology_settings): self._topology_id = topology_settings._topology_id self._listeners = topology_settings._pool_options.event_listeners pub = self._listeners is not None self._publish_server = pub and self._listeners.enabled_for_server self._publish_tp = pub and self._listeners.enabled_for_topology # Create events queue if there are publishers. self._events = None self.__events_executor = None if self._publish_server or self._publish_tp: self._events = Queue.Queue(maxsize=100) if self._publish_tp: self._events.put((self._listeners.publish_topology_opened, (self._topology_id,))) self._settings = topology_settings topology_description = TopologyDescription( topology_settings.get_topology_type(), topology_settings.get_server_descriptions(), topology_settings.replica_set_name, None, None, topology_settings) self._description = topology_description if self._publish_tp: initial_td = TopologyDescription(TOPOLOGY_TYPE.Unknown, {}, None, None, None, self._settings) self._events.put(( self._listeners.publish_topology_description_changed, (initial_td, self._description, self._topology_id))) for seed in topology_settings.seeds: if self._publish_server: self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id))) # Store the seed list to help diagnose errors in _error_message(). self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._lock = threading.Lock() self._condition = self._settings.condition_class(self._lock) self._servers = {} self._pid = None self._max_cluster_time = None self._session_pool = _ServerSessionPool() if self._publish_server or self._publish_tp: def target(): return process_events_queue(weak) executor = periodic_executor.PeriodicExecutor( interval=common.EVENTS_QUEUE_FREQUENCY, min_interval=0.5, target=target, name="pymongo_events_thread") # We strongly reference the executor and it weakly references # the queue via this closure. When the topology is freed, stop # the executor soon. weak = weakref.ref(self._events, executor.close) self.__events_executor = executor executor.open() self._srv_monitor = None if self._settings.fqdn is not None: self._srv_monitor = SrvMonitor(self, self._settings) def open(self): """Start monitoring, or restart after a fork. No effect if called multiple times. .. warning:: Topology is shared among multiple threads and is protected by mutual exclusion. Using Topology from a process other than the one that initialized it will emit a warning and may result in deadlock. To prevent this from happening, MongoClient must be created after any forking. """ if self._pid is None: self._pid = os.getpid() else: if os.getpid() != self._pid: warnings.warn( "MongoClient opened before fork. Create MongoClient only " "after forking. See PyMongo's documentation for details: " "https://pymongo.readthedocs.io/en/stable/faq.html#" "is-pymongo-fork-safe") with self._lock: # Reset the session pool to avoid duplicate sessions in # the child process. self._session_pool.reset() with self._lock: self._ensure_opened() def select_servers(self, selector, server_selection_timeout=None, address=None): """Return a list of Servers matching selector, or time out. :Parameters: - `selector`: function that takes a list of Servers and returns a subset of them. - `server_selection_timeout` (optional): maximum seconds to wait. If not provided, the default value common.SERVER_SELECTION_TIMEOUT is used. - `address`: optional server address to select. Calls self.open() if needed. Raises exc:`ServerSelectionTimeoutError` after `server_selection_timeout` if no matching servers are found. """ if server_selection_timeout is None: server_timeout = self._settings.server_selection_timeout else: server_timeout = server_selection_timeout with self._lock: server_descriptions = self._select_servers_loop( selector, server_timeout, address) return [self.get_server_by_address(sd.address) for sd in server_descriptions] def _select_servers_loop(self, selector, timeout, address): """select_servers() guts. Hold the lock when calling this.""" now = _time() end_time = now + timeout server_descriptions = self._description.apply_selector( selector, address, custom_selector=self._settings.server_selector) while not server_descriptions: # No suitable servers. if timeout == 0 or now > end_time: raise ServerSelectionTimeoutError( "%s, Timeout: %ss, Topology Description: %r" % (self._error_message(selector), timeout, self.description)) self._ensure_opened() self._request_check_all() # Release the lock and wait for the topology description to # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = _time() server_descriptions = self._description.apply_selector( selector, address, custom_selector=self._settings.server_selector) self._description.check_compatible() return server_descriptions def select_server(self, selector, server_selection_timeout=None, address=None): """Like select_servers, but choose a random server if several match.""" return random.choice(self.select_servers(selector, server_selection_timeout, address)) def select_server_by_address(self, address, server_selection_timeout=None): """Return a Server for "address", reconnecting if necessary. If the server's type is not known, request an immediate check of all servers. Time out after "server_selection_timeout" if the server cannot be reached. :Parameters: - `address`: A (host, port) pair. - `server_selection_timeout` (optional): maximum seconds to wait. If not provided, the default value common.SERVER_SELECTION_TIMEOUT is used. Calls self.open() if needed. Raises exc:`ServerSelectionTimeoutError` after `server_selection_timeout` if no matching servers are found. """ return self.select_server(any_server_selector, server_selection_timeout, address) def _process_change(self, server_description, reset_pool=False): """Process a new ServerDescription on an opened topology. Hold the lock when calling this. """ td_old = self._description sd_old = td_old._server_descriptions[server_description.address] if _is_stale_server_description(sd_old, server_description): # This is a stale isMaster response. Ignore it. return suppress_event = ((self._publish_server or self._publish_tp) and sd_old == server_description) if self._publish_server and not suppress_event: self._events.put(( self._listeners.publish_server_description_changed, (sd_old, server_description, server_description.address, self._topology_id))) self._description = updated_topology_description( self._description, server_description) self._update_servers() self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: self._events.put(( self._listeners.publish_topology_description_changed, (td_old, self._description, self._topology_id))) # Shutdown SRV polling for unsupported cluster types. # This is only applicable if the old topology was Unknown, and the # new one is something other than Unknown or Sharded. if self._srv_monitor and (td_old.topology_type == TOPOLOGY_TYPE.Unknown and self._description.topology_type not in SRV_POLLING_TOPOLOGIES): self._srv_monitor.close() # Clear the pool from a failed heartbeat. if reset_pool: server = self._servers.get(server_description.address) if server: server.pool.reset() # Wake waiters in select_servers(). self._condition.notify_all() def on_change(self, server_description, reset_pool=False): """Process a new ServerDescription after an ismaster call completes.""" # We do no I/O holding the lock. with self._lock: # Monitors may continue working on ismaster calls for some time # after a call to Topology.close, so this method may be called at # any time. Ensure the topology is open before processing the # change. # Any monitored server was definitely in the topology description # once. Check if it's still in the description or if some state- # change removed it. E.g., we got a host list from the primary # that didn't include this server. if (self._opened and self._description.has_server(server_description.address)): self._process_change(server_description, reset_pool) def _process_srv_update(self, seedlist): """Process a new seedlist on an opened topology. Hold the lock when calling this. """ td_old = self._description self._description = _updated_topology_description_srv_polling( self._description, seedlist) self._update_servers() if self._publish_tp: self._events.put(( self._listeners.publish_topology_description_changed, (td_old, self._description, self._topology_id))) def on_srv_update(self, seedlist): """Process a new list of nodes obtained from scanning SRV records.""" # We do no I/O holding the lock. with self._lock: if self._opened: self._process_srv_update(seedlist) def get_server_by_address(self, address): """Get a Server or None. Returns the current version of the server immediately, even if it's Unknown or absent from the topology. Only use this in unittests. In driver code, use select_server_by_address, since then you're assured a recent view of the server's type and wire protocol version. """ return self._servers.get(address) def has_server(self, address): return address in self._servers def get_primary(self): """Return primary's address or None.""" # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: topology_type = self._description.topology_type if topology_type != TOPOLOGY_TYPE.ReplicaSetWithPrimary: return None return writable_server_selector(self._new_selection())[0].address def _get_replica_set_members(self, selector): """Return set of replica set member addresses.""" # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: topology_type = self._description.topology_type if topology_type not in (TOPOLOGY_TYPE.ReplicaSetWithPrimary, TOPOLOGY_TYPE.ReplicaSetNoPrimary): return set() return set([sd.address for sd in selector(self._new_selection())]) def get_secondaries(self): """Return set of secondary addresses.""" return self._get_replica_set_members(secondary_server_selector) def get_arbiters(self): """Return set of arbiter addresses.""" return self._get_replica_set_members(arbiter_server_selector) def max_cluster_time(self): """Return a document, the highest seen $clusterTime.""" return self._max_cluster_time def _receive_cluster_time_no_lock(self, cluster_time): # Driver Sessions Spec: "Whenever a driver receives a cluster time from # a server it MUST compare it to the current highest seen cluster time # for the deployment. If the new cluster time is higher than the # highest seen cluster time it MUST become the new highest seen cluster # time. Two cluster times are compared using only the BsonTimestamp # value of the clusterTime embedded field." if cluster_time: # ">" uses bson.timestamp.Timestamp's comparison operator. if (not self._max_cluster_time or cluster_time['clusterTime'] > self._max_cluster_time['clusterTime']): self._max_cluster_time = cluster_time def receive_cluster_time(self, cluster_time): with self._lock: self._receive_cluster_time_no_lock(cluster_time) def request_check_all(self, wait_time=5): """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() self._condition.wait(wait_time) def handle_getlasterror(self, address, error_msg): """Clear our pool for a server, mark it Unknown, and check it soon.""" error = NotMasterError(error_msg, {'code': 10107, 'errmsg': error_msg}) with self._lock: server = self._servers.get(address) if server: self._process_change( ServerDescription(address, error=error), True) server.request_check() def update_pool(self, all_credentials): # Remove any stale sockets and add new sockets if pool is too small. servers = [] with self._lock: for server in self._servers.values(): servers.append((server, server._pool.generation)) for server, generation in servers: server._pool.remove_stale_sockets(generation, all_credentials) def close(self): """Clear pools and terminate monitors. Topology reopens on demand.""" with self._lock: for server in self._servers.values(): server.close() # Mark all servers Unknown. self._description = self._description.reset() for address, sd in self._description.server_descriptions().items(): if address in self._servers: self._servers[address].description = sd # Stop SRV polling thread. if self._srv_monitor: self._srv_monitor.close() self._opened = False # Publish only after releasing the lock. if self._publish_tp: self._events.put((self._listeners.publish_topology_closed, (self._topology_id,))) if self._publish_server or self._publish_tp: self.__events_executor.close() @property def description(self): return self._description def pop_all_sessions(self): """Pop all session ids from the pool.""" with self._lock: return self._session_pool.pop_all() def get_server_session(self): """Start or resume a server session, or raise ConfigurationError.""" with self._lock: session_timeout = self._description.logical_session_timeout_minutes if session_timeout is None: # Maybe we need an initial scan? Can raise ServerSelectionError. if self._description.topology_type == TOPOLOGY_TYPE.Single: if not self._description.has_known_servers: self._select_servers_loop( any_server_selector, self._settings.server_selection_timeout, None) elif not self._description.readable_servers: self._select_servers_loop( readable_server_selector, self._settings.server_selection_timeout, None) session_timeout = self._description.logical_session_timeout_minutes if session_timeout is None: raise ConfigurationError( "Sessions are not supported by this MongoDB deployment") return self._session_pool.get_server_session(session_timeout) def return_server_session(self, server_session, lock): if lock: with self._lock: session_timeout = \ self._description.logical_session_timeout_minutes if session_timeout is not None: self._session_pool.return_server_session(server_session, session_timeout) else: # Called from a __del__ method, can't use a lock. self._session_pool.return_server_session_no_lock(server_session) def _new_selection(self): """A Selection object, initially including all known servers. Hold the lock when calling this. """ return Selection.from_topology_description(self._description) def _ensure_opened(self): """Start monitors, or restart after a fork. Hold the lock when calling this. """ if not self._opened: self._opened = True self._update_servers() # Start or restart the events publishing thread. if self._publish_tp or self._publish_server: self.__events_executor.open() # Start the SRV polling thread. if self._srv_monitor and (self.description.topology_type in SRV_POLLING_TOPOLOGIES): self._srv_monitor.open() # Ensure that the monitors are open. for server in itervalues(self._servers): server.open() def _is_stale_error(self, address, err_ctx): server = self._servers.get(address) if server is None: # Another thread removed this server from the topology. return True if err_ctx.sock_generation != server._pool.generation: # This is an outdated error from a previous pool version. return True # topologyVersion check, ignore error when cur_tv >= error_tv: cur_tv = server.description.topology_version error = err_ctx.error error_tv = None if error and hasattr(error, 'details'): if isinstance(error.details, dict): error_tv = error.details.get('topologyVersion') return _is_stale_error_topology_version(cur_tv, error_tv) def _handle_error(self, address, err_ctx): if self._is_stale_error(address, err_ctx): return server = self._servers[address] error = err_ctx.error exc_type = type(error) if (issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake): # The socket has been closed. Don't reset the server. # Server Discovery And Monitoring Spec: "When an application # operation fails because of any network error besides a socket # timeout...." return elif issubclass(exc_type, NotMasterError): # As per the SDAM spec if: # - the server sees a "not master" error, and # - the server is not shutting down, and # - the server version is >= 4.2, then # we keep the existing connection pool, but mark the server type # as Unknown and request an immediate check of the server. # Otherwise, we clear the connection pool, mark the server as # Unknown and request an immediate check of the server. err_code = error.details.get('code', -1) is_shutting_down = err_code in helpers._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. self._process_change(ServerDescription(address, error=error)) if is_shutting_down or (err_ctx.max_wire_version <= 7): # Clear the pool. server.reset() server.request_check() elif issubclass(exc_type, ConnectionFailure): # "Client MUST replace the server's description with type Unknown # ... MUST NOT request an immediate check of the server." self._process_change(ServerDescription(address, error=error)) # Clear the pool. server.reset() # "When a client marks a server Unknown from `Network error when # reading or writing`_, clients MUST cancel the isMaster check on # that server and close the current monitoring connection." server._monitor.cancel_check() elif issubclass(exc_type, OperationFailure): # Do not request an immediate check since the server is likely # shutting down. if error.code in helpers._NOT_MASTER_CODES: self._process_change(ServerDescription(address, error=error)) # Clear the pool. server.reset() def handle_error(self, address, err_ctx): """Handle an application error. May reset the server to Unknown, clear the pool, and request an immediate check depending on the error and the context. """ with self._lock: self._handle_error(address, err_ctx) def _request_check_all(self): """Wake all monitors. Hold the lock when calling this.""" for server in self._servers.values(): server.request_check() def _update_servers(self): """Sync our Servers from TopologyDescription.server_descriptions. Hold the lock while calling this. """ for address, sd in self._description.server_descriptions().items(): if address not in self._servers: monitor = self._settings.monitor_class( server_description=sd, topology=self, pool=self._create_pool_for_monitor(address), topology_settings=self._settings) weak = None if self._publish_server: weak = weakref.ref(self._events) server = Server( server_description=sd, pool=self._create_pool_for_server(address), monitor=monitor, topology_id=self._topology_id, listeners=self._listeners, events=weak) self._servers[address] = server server.open() else: # Cache old is_writable value. was_writable = self._servers[address].description.is_writable # Update server description. self._servers[address].description = sd # Update is_writable value of the pool, if it changed. if was_writable != sd.is_writable: self._servers[address].pool.update_is_writable( sd.is_writable) for address, server in list(self._servers.items()): if not self._description.has_server(address): server.close() self._servers.pop(address) def _create_pool_for_server(self, address): return self._settings.pool_class(address, self._settings.pool_options) def _create_pool_for_monitor(self, address): options = self._settings.pool_options # According to the Server Discovery And Monitoring Spec, monitors use # connect_timeout for both connect_timeout and socket_timeout. The # pool only has one socket so maxPoolSize and so on aren't needed. monitor_pool_options = PoolOptions( connect_timeout=options.connect_timeout, socket_timeout=options.connect_timeout, ssl_context=options.ssl_context, ssl_match_hostname=options.ssl_match_hostname, event_listeners=options.event_listeners, appname=options.appname, driver=options.driver) return self._settings.pool_class(address, monitor_pool_options, handshake=False) def _error_message(self, selector): """Format an error message if server selection fails. Hold the lock when calling this. """ is_replica_set = self._description.topology_type in ( TOPOLOGY_TYPE.ReplicaSetWithPrimary, TOPOLOGY_TYPE.ReplicaSetNoPrimary) if is_replica_set: server_plural = 'replica set members' elif self._description.topology_type == TOPOLOGY_TYPE.Sharded: server_plural = 'mongoses' else: server_plural = 'servers' if self._description.known_servers: # We've connected, but no servers match the selector. if selector is writable_server_selector: if is_replica_set: return 'No primary available for writes' else: return 'No %s available for writes' % server_plural else: return 'No %s match selector "%s"' % (server_plural, selector) else: addresses = list(self._description.server_descriptions()) servers = list(self._description.server_descriptions().values()) if not servers: if is_replica_set: # We removed all servers because of the wrong setName? return 'No %s available for replica set name "%s"' % ( server_plural, self._settings.replica_set_name) else: return 'No %s available' % server_plural # 1 or more servers, all Unknown. Are they unknown for one reason? error = servers[0].error same = all(server.error == error for server in servers[1:]) if same: if error is None: # We're still discovering. return 'No %s found yet' % server_plural if (is_replica_set and not set(addresses).intersection(self._seed_addresses)): # We replaced our seeds with new hosts but can't reach any. return ( 'Could not reach any servers in %s. Replica set is' ' configured with internal hostnames or IPs?' % addresses) return str(error) else: return ','.join(str(server.error) for server in servers if server.error) def __repr__(self): msg = '' if not self._opened: msg = 'CLOSED ' return '<%s %s%r>' % (self.__class__.__name__, msg, self._description) class _ErrorContext(object): """An error with context for SDAM error handling.""" def __init__(self, error, max_wire_version, sock_generation, completed_handshake): self.error = error self.max_wire_version = max_wire_version self.sock_generation = sock_generation self.completed_handshake = completed_handshake def _is_stale_error_topology_version(current_tv, error_tv): """Return True if the error's topologyVersion is <= current.""" if current_tv is None or error_tv is None: return False if current_tv['processId'] != error_tv['processId']: return False return current_tv['counter'] >= error_tv['counter'] def _is_stale_server_description(current_sd, new_sd): """Return True if the new topologyVersion is < current.""" current_tv, new_tv = current_sd.topology_version, new_sd.topology_version if current_tv is None or new_tv is None: return False if current_tv['processId'] != new_tv['processId']: return False return current_tv['counter'] > new_tv['counter'] pymongo-3.11.0/pymongo/topology_description.py000066400000000000000000000551211374256237000216220ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Represent a deployment of MongoDB servers.""" from collections import namedtuple from pymongo import common from pymongo.errors import ConfigurationError from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection from pymongo.server_type import SERVER_TYPE # Enumeration for various kinds of MongoDB cluster topologies. TOPOLOGY_TYPE = namedtuple('TopologyType', ['Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded', 'Unknown'])(*range(5)) # Topologies compatible with SRV record polling. SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) class TopologyDescription(object): def __init__(self, topology_type, server_descriptions, replica_set_name, max_set_version, max_election_id, topology_settings): """Representation of a deployment of MongoDB servers. :Parameters: - `topology_type`: initial type - `server_descriptions`: dict of (address, ServerDescription) for all seeds - `replica_set_name`: replica set name or None - `max_set_version`: greatest setVersion seen from a primary, or None - `max_election_id`: greatest electionId seen from a primary, or None - `topology_settings`: a TopologySettings """ self._topology_type = topology_type self._replica_set_name = replica_set_name self._server_descriptions = server_descriptions self._max_set_version = max_set_version self._max_election_id = max_election_id # The heartbeat_frequency is used in staleness estimates. self._topology_settings = topology_settings # Is PyMongo compatible with all servers' wire protocols? self._incompatible_err = None for s in self._server_descriptions.values(): if not s.is_server_type_known: continue # s.min/max_wire_version is the server's wire protocol. # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. server_too_new = ( # Server too new. s.min_wire_version is not None and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION) server_too_old = ( # Server too old. s.max_wire_version is not None and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION) if server_too_new: self._incompatible_err = ( "Server at %s:%d requires wire version %d, but this " "version of PyMongo only supports up to %d." % (s.address[0], s.address[1], s.min_wire_version, common.MAX_SUPPORTED_WIRE_VERSION)) elif server_too_old: self._incompatible_err = ( "Server at %s:%d reports wire version %d, but this " "version of PyMongo requires at least %d (MongoDB %s)." % (s.address[0], s.address[1], s.max_wire_version, common.MIN_SUPPORTED_WIRE_VERSION, common.MIN_SUPPORTED_SERVER_VERSION)) break # Server Discovery And Monitoring Spec: Whenever a client updates the # TopologyDescription from an ismaster response, it MUST set # TopologyDescription.logicalSessionTimeoutMinutes to the smallest # logicalSessionTimeoutMinutes value among ServerDescriptions of all # data-bearing server types. If any have a null # logicalSessionTimeoutMinutes, then # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. readable_servers = self.readable_servers if not readable_servers: self._ls_timeout_minutes = None elif any(s.logical_session_timeout_minutes is None for s in readable_servers): self._ls_timeout_minutes = None else: self._ls_timeout_minutes = min(s.logical_session_timeout_minutes for s in readable_servers) def check_compatible(self): """Raise ConfigurationError if any server is incompatible. A server is incompatible if its wire protocol version range does not overlap with PyMongo's. """ if self._incompatible_err: raise ConfigurationError(self._incompatible_err) def has_server(self, address): return address in self._server_descriptions def reset_server(self, address): """A copy of this description, with one server marked Unknown.""" unknown_sd = self._server_descriptions[address].to_unknown() return updated_topology_description(self, unknown_sd) def reset(self): """A copy of this description, with all servers marked Unknown.""" if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary else: topology_type = self._topology_type # The default ServerDescription's type is Unknown. sds = dict((address, ServerDescription(address)) for address in self._server_descriptions) return TopologyDescription( topology_type, sds, self._replica_set_name, self._max_set_version, self._max_election_id, self._topology_settings) def server_descriptions(self): """Dict of (address, :class:`~pymongo.server_description.ServerDescription`).""" return self._server_descriptions.copy() @property def topology_type(self): """The type of this topology.""" return self._topology_type @property def topology_type_name(self): """The topology type as a human readable string. .. versionadded:: 3.4 """ return TOPOLOGY_TYPE._fields[self._topology_type] @property def replica_set_name(self): """The replica set name.""" return self._replica_set_name @property def max_set_version(self): """Greatest setVersion seen from a primary, or None.""" return self._max_set_version @property def max_election_id(self): """Greatest electionId seen from a primary, or None.""" return self._max_election_id @property def logical_session_timeout_minutes(self): """Minimum logical session timeout, or None.""" return self._ls_timeout_minutes @property def known_servers(self): """List of Servers of types besides Unknown.""" return [s for s in self._server_descriptions.values() if s.is_server_type_known] @property def has_known_servers(self): """Whether there are any Servers of types besides Unknown.""" return any(s for s in self._server_descriptions.values() if s.is_server_type_known) @property def readable_servers(self): """List of readable Servers.""" return [s for s in self._server_descriptions.values() if s.is_readable] @property def common_wire_version(self): """Minimum of all servers' max wire versions, or None.""" servers = self.known_servers if servers: return min(s.max_wire_version for s in self.known_servers) return None @property def heartbeat_frequency(self): return self._topology_settings.heartbeat_frequency def apply_selector(self, selector, address, custom_selector=None): def apply_local_threshold(selection): if not selection: return [] settings = self._topology_settings # Round trip time in seconds. fastest = min( s.round_trip_time for s in selection.server_descriptions) threshold = settings.local_threshold_ms / 1000.0 return [s for s in selection.server_descriptions if (s.round_trip_time - fastest) <= threshold] if getattr(selector, 'min_wire_version', 0): common_wv = self.common_wire_version if common_wv and common_wv < selector.min_wire_version: raise ConfigurationError( "%s requires min wire version %d, but topology's min" " wire version is %d" % (selector, selector.min_wire_version, common_wv)) if self.topology_type == TOPOLOGY_TYPE.Single: # Ignore selectors for standalone. return self.known_servers elif address: # Ignore selectors when explicit address is requested. description = self.server_descriptions().get(address) return [description] if description else [] elif self.topology_type == TOPOLOGY_TYPE.Sharded: # Ignore read preference. selection = Selection.from_topology_description(self) else: selection = selector(Selection.from_topology_description(self)) # Apply custom selector followed by localThresholdMS. if custom_selector is not None and selection: selection = selection.with_server_descriptions( custom_selector(selection.server_descriptions)) return apply_local_threshold(selection) def has_readable_server(self, read_preference=ReadPreference.PRIMARY): """Does this topology have any readable servers available matching the given read preference? :Parameters: - `read_preference`: an instance of a read preference from :mod:`~pymongo.read_preferences`. Defaults to :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. .. note:: When connected directly to a single server this method always returns ``True``. .. versionadded:: 3.4 """ common.validate_read_preference("read_preference", read_preference) return any(self.apply_selector(read_preference, None)) def has_writable_server(self): """Does this topology have a writable server available? .. note:: When connected directly to a single server this method always returns ``True``. .. versionadded:: 3.4 """ return self.has_readable_server(ReadPreference.PRIMARY) def __repr__(self): # Sort the servers by address. servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) return "<%s id: %s, topology_type: %s, servers: %r>" % ( self.__class__.__name__, self._topology_settings._topology_id, self.topology_type_name, servers) # If topology type is Unknown and we receive an ismaster response, what should # the new topology type be? _SERVER_TYPE_TO_TOPOLOGY_TYPE = { SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, } def updated_topology_description(topology_description, server_description): """Return an updated copy of a TopologyDescription. :Parameters: - `topology_description`: the current TopologyDescription - `server_description`: a new ServerDescription that resulted from an ismaster call Called after attempting (successfully or not) to call ismaster on the server at server_description.address. Does not modify topology_description. """ address = server_description.address # These values will be updated, if necessary, to form the new # TopologyDescription. topology_type = topology_description.topology_type set_name = topology_description.replica_set_name max_set_version = topology_description.max_set_version max_election_id = topology_description.max_election_id server_type = server_description.server_type # Don't mutate the original dict of server descriptions; copy it. sds = topology_description.server_descriptions() # Replace this server's description with the new one. sds[address] = server_description if topology_type == TOPOLOGY_TYPE.Single: # Set server type to Unknown if replica set name does not match. if (set_name is not None and set_name != server_description.replica_set_name): error = ConfigurationError( "client is configured to connect to a replica set named " "'%s' but this node belongs to a set named '%s'" % ( set_name, server_description.replica_set_name)) sds[address] = server_description.to_unknown(error=error) # Single type never changes. return TopologyDescription( TOPOLOGY_TYPE.Single, sds, set_name, max_set_version, max_election_id, topology_description._topology_settings) if topology_type == TOPOLOGY_TYPE.Unknown: if server_type == SERVER_TYPE.Standalone: if len(topology_description._topology_settings.seeds) == 1: topology_type = TOPOLOGY_TYPE.Single else: # Remove standalone from Topology when given multiple seeds. sds.pop(address) elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] if topology_type == TOPOLOGY_TYPE.Sharded: if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): sds.pop(address) elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): sds.pop(address) elif server_type == SERVER_TYPE.RSPrimary: (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(sds, set_name, server_description, max_set_version, max_election_id) elif server_type in ( SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): topology_type, set_name = _update_rs_no_primary_from_member( sds, set_name, server_description) elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): sds.pop(address) topology_type = _check_has_primary(sds) elif server_type == SERVER_TYPE.RSPrimary: (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary(sds, set_name, server_description, max_set_version, max_election_id) elif server_type in ( SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): topology_type = _update_rs_with_primary_from_member( sds, set_name, server_description) else: # Server type is Unknown or RSGhost: did we just lose the primary? topology_type = _check_has_primary(sds) # Return updated copy. return TopologyDescription(topology_type, sds, set_name, max_set_version, max_election_id, topology_description._topology_settings) def _updated_topology_description_srv_polling(topology_description, seedlist): """Return an updated copy of a TopologyDescription. :Parameters: - `topology_description`: the current TopologyDescription - `seedlist`: a list of new seeds new ServerDescription that resulted from an ismaster call """ # Create a copy of the server descriptions. sds = topology_description.server_descriptions() # If seeds haven't changed, don't do anything. if set(sds.keys()) == set(seedlist): return topology_description # Add SDs corresponding to servers recently added to the SRV record. for address in seedlist: if address not in sds: sds[address] = ServerDescription(address) # Remove SDs corresponding to servers no longer part of the SRV record. for address in list(sds.keys()): if address not in seedlist: sds.pop(address) return TopologyDescription( topology_description.topology_type, sds, topology_description.replica_set_name, topology_description.max_set_version, topology_description.max_election_id, topology_description._topology_settings) def _update_rs_from_primary( sds, replica_set_name, server_description, max_set_version, max_election_id): """Update topology description from a primary's ismaster response. Pass in a dict of ServerDescriptions, current replica set name, the ServerDescription we are processing, and the TopologyDescription's max_set_version and max_election_id if any. Returns (new topology type, new replica_set_name, new max_set_version, new max_election_id). """ if replica_set_name is None: replica_set_name = server_description.replica_set_name elif replica_set_name != server_description.replica_set_name: # We found a primary but it doesn't have the replica_set_name # provided by the user. sds.pop(server_description.address) return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) max_election_tuple = max_set_version, max_election_id if None not in server_description.election_tuple: if (None not in max_election_tuple and max_election_tuple > server_description.election_tuple): # Stale primary, set to type Unknown. sds[server_description.address] = server_description.to_unknown() return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) max_election_id = server_description.election_id if (server_description.set_version is not None and (max_set_version is None or server_description.set_version > max_set_version)): max_set_version = server_description.set_version # We've heard from the primary. Is it the same primary as before? for server in sds.values(): if (server.server_type is SERVER_TYPE.RSPrimary and server.address != server_description.address): # Reset old primary's type to Unknown. sds[server.address] = server.to_unknown() # There can be only one prior primary. break # Discover new hosts from this primary's response. for new_address in server_description.all_hosts: if new_address not in sds: sds[new_address] = ServerDescription(new_address) # Remove hosts not in the response. for addr in set(sds) - server_description.all_hosts: sds.pop(addr) # If the host list differs from the seed list, we may not have a primary # after all. return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) def _update_rs_with_primary_from_member( sds, replica_set_name, server_description): """RS with known primary. Process a response from a non-primary. Pass in a dict of ServerDescriptions, current replica set name, and the ServerDescription we are processing. Returns new topology type. """ assert replica_set_name is not None if replica_set_name != server_description.replica_set_name: sds.pop(server_description.address) elif (server_description.me and server_description.address != server_description.me): sds.pop(server_description.address) # Had this member been the primary? return _check_has_primary(sds) def _update_rs_no_primary_from_member( sds, replica_set_name, server_description): """RS without known primary. Update from a non-primary's response. Pass in a dict of ServerDescriptions, current replica set name, and the ServerDescription we are processing. Returns (new topology type, new replica_set_name). """ topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary if replica_set_name is None: replica_set_name = server_description.replica_set_name elif replica_set_name != server_description.replica_set_name: sds.pop(server_description.address) return topology_type, replica_set_name # This isn't the primary's response, so don't remove any servers # it doesn't report. Only add new servers. for address in server_description.all_hosts: if address not in sds: sds[address] = ServerDescription(address) if (server_description.me and server_description.address != server_description.me): sds.pop(server_description.address) return topology_type, replica_set_name def _check_has_primary(sds): """Current topology type is ReplicaSetWithPrimary. Is primary still known? Pass in a dict of ServerDescriptions. Returns new topology type. """ for s in sds.values(): if s.server_type == SERVER_TYPE.RSPrimary: return TOPOLOGY_TYPE.ReplicaSetWithPrimary else: return TOPOLOGY_TYPE.ReplicaSetNoPrimary pymongo-3.11.0/pymongo/uri_parser.py000066400000000000000000000467671374256237000175360ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. """Tools to parse and validate a MongoDB URI.""" import re import warnings from bson.py3compat import string_type, PY3 if PY3: from urllib.parse import unquote_plus else: from urllib import unquote_plus from pymongo.common import ( get_validated_options, INTERNAL_URI_OPTION_NAME_MAP, URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary) from pymongo.errors import ConfigurationError, InvalidURI from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver SCHEME = 'mongodb://' SCHEME_LEN = len(SCHEME) SRV_SCHEME = 'mongodb+srv://' SRV_SCHEME_LEN = len(SRV_SCHEME) DEFAULT_PORT = 27017 def parse_userinfo(userinfo): """Validates the format of user information in a MongoDB URI. Reserved characters like ':', '/', '+' and '@' must be escaped following RFC 3986. Returns a 2-tuple containing the unescaped username followed by the unescaped password. :Paramaters: - `userinfo`: A string of the form : .. versionchanged:: 2.2 Now uses `urllib.unquote_plus` so `+` characters must be escaped. """ if '@' in userinfo or userinfo.count(':') > 1: if PY3: quote_fn = "urllib.parse.quote_plus" else: quote_fn = "urllib.quote_plus" raise InvalidURI("Username and password must be escaped according to " "RFC 3986, use %s()." % quote_fn) user, _, passwd = userinfo.partition(":") # No password is expected with GSSAPI authentication. if not user: raise InvalidURI("The empty string is not valid username.") return unquote_plus(user), unquote_plus(passwd) def parse_ipv6_literal_host(entity, default_port): """Validates an IPv6 literal host:port string. Returns a 2-tuple of IPv6 literal followed by port where port is default_port if it wasn't specified in entity. :Parameters: - `entity`: A string that represents an IPv6 literal enclosed in braces (e.g. '[::1]' or '[::1]:27017'). - `default_port`: The port number to use when one wasn't specified in entity. """ if entity.find(']') == -1: raise ValueError("an IPv6 address literal must be " "enclosed in '[' and ']' according " "to RFC 2732.") i = entity.find(']:') if i == -1: return entity[1:-1], default_port return entity[1: i], entity[i + 2:] def parse_host(entity, default_port=DEFAULT_PORT): """Validates a host string Returns a 2-tuple of host followed by port where port is default_port if it wasn't specified in the string. :Parameters: - `entity`: A host or host:port string where host could be a hostname or IP address. - `default_port`: The port number to use when one wasn't specified in entity. """ host = entity port = default_port if entity[0] == '[': host, port = parse_ipv6_literal_host(entity, default_port) elif entity.endswith(".sock"): return entity, default_port elif entity.find(':') != -1: if entity.count(':') > 1: raise ValueError("Reserved characters such as ':' must be " "escaped according RFC 2396. An IPv6 " "address literal must be enclosed in '[' " "and ']' according to RFC 2732.") host, port = host.split(':', 1) if isinstance(port, string_type): if not port.isdigit() or int(port) > 65535 or int(port) <= 0: raise ValueError("Port must be an integer between 0 and 65535: %s" % (port,)) port = int(port) # Normalize hostname to lowercase, since DNS is case-insensitive: # http://tools.ietf.org/html/rfc4343 # This prevents useless rediscovery if "foo.com" is in the seed list but # "FOO.com" is in the ismaster response. return host.lower(), port # Options whose values are implicitly determined by tlsInsecure. _IMPLICIT_TLSINSECURE_OPTS = { "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", "tlsdisableocspendpointcheck",} # Options that cannot be specified when tlsInsecure is also specified. _TLSINSECURE_EXCLUDE_OPTS = ( {k for k in _IMPLICIT_TLSINSECURE_OPTS} | {INTERNAL_URI_OPTION_NAME_MAP[k] for k in _IMPLICIT_TLSINSECURE_OPTS}) def _parse_options(opts, delim): """Helper method for split_options which creates the options dict. Also handles the creation of a list for the URI tag_sets/ readpreferencetags portion, and the use of a unicode options string.""" options = _CaseInsensitiveDictionary() for uriopt in opts.split(delim): key, value = uriopt.split("=") if key.lower() == 'readpreferencetags': options.setdefault(key, []).append(value) else: if key in options: warnings.warn("Duplicate URI option '%s'." % (key,)) if key.lower() == 'authmechanismproperties': val = value else: val = unquote_plus(value) options[key] = val return options def _handle_security_options(options): """Raise appropriate errors when conflicting TLS options are present in the options dictionary. :Parameters: - `options`: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ tlsinsecure = options.get('tlsinsecure') if tlsinsecure is not None: for opt in _TLSINSECURE_EXCLUDE_OPTS: if opt in options: err_msg = ("URI options %s and %s cannot be specified " "simultaneously.") raise InvalidURI(err_msg % ( options.cased_key('tlsinsecure'), options.cased_key(opt))) # Convenience function to retrieve option values based on public or private names. def _getopt(opt): return (options.get(opt) or options.get(INTERNAL_URI_OPTION_NAME_MAP[opt])) # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. tlsallowinvalidcerts = _getopt('tlsallowinvalidcertificates') if tlsallowinvalidcerts is not None: if 'tlsdisableocspendpointcheck' in options: err_msg = ("URI options %s and %s cannot be specified " "simultaneously.") raise InvalidURI(err_msg % ( 'tlsallowinvalidcertificates', options.cased_key( 'tlsdisableocspendpointcheck'))) if tlsallowinvalidcerts is True: options['tlsdisableocspendpointcheck'] = True # Handle co-occurence of CRL and OCSP-related options. tlscrlfile = _getopt('tlscrlfile') if tlscrlfile is not None: for opt in ('tlsinsecure', 'tlsallowinvalidcertificates', 'tlsdisableocspendpointcheck'): if options.get(opt) is True: err_msg = ("URI option %s=True cannot be specified when " "CRL checking is enabled.") raise InvalidURI(err_msg % (opt,)) if 'ssl' in options and 'tls' in options: def truth_value(val): if val in ('true', 'false'): return val == 'true' if isinstance(val, bool): return val return val if truth_value(options.get('ssl')) != truth_value(options.get('tls')): err_msg = ("Can not specify conflicting values for URI options %s " "and %s.") raise InvalidURI(err_msg % ( options.cased_key('ssl'), options.cased_key('tls'))) return options def _handle_option_deprecations(options): """Issue appropriate warnings when deprecated options are present in the options dictionary. Removes deprecated option key, value pairs if the options dictionary is found to also have the renamed option. :Parameters: - `options`: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ for optname in list(options): if optname in URI_OPTIONS_DEPRECATION_MAP: mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] if mode == 'renamed': newoptname = message if newoptname in options: warn_msg = ("Deprecated option '%s' ignored in favor of " "'%s'.") warnings.warn( warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), DeprecationWarning, stacklevel=2) options.pop(optname) continue warn_msg = "Option '%s' is deprecated, use '%s' instead." warnings.warn( warn_msg % (options.cased_key(optname), newoptname), DeprecationWarning, stacklevel=2) elif mode == 'removed': warn_msg = "Option '%s' is deprecated. %s." warnings.warn( warn_msg % (options.cased_key(optname), message), DeprecationWarning, stacklevel=2) return options def _normalize_options(options): """Normalizes option names in the options dictionary by converting them to their internally-used names. Also handles use of the tlsInsecure option. :Parameters: - `options`: Instance of _CaseInsensitiveDictionary containing MongoDB URI options. """ tlsinsecure = options.get('tlsinsecure') if tlsinsecure is not None: for opt in _IMPLICIT_TLSINSECURE_OPTS: intname = INTERNAL_URI_OPTION_NAME_MAP[opt] # Internal options are logical inverse of public options. options[intname] = not tlsinsecure for optname in list(options): intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) if intname is not None: options[intname] = options.pop(optname) return options def validate_options(opts, warn=False): """Validates and normalizes options passed in a MongoDB URI. Returns a new dictionary of validated and normalized options. If warn is False then errors will be thrown for invalid options, otherwise they will be ignored and a warning will be issued. :Parameters: - `opts`: A dict of MongoDB URI options. - `warn` (optional): If ``True`` then warnings will be logged and invalid options will be ignored. Otherwise invalid options will cause errors. """ return get_validated_options(opts, warn) def split_options(opts, validate=True, warn=False, normalize=True): """Takes the options portion of a MongoDB URI, validates each option and returns the options in a dictionary. :Parameters: - `opt`: A string representing MongoDB URI options. - `validate`: If ``True`` (the default), validate and normalize all options. - `warn`: If ``False`` (the default), suppress all warnings raised during validation of options. - `normalize`: If ``True`` (the default), renames all options to their internally-used names. """ and_idx = opts.find("&") semi_idx = opts.find(";") try: if and_idx >= 0 and semi_idx >= 0: raise InvalidURI("Can not mix '&' and ';' for option separators.") elif and_idx >= 0: options = _parse_options(opts, "&") elif semi_idx >= 0: options = _parse_options(opts, ";") elif opts.find("=") != -1: options = _parse_options(opts, None) else: raise ValueError except ValueError: raise InvalidURI("MongoDB URI options are key=value pairs.") options = _handle_security_options(options) options = _handle_option_deprecations(options) if validate: options = validate_options(options, warn) if options.get('authsource') == '': raise InvalidURI( "the authSource database cannot be an empty string") if normalize: options = _normalize_options(options) return options def split_hosts(hosts, default_port=DEFAULT_PORT): """Takes a string of the form host1[:port],host2[:port]... and splits it into (host, port) tuples. If [:port] isn't present the default_port is used. Returns a set of 2-tuples containing the host name (or IP) followed by port number. :Parameters: - `hosts`: A string of the form host1[:port],host2[:port],... - `default_port`: The port number to use when one wasn't specified for a host. """ nodes = [] for entity in hosts.split(','): if not entity: raise ConfigurationError("Empty host " "(or extra comma in host list).") port = default_port # Unix socket entities don't have ports if entity.endswith('.sock'): port = None nodes.append(parse_host(entity, port)) return nodes # Prohibited characters in database name. DB names also can't have ".", but for # backward-compat we allow "db.collection" in URI. _BAD_DB_CHARS = re.compile('[' + re.escape(r'/ "$') + ']') _ALLOWED_TXT_OPTS = frozenset( ['authsource', 'authSource', 'replicaset', 'replicaSet']) def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, normalize=True, connect_timeout=None): """Parse and validate a MongoDB URI. Returns a dict of the form:: { 'nodelist': , 'username': or None, 'password': or None, 'database': or None, 'collection': or None, 'options': , 'fqdn': or None } If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done to build nodelist and options. :Parameters: - `uri`: The MongoDB URI to parse. - `default_port`: The port number to use when one wasn't specified for a host in the URI. - `validate` (optional): If ``True`` (the default), validate and normalize all options. Default: ``True``. - `warn` (optional): When validating, if ``True`` then will warn the user then ignore any invalid options or values. If ``False``, validation will error when options are unsupported or values are invalid. Default: ``False``. - `normalize` (optional): If ``True``, convert names of URI options to their internally-used names. Default: ``True``. - `connect_timeout` (optional): The maximum time in milliseconds to wait for a response from the DNS server. .. versionchanged:: 3.9 Added the ``normalize`` parameter. .. versionchanged:: 3.6 Added support for mongodb+srv:// URIs. .. versionchanged:: 3.5 Return the original value of the ``readPreference`` MongoDB URI option instead of the validated read preference mode. .. versionchanged:: 3.1 ``warn`` added so invalid options can be ignored. """ if uri.startswith(SCHEME): is_srv = False scheme_free = uri[SCHEME_LEN:] elif uri.startswith(SRV_SCHEME): if not _HAVE_DNSPYTHON: raise ConfigurationError('The "dnspython" module must be ' 'installed to use mongodb+srv:// URIs') is_srv = True scheme_free = uri[SRV_SCHEME_LEN:] else: raise InvalidURI("Invalid URI scheme: URI must " "begin with '%s' or '%s'" % (SCHEME, SRV_SCHEME)) if not scheme_free: raise InvalidURI("Must provide at least one hostname or IP.") user = None passwd = None dbase = None collection = None options = _CaseInsensitiveDictionary() host_part, _, path_part = scheme_free.partition('/') if not host_part: host_part = path_part path_part = "" if not path_part and '?' in host_part: raise InvalidURI("A '/' is required between " "the host list and any options.") if path_part: dbase, _, opts = path_part.partition('?') if dbase: dbase = unquote_plus(dbase) if '.' in dbase: dbase, collection = dbase.split('.', 1) if _BAD_DB_CHARS.search(dbase): raise InvalidURI('Bad database name "%s"' % dbase) else: dbase = None if opts: options.update(split_options(opts, validate, warn, normalize)) if '@' in host_part: userinfo, _, hosts = host_part.rpartition('@') user, passwd = parse_userinfo(userinfo) else: hosts = host_part if '/' in hosts: raise InvalidURI("Any '/' in a unix domain socket must be" " percent-encoded: %s" % host_part) hosts = unquote_plus(hosts) fqdn = None if is_srv: if options.get('directConnection'): raise ConfigurationError( "Cannot specify directConnection=true with " "%s URIs" % (SRV_SCHEME,)) nodes = split_hosts(hosts, default_port=None) if len(nodes) != 1: raise InvalidURI( "%s URIs must include one, " "and only one, hostname" % (SRV_SCHEME,)) fqdn, port = nodes[0] if port is not None: raise InvalidURI( "%s URIs must not include a port number" % (SRV_SCHEME,)) # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") dns_resolver = _SrvResolver(fqdn, connect_timeout=connect_timeout) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: parsed_dns_options = split_options( dns_options, validate, warn, normalize) if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: raise ConfigurationError( "Only authSource and replicaSet are supported from DNS") for opt, val in parsed_dns_options.items(): if opt not in options: options[opt] = val if "ssl" not in options: options["ssl"] = True if validate else 'true' else: nodes = split_hosts(hosts, default_port=default_port) if len(nodes) > 1 and options.get('directConnection'): raise ConfigurationError( "Cannot specify multiple hosts with directConnection=true") return { 'nodelist': nodes, 'username': user, 'password': passwd, 'database': dbase, 'collection': collection, 'options': options, 'fqdn': fqdn } if __name__ == '__main__': import pprint import sys try: pprint.pprint(parse_uri(sys.argv[1])) except InvalidURI as exc: print(exc) sys.exit(0)pymongo-3.11.0/pymongo/write_concern.py000066400000000000000000000116101374256237000201770ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for working with write concerns.""" from bson.py3compat import integer_types, string_type from pymongo.errors import ConfigurationError class WriteConcern(object): """WriteConcern :Parameters: - `w`: (integer or string) Used with replication, write operations will block until they have been replicated to the specified number or tagged set of servers. `w=` always includes the replica set primary (e.g. w=3 means write to the primary and wait until replicated to **two** secondaries). **w=0 disables acknowledgement of write operations and can not be used with other write concern options.** - `wtimeout`: (integer) Used in conjunction with `w`. Specify a value in milliseconds to control how long to wait for write propagation to complete. If replication does not complete in the given timeframe, a timeout exception is raised. - `j`: If ``True`` block until write operations have been committed to the journal. Cannot be used in combination with `fsync`. Prior to MongoDB 2.6 this option was ignored if the server was running without journaling. Starting with MongoDB 2.6 write operations will fail with an exception if this option is used when the server is running without journaling. - `fsync`: If ``True`` and the server is running without journaling, blocks until the server has synced all data files to disk. If the server is running with journaling, this acts the same as the `j` option, blocking until write operations have been committed to the journal. Cannot be used in combination with `j`. """ __slots__ = ("__document", "__acknowledged", "__server_default") def __init__(self, w=None, wtimeout=None, j=None, fsync=None): self.__document = {} self.__acknowledged = True if wtimeout is not None: if not isinstance(wtimeout, integer_types): raise TypeError("wtimeout must be an integer") if wtimeout < 0: raise ValueError("wtimeout cannot be less than 0") self.__document["wtimeout"] = wtimeout if j is not None: if not isinstance(j, bool): raise TypeError("j must be True or False") self.__document["j"] = j if fsync is not None: if not isinstance(fsync, bool): raise TypeError("fsync must be True or False") if j and fsync: raise ConfigurationError("Can't set both j " "and fsync at the same time") self.__document["fsync"] = fsync if w == 0 and j is True: raise ConfigurationError("Cannot set w to 0 and j to True") if w is not None: if isinstance(w, integer_types): if w < 0: raise ValueError("w cannot be less than 0") self.__acknowledged = w > 0 elif not isinstance(w, string_type): raise TypeError("w must be an integer or string") self.__document["w"] = w self.__server_default = not self.__document @property def is_server_default(self): """Does this WriteConcern match the server default.""" return self.__server_default @property def document(self): """The document representation of this write concern. .. note:: :class:`WriteConcern` is immutable. Mutating the value of :attr:`document` does not mutate this :class:`WriteConcern`. """ return self.__document.copy() @property def acknowledged(self): """If ``True`` write operations will wait for acknowledgement before returning. """ return self.__acknowledged def __repr__(self): return ("WriteConcern(%s)" % ( ", ".join("%s=%s" % kvt for kvt in self.__document.items()),)) def __eq__(self, other): if isinstance(other, WriteConcern): return self.__document == other.document return NotImplemented def __ne__(self, other): if isinstance(other, WriteConcern): return self.__document != other.document return NotImplemented DEFAULT_WRITE_CONCERN = WriteConcern() pymongo-3.11.0/setup.cfg000066400000000000000000000000461374256237000151160ustar00rootroot00000000000000[egg_info] tag_build = tag_date = 0 pymongo-3.11.0/setup.py000077500000000000000000000375731374256237000150310ustar00rootroot00000000000000import os import platform import re import sys import warnings if sys.version_info[:2] < (2, 7): raise RuntimeError("Python version >= 2.7 required.") # Hack to silence atexit traceback in some Python versions try: import multiprocessing except ImportError: pass # Don't force people to install setuptools unless # we have to. try: from setuptools import setup, __version__ as _setuptools_version except ImportError: from ez_setup import use_setuptools use_setuptools() from setuptools import setup, __version__ as _setuptools_version from distutils.cmd import Command from distutils.command.build_ext import build_ext from distutils.errors import CCompilerError, DistutilsOptionError from distutils.errors import DistutilsPlatformError, DistutilsExecError from distutils.core import Extension _HAVE_SPHINX = True try: from sphinx.cmd import build as sphinx except ImportError: try: import sphinx except ImportError: _HAVE_SPHINX = False version = "3.11.0" f = open("README.rst") try: try: readme_content = f.read() except: readme_content = "" finally: f.close() # PYTHON-654 - Clang doesn't support -mno-fused-madd but the pythons Apple # ships are built with it. This is a problem starting with Xcode 5.1 # since clang 3.4 errors out when it encounters unrecognized compiler # flags. This hack removes -mno-fused-madd from the CFLAGS automatically # generated by distutils for Apple provided pythons, allowing C extension # builds to complete without error. The inspiration comes from older # versions of distutils.sysconfig.get_config_vars. if sys.platform == 'darwin' and 'clang' in platform.python_compiler().lower(): from distutils.sysconfig import get_config_vars res = get_config_vars() for key in ('CFLAGS', 'PY_CFLAGS'): if key in res: flags = res[key] flags = re.sub('-mno-fused-madd', '', flags) res[key] = flags class test(Command): description = "run the tests" user_options = [ ("test-module=", "m", "Discover tests in specified module"), ("test-suite=", "s", "Test suite to run (e.g. 'some_module.test_suite')"), ("failfast", "f", "Stop running tests on first failure or error"), ("xunit-output=", "x", "Generate a results directory with XUnit XML format") ] def initialize_options(self): self.test_module = None self.test_suite = None self.failfast = False self.xunit_output = None def finalize_options(self): if self.test_suite is None and self.test_module is None: self.test_module = 'test' elif self.test_module is not None and self.test_suite is not None: raise DistutilsOptionError( "You may specify a module or suite, but not both" ) def run(self): # Installing required packages, running egg_info and build_ext are # part of normal operation for setuptools.command.test.test if self.distribution.install_requires: self.distribution.fetch_build_eggs( self.distribution.install_requires) if self.distribution.tests_require: self.distribution.fetch_build_eggs(self.distribution.tests_require) if self.xunit_output: self.distribution.fetch_build_eggs(["unittest-xml-reporting"]) self.run_command('egg_info') build_ext_cmd = self.reinitialize_command('build_ext') build_ext_cmd.inplace = 1 self.run_command('build_ext') # Construct a TextTestRunner directly from the unittest imported from # test, which creates a TestResult that supports the 'addSkip' method. # setuptools will by default create a TextTestRunner that uses the old # TestResult class. from test import unittest, PymongoTestRunner, test_cases if self.test_suite is None: all_tests = unittest.defaultTestLoader.discover(self.test_module) suite = unittest.TestSuite() suite.addTests(sorted(test_cases(all_tests), key=lambda x: x.__module__)) else: suite = unittest.defaultTestLoader.loadTestsFromName( self.test_suite) if self.xunit_output: from test import PymongoXMLTestRunner runner = PymongoXMLTestRunner(verbosity=2, failfast=self.failfast, output=self.xunit_output) else: runner = PymongoTestRunner(verbosity=2, failfast=self.failfast) result = runner.run(suite) sys.exit(not result.wasSuccessful()) class doc(Command): description = "generate or test documentation" user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] def initialize_options(self): self.test = False def finalize_options(self): pass def run(self): if not _HAVE_SPHINX: raise RuntimeError( "You must install Sphinx to build or test the documentation.") if sys.version_info[0] >= 3: import doctest from doctest import OutputChecker as _OutputChecker # Match u or U (possibly followed by r or R), removing it. # r/R can follow u/U but not precede it. Don't match the # single character string 'u' or 'U'. _u_literal_re = re.compile( r"(\W|^)(?=17.2.0", "requests<3.0.0", "service_identity>=18.1.0"] extras_require = { 'encryption': ['pymongocrypt<2.0.0'], 'ocsp': pyopenssl_reqs, 'snappy': ['python-snappy'], 'tls': [], 'zstd': ['zstandard'], 'aws': ['pymongo-auth-aws<2.0.0'], } # https://jira.mongodb.org/browse/PYTHON-2117 # Environment marker support didn't settle down until version 20.10 # https://setuptools.readthedocs.io/en/latest/history.html#v20-10-0 _use_env_markers = tuple(map(int, _setuptools_version.split('.')[:2])) > (20, 9) # TLS and DNS extras # We install PyOpenSSL and service_identity for Python < 2.7.9 to # get support for SNI, which is required to connection to Altas # free and shared tier. if sys.version_info[0] == 2: if _use_env_markers: # For building wheels on Python versions >= 2.7.9 for req in pyopenssl_reqs: extras_require['tls'].append( "%s ; python_full_version < '2.7.9'" % (req,)) if sys.platform == 'win32': extras_require['tls'].append( "wincertstore>=0.2 ; python_full_version < '2.7.9'") else: extras_require['tls'].append( "certifi ; python_full_version < '2.7.9'") elif sys.version_info < (2, 7, 9): # For installing from source or egg files on Python versions # older than 2.7.9, or systems that have setuptools versions # older than 20.10. extras_require['tls'].extend(pyopenssl_reqs) if sys.platform == 'win32': extras_require['tls'].append("wincertstore>=0.2") else: extras_require['tls'].append("certifi") extras_require.update({'srv': ["dnspython>=1.16.0,<1.17.0"]}) extras_require.update({'tls': ["ipaddress"]}) else: extras_require.update({'srv': ["dnspython>=1.16.0,<2.0.0"]}) # GSSAPI extras if sys.platform == 'win32': extras_require['gssapi'] = ["winkerberos>=0.5.0"] else: extras_require['gssapi'] = ["pykerberos"] extra_opts = { "packages": ["bson", "pymongo", "gridfs"] } if "--no_ext" in sys.argv: sys.argv.remove("--no_ext") elif (sys.platform.startswith("java") or sys.platform == "cli" or "PyPy" in sys.version): sys.stdout.write(""" *****************************************************\n The optional C extensions are currently not supported\n by this python implementation.\n *****************************************************\n """) else: extra_opts['ext_modules'] = ext_modules setup( name="pymongo", version=version, description="Python driver for MongoDB ", long_description=readme_content, author="Mike Dirolf", author_email="mongodb-user@googlegroups.com", maintainer="Bernie Hackett", maintainer_email="bernie@mongodb.com", url="http://github.com/mongodb/mongo-python-driver", keywords=["mongo", "mongodb", "pymongo", "gridfs", "bson"], install_requires=[], license="Apache License, Version 2.0", python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*", classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "License :: OSI Approved :: Apache Software License", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Database"], cmdclass={"build_ext": custom_build_ext, "doc": doc, "test": test}, extras_require=extras_require, **extra_opts ) pymongo-3.11.0/test/000077500000000000000000000000001374256237000142545ustar00rootroot00000000000000pymongo-3.11.0/test/__init__.py000066400000000000000000000774651374256237000164100ustar00rootroot00000000000000# Copyright 2010-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test suite for pymongo, bson, and gridfs. """ import gc import os import socket import sys import threading import time import unittest import warnings try: from xmlrunner import XMLTestRunner HAVE_XML = True # ValueError is raised when version 3+ is installed on Jython 2.7. except (ImportError, ValueError): HAVE_XML = False try: import ipaddress HAVE_IPADDRESS = True except ImportError: HAVE_IPADDRESS = False from contextlib import contextmanager from functools import wraps from unittest import SkipTest import pymongo import pymongo.errors from bson.son import SON from pymongo import common, message from pymongo.common import partition_node from pymongo.ssl_support import HAVE_SSL, validate_cert_reqs from test.version import Version if HAVE_SSL: import ssl try: # Enable the fault handler to dump the traceback of each running thread # after a segfault. import faulthandler faulthandler.enable() except ImportError: pass # Enable debug output for uncollectable objects. PyPy does not have set_debug. if hasattr(gc, 'set_debug'): gc.set_debug( gc.DEBUG_UNCOLLECTABLE | getattr(gc, 'DEBUG_OBJECTS', 0) | getattr(gc, 'DEBUG_INSTANCES', 0)) # The host and port of a single mongod or mongos, or the seed host # for a replica set. host = os.environ.get("DB_IP", 'localhost') port = int(os.environ.get("DB_PORT", 27017)) db_user = os.environ.get("DB_USER", "user") db_pwd = os.environ.get("DB_PASSWORD", "password") CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'certificates') CLIENT_PEM = os.environ.get('CLIENT_PEM', os.path.join(CERT_PATH, 'client.pem')) CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem')) TLS_OPTIONS = dict(tls=True) if CLIENT_PEM: TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM if CA_PEM: TLS_OPTIONS['tlsCAFile'] = CA_PEM COMPRESSORS = os.environ.get("COMPRESSORS") def is_server_resolvable(): """Returns True if 'server' is resolvable.""" socket_timeout = socket.getdefaulttimeout() socket.setdefaulttimeout(1) try: try: socket.gethostbyname('server') return True except socket.error: return False finally: socket.setdefaulttimeout(socket_timeout) def _create_user(authdb, user, pwd=None, roles=None, **kwargs): cmd = SON([('createUser', user)]) # X509 doesn't use a password if pwd: cmd['pwd'] = pwd cmd['roles'] = roles or ['root'] cmd.update(**kwargs) return authdb.command(cmd) class client_knobs(object): def __init__( self, heartbeat_frequency=None, min_heartbeat_interval=None, kill_cursor_frequency=None, events_queue_frequency=None): self.heartbeat_frequency = heartbeat_frequency self.min_heartbeat_interval = min_heartbeat_interval self.kill_cursor_frequency = kill_cursor_frequency self.events_queue_frequency = events_queue_frequency self.old_heartbeat_frequency = None self.old_min_heartbeat_interval = None self.old_kill_cursor_frequency = None self.old_events_queue_frequency = None def enable(self): self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY if self.heartbeat_frequency is not None: common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency if self.min_heartbeat_interval is not None: common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval if self.kill_cursor_frequency is not None: common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency if self.events_queue_frequency is not None: common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency def __enter__(self): self.enable() def disable(self): common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency def __exit__(self, exc_type, exc_val, exc_tb): self.disable() def _all_users(db): return set(u['user'] for u in db.command('usersInfo').get('users', [])) class ClientContext(object): def __init__(self): """Create a client and grab essential information from the server.""" self.connection_attempts = [] self.connected = False self.w = None self.nodes = set() self.replica_set_name = None self.cmd_line = None self.server_status = None self.version = Version(-1) # Needs to be comparable with Version self.auth_enabled = False self.test_commands_enabled = False self.is_mongos = False self.mongoses = [] self.is_rs = False self.has_ipv6 = False self.tls = False self.ssl_certfile = False self.server_is_resolvable = is_server_resolvable() self.default_client_options = {} self.sessions_enabled = False self.client = None self.conn_lock = threading.Lock() if COMPRESSORS: self.default_client_options["compressors"] = COMPRESSORS @property def ismaster(self): return self.client.admin.command('isMaster') def _connect(self, host, port, **kwargs): # Jython takes a long time to connect. if sys.platform.startswith('java'): timeout_ms = 10000 else: timeout_ms = 5000 if COMPRESSORS: kwargs["compressors"] = COMPRESSORS client = pymongo.MongoClient( host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) try: try: client.admin.command('isMaster') # Can we connect? except pymongo.errors.OperationFailure as exc: # SERVER-32063 self.connection_attempts.append( 'connected client %r, but isMaster failed: %s' % ( client, exc)) else: self.connection_attempts.append( 'successfully connected client %r' % (client,)) # If connected, then return client with default timeout return pymongo.MongoClient(host, port, **kwargs) except pymongo.errors.ConnectionFailure as exc: self.connection_attempts.append( 'failed to connect client %r: %s' % (client, exc)) return None finally: client.close() def _init_client(self): self.client = self._connect(host, port) if HAVE_SSL and not self.client: # Is MongoDB configured for SSL? self.client = self._connect(host, port, **TLS_OPTIONS) if self.client: self.tls = True self.default_client_options.update(TLS_OPTIONS) self.ssl_certfile = True if self.client: self.connected = True try: self.cmd_line = self.client.admin.command('getCmdLineOpts') except pymongo.errors.OperationFailure as e: msg = e.details.get('errmsg', '') if e.code == 13 or 'unauthorized' in msg or 'login' in msg: # Unauthorized. self.auth_enabled = True else: raise else: self.auth_enabled = self._server_started_with_auth() if self.auth_enabled: # See if db_user already exists. if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) self.client = self._connect( host, port, username=db_user, password=db_pwd, replicaSet=self.replica_set_name, **self.default_client_options) # May not have this if OperationFailure was raised earlier. self.cmd_line = self.client.admin.command('getCmdLineOpts') self.server_status = self.client.admin.command('serverStatus') if self.storage_engine == "mmapv1": # MMAPv1 does not support retryWrites=True. self.default_client_options['retryWrites'] = False ismaster = self.ismaster self.sessions_enabled = 'logicalSessionTimeoutMinutes' in ismaster if 'setName' in ismaster: self.replica_set_name = str(ismaster['setName']) self.is_rs = True if self.auth_enabled: # It doesn't matter which member we use as the seed here. self.client = pymongo.MongoClient( host, port, username=db_user, password=db_pwd, replicaSet=self.replica_set_name, **self.default_client_options) else: self.client = pymongo.MongoClient( host, port, replicaSet=self.replica_set_name, **self.default_client_options) # Get the authoritative ismaster result from the primary. ismaster = self.ismaster nodes = [partition_node(node.lower()) for node in ismaster.get('hosts', [])] nodes.extend([partition_node(node.lower()) for node in ismaster.get('passives', [])]) nodes.extend([partition_node(node.lower()) for node in ismaster.get('arbiters', [])]) self.nodes = set(nodes) else: self.nodes = set([(host, port)]) self.w = len(ismaster.get("hosts", [])) or 1 self.version = Version.from_client(self.client) if 'enableTestCommands=1' in self.cmd_line['argv']: self.test_commands_enabled = True elif 'parsed' in self.cmd_line: params = self.cmd_line['parsed'].get('setParameter', []) if 'enableTestCommands=1' in params: self.test_commands_enabled = True else: params = self.cmd_line['parsed'].get('setParameter', {}) if params.get('enableTestCommands') == '1': self.test_commands_enabled = True self.is_mongos = (self.ismaster.get('msg') == 'isdbgrid') self.has_ipv6 = self._server_started_with_ipv6() if self.is_mongos: # Check for another mongos on the next port. address = self.client.address next_address = address[0], address[1] + 1 self.mongoses.append(address) mongos_client = self._connect(*next_address, **self.default_client_options) if mongos_client: ismaster = mongos_client.admin.command('ismaster') if ismaster.get('msg') == 'isdbgrid': self.mongoses.append(next_address) def init(self): with self.conn_lock: if not self.client and not self.connection_attempts: self._init_client() def connection_attempt_info(self): return '\n'.join(self.connection_attempts) @property def host(self): if self.is_rs: primary = self.client.primary return str(primary[0]) if primary is not None else host return host @property def port(self): if self.is_rs: primary = self.client.primary return primary[1] if primary is not None else port return port @property def pair(self): return "%s:%d" % (self.host, self.port) @property def has_secondaries(self): if not self.client: return False return bool(len(self.client.secondaries)) @property def storage_engine(self): try: return self.server_status.get("storageEngine", {}).get("name") except AttributeError: # Raised if self.server_status is None. return None def _check_user_provided(self): """Return True if db_user/db_password is already an admin user.""" client = pymongo.MongoClient( host, port, username=db_user, password=db_pwd, serverSelectionTimeoutMS=100, **self.default_client_options) try: return db_user in _all_users(client.admin) except pymongo.errors.OperationFailure as e: msg = e.details.get('errmsg', '') if e.code == 18 or 'auth fails' in msg: # Auth failed. return False else: raise def _server_started_with_auth(self): # MongoDB >= 2.0 if 'parsed' in self.cmd_line: parsed = self.cmd_line['parsed'] # MongoDB >= 2.6 if 'security' in parsed: security = parsed['security'] # >= rc3 if 'authorization' in security: return security['authorization'] == 'enabled' # < rc3 return (security.get('auth', False) or bool(security.get('keyFile'))) return parsed.get('auth', False) or bool(parsed.get('keyFile')) # Legacy argv = self.cmd_line['argv'] return '--auth' in argv or '--keyFile' in argv def _server_started_with_ipv6(self): if not socket.has_ipv6: return False if 'parsed' in self.cmd_line: if not self.cmd_line['parsed'].get('net', {}).get('ipv6'): return False else: if '--ipv6' not in self.cmd_line['argv']: return False # The server was started with --ipv6. Is there an IPv6 route to it? try: for info in socket.getaddrinfo(self.host, self.port): if info[0] == socket.AF_INET6: return True except socket.error: pass return False def _require(self, condition, msg, func=None): def make_wrapper(f): @wraps(f) def wrap(*args, **kwargs): self.init() # Always raise SkipTest if we can't connect to MongoDB if not self.connected: raise SkipTest( "Cannot connect to MongoDB on %s" % (self.pair,)) if condition(): return f(*args, **kwargs) raise SkipTest(msg) return wrap if func is None: def decorate(f): return make_wrapper(f) return decorate return make_wrapper(func) def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): kwargs['writeConcern'] = {'w': self.w} return _create_user(self.client[dbname], user, pwd, roles, **kwargs) def drop_user(self, dbname, user): self.client[dbname].command( 'dropUser', user, writeConcern={'w': self.w}) def require_connection(self, func): """Run a test only if we can connect to MongoDB.""" return self._require( lambda: True, # _require checks if we're connected "Cannot connect to MongoDB on %s" % (self.pair,), func=func) def require_no_mmap(self, func): """Run a test only if the server is not using the MMAPv1 storage engine. Only works for standalone and replica sets; tests are run regardless of storage engine on sharded clusters. """ def is_not_mmap(): if self.is_mongos: return True return self.storage_engine != 'mmapv1' return self._require( is_not_mmap, "Storage engine must not be MMAPv1", func=func) def require_version_min(self, *ver): """Run a test only if the server version is at least ``version``.""" other_version = Version(*ver) return self._require(lambda: self.version >= other_version, "Server version must be at least %s" % str(other_version)) def require_version_max(self, *ver): """Run a test only if the server version is at most ``version``.""" other_version = Version(*ver) return self._require(lambda: self.version <= other_version, "Server version must be at most %s" % str(other_version)) def require_auth(self, func): """Run a test only if the server is running with auth enabled.""" return self.check_auth_with_sharding( self._require(lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func)) def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require(lambda: not self.auth_enabled, "Authentication must not be enabled on the server", func=func) def require_replica_set(self, func): """Run a test only if the client is connected to a replica set.""" return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) def require_secondaries_count(self, count): """Run a test only if the client is connected to a replica set that has `count` secondaries. """ def sec_count(): return 0 if not self.client else len(self.client.secondaries) return self._require(lambda: sec_count() >= count, "Not enough secondaries available") def require_no_replica_set(self, func): """Run a test if the client is *not* connected to a replica set.""" return self._require( lambda: not self.is_rs, "Connected to a replica set, not a standalone mongod", func=func) def require_ipv6(self, func): """Run a test only if the client can connect to a server via IPv6.""" return self._require(lambda: self.has_ipv6, "No IPv6", func=func) def require_no_mongos(self, func): """Run a test only if the client is not connected to a mongos.""" return self._require(lambda: not self.is_mongos, "Must be connected to a mongod, not a mongos", func=func) def require_mongos(self, func): """Run a test only if the client is connected to a mongos.""" return self._require(lambda: self.is_mongos, "Must be connected to a mongos", func=func) def require_multiple_mongoses(self, func): """Run a test only if the client is connected to a sharded cluster that has 2 mongos nodes.""" return self._require(lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func) def require_standalone(self, func): """Run a test only if the client is connected to a standalone.""" return self._require(lambda: not (self.is_mongos or self.is_rs), "Must be connected to a standalone", func=func) def require_no_standalone(self, func): """Run a test only if the client is not connected to a standalone.""" return self._require(lambda: self.is_mongos or self.is_rs, "Must be connected to a replica set or mongos", func=func) def check_auth_with_sharding(self, func): """Skip a test when connected to mongos < 2.0 and running with auth.""" condition = lambda: not (self.auth_enabled and self.is_mongos and self.version < (2,)) return self._require(condition, "Auth with sharding requires MongoDB >= 2.0.0", func=func) def is_topology_type(self, topologies): if 'single' in topologies and not (self.is_mongos or self.is_rs): return True if 'replicaset' in topologies and self.is_rs: return True if 'sharded' in topologies and self.is_mongos: return True return False def require_cluster_type(self, topologies=[]): """Run a test only if the client is connected to a cluster that conforms to one of the specified topologies. Acceptable topologies are 'single', 'replicaset', and 'sharded'.""" def _is_valid_topology(): return self.is_topology_type(topologies) return self._require( _is_valid_topology, "Cluster type not in %s" % (topologies)) def require_test_commands(self, func): """Run a test only if the server has test commands enabled.""" return self._require(lambda: self.test_commands_enabled, "Test commands must be enabled", func=func) def require_failCommand_fail_point(self, func): """Run a test only if the server supports the failCommand fail point.""" return self._require(lambda: self.supports_failCommand_fail_point, "failCommand fail point must be supported", func=func) def require_failCommand_appName(self, func): """Run a test only if the server supports the failCommand appName.""" # SERVER-47195 return self._require(lambda: (self.test_commands_enabled and self.version >= (4, 4, -1)), "failCommand appName must be supported", func=func) def require_tls(self, func): """Run a test only if the client can connect over TLS.""" return self._require(lambda: self.tls, "Must be able to connect via TLS", func=func) def require_no_tls(self, func): """Run a test only if the client can connect over TLS.""" return self._require(lambda: not self.tls, "Must be able to connect without TLS", func=func) def require_ssl_certfile(self, func): """Run a test only if the client can connect with ssl_certfile.""" return self._require(lambda: self.ssl_certfile, "Must be able to connect with ssl_certfile", func=func) def require_server_resolvable(self, func): """Run a test only if the hostname 'server' is resolvable.""" return self._require(lambda: self.server_is_resolvable, "No hosts entry for 'server'. Cannot validate " "hostname in the certificate", func=func) def require_sessions(self, func): """Run a test only if the deployment supports sessions.""" return self._require(lambda: self.sessions_enabled, "Sessions not supported", func=func) def supports_transactions(self): if self.storage_engine == 'mmapv1': return False if self.version.at_least(4, 1, 8): return self.is_mongos or self.is_rs if self.version.at_least(4, 0): return self.is_rs return False def require_transactions(self, func): """Run a test only if the deployment might support transactions. *Might* because this does not test the storage engine or FCV. """ return self._require(self.supports_transactions, "Transactions are not supported", func=func) def mongos_seeds(self): return ','.join('%s:%s' % address for address in self.mongoses) @property def supports_reindex(self): """Does the connected server support reindex?""" return not ((self.version.at_least(4, 1, 0) and self.is_mongos) or (self.version.at_least(4, 5, 0) and ( self.is_mongos or self.is_rs))) @property def supports_getpreverror(self): """Does the connected server support getpreverror?""" return not (self.version.at_least(4, 1, 0) or self.is_mongos) @property def supports_failCommand_fail_point(self): """Does the server support the failCommand fail point?""" if self.is_mongos: return (self.version.at_least(4, 1, 5) and self.test_commands_enabled) else: return (self.version.at_least(4, 0) and self.test_commands_enabled) @property def requires_hint_with_min_max_queries(self): """Does the server require a hint with min/max queries.""" # Changed in SERVER-39567. return self.version.at_least(4, 1, 10) # Reusable client context client_context = ClientContext() def sanitize_cmd(cmd): cp = cmd.copy() cp.pop('$clusterTime', None) cp.pop('$db', None) cp.pop('$readPreference', None) cp.pop('lsid', None) # OP_MSG encoding may move the payload type one field to the # end of the command. Do the same here. name = next(iter(cp)) try: identifier = message._FIELD_MAP[name] docs = cp.pop(identifier) cp[identifier] = docs except KeyError: pass return cp def sanitize_reply(reply): cp = reply.copy() cp.pop('$clusterTime', None) cp.pop('operationTime', None) return cp class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) def assertEqualReply(self, expected, actual, msg=None): self.assertEqual(sanitize_reply(expected), sanitize_reply(actual), msg) @contextmanager def fail_point(self, command_args): cmd_on = SON([('configureFailPoint', 'failCommand')]) cmd_on.update(command_args) client_context.client.admin.command(cmd_on) try: yield finally: client_context.client.admin.command( 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') class IntegrationTest(PyMongoTestCase): """Base class for TestCases that need a connection to MongoDB to pass.""" @classmethod @client_context.require_connection def setUpClass(cls): cls.client = client_context.client cls.db = cls.client.pymongo_test if client_context.auth_enabled: cls.credentials = {'username': db_user, 'password': db_pwd} else: cls.credentials = {} # Use assertRaisesRegex if available, otherwise use Python 2.7's # deprecated assertRaisesRegexp, with a 'p'. if not hasattr(unittest.TestCase, 'assertRaisesRegex'): unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp class MockClientTest(unittest.TestCase): """Base class for TestCases that use MockClient. This class is *not* an IntegrationTest: if properly written, MockClient tests do not require a running server. The class temporarily overrides HEARTBEAT_FREQUENCY to speed up tests. """ def setUp(self): super(MockClientTest, self).setUp() self.client_knobs = client_knobs( heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs.enable() def tearDown(self): self.client_knobs.disable() super(MockClientTest, self).tearDown() def setup(): client_context.init() warnings.resetwarnings() warnings.simplefilter("always") def _get_executors(topology): executors = [] for server in topology._servers.values(): # Some MockMonitor do not have an _executor. if hasattr(server._monitor, '_executor'): executors.append(server._monitor._executor) if hasattr(server._monitor, '_rtt_monitor'): executors.append(server._monitor._rtt_monitor._executor) executors.append(topology._Topology__events_executor) if topology._srv_monitor: executors.append(topology._srv_monitor._executor) return [e for e in executors if e is not None] def all_executors_stopped(topology): running = [e for e in _get_executors(topology) if not e._stopped] if running: print(' Topology %s has THREADS RUNNING: %s, created at: %s' % ( topology, running, topology._settings._stack)) return False return True def print_unclosed_clients(): from pymongo.topology import Topology processed = set() # Call collect to manually cleanup any would-be gc'd clients to avoid # false positives. gc.collect() for obj in gc.get_objects(): try: if isinstance(obj, Topology): # Avoid printing the same Topology multiple times. if obj._topology_id in processed: continue all_executors_stopped(obj) processed.add(obj._topology_id) except ReferenceError: pass def teardown(): garbage = [] for g in gc.garbage: garbage.append('GARBAGE: %r' % (g,)) garbage.append(' gc.get_referents: %r' % (gc.get_referents(g),)) garbage.append(' gc.get_referrers: %r' % (gc.get_referrers(g),)) if garbage: assert False, '\n'.join(garbage) c = client_context.client if c: c.drop_database("pymongo-pooling-tests") c.drop_database("pymongo_test") c.drop_database("pymongo_test1") c.drop_database("pymongo_test2") c.drop_database("pymongo_test_mike") c.drop_database("pymongo_test_bernie") c.close() # Jython does not support gc.get_objects. if not sys.platform.startswith('java'): print_unclosed_clients() class PymongoTestRunner(unittest.TextTestRunner): def run(self, test): setup() result = super(PymongoTestRunner, self).run(test) teardown() return result if HAVE_XML: class PymongoXMLTestRunner(XMLTestRunner): def run(self, test): setup() result = super(PymongoXMLTestRunner, self).run(test) teardown() return result def test_cases(suite): """Iterator over all TestCases within a TestSuite.""" for suite_or_case in suite._tests: if isinstance(suite_or_case, unittest.TestCase): # unittest.TestCase yield suite_or_case else: # unittest.TestSuite for case in test_cases(suite_or_case): yield case # Helper method to workaround https://bugs.python.org/issue21724 def clear_warning_registry(): """Clear the __warningregistry__ for all modules.""" for name, module in list(sys.modules.items()): if hasattr(module, "__warningregistry__"): setattr(module, "__warningregistry__", {}) pymongo-3.11.0/test/atlas/000077500000000000000000000000001374256237000153605ustar00rootroot00000000000000pymongo-3.11.0/test/atlas/test_connection.py000066400000000000000000000033571374256237000211400ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test connections to various Atlas cluster types.""" import os import sys import unittest sys.path[0:0] = [""] import pymongo from pymongo.ssl_support import HAS_SNI _REPL = os.environ.get("ATLAS_REPL") _SHRD = os.environ.get("ATLAS_SHRD") _FREE = os.environ.get("ATLAS_FREE") _TLS11 = os.environ.get("ATLAS_TLS11") _TLS12 = os.environ.get("ATLAS_TLS12") def _connect(uri): client = pymongo.MongoClient(uri) # No TLS error client.admin.command('ismaster') # No auth error client.test.test.count_documents({}) class TestAtlasConnect(unittest.TestCase): @classmethod def setUpClass(cls): if not all([_REPL, _SHRD, _FREE]): raise Exception( "Must set ATLAS_REPL/SHRD/FREE env variables to test.") def test_replica_set(self): _connect(_REPL) def test_sharded_cluster(self): _connect(_SHRD) def test_free_tier(self): if not HAS_SNI: raise unittest.SkipTest("Free tier requires SNI support.") _connect(_FREE) def test_tls_11(self): _connect(_TLS11) def test_tls_12(self): _connect(_TLS12) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/auth_aws/000077500000000000000000000000001374256237000160675ustar00rootroot00000000000000pymongo-3.11.0/test/auth_aws/test_auth_aws.py000066400000000000000000000036241374256237000213200ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test MONGODB-AWS Authentication.""" import os import sys import unittest sys.path[0:0] = [""] from pymongo import MongoClient from pymongo.errors import OperationFailure from pymongo.uri_parser import parse_uri if not hasattr(unittest.TestCase, 'assertRaisesRegex'): unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp class TestAuthAWS(unittest.TestCase): @classmethod def setUpClass(cls): cls.uri = os.environ['MONGODB_URI'] def test_should_fail_without_credentials(self): if '@' not in self.uri: self.skipTest('MONGODB_URI already has no credentials') hosts = ['%s:%s' % addr for addr in parse_uri(self.uri)['nodelist']] self.assertTrue(hosts) with MongoClient(hosts) as client: with self.assertRaises(OperationFailure): client.aws.test.find_one() def test_should_fail_incorrect_credentials(self): with MongoClient(self.uri, username='fake', password='fake', authMechanism='MONGODB-AWS') as client: with self.assertRaises(OperationFailure): client.get_database().test.find_one() def test_connect_uri(self): with MongoClient(self.uri) as client: client.get_database().test.find_one() if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/barrier.py000066400000000000000000000155701374256237000162640ustar00rootroot00000000000000# Backport of the threading.Barrier class from python 3.8, with small # changes to support python 2.7. # https://github.com/python/cpython/blob/v3.8.2/Lib/threading.py#L562-L728 from threading import (Condition, Lock) from pymongo.monotonic import time as _time # Backport Condition.wait_for from 3.8.2 # https://github.com/python/cpython/blob/v3.8.2/Lib/threading.py#L318-L339 def wait_for(condition, predicate, timeout=None): """Wait until a condition evaluates to True. predicate should be a callable which result will be interpreted as a boolean value. A timeout may be provided giving the maximum time to wait. """ endtime = None waittime = timeout result = predicate() while not result: if waittime is not None: if endtime is None: endtime = _time() + waittime else: waittime = endtime - _time() if waittime <= 0: break condition.wait(waittime) result = predicate() return result # A barrier class. Inspired in part by the pthread_barrier_* api and # the CyclicBarrier class from Java. See # http://sourceware.org/pthreads-win32/manual/pthread_barrier_init.html and # http://java.sun.com/j2se/1.5.0/docs/api/java/util/concurrent/ # CyclicBarrier.html # for information. # We maintain two main states, 'filling' and 'draining' enabling the barrier # to be cyclic. Threads are not allowed into it until it has fully drained # since the previous cycle. In addition, a 'resetting' state exists which is # similar to 'draining' except that threads leave with a BrokenBarrierError, # and a 'broken' state in which all threads get the exception. class Barrier(object): """Implements a Barrier. Useful for synchronizing a fixed number of threads at known synchronization points. Threads block on 'wait()' and are simultaneously awoken once they have all made that call. """ def __init__(self, parties, action=None, timeout=None): """Create a barrier, initialised to 'parties' threads. 'action' is a callable which, when supplied, will be called by one of the threads after they have all entered the barrier and just prior to releasing them all. If a 'timeout' is provided, it is used as the default for all subsequent 'wait()' calls. """ self._cond = Condition(Lock()) self._action = action self._timeout = timeout self._parties = parties self._state = 0 #0 filling, 1, draining, -1 resetting, -2 broken self._count = 0 def wait(self, timeout=None): """Wait for the barrier. When the specified number of threads have started waiting, they are all simultaneously awoken. If an 'action' was provided for the barrier, one of the threads will have executed that callback prior to returning. Returns an individual index number from 0 to 'parties-1'. """ if timeout is None: timeout = self._timeout with self._cond: self._enter() # Block while the barrier drains. index = self._count self._count += 1 try: if index + 1 == self._parties: # We release the barrier self._release() else: # We wait until someone releases us self._wait(timeout) return index finally: self._count -= 1 # Wake up any threads waiting for barrier to drain. self._exit() # Block until the barrier is ready for us, or raise an exception # if it is broken. def _enter(self): while self._state in (-1, 1): # It is draining or resetting, wait until done self._cond.wait() #see if the barrier is in a broken state if self._state < 0: raise BrokenBarrierError assert self._state == 0 # Optionally run the 'action' and release the threads waiting # in the barrier. def _release(self): try: if self._action: self._action() # enter draining state self._state = 1 self._cond.notify_all() except: #an exception during the _action handler. Break and reraise self._break() raise # Wait in the barrier until we are released. Raise an exception # if the barrier is reset or broken. def _wait(self, timeout): if not wait_for(self._cond, lambda : self._state != 0, timeout): #timed out. Break the barrier self._break() raise BrokenBarrierError if self._state < 0: raise BrokenBarrierError assert self._state == 1 # If we are the last thread to exit the barrier, signal any threads # waiting for the barrier to drain. def _exit(self): if self._count == 0: if self._state in (-1, 1): #resetting or draining self._state = 0 self._cond.notify_all() def reset(self): """Reset the barrier to the initial state. Any threads currently waiting will get the BrokenBarrier exception raised. """ with self._cond: if self._count > 0: if self._state == 0: #reset the barrier, waking up threads self._state = -1 elif self._state == -2: #was broken, set it to reset state #which clears when the last thread exits self._state = -1 else: self._state = 0 self._cond.notify_all() def abort(self): """Place the barrier into a 'broken' state. Useful in case of error. Any currently waiting threads and threads attempting to 'wait()' will have BrokenBarrierError raised. """ with self._cond: self._break() def _break(self): # An internal error was detected. The barrier is set to # a broken state all parties awakened. self._state = -2 self._cond.notify_all() @property def parties(self): """Return the number of threads required to trip the barrier.""" return self._parties @property def n_waiting(self): """Return the number of threads currently waiting at the barrier.""" # We don't need synchronization here since this is an ephemeral result # anyway. It returns the correct value in the steady state. if self._state == 0: return self._count return 0 @property def broken(self): """Return True if the barrier is in a broken state.""" return self._state == -2 # exception raised by the Barrier class class BrokenBarrierError(RuntimeError): pass pymongo-3.11.0/test/certificates/000077500000000000000000000000001374256237000167215ustar00rootroot00000000000000pymongo-3.11.0/test/certificates/ca.pem000066400000000000000000000023701374256237000200110ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIDfzCCAmegAwIBAgIDB1MGMA0GCSqGSIb3DQEBCwUAMHkxGzAZBgNVBAMTEkRy aXZlcnMgVGVzdGluZyBDQTEQMA4GA1UECxMHRHJpdmVyczEQMA4GA1UEChMHTW9u Z29EQjEWMBQGA1UEBxMNTmV3IFlvcmsgQ2l0eTERMA8GA1UECBMITmV3IFlvcmsx CzAJBgNVBAYTAlVTMB4XDTE5MDUyMjIwMjMxMVoXDTM5MDUyMjIwMjMxMVoweTEb MBkGA1UEAxMSRHJpdmVycyBUZXN0aW5nIENBMRAwDgYDVQQLEwdEcml2ZXJzMRAw DgYDVQQKEwdNb25nb0RCMRYwFAYDVQQHEw1OZXcgWW9yayBDaXR5MREwDwYDVQQI EwhOZXcgWW9yazELMAkGA1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw ggEKAoIBAQCl7VN+WsQfHlwapcOpTLZVoeMAl1LTbWTFuXSAavIyy0W1Ytky1UP/ bxCSW0mSWwCgqoJ5aXbAvrNRp6ArWu3LsTQIEcD3pEdrFIVQhYzWUs9fXqPyI9k+ QNNQ+MRFKeGteTPYwF2eVEtPzUHU5ws3+OKp1m6MCLkwAG3RBFUAfddUnLvGoZiT pd8/eNabhgHvdrCw+tYFCWvSjz7SluEVievpQehrSEPKe8DxJq/IM3tSl3tdylzT zeiKNO7c7LuQrgjAfrZl7n2SriHIlNmqiDR/kdd8+TxBuxjFlcf2WyHCO3lIcIgH KXTlhUCg50KfHaxHu05Qw0x8869yIzqbAgMBAAGjEDAOMAwGA1UdEwQFMAMBAf8w DQYJKoZIhvcNAQELBQADggEBAEHuhTL8KQZcKCTSJbYA9MgZj7U32arMGBbc1hiq VBREwvdVz4+9tIyWMzN9R/YCKmUTnCq8z3wTlC8kBtxYn/l4Tj8nJYcgLJjQ0Fwe gT564CmvkUat8uXPz6olOCdwkMpJ9Sj62i0mpgXJdBfxKQ6TZ9yGz6m3jannjZpN LchB7xSAEWtqUgvNusq0dApJsf4n7jZ+oBZVaQw2+tzaMfaLqHgMwcu1FzA8UKCD sxCgIsZUs8DdxaD418Ot6nPfheOTqe24n+TTa+Z6O0W0QtnofJBx7tmAo1aEc57i 77s89pfwIJetpIlhzNSMKurCAocFCJMJLAASJFuu6dyDvPo= -----END CERTIFICATE-----pymongo-3.11.0/test/certificates/client.pem000066400000000000000000000056141374256237000207100ustar00rootroot00000000000000-----BEGIN RSA PRIVATE KEY----- MIIEpAIBAAKCAQEAsNS8UEuin7/K29jXfIOLpIoh1jEyWVqxiie2Onx7uJJKcoKo khA3XeUnVN0k6X5MwYWcN52xcns7LYtyt06nRpTG2/emoV44w9uKTuHsvUbiOwSV m/ToKQQ4FUFZoqorXH+ZmJuIpJNfoW+3CkE1vEDCIecIq6BNg5ySsPtvSuSJHGjp mc7/5ZUDvFE2aJ8QbJU3Ws0HXiEb6ymi048LlzEL2VKX3w6mqqh+7dcZGAy7qYk2 5FZ9ktKvCeQau7mTyU1hsPrKFiKtMN8Q2ZAItX13asw5/IeSTq2LgLFHlbj5Kpq4 GmLdNCshzH5X7Ew3IYM8EHmsX8dmD6mhv7vpVwIDAQABAoIBABOdpb4qhcG+3twA c/cGCKmaASLnljQ/UU6IFTjrsjXJVKTbRaPeVKX/05sgZQXZ0t3s2mV5AsQ2U1w8 Cd+3w+qaemzQThW8hAOGCROzEDX29QWi/o2sX0ydgTMqaq0Wv3SlWv6I0mGfT45y /BURIsrdTCvCmz2erLqa1dL4MWJXRFjT9UTs5twlecIOM2IHKoGGagFhymRK4kDe wTRC9fpfoAgyfus3pCO/wi/F8yKGPDEwY+zgkhrJQ+kSeki7oKdGD1H540vB8gRt EIqssE0Y6rEYf97WssQlxJgvoJBDSftOijS6mwvoasDUwfFqyyPiirawXWWhHXkc DjIi/XECgYEA5xfjilw9YyM2UGQNESbNNunPcj7gDZbN347xJwmYmi9AUdPLt9xN 3XaMqqR22k1DUOxC/5hH0uiXir7mDfqmC+XS/ic/VOsa3CDWejkEnyGLiwSHY502 wD/xWgHwUiGVAG9HY64vnDGm6L3KGXA2oqxanL4V0+0+Ht49pZ16i8sCgYEAw+Ox CHGtpkzjCP/z8xr+1VTSdpc/4CP2HONnYopcn48KfQnf7Nale69/1kZpypJlvQSG eeA3jMGigNJEkb8/kaVoRLCisXcwLc0XIfCTeiK6FS0Ka30D/84Qm8UsHxRdpGkM kYITAa2r64tgRL8as4/ukeXBKE+oOhX43LeEfyUCgYBkf7IX2Ndlhsm3GlvIarxy NipeP9PGdR/hKlPbq0OvQf9R1q7QrcE7H7Q6/b0mYNV2mtjkOQB7S2WkFDMOP0P5 BqDEoKLdNkV/F9TOYH+PCNKbyYNrodJOt0Ap6Y/u1+Xpw3sjcXwJDFrO+sKqX2+T PStG4S+y84jBedsLbDoAEwKBgQCTz7/KC11o2yOFqv09N+WKvBKDgeWlD/2qFr3w UU9K5viXGVhqshz0k5z25vL09Drowf1nAZVpFMO2SPOMtq8VC6b+Dfr1xmYIaXVH Gu1tf77CM9Zk/VSDNc66e7GrUgbHBK2DLo+A+Ld9aRIfTcSsMbNnS+LQtCrQibvb cG7+MQKBgQCY11oMT2dUekoZEyW4no7W5D74lR8ztMjp/fWWTDo/AZGPBY6cZoZF IICrzYtDT/5BzB0Jh1f4O9ZQkm5+OvlFbmoZoSbMzHL3oJCBOY5K0/kdGXL46WWh IRJSYakNU6VIS7SjDpKgm9D8befQqZeoSggSjIIULIiAtYgS80vmGA== -----END RSA PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIIDgzCCAmugAwIBAgIDAxOUMA0GCSqGSIb3DQEBCwUAMHkxGzAZBgNVBAMTEkRy aXZlcnMgVGVzdGluZyBDQTEQMA4GA1UECxMHRHJpdmVyczEQMA4GA1UEChMHTW9u Z29EQjEWMBQGA1UEBxMNTmV3IFlvcmsgQ2l0eTERMA8GA1UECBMITmV3IFlvcmsx CzAJBgNVBAYTAlVTMB4XDTE5MDUyMjIzNTU1NFoXDTM5MDUyMjIzNTU1NFowaTEP MA0GA1UEAxMGY2xpZW50MRAwDgYDVQQLEwdEcml2ZXJzMQwwCgYDVQQKEwNNREIx FjAUBgNVBAcTDU5ldyBZb3JrIENpdHkxETAPBgNVBAgTCE5ldyBZb3JrMQswCQYD VQQGEwJVUzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBALDUvFBLop+/ ytvY13yDi6SKIdYxMllasYontjp8e7iSSnKCqJIQN13lJ1TdJOl+TMGFnDedsXJ7 Oy2LcrdOp0aUxtv3pqFeOMPbik7h7L1G4jsElZv06CkEOBVBWaKqK1x/mZibiKST X6FvtwpBNbxAwiHnCKugTYOckrD7b0rkiRxo6ZnO/+WVA7xRNmifEGyVN1rNB14h G+spotOPC5cxC9lSl98Opqqofu3XGRgMu6mJNuRWfZLSrwnkGru5k8lNYbD6yhYi rTDfENmQCLV9d2rMOfyHkk6ti4CxR5W4+SqauBpi3TQrIcx+V+xMNyGDPBB5rF/H Zg+pob+76VcCAwEAAaMkMCIwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUF BwMCMA0GCSqGSIb3DQEBCwUAA4IBAQAqRcLAGvYMaGYOV4HJTzNotT2qE0I9THNQ wOV1fBg69x6SrUQTQLjJEptpOA288Wue6Jt3H+p5qAGV5GbXjzN/yjCoItggSKxG Xg7279nz6/C5faoIKRjpS9R+MsJGlttP9nUzdSxrHvvqm62OuSVFjjETxD39DupE YPFQoHOxdFTtBQlc/zIKxVdd20rs1xJeeU2/L7jtRBSPuR/Sk8zot7G2/dQHX49y kHrq8qz12kj1T6XDXf8KZawFywXaz0/Ur+fUYKmkVk1T0JZaNtF4sKqDeNE4zcns p3xLVDSl1Q5Gwj7bgph9o4Hxs9izPwiqjmNaSjPimGYZ399zcurY -----END CERTIFICATE----- pymongo-3.11.0/test/certificates/crl.pem000066400000000000000000000013171374256237000202060ustar00rootroot00000000000000-----BEGIN X509 CRL----- MIIB6jCB0wIBATANBgkqhkiG9w0BAQsFADB5MRswGQYDVQQDExJEcml2ZXJzIFRl c3RpbmcgQ0ExEDAOBgNVBAsTB0RyaXZlcnMxEDAOBgNVBAoTB01vbmdvREIxFjAU BgNVBAcTDU5ldyBZb3JrIENpdHkxETAPBgNVBAgTCE5ldyBZb3JrMQswCQYDVQQG EwJVUxcNMTkwNTIyMjI0NTUzWhcNMTkwNjIxMjI0NTUzWjAVMBMCAncVFw0xOTA1 MjIyMjQ1MzJaoA8wDTALBgNVHRQEBAICEAAwDQYJKoZIhvcNAQELBQADggEBACwQ W9OF6ExJSzzYbpCRroznkfdLG7ghNSxIpBQUGtcnYbkP4em6TdtAj5K3yBjcKn4a hnUoa5EJGr2Xgg0QascV/1GuWEJC9rsYYB9boVi95l1CrkS0pseaunM086iItZ4a hRVza8qEMBc3rdsracA7hElYMKdFTRLpIGciJehXzv40yT5XFBHGy/HIT0CD50O7 BDOHzA+rCFCvxX8UY9myDfb1r1zUW7Gzjn241VT7bcIJmhFE9oV0popzDyqr6GvP qB2t5VmFpbnSwkuc4ie8Jizip1P8Hg73lut3oVAHACFGPpfaNIAp4GcSH61zJmff 9UBe3CJ1INwqyiuqGeA= -----END X509 CRL----- pymongo-3.11.0/test/certificates/password_protected.pem000066400000000000000000000060771374256237000233510ustar00rootroot00000000000000-----BEGIN ENCRYPTED PRIVATE KEY----- MIIFHzBJBgkqhkiG9w0BBQ0wPDAbBgkqhkiG9w0BBQwwDgQIC8as6PDVhwECAggA MB0GCWCGSAFlAwQBAgQQTYOgCJcRqUI7dsgqNojv/ASCBNCG9fiu642V4AuFK34c Q42lvy/cR0CIXLq/rDXN1L685kdeKex7AfDuRtnjY2+7CLJiJimgQNJXDJPHab/k MBHbwbBs38fg6eSYX8V08/IyyTege5EJMhYxmieHDC3DXKt0gyHk6hA/r5+Mr49h HeVGwqBLJEQ3gVIeHaOleZYspsXXWqOPHnFiqnk/biaJS0+LkDDEiQgTLEYSnOjP lexxUc4BV/TN0Z920tZCMfwx7IXD/C+0AkV/Iqq4LALmT702EccB3indaIJ8biGR radqDLR32Q+vT9uZHgT8EFiUsISMqhob2mnyTfFV/s9ghWwogjSz0HrRcq6fxdg7 oeyT9K0ET53AGTGmV0206byPu6qCj1eNvtn+t1Ob+d5hecaTugRMVheWPlc5frsz AcewDNa0pv4pZItjAGMqOPJHfzEDnzTJXpLqGYhg044H1+OCY8+1YK7U0u8dO+/3 f5AoDMq18ipDVTFTooJURej4/Wjbrfad3ZFjp86nxfHPeWM1YjC9+IlLtK1wr0/U V8TjGqCkw8yHayz01A86iA8X53YQBg+tyMGjxmivo6LgFGKa9mXGvDkN+B+0+OcA PqldAuH/TJhnkqzja767e4n9kcr+TmV19Hn1hcJPTDrRU8+sSqQFsWN4pvHazAYB UdWie+EXI0eU2Av9JFgrVcpRipXjB48BaPwuBw8hm+VStCH7ynF4lJy6/3esjYwk Mx+NUf8+pp1DRzpzuJa2vAutzqia5r58+zloQMxkgTZtJkQU6OCRoUhHGVk7WNb1 nxsibOSzyVSP9ZNbHIHAn43vICFGrPubRs200Kc4CdXsOSEWoP0XYebhiNJgGtQs KoISsV4dFRLwhaJhIlayTBQz6w6Ph87WbtuiAqoLiuqdXhUGz/79j/6JZqCH8t/H eZs4Dhu+HdD/wZKJDYAS+JBsiwYWnI3y/EowZYgLdOMI4u6xYDejhxwEw20LW445 qjJ7pV/iX2uavazHgC91Bfd4zodfXIQ1IDyTmb51UFwx0ARzG6enntduO6xtcYU9 MXwfrEpuZ/MkWTLkR0PHPbIPcR1MiVwPKdvrLk42Bzj/urtXYrAFUckMFMzEh+uv 0lix2hbq/Xwj4dXcY4w9hnC6QQDCJTf9S6MU6OisrZHKk0qZ2Vb4aU/eBcBsHBwo X/QGcDHneHxlrrs2eLX26Vh8Odc5h8haeIxnfaa1t+Yv56OKHuAztPMnJOUL7KtQ A556LxT0b5IGx0RcfUcbG8XbxEHseACptoDOoguh9923IBI0uXmpi8q0P815LPUu 0AsE47ATDMGPnXbopejRDicfgMGjykJn8vKO8r/Ia3Fpnomx4iJNCXGqomL+GMpZ IhQbKNrRG6XZMlx5kVCT0Qr1nOWMiOTSDCQ5vrG3c1Viu+0bctvidEvs+LCm98tb 7ty8F0uOno0rYGNQz18OEE1Tj+E19Vauz1U35Z5SsgJJ/GfzhSJ79Srmdg2PsAzk AUNTKXux1GLf1cMjTiiU5g+tCEtUL9Me7lsv3L6aFdrCyRbhXUQfJh4NAG8+3Pvh EaprThBzKsVvbOfU81mOaH9YMmUgmxG86vxDiNtaWd4v6c1k+HGspJr/q49pcXZP ltBMuS9AihstZ1sHJsyQCmNXkA== -----END ENCRYPTED PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIIDgzCCAmugAwIBAgIDBXUHMA0GCSqGSIb3DQEBCwUAMHkxGzAZBgNVBAMTEkRy aXZlcnMgVGVzdGluZyBDQTEQMA4GA1UECxMHRHJpdmVyczEQMA4GA1UEChMHTW9u Z29EQjEWMBQGA1UEBxMNTmV3IFlvcmsgQ2l0eTERMA8GA1UECBMITmV3IFlvcmsx CzAJBgNVBAYTAlVTMB4XDTE5MDUyMzAwMDEyOVoXDTM5MDUyMzAwMDEyOVowaTEP MA0GA1UEAxMGY2xpZW50MRAwDgYDVQQLEwdEcml2ZXJzMQwwCgYDVQQKEwNNREIx FjAUBgNVBAcTDU5ldyBZb3JrIENpdHkxETAPBgNVBAgTCE5ldyBZb3JrMQswCQYD VQQGEwJVUzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOqCb0Lo4XsV W327Wlnqc5rwWa5Elw0rFuehSfViRIcYfuFWAPXoOj3fIDsYz6d41G8hp6tkF88p swlbzDF8Fc7mXDhauwwl2F/NrWYUXwCT8fKju4DtGd2JlDMi1TRDeofkYCGVPp70 vNqd0H8iDWWs8OmiNrdBLJwNiGaf9y15ena4ImQGitXLFn+qNSXYJ1Rs8p7Y2PTr L+dff5gJCVbANwGII1rjMAsrMACPVmr8c1Lxoq4fSdJiLweosrv2Lk0WWGsO0Seg ZY71dNHEyNjItE+VtFEtslJ5L261i3BfF/FqNnH2UmKXzShwfwxyHT8o84gSAltQ 5/lVJ4QQKosCAwEAAaMkMCIwCwYDVR0PBAQDAgeAMBMGA1UdJQQMMAoGCCsGAQUF BwMCMA0GCSqGSIb3DQEBCwUAA4IBAQBOAlKxIMFcTZ+4k8NJv97RSf+zOb5Wu2ct uxSZxzgKTxLFUuEM8XQiEz1iHQ3XG+uV1fzA74YLQiKjjLrU0mx54eM1vaRtOXvF sJlzZU8Z2+523FVPx4HBPyObQrfXmIoAiHoQ4VUeepkPRpXxpifgWd/OCWhLDr2/ 0Kgcb0ybaGVDpA0UD9uVIwgFjRu6id7wG+lVcdRxJYskTOOaN2o1hMdAKkrpFQbd zNRfEoBPUYR3QAmAKP2HBjpgp4ktOHoOKMlfeAuuMCUocSnmPKc3xJaH/6O7rHcf /Rm0X411RH8JfoXYsSiPsd601kZefhuWvJH0sJLibRDvT7zs8C1v -----END CERTIFICATE----- pymongo-3.11.0/test/certificates/server.pem000066400000000000000000000057221374256237000207400ustar00rootroot00000000000000-----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEAhNrB0E6GY/kFSd8/vNpu/t952tbnOsD5drV0XPvmuy7SgKDY a/S+xb/jPnlZKKehdBnH7qP/gYbv34ZykzcDFZscjPLiGc2cRGP+NQCSFK0d2/7d y15zSD3zhj14G8+MkpAejTU+0/qFNZMc5neDvGanTe0+8aWa0DXssM0MuTxIv7j6 CtsMWeqLLofN7a1Kw2UvmieCHfHMuA/08pJwRnV/+5T9WONBPJja2ZQRrG1BjpI4 81zSPUZesIqi8yDlExdvgNaRZIEHi/njREqwVgJOZomUY57zmKypiMzbz48dDTsV gUStxrEqbaP+BEjQYPX5+QQk4GdMjkLf52LR6QIDAQABAoIBAHSs+hHLJNOf2zkp S3y8CUblVMsQeTpsR6otaehPgi9Zy50TpX4KD5D0GMrBH8BIl86y5Zd7h+VlcDzK gs0vPxI2izhuBovKuzaE6rf5rFFkSBjxGDCG3o/PeJOoYFdsS3RcBbjVzju0hFCs xnDQ/Wz0anJRrTnjyraY5SnQqx/xuhLXkj/lwWoWjP2bUqDprnuLOj16soNu60Um JziWbmWx9ty0wohkI/8DPBl9FjSniEEUi9pnZXPElFN6kwPkgdfT5rY/TkMH4lsu ozOUc5xgwlkT6kVjXHcs3fleuT/mOfVXLPgNms85JKLucfd6KiV7jYZkT/bXIjQ+ 7CZEn0ECgYEA5QiKZgsfJjWvZpt21V/i7dPje2xdwHtZ8F9NjX7ZUFA7mUPxUlwe GiXxmy6RGzNdnLOto4SF0/7ebuF3koO77oLup5a2etL+y/AnNAufbu4S5D72sbiz wdLzr3d5JQ12xeaEH6kQNk2SD5/ShctdS6GmTgQPiJIgH0MIdi9F3v0CgYEAlH84 hMWcC+5b4hHUEexeNkT8kCXwHVcUjGRaYFdSHgovvWllApZDHSWZ+vRcMBdlhNPu 09Btxo99cjOZwGYJyt20QQLGc/ZyiOF4ximQzabTeFgLkTH3Ox6Mh2Rx9yIruYoX nE3UfMDkYELanEJUv0zenKpZHw7tTt5yXXSlEF0CgYBSsEOvVcKYO/eoluZPYQAA F2jgzZ4HeUFebDoGpM52lZD+463Dq2hezmYtPaG77U6V3bUJ/TWH9VN/Or290vvN v83ECcC2FWlSXdD5lFyqYx/E8gqE3YdgqfW62uqM+xBvoKsA9zvYLydVpsEN9v8m 6CSvs/2btA4O21e5u5WBTQKBgGtAb6vFpe0gHRDs24SOeYUs0lWycPhf+qFjobrP lqnHpa9iPeheat7UV6BfeW3qmBIVl/s4IPE2ld4z0qqZiB0Tf6ssu/TpXNPsNXS6 dLFz+myC+ufFdNEoQUtQitd5wKbjTCZCOGRaVRgJcSdG6Tq55Fa22mOKPm+mTmed ZdKpAoGAFsTYBAHPxs8nzkCJCl7KLa4/zgbgywO6EcQgA7tfelB8bc8vcAMG5o+8 YqAfwxrzhVSVbJx0fibTARXROmbh2pn010l2wj3+qUajM8NiskCPFbSjGy7HSUze P8Kt1uMDJdj55gATzn44au31QBioZY2zXleorxF21cr+BZCJgfA= -----END RSA PRIVATE KEY----- -----BEGIN CERTIFICATE----- MIIDlTCCAn2gAwIBAgICdxUwDQYJKoZIhvcNAQELBQAweTEbMBkGA1UEAxMSRHJp dmVycyBUZXN0aW5nIENBMRAwDgYDVQQLEwdEcml2ZXJzMRAwDgYDVQQKEwdNb25n b0RCMRYwFAYDVQQHEw1OZXcgWW9yayBDaXR5MREwDwYDVQQIEwhOZXcgWW9yazEL MAkGA1UEBhMCVVMwHhcNMTkwNTIyMjIzMjU2WhcNMzkwNTIyMjIzMjU2WjBwMRIw EAYDVQQDEwlsb2NhbGhvc3QxEDAOBgNVBAsTB0RyaXZlcnMxEDAOBgNVBAoTB01v bmdvREIxFjAUBgNVBAcTDU5ldyBZb3JrIENpdHkxETAPBgNVBAgTCE5ldyBZb3Jr MQswCQYDVQQGEwJVUzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAITa wdBOhmP5BUnfP7zabv7fedrW5zrA+Xa1dFz75rsu0oCg2Gv0vsW/4z55WSinoXQZ x+6j/4GG79+GcpM3AxWbHIzy4hnNnERj/jUAkhStHdv+3ctec0g984Y9eBvPjJKQ Ho01PtP6hTWTHOZ3g7xmp03tPvGlmtA17LDNDLk8SL+4+grbDFnqiy6Hze2tSsNl L5ongh3xzLgP9PKScEZ1f/uU/VjjQTyY2tmUEaxtQY6SOPNc0j1GXrCKovMg5RMX b4DWkWSBB4v540RKsFYCTmaJlGOe85isqYjM28+PHQ07FYFErcaxKm2j/gRI0GD1 +fkEJOBnTI5C3+di0ekCAwEAAaMwMC4wLAYDVR0RBCUwI4IJbG9jYWxob3N0hwR/ AAABhxAAAAAAAAAAAAAAAAAAAAABMA0GCSqGSIb3DQEBCwUAA4IBAQBol8+YH7MA HwnIh7KcJ8h87GkCWsjOJCDJWiYBJArQ0MmgDO0qdx+QEtvLMn3XNtP05ZfK0WyX or4cWllAkMFYaFbyB2hYazlD1UAAG+22Rku0UP6pJMLbWe6pnqzx+RL68FYdbZhN fCW2xiiKsdPoo2VEY7eeZKrNr/0RFE5EKXgzmobpTBQT1Dl3Ve4aWLoTy9INlQ/g z40qS7oq1PjjPLgxINhf4ncJqfmRXugYTOnyFiVXLZTys5Pb9SMKdToGl3NTYWLL 2AZdjr6bKtT+WtXyHqO0cQ8CkAW0M6VOlMluACllcJxfrtdlQS2S4lUIj76QKBdZ khBHXq/b8MFX -----END CERTIFICATE----- pymongo-3.11.0/test/certificates/trusted-ca.pem000066400000000000000000000112331374256237000214770ustar00rootroot00000000000000# CA bundle file used to test tlsCAFile loading for OCSP. # Copied from the server: # https://github.com/mongodb/mongo/blob/r4.3.4/jstests/libs/trusted-ca.pem # Autogenerated file, do not edit. # Generate using jstests/ssl/x509/mkcert.py --config jstests/ssl/x509/certs.yml trusted-ca.pem # # CA for alternate client/server certificate chain. -----BEGIN CERTIFICATE----- MIIDojCCAooCBG585gswDQYJKoZIhvcNAQELBQAwfDELMAkGA1UEBhMCVVMxETAP BgNVBAgMCE5ldyBZb3JrMRYwFAYDVQQHDA1OZXcgWW9yayBDaXR5MRAwDgYDVQQK DAdNb25nb0RCMQ8wDQYDVQQLDAZLZXJuZWwxHzAdBgNVBAMMFlRydXN0ZWQgS2Vy bmVsIFRlc3QgQ0EwHhcNMTkwOTI1MjMyNzQxWhcNMzkwOTI3MjMyNzQxWjB8MQsw CQYDVQQGEwJVUzERMA8GA1UECAwITmV3IFlvcmsxFjAUBgNVBAcMDU5ldyBZb3Jr IENpdHkxEDAOBgNVBAoMB01vbmdvREIxDzANBgNVBAsMBktlcm5lbDEfMB0GA1UE AwwWVHJ1c3RlZCBLZXJuZWwgVGVzdCBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEP ADCCAQoCggEBANlRxtpMeCGhkotkjHQqgqvO6O6hoRoAGGJlDaTVtqrjmC8nwySz 1nAFndqUHttxS3A5j4enOabvffdOcV7+Z6vDQmREF6QZmQAk81pmazSc3wOnRiRs AhXjld7i+rhB50CW01oYzQB50rlBFu+ONKYj32nBjD+1YN4AZ2tuRlbxfx2uf8Bo Zowfr4n9nHVcWXBLFmaQLn+88WFO/wuwYUOn6Di1Bvtkvqum0or5QeAF0qkJxfhg 3a4vBnomPdwEXCgAGLvHlB41CWG09EuAjrnE3HPPi5vII8pjY2dKKMomOEYmA+KJ AC1NlTWdN0TtsoaKnyhMMhLWs3eTyXL7kbkCAwEAAaMxMC8wDAYDVR0TBAUwAwEB /zAfBgNVHREEGDAWgglsb2NhbGhvc3SCCTEyNy4wLjAuMTANBgkqhkiG9w0BAQsF AAOCAQEAQk56MO9xAhtO077COCqIYe6pYv3uzOplqjXpJ7Cph7GXwQqdFWfKls7B cLfF/fhIUZIu5itStEkY+AIwht4mBr1F5+hZUp9KZOed30/ewoBXAUgobLipJV66 FKg8NRtmJbiZrrC00BSO+pKfQThU8k0zZjBmNmpjxnbKZZSFWUKtbhHV1vujver6 SXZC7R6692vLwRBMoZxhgy/FkYRdiN0U9wpluKd63eo/O02Nt6OEMyeiyl+Z3JWi 8g5iHNrBYGBbGSnDOnqV6tjEY3eq600JDWiodpA1OQheLi78pkc/VQZwof9dyBCm 6BoCskTjip/UB+vIhdPFT9sgUdgDTg== -----END CERTIFICATE----- -----BEGIN PRIVATE KEY----- MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDZUcbaTHghoZKL ZIx0KoKrzujuoaEaABhiZQ2k1baq45gvJ8Mks9ZwBZ3alB7bcUtwOY+Hpzmm7333 TnFe/merw0JkRBekGZkAJPNaZms0nN8Dp0YkbAIV45Xe4vq4QedAltNaGM0AedK5 QRbvjjSmI99pwYw/tWDeAGdrbkZW8X8drn/AaGaMH6+J/Zx1XFlwSxZmkC5/vPFh Tv8LsGFDp+g4tQb7ZL6rptKK+UHgBdKpCcX4YN2uLwZ6Jj3cBFwoABi7x5QeNQlh tPRLgI65xNxzz4ubyCPKY2NnSijKJjhGJgPiiQAtTZU1nTdE7bKGip8oTDIS1rN3 k8ly+5G5AgMBAAECggEAS7GjLKgT88reSzUTgubHquYf1fZwMak01RjTnsVdoboy aMJVwzPsjgo2yEptUQvuNcGmz54cg5vJaVlmPaspGveg6WGaRmswEo/MP4GK98Fo IFKkKM2CEHO74O14XLN/w8yFA02+IdtM3X/haEFE71VxXNmwawRXIBxN6Wp4j5Fb mPLKIspnWQ/Y/Fn799sCFAzX5mKkbCt1IEgKssgQQEm1UkvmCkcZE+mdO/ErYP8A COO0LpM+TK6WQY2LKiteeCCiosTZFb1GO7MkXrRP5uOBZKaW5kq1R0b6PcopJPCM OcYF0Zli6KB7oiQLdXgU2jCaxYOnuRb6RYh2l7NvAQKBgQD6CZ9TKOn/EUQtukyw pvYTyt1hoLXqYGcbRtLc1gcC+Z2BD28hd3eD/mEUv+g/8bq/OP4wYV9X+VRvR8xN MmfAG/sJeOCOClz1A1TyNeA+G0GZ25qWHyHQ2W4WlSG1CXQgxGzU6wo/t6wiVW5R O4jplFVEOXznf4vmVfBJK50R2QKBgQDegGxm23jF2N5sIYDZ14oxms8bbjPz8zH6 tiIRYNGbSzI7J4KFGY2HiBwtf1yxS22HBL69Y1WrEzGm1vm4aZG/GUwBzI79QZAO +YFIGaIrdlv12Zm6lpJMmAWlOs9XFirC17oQEwOQFweOdQSt7F/+HMZOigdikRBV pK+8Kfay4QKBgQDarDevHwUmkg8yftA7Xomv3aenjkoK5KzH6jTX9kbDj1L0YG8s sbLQuVRmNUAFTH+qZUnJPh+IbQIvIHfIu+CI3u+55QFeuCl8DqHoAr5PEr9Ys/qK eEe2w7HIBj0oe1AYqDEWNUkNWLEuhdCpMowW3CeGN1DJlX7gvyAang4MYQKBgHwM aWNnFQxo/oiWnTnWm2tQfgszA7AMdF7s0E2UBwhnghfMzU3bkzZuwhbznQATp3rR QG5iRU7dop7717ni0akTN3cBTu8PcHuIy3UhJXLJyDdnG/gVHnepgew+v340E58R muB/WUsqK8JWp0c4M8R+0mjTN47ShaLZ8EgdtTbBAoGBAKOcpuDfFEMI+YJgn8zX h0nFT60LX6Lx+zcSDY9+6J6a4n5NhC+weYCDFOGlsLka1SwHcg1xanfrLVjpH7Ok HDJGLrSh1FP2Rq/oFxZ/OKCjonHLa8IulqD/AA+sqYRbysKNsT3Pi0554F2xFEqQ z/C84nlT1R2uTCWIxvrnpU2h -----END PRIVATE KEY----- # Pre Oct 2019 trusted-ca.pem # Transitional pending BUILD update. -----BEGIN CERTIFICATE----- MIIDpjCCAo6gAwIBAgIDAghHMA0GCSqGSIb3DQEBBQUAMHwxHzAdBgNVBAMTFlRy dXN0ZWQgS2VybmVsIFRlc3QgQ0ExDzANBgNVBAsTBktlcm5lbDEQMA4GA1UEChMH TW9uZ29EQjEWMBQGA1UEBxMNTmV3IFlvcmsgQ2l0eTERMA8GA1UECBMITmV3IFlv cmsxCzAJBgNVBAYTAlVTMB4XDTE2MDMzMTE0NTY1NVoXDTM2MDMzMTE0NTY1NVow fDEfMB0GA1UEAxMWVHJ1c3RlZCBLZXJuZWwgVGVzdCBDQTEPMA0GA1UECxMGS2Vy bmVsMRAwDgYDVQQKEwdNb25nb0RCMRYwFAYDVQQHEw1OZXcgWW9yayBDaXR5MREw DwYDVQQIEwhOZXcgWW9yazELMAkGA1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUA A4IBDwAwggEKAoIBAQCePFHZTydC96SlSHSyu73vw//ddaE33kPllBB9DP2L7yRF 6D/blFmno9fSM+Dfg64VfGV+0pCXPIZbpH29nzJu0DkvHzKiWK7P1zUj8rAHaX++ d6k0yeTLFM9v+7YE9rHoANVn22aOyDvTgAyMmA0CLn+SmUy6WObwMIf9cZn97Znd lww7IeFNyK8sWtfsVN4yRBnjr7kKN2Qo0QmWeFa7jxVQptMJQrY8k1PcyVUOgOjQ ocJLbWLlm9k0/OMEQSwQHJ+d9weUbKjlZ9ExOrm4QuuA2tJhb38baTdAYw3Jui4f yD6iBAGD0Jkpc+3YaWv6CBmK8NEFkYJD/gn+lJ75AgMBAAGjMTAvMAwGA1UdEwQF MAMBAf8wHwYDVR0RBBgwFoIJbG9jYWxob3N0ggkxMjcuMC4wLjEwDQYJKoZIhvcN AQEFBQADggEBADYikjB6iwAUs6sglwkE4rOkeMkJdRCNwK/5LpFJTWrDjBvBQCdA Y5hlAVq8PfIYeh+wEuSvsEHXmx7W29X2+p4VuJ95/xBA6NLapwtzuiijRj2RBAOG 1EGuyFQUPTL27DR3+tfayNykDclsVDNN8+l7nt56j8HojP74P5OMHtn+6HX5+mtF FfZMTy0mWguCsMOkZvjAskm6s4U5gEC8pYEoC0ZRbfUdyYsxZe/nrXIFguVlVPCB XnfB/0iG9t+VH5cUVj1LP9skXTW4kXfhQmljUuo+EVBNR6n2nfTnpoC65WeAgHV4 V+s9mJsUv2x72KtKYypqEVT0gaJ1WIN9N1s= -----END CERTIFICATE----- pymongo-3.11.0/test/mod_wsgi_test/000077500000000000000000000000001374256237000171235ustar00rootroot00000000000000pymongo-3.11.0/test/mod_wsgi_test/test_client.py000066400000000000000000000113221374256237000220110ustar00rootroot00000000000000# Copyright 2012-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test client for mod_wsgi application, see bug PYTHON-353. """ import sys import threading import time from optparse import OptionParser try: from urllib2 import urlopen except ImportError: # Python 3. from urllib.request import urlopen try: import thread except ImportError: # Python 3. import _thread as thread def parse_args(): parser = OptionParser("""usage: %prog [options] mode url mode:\tparallel or serial""") # Should be enough that any connection leak will exhaust available file # descriptors. parser.add_option( "-n", "--nrequests", type="int", dest="nrequests", default=50 * 1000, help="Number of times to GET the URL, in total") parser.add_option( "-t", "--nthreads", type="int", dest="nthreads", default=100, help="Number of threads with mode 'parallel'") parser.add_option( "-q", "--quiet", action="store_false", dest="verbose", default=True, help="Don't print status messages to stdout") parser.add_option( "-c", "--continue", action="store_true", dest="continue_", default=False, help="Continue after HTTP errors") try: options, (mode, url) = parser.parse_args() except ValueError: parser.print_usage() sys.exit(1) if mode not in ('parallel', 'serial'): parser.print_usage() sys.exit(1) return options, mode, url def get(url): urlopen(url).read().strip() class URLGetterThread(threading.Thread): # Class variables. counter_lock = threading.Lock() counter = 0 def __init__(self, options, url, nrequests_per_thread): super(URLGetterThread, self).__init__() self.options = options self.url = url self.nrequests_per_thread = nrequests_per_thread self.errors = 0 def run(self): for i in range(self.nrequests_per_thread): try: get(url) except Exception as e: print(e) if not options.continue_: thread.interrupt_main() thread.exit() self.errors += 1 with URLGetterThread.counter_lock: URLGetterThread.counter += 1 counter = URLGetterThread.counter should_print = options.verbose and not counter % 1000 if should_print: print(counter) def main(options, mode, url): start_time = time.time() errors = 0 if mode == 'parallel': nrequests_per_thread = options.nrequests // options.nthreads if options.verbose: print ( 'Getting %s %s times total in %s threads, ' '%s times per thread' % ( url, nrequests_per_thread * options.nthreads, options.nthreads, nrequests_per_thread)) threads = [ URLGetterThread(options, url, nrequests_per_thread) for _ in range(options.nthreads) ] for t in threads: t.start() for t in threads: t.join() errors = sum([t.errors for t in threads]) nthreads_with_errors = len([t for t in threads if t.errors]) if nthreads_with_errors: print('%d threads had errors! %d errors in total' % ( nthreads_with_errors, errors)) else: assert mode == 'serial' if options.verbose: print('Getting %s %s times in one thread' % ( url, options.nrequests )) for i in range(1, options.nrequests + 1): try: get(url) except Exception as e: print(e) if not options.continue_: sys.exit(1) errors += 1 if options.verbose and not i % 1000: print(i) if errors: print('%d errors!' % errors) if options.verbose: print('Completed in %.2f seconds' % (time.time() - start_time)) if errors: # Failure sys.exit(1) if __name__ == '__main__': options, mode, url = parse_args() main(options, mode, url) pymongo-3.11.0/test/ocsp/000077500000000000000000000000001374256237000152205ustar00rootroot00000000000000pymongo-3.11.0/test/ocsp/test_ocsp.py000066400000000000000000000044021374256237000175750ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test OCSP.""" import logging import os import sys import unittest sys.path[0:0] = [""] import pymongo from pymongo.errors import ServerSelectionTimeoutError CA_FILE = os.environ.get("CA_FILE") OCSP_TLS_SHOULD_SUCCEED = (os.environ.get('OCSP_TLS_SHOULD_SUCCEED') == 'true') # Enable logs in this format: # 2020-06-08 23:49:35,982 DEBUG ocsp_support Peer did not staple an OCSP response FORMAT = '%(asctime)s %(levelname)s %(module)s %(message)s' logging.basicConfig(format=FORMAT, level=logging.DEBUG) if sys.platform == 'win32': # The non-stapled OCSP endpoint check is slow on Windows. TIMEOUT_MS = 5000 else: TIMEOUT_MS = 500 def _connect(options): uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS=%s" "&tlsCAFile=%s&%s") % (TIMEOUT_MS, CA_FILE, options) print(uri) client = pymongo.MongoClient(uri) client.admin.command('ismaster') if not hasattr(unittest.TestCase, 'assertRaisesRegex'): unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp class TestOCSP(unittest.TestCase): def test_tls_insecure(self): # Should always succeed options = "tls=true&tlsInsecure=true" _connect(options) def test_allow_invalid_certificates(self): # Should always succeed options = "tls=true&tlsAllowInvalidCertificates=true" _connect(options) def test_tls(self): options = "tls=true" if not OCSP_TLS_SHOULD_SUCCEED: self.assertRaisesRegex( ServerSelectionTimeoutError, "invalid status response", _connect, options) else: _connect(options) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/performance/000077500000000000000000000000001374256237000165555ustar00rootroot00000000000000pymongo-3.11.0/test/performance/perf_test.py000066400000000000000000000346011374256237000211260ustar00rootroot00000000000000# Copyright 2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the MongoDB Driver Performance Benchmarking Spec.""" import multiprocessing as mp import os import sys import tempfile import warnings try: import simplejson as json except ImportError: import json sys.path[0:0] = [""] from bson import decode, encode from bson.json_util import loads from gridfs import GridFSBucket from pymongo import MongoClient from pymongo.monotonic import time from test import client_context, host, port, unittest NUM_ITERATIONS = 100 MAX_ITERATION_TIME = 300 NUM_DOCS = 10000 TEST_PATH = os.environ.get('TEST_PATH', os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join('data'))) OUTPUT_FILE = os.environ.get('OUTPUT_FILE') result_data = [] def tearDownModule(): output = json.dumps({ 'results': result_data }, indent=4) if OUTPUT_FILE: with open(OUTPUT_FILE, 'w') as opf: opf.write(output) else: print(output) class Timer(object): def __enter__(self): self.start = time() return self def __exit__(self, *args): self.end = time() self.interval = self.end - self.start class PerformanceTest(object): @classmethod def setUpClass(cls): client_context.init() def setUp(self): pass def tearDown(self): name = self.__class__.__name__ median = self.percentile(50) result = self.data_size / median print('Running %s. MEDIAN=%s' % (self.__class__.__name__, self.percentile(50))) result_data.append({ 'name': name, 'results': { '1': { 'ops_per_sec': result } } }) def before(self): pass def after(self): pass def percentile(self, percentile): if hasattr(self, 'results'): sorted_results = sorted(self.results) percentile_index = int(len(sorted_results) * percentile / 100) - 1 return sorted_results[percentile_index] else: self.fail('Test execution failed') def runTest(self): results = [] start = time() self.max_iterations = NUM_ITERATIONS for i in range(NUM_ITERATIONS): if time() - start > MAX_ITERATION_TIME: warnings.warn('Test timed out, completed %s iterations.' % i) break self.before() with Timer() as timer: self.do_task() self.after() results.append(timer.interval) self.results = results # BSON MICRO-BENCHMARKS class BsonEncodingTest(PerformanceTest): def setUp(self): # Location of test data. with open( os.path.join(TEST_PATH, os.path.join('extended_bson', self.dataset))) as data: self.document = loads(data.read()) def do_task(self): for _ in range(NUM_DOCS): encode(self.document) class BsonDecodingTest(PerformanceTest): def setUp(self): # Location of test data. with open( os.path.join(TEST_PATH, os.path.join('extended_bson', self.dataset))) as data: self.document = encode(json.loads(data.read())) def do_task(self): for _ in range(NUM_DOCS): decode(self.document) class TestFlatEncoding(BsonEncodingTest, unittest.TestCase): dataset = 'flat_bson.json' data_size = 75310000 class TestFlatDecoding(BsonDecodingTest, unittest.TestCase): dataset = 'flat_bson.json' data_size = 75310000 class TestDeepEncoding(BsonEncodingTest, unittest.TestCase): dataset = 'deep_bson.json' data_size = 19640000 class TestDeepDecoding(BsonDecodingTest, unittest.TestCase): dataset = 'deep_bson.json' data_size = 19640000 class TestFullEncoding(BsonEncodingTest, unittest.TestCase): dataset = 'full_bson.json' data_size = 57340000 class TestFullDecoding(BsonDecodingTest, unittest.TestCase): dataset = 'full_bson.json' data_size = 57340000 # SINGLE-DOC BENCHMARKS class TestRunCommand(PerformanceTest, unittest.TestCase): data_size = 160000 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') def do_task(self): command = self.client.perftest.command for _ in range(NUM_DOCS): command("ismaster") class TestDocument(PerformanceTest): def setUp(self): # Location of test data. with open( os.path.join( TEST_PATH, os.path.join( 'single_and_multi_document', self.dataset)), 'r') as data: self.document = json.loads(data.read()) self.client = client_context.client self.client.drop_database('perftest') def tearDown(self): super(TestDocument, self).tearDown() self.client.drop_database('perftest') def before(self): self.corpus = self.client.perftest.create_collection('corpus') def after(self): self.client.perftest.drop_collection('corpus') class TestFindOneByID(TestDocument, unittest.TestCase): data_size = 16220000 def setUp(self): self.dataset = 'tweet.json' super(TestFindOneByID, self).setUp() documents = [self.document.copy() for _ in range(NUM_DOCS)] self.corpus = self.client.perftest.corpus result = self.corpus.insert_many(documents) self.inserted_ids = result.inserted_ids def do_task(self): find_one = self.corpus.find_one for _id in self.inserted_ids: find_one({'_id': _id}) def before(self): pass def after(self): pass class TestSmallDocInsertOne(TestDocument, unittest.TestCase): data_size = 2750000 def setUp(self): self.dataset = 'small_doc.json' super(TestSmallDocInsertOne, self).setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] def do_task(self): insert_one = self.corpus.insert_one for doc in self.documents: insert_one(doc) class TestLargeDocInsertOne(TestDocument, unittest.TestCase): data_size = 27310890 def setUp(self): self.dataset = 'large_doc.json' super(TestLargeDocInsertOne, self).setUp() self.documents = [self.document.copy() for _ in range(10)] def do_task(self): insert_one = self.corpus.insert_one for doc in self.documents: insert_one(doc) # MULTI-DOC BENCHMARKS class TestFindManyAndEmptyCursor(TestDocument, unittest.TestCase): data_size = 16220000 def setUp(self): self.dataset = 'tweet.json' super(TestFindManyAndEmptyCursor, self).setUp() for _ in range(10): self.client.perftest.command( 'insert', 'corpus', documents=[self.document] * 1000) self.corpus = self.client.perftest.corpus def do_task(self): list(self.corpus.find()) def before(self): pass def after(self): pass class TestSmallDocBulkInsert(TestDocument, unittest.TestCase): data_size = 2750000 def setUp(self): self.dataset = 'small_doc.json' super(TestSmallDocBulkInsert, self).setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] def before(self): self.corpus = self.client.perftest.create_collection('corpus') def do_task(self): self.corpus.insert_many(self.documents, ordered=True) class TestLargeDocBulkInsert(TestDocument, unittest.TestCase): data_size = 27310890 def setUp(self): self.dataset = 'large_doc.json' super(TestLargeDocBulkInsert, self).setUp() self.documents = [self.document.copy() for _ in range(10)] def before(self): self.corpus = self.client.perftest.create_collection('corpus') def do_task(self): self.corpus.insert_many(self.documents, ordered=True) class TestGridFsUpload(PerformanceTest, unittest.TestCase): data_size = 52428800 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') gridfs_path = os.path.join( TEST_PATH, os.path.join('single_and_multi_document', 'gridfs_large.bin')) with open(gridfs_path, 'rb') as data: self.document = data.read() self.bucket = GridFSBucket(self.client.perftest) def tearDown(self): super(TestGridFsUpload, self).tearDown() self.client.drop_database('perftest') def before(self): self.bucket.upload_from_stream('init', b'x') def do_task(self): self.bucket.upload_from_stream('gridfstest', self.document) class TestGridFsDownload(PerformanceTest, unittest.TestCase): data_size = 52428800 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') gridfs_path = os.path.join( TEST_PATH, os.path.join('single_and_multi_document', 'gridfs_large.bin')) self.bucket = GridFSBucket(self.client.perftest) with open(gridfs_path, 'rb') as gfile: self.uploaded_id = self.bucket.upload_from_stream( 'gridfstest', gfile) def tearDown(self): super(TestGridFsDownload, self).tearDown() self.client.drop_database('perftest') def do_task(self): self.bucket.open_download_stream(self.uploaded_id).read() proc_client = None def proc_init(*dummy): global proc_client proc_client = MongoClient(host, port) # PARALLEL BENCHMARKS def mp_map(map_func, files): pool = mp.Pool(initializer=proc_init) pool.map(map_func, files) pool.close() def insert_json_file(filename): with open(filename, 'r') as data: coll = proc_client.perftest.corpus coll.insert_many([json.loads(line) for line in data]) def insert_json_file_with_file_id(filename): documents = [] with open(filename, 'r') as data: for line in data: doc = json.loads(line) doc['file'] = filename documents.append(doc) coll = proc_client.perftest.corpus coll.insert_many(documents) def read_json_file(filename): coll = proc_client.perftest.corpus temp = tempfile.TemporaryFile() try: temp.writelines( [json.dumps(doc) + '\n' for doc in coll.find({'file': filename}, {'_id': False})]) finally: temp.close() def insert_gridfs_file(filename): bucket = GridFSBucket(proc_client.perftest) with open(filename, 'rb') as gfile: bucket.upload_from_stream(filename, gfile) def read_gridfs_file(filename): bucket = GridFSBucket(proc_client.perftest) temp = tempfile.TemporaryFile() try: bucket.download_to_stream_by_name(filename, temp) finally: temp.close() class TestJsonMultiImport(PerformanceTest, unittest.TestCase): data_size = 565000000 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') def before(self): self.client.perftest.command({'create': 'corpus'}) self.corpus = self.client.perftest.corpus ldjson_path = os.path.join( TEST_PATH, os.path.join('parallel', 'ldjson_multi')) self.files = [os.path.join( ldjson_path, s) for s in os.listdir(ldjson_path)] def do_task(self): mp_map(insert_json_file, self.files) def after(self): self.client.perftest.drop_collection('corpus') def tearDown(self): super(TestJsonMultiImport, self).tearDown() self.client.drop_database('perftest') class TestJsonMultiExport(PerformanceTest, unittest.TestCase): data_size = 565000000 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') self.client.perfest.corpus.create_index('file') ldjson_path = os.path.join( TEST_PATH, os.path.join('parallel', 'ldjson_multi')) self.files = [os.path.join( ldjson_path, s) for s in os.listdir(ldjson_path)] mp_map(insert_json_file_with_file_id, self.files) def do_task(self): mp_map(read_json_file, self.files) def tearDown(self): super(TestJsonMultiExport, self).tearDown() self.client.drop_database('perftest') class TestGridFsMultiFileUpload(PerformanceTest, unittest.TestCase): data_size = 262144000 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') def before(self): self.client.perftest.drop_collection('fs.files') self.client.perftest.drop_collection('fs.chunks') self.bucket = GridFSBucket(self.client.perftest) gridfs_path = os.path.join( TEST_PATH, os.path.join('parallel', 'gridfs_multi')) self.files = [os.path.join( gridfs_path, s) for s in os.listdir(gridfs_path)] def do_task(self): mp_map(insert_gridfs_file, self.files) def tearDown(self): super(TestGridFsMultiFileUpload, self).tearDown() self.client.drop_database('perftest') class TestGridFsMultiFileDownload(PerformanceTest, unittest.TestCase): data_size = 262144000 def setUp(self): self.client = client_context.client self.client.drop_database('perftest') bucket = GridFSBucket(self.client.perftest) gridfs_path = os.path.join( TEST_PATH, os.path.join('parallel', 'gridfs_multi')) self.files = [os.path.join( gridfs_path, s) for s in os.listdir(gridfs_path)] for fname in self.files: with open(fname, 'rb') as gfile: bucket.upload_from_stream(fname, gfile) def do_task(self): mp_map(read_gridfs_file, self.files) def tearDown(self): super(TestGridFsMultiFileDownload, self).tearDown() self.client.drop_database('perftest') if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/pymongo_mocks.py000066400000000000000000000160571374256237000175230ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for mocking parts of PyMongo to test other parts.""" import contextlib from functools import partial import weakref from pymongo import common from pymongo import MongoClient from pymongo.errors import AutoReconnect, NetworkTimeout from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.pool import Pool from pymongo.server_description import ServerDescription from test import client_context class MockPool(Pool): def __init__(self, client, pair, *args, **kwargs): # MockPool gets a 'client' arg, regular pools don't. Weakref it to # avoid cycle with __del__, causing ResourceWarnings in Python 3.3. self.client = weakref.proxy(client) self.mock_host, self.mock_port = pair # Actually connect to the default server. Pool.__init__(self, (client_context.host, client_context.port), *args, **kwargs) @contextlib.contextmanager def get_socket(self, all_credentials, checkout=False): client = self.client host_and_port = '%s:%s' % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: raise AutoReconnect('mock error') assert host_and_port in ( client.mock_standalones + client.mock_members + client.mock_mongoses), "bad host: %s" % host_and_port with Pool.get_socket(self, all_credentials, checkout) as sock_info: sock_info.mock_host = self.mock_host sock_info.mock_port = self.mock_port yield sock_info class MockMonitor(Monitor): def __init__( self, client, server_description, topology, pool, topology_settings): # MockMonitor gets a 'client' arg, regular monitors don't. Weakref it # to avoid cycles. self.client = weakref.proxy(client) Monitor.__init__( self, server_description, topology, pool, topology_settings) def _check_once(self): client = self.client address = self._server_description.address response, rtt = client.mock_is_master('%s:%d' % address) return ServerDescription(address, IsMaster(response), rtt) class MockClient(MongoClient): def __init__( self, standalones, members, mongoses, ismaster_hosts=None, *args, **kwargs): """A MongoClient connected to the default server, with a mock topology. standalones, members, mongoses determine the configuration of the topology. They are formatted like ['a:1', 'b:2']. ismaster_hosts provides an alternative host list for the server's mocked ismaster response; see test_connect_with_internal_ips. """ self.mock_standalones = standalones[:] self.mock_members = members[:] if self.mock_members: self.mock_primary = self.mock_members[0] else: self.mock_primary = None if ismaster_hosts is not None: self.mock_ismaster_hosts = ismaster_hosts else: self.mock_ismaster_hosts = members[:] self.mock_mongoses = mongoses[:] # Hosts that should raise socket errors. self.mock_down_hosts = [] # Hostname -> (min wire version, max wire version) self.mock_wire_versions = {} # Hostname -> max write batch size self.mock_max_write_batch_sizes = {} # Hostname -> round trip time self.mock_rtts = {} kwargs['_pool_class'] = partial(MockPool, self) kwargs['_monitor_class'] = partial(MockMonitor, self) client_options = client_context.default_client_options.copy() client_options.update(kwargs) super(MockClient, self).__init__(*args, **client_options) def kill_host(self, host): """Host is like 'a:1'.""" self.mock_down_hosts.append(host) def revive_host(self, host): """Host is like 'a:1'.""" self.mock_down_hosts.remove(host) def set_wire_version_range(self, host, min_version, max_version): self.mock_wire_versions[host] = (min_version, max_version) def set_max_write_batch_size(self, host, size): self.mock_max_write_batch_sizes[host] = size def mock_is_master(self, host): """Return mock ismaster response (a dict) and round trip time.""" if host in self.mock_wire_versions: min_wire_version, max_wire_version = self.mock_wire_versions[host] else: min_wire_version = common.MIN_SUPPORTED_WIRE_VERSION max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION max_write_batch_size = self.mock_max_write_batch_sizes.get( host, common.MAX_WRITE_BATCH_SIZE) rtt = self.mock_rtts.get(host, 0) # host is like 'a:1'. if host in self.mock_down_hosts: raise NetworkTimeout('mock timeout') elif host in self.mock_standalones: response = { 'ok': 1, 'ismaster': True, 'minWireVersion': min_wire_version, 'maxWireVersion': max_wire_version, 'maxWriteBatchSize': max_write_batch_size} elif host in self.mock_members: ismaster = (host == self.mock_primary) # Simulate a replica set member. response = { 'ok': 1, 'ismaster': ismaster, 'secondary': not ismaster, 'setName': 'rs', 'hosts': self.mock_ismaster_hosts, 'minWireVersion': min_wire_version, 'maxWireVersion': max_wire_version, 'maxWriteBatchSize': max_write_batch_size} if self.mock_primary: response['primary'] = self.mock_primary elif host in self.mock_mongoses: response = { 'ok': 1, 'ismaster': True, 'minWireVersion': min_wire_version, 'maxWireVersion': max_wire_version, 'msg': 'isdbgrid', 'maxWriteBatchSize': max_write_batch_size} else: # In test_internal_ips(), we try to connect to a host listed # in ismaster['hosts'] but not publicly accessible. raise AutoReconnect('Unknown host: %s' % host) return response, rtt def _process_periodic_tasks(self): # Avoid the background thread causing races, e.g. a surprising # reconnect while we're trying to test a disconnected client. pass pymongo-3.11.0/test/qcheck.py000066400000000000000000000167641374256237000161020ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random import traceback import datetime import re import sys sys.path[0:0] = [""] from bson.binary import Binary from bson.dbref import DBRef from bson.objectid import ObjectId from bson.py3compat import MAXSIZE, PY3, iteritems from bson.son import SON if PY3: unichr = chr gen_target = 100 reduction_attempts = 10 examples = 5 def lift(value): return lambda: value def choose_lifted(generator_list): return lambda: random.choice(generator_list) def my_map(generator, function): return lambda: function(generator()) def choose(list): return lambda: random.choice(list)() def gen_range(start, stop): return lambda: random.randint(start, stop) def gen_int(): max_int = 2147483647 return lambda: random.randint(-max_int - 1, max_int) def gen_float(): return lambda: (random.random() - 0.5) * MAXSIZE def gen_boolean(): return lambda: random.choice([True, False]) def gen_printable_char(): return lambda: chr(random.randint(32, 126)) def gen_printable_string(gen_length): return lambda: "".join(gen_list(gen_printable_char(), gen_length)()) if PY3: def gen_char(set=None): return lambda: bytes([random.randint(0, 255)]) else: def gen_char(set=None): return lambda: chr(random.randint(0, 255)) def gen_string(gen_length): return lambda: b"".join(gen_list(gen_char(), gen_length)()) def gen_unichar(): return lambda: unichr(random.randint(1, 0xFFF)) def gen_unicode(gen_length): return lambda: u"".join([x for x in gen_list(gen_unichar(), gen_length)() if x not in ".$"]) def gen_list(generator, gen_length): return lambda: [generator() for _ in range(gen_length())] def gen_datetime(): return lambda: datetime.datetime(random.randint(1970, 2037), random.randint(1, 12), random.randint(1, 28), random.randint(0, 23), random.randint(0, 59), random.randint(0, 59), random.randint(0, 999) * 1000) def gen_dict(gen_key, gen_value, gen_length): def a_dict(gen_key, gen_value, length): result = {} for _ in range(length): result[gen_key()] = gen_value() return result return lambda: a_dict(gen_key, gen_value, gen_length()) def gen_regexp(gen_length): # TODO our patterns only consist of one letter. # this is because of a bug in CPython's regex equality testing, # which I haven't quite tracked down, so I'm just ignoring it... pattern = lambda: u"".join(gen_list(choose_lifted(u"a"), gen_length)()) def gen_flags(): flags = 0 if random.random() > 0.5: flags = flags | re.IGNORECASE if random.random() > 0.5: flags = flags | re.MULTILINE if random.random() > 0.5: flags = flags | re.VERBOSE return flags return lambda: re.compile(pattern(), gen_flags()) def gen_objectid(): return lambda: ObjectId() def gen_dbref(): collection = gen_unicode(gen_range(0, 20)) return lambda: DBRef(collection(), gen_mongo_value(1, True)()) def gen_mongo_value(depth, ref): bintype = Binary if PY3: # If we used Binary in python3 tests would fail since we # decode BSON binary subtype 0 to bytes. Testing this with # bytes in python3 makes a lot more sense. bintype = bytes choices = [gen_unicode(gen_range(0, 50)), gen_printable_string(gen_range(0, 50)), my_map(gen_string(gen_range(0, 1000)), bintype), gen_int(), gen_float(), gen_boolean(), gen_datetime(), gen_objectid(), lift(None)] if ref: choices.append(gen_dbref()) if depth > 0: choices.append(gen_mongo_list(depth, ref)) choices.append(gen_mongo_dict(depth, ref)) return choose(choices) def gen_mongo_list(depth, ref): return gen_list(gen_mongo_value(depth - 1, ref), gen_range(0, 10)) def gen_mongo_dict(depth, ref=True): return my_map(gen_dict(gen_unicode(gen_range(0, 20)), gen_mongo_value(depth - 1, ref), gen_range(0, 10)), SON) def simplify(case): # TODO this is a hack if isinstance(case, SON) and "$ref" not in case: simplified = SON(case) # make a copy! if random.choice([True, False]): # delete simplified_keys = list(simplified) if not len(simplified_keys): return (False, case) simplified.pop(random.choice(simplified_keys)) return (True, simplified) else: # simplify a value simplified_items = list(iteritems(simplified)) if not len(simplified_items): return (False, case) (key, value) = random.choice(simplified_items) (success, value) = simplify(value) simplified[key] = value return (success, success and simplified or case) if isinstance(case, list): simplified = list(case) if random.choice([True, False]): # delete if not len(simplified): return (False, case) simplified.pop(random.randrange(len(simplified))) return (True, simplified) else: # simplify an item if not len(simplified): return (False, case) index = random.randrange(len(simplified)) (success, value) = simplify(simplified[index]) simplified[index] = value return (success, success and simplified or case) return (False, case) def reduce(case, predicate, reductions=0): for _ in range(reduction_attempts): (reduced, simplified) = simplify(case) if reduced and not predicate(simplified): return reduce(simplified, predicate, reductions + 1) return (reductions, case) def isnt(predicate): return lambda x: not predicate(x) def check(predicate, generator): counter_examples = [] for _ in range(gen_target): case = generator() try: if not predicate(case): reduction = reduce(case, predicate) counter_examples.append("after %s reductions: %r" % reduction) except: counter_examples.append("%r : %s" % (case, traceback.format_exc())) return counter_examples def check_unittest(test, predicate, generator): counter_examples = check(predicate, generator) if counter_examples: failures = len(counter_examples) message = "\n".join([" -> %s" % f for f in counter_examples[:examples]]) message = ("found %d counter examples, displaying first %d:\n%s" % (failures, min(failures, examples), message)) test.fail(message) pymongo-3.11.0/test/test_auth.py000066400000000000000000000717521374256237000166420ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Authentication Tests.""" import os import sys import threading try: from urllib.parse import quote_plus except ImportError: # Python 2 from urllib import quote_plus sys.path[0:0] = [""] from pymongo import MongoClient, monitoring from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple from pymongo.errors import OperationFailure from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP from test import client_context, SkipTest, unittest, Version from test.utils import (delay, ignore_deprecations, single_client, rs_or_single_client, rs_or_single_client_noauth, single_client_noauth, WhiteListEventListener) # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get('GSSAPI_HOST') GSSAPI_PORT = int(os.environ.get('GSSAPI_PORT', '27017')) GSSAPI_PRINCIPAL = os.environ.get('GSSAPI_PRINCIPAL') GSSAPI_SERVICE_NAME = os.environ.get('GSSAPI_SERVICE_NAME', 'mongodb') GSSAPI_CANONICALIZE = os.environ.get('GSSAPI_CANONICALIZE', 'false') GSSAPI_SERVICE_REALM = os.environ.get('GSSAPI_SERVICE_REALM') GSSAPI_PASS = os.environ.get('GSSAPI_PASS') GSSAPI_DB = os.environ.get('GSSAPI_DB', 'test') SASL_HOST = os.environ.get('SASL_HOST') SASL_PORT = int(os.environ.get('SASL_PORT', '27017')) SASL_USER = os.environ.get('SASL_USER') SASL_PASS = os.environ.get('SASL_PASS') SASL_DB = os.environ.get('SASL_DB', '$external') class AutoAuthenticateThread(threading.Thread): """Used in testing threaded authentication. This does collection.find_one() with a 1-second delay to ensure it must check out and authenticate multiple sockets from the pool concurrently. :Parameters: `collection`: An auth-protected collection containing one document. """ def __init__(self, collection): super(AutoAuthenticateThread, self).__init__() self.collection = collection self.success = False def run(self): assert self.collection.find_one({'$where': delay(1)}) is not None self.success = True class DBAuthenticateThread(threading.Thread): """Used in testing threaded authentication. This does db.test.find_one() with a 1-second delay to ensure it must check out and authenticate multiple sockets from the pool concurrently. :Parameters: `db`: An auth-protected db with a 'test' collection containing one document. """ def __init__(self, db, username, password): super(DBAuthenticateThread, self).__init__() self.db = db self.username = username self.password = password self.success = False def run(self): self.db.authenticate(self.username, self.password) assert self.db.test.find_one({'$where': delay(1)}) is not None self.success = True class TestGSSAPI(unittest.TestCase): @classmethod def setUpClass(cls): if not HAVE_KERBEROS: raise SkipTest('Kerberos module not available.') if not GSSAPI_HOST or not GSSAPI_PRINCIPAL: raise SkipTest( 'Must set GSSAPI_HOST and GSSAPI_PRINCIPAL to test GSSAPI') cls.service_realm_required = ( GSSAPI_SERVICE_REALM is not None and GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL) mech_properties = 'SERVICE_NAME:%s' % (GSSAPI_SERVICE_NAME,) mech_properties += ( ',CANONICALIZE_HOST_NAME:%s' % (GSSAPI_CANONICALIZE,)) if GSSAPI_SERVICE_REALM is not None: mech_properties += ',SERVICE_REALM:%s' % (GSSAPI_SERVICE_REALM,) cls.mech_properties = mech_properties def test_credentials_hashing(self): # GSSAPI credentials are properly hashed. creds0 = _build_credentials_tuple( 'GSSAPI', None, 'user', 'pass', {}, None) creds1 = _build_credentials_tuple( 'GSSAPI', None, 'user', 'pass', {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None) creds2 = _build_credentials_tuple( 'GSSAPI', None, 'user', 'pass', {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None) creds3 = _build_credentials_tuple( 'GSSAPI', None, 'user', 'pass', {'authmechanismproperties': {'SERVICE_NAME': 'B'}}, None) self.assertEqual(1, len(set([creds1, creds2]))) self.assertEqual(3, len(set([creds0, creds1, creds2, creds3]))) @ignore_deprecations def test_gssapi_simple(self): if GSSAPI_PASS is not None: uri = ('mongodb://%s:%s@%s:%d/?authMechanism=' 'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL), GSSAPI_PASS, GSSAPI_HOST, GSSAPI_PORT)) else: uri = ('mongodb://%s@%s:%d/?authMechanism=' 'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL), GSSAPI_HOST, GSSAPI_PORT)) if not self.service_realm_required: # Without authMechanismProperties. client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI') client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. client = MongoClient(uri) client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI', authMechanismProperties=self.mech_properties) client[GSSAPI_DB].collection.find_one() # Log in using URI, with authMechanismProperties. mech_uri = uri + '&authMechanismProperties=%s' % (self.mech_properties,) client = MongoClient(mech_uri) client[GSSAPI_DB].collection.find_one() set_name = client.admin.command('ismaster').get('setName') if set_name: if not self.service_realm_required: # Without authMechanismProperties client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI', replicaSet=set_name) client[GSSAPI_DB].list_collection_names() uri = uri + '&replicaSet=%s' % (str(set_name),) client = MongoClient(uri) client[GSSAPI_DB].list_collection_names() # With authMechanismProperties client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI', authMechanismProperties=self.mech_properties, replicaSet=set_name) client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + '&replicaSet=%s' % (str(set_name),) client = MongoClient(mech_uri) client[GSSAPI_DB].list_collection_names() @ignore_deprecations def test_gssapi_threaded(self): client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI', authMechanismProperties=self.mech_properties) # Authentication succeeded? client.server_info() db = client[GSSAPI_DB] # Need one document in the collection. AutoAuthenticateThread does # collection.find_one with a 1-second delay, forcing it to check out # multiple sockets from the pool concurrently, proving that # auto-authentication works with GSSAPI. collection = db.test if not collection.count_documents({}): try: collection.drop() collection.insert_one({'_id': 1}) except OperationFailure: raise SkipTest("User must be able to write.") threads = [] for _ in range(4): threads.append(AutoAuthenticateThread(collection)) for thread in threads: thread.start() for thread in threads: thread.join() self.assertTrue(thread.success) set_name = client.admin.command('ismaster').get('setName') if set_name: client = MongoClient(GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, password=GSSAPI_PASS, authMechanism='GSSAPI', authMechanismProperties=self.mech_properties, replicaSet=set_name) # Succeeded? client.server_info() threads = [] for _ in range(4): threads.append(AutoAuthenticateThread(collection)) for thread in threads: thread.start() for thread in threads: thread.join() self.assertTrue(thread.success) class TestSASLPlain(unittest.TestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest('Must set SASL_HOST, ' 'SASL_USER, and SASL_PASS to test SASL') def test_sasl_plain(self): client = MongoClient(SASL_HOST, SASL_PORT, username=SASL_USER, password=SASL_PASS, authSource=SASL_DB, authMechanism='PLAIN') client.ldap.test.find_one() uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' 'authSource=%s' % (quote_plus(SASL_USER), quote_plus(SASL_PASS), SASL_HOST, SASL_PORT, SASL_DB)) client = MongoClient(uri) client.ldap.test.find_one() set_name = client.admin.command('ismaster').get('setName') if set_name: client = MongoClient(SASL_HOST, SASL_PORT, replicaSet=set_name, username=SASL_USER, password=SASL_PASS, authSource=SASL_DB, authMechanism='PLAIN') client.ldap.test.find_one() uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' 'authSource=%s;replicaSet=%s' % (quote_plus(SASL_USER), quote_plus(SASL_PASS), SASL_HOST, SASL_PORT, SASL_DB, str(set_name))) client = MongoClient(uri) client.ldap.test.find_one() def test_sasl_plain_bad_credentials(self): with ignore_deprecations(): client = MongoClient(SASL_HOST, SASL_PORT) # Bad username self.assertRaises(OperationFailure, client.ldap.authenticate, 'not-user', SASL_PASS, SASL_DB, 'PLAIN') self.assertRaises(OperationFailure, client.ldap.test.find_one) self.assertRaises(OperationFailure, client.ldap.test.insert_one, {"failed": True}) # Bad password self.assertRaises(OperationFailure, client.ldap.authenticate, SASL_USER, 'not-pwd', SASL_DB, 'PLAIN') self.assertRaises(OperationFailure, client.ldap.test.find_one) self.assertRaises(OperationFailure, client.ldap.test.insert_one, {"failed": True}) def auth_string(user, password): uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' 'authSource=%s' % (quote_plus(user), quote_plus(password), SASL_HOST, SASL_PORT, SASL_DB)) return uri bad_user = MongoClient(auth_string('not-user', SASL_PASS)) bad_pwd = MongoClient(auth_string(SASL_USER, 'not-pwd')) # OperationFailure raised upon connecting. self.assertRaises(OperationFailure, bad_user.admin.command, 'ismaster') self.assertRaises(OperationFailure, bad_pwd.admin.command, 'ismaster') class TestSCRAMSHA1(unittest.TestCase): @client_context.require_auth @client_context.require_version_min(2, 7, 2) def setUp(self): # Before 2.7.7, SCRAM-SHA-1 had to be enabled from the command line. if client_context.version < Version(2, 7, 7): cmd_line = client_context.cmd_line if 'SCRAM-SHA-1' not in cmd_line.get( 'parsed', {}).get('setParameter', {}).get('authenticationMechanisms', ''): raise SkipTest('SCRAM-SHA-1 mechanism not enabled') client_context.create_user( 'pymongo_test', 'user', 'pass', roles=['userAdmin', 'readWrite']) def tearDown(self): client_context.drop_user('pymongo_test', 'user') def test_scram_sha1(self): host, port = client_context.host, client_context.port with ignore_deprecations(): client = rs_or_single_client_noauth() self.assertTrue(client.pymongo_test.authenticate( 'user', 'pass', mechanism='SCRAM-SHA-1')) client.pymongo_test.command('dbstats') client = rs_or_single_client_noauth( 'mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' % (host, port)) client.pymongo_test.command('dbstats') if client_context.is_rs: uri = ('mongodb://user:pass' '@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' '&replicaSet=%s' % (host, port, client_context.replica_set_name)) client = single_client_noauth(uri) client.pymongo_test.command('dbstats') db = client.get_database( 'pymongo_test', read_preference=ReadPreference.SECONDARY) db.command('dbstats') class TestSCRAM(unittest.TestCase): @client_context.require_auth @client_context.require_version_min(3, 7, 2) def setUp(self): self._SENSITIVE_COMMANDS = monitoring._SENSITIVE_COMMANDS monitoring._SENSITIVE_COMMANDS = set([]) self.listener = WhiteListEventListener("saslStart") def tearDown(self): monitoring._SENSITIVE_COMMANDS = self._SENSITIVE_COMMANDS client_context.client.testscram.command("dropAllUsersFromDatabase") client_context.client.drop_database("testscram") def test_scram_skip_empty_exchange(self): listener = WhiteListEventListener("saslStart", "saslContinue") client_context.create_user( 'testscram', 'sha256', 'pwd', roles=['dbOwner'], mechanisms=['SCRAM-SHA-256']) client = rs_or_single_client_noauth( username='sha256', password='pwd', authSource='testscram', event_listeners=[listener]) client.testscram.command('dbstats') if client_context.version < (4, 4, -1): # Assert we sent the skipEmptyExchange option. first_event = listener.results['started'][0] self.assertEqual(first_event.command_name, 'saslStart') self.assertEqual( first_event.command['options'], {'skipEmptyExchange': True}) # Assert the third exchange was skipped on servers that support it. # Note that the first exchange occurs on the connection handshake. started = listener.started_command_names() if client_context.version.at_least(4, 4, -1): self.assertEqual(started, ['saslContinue']) else: self.assertEqual( started, ['saslStart', 'saslContinue', 'saslContinue']) @ignore_deprecations def test_scram(self): host, port = client_context.host, client_context.port client_context.create_user( 'testscram', 'sha1', 'pwd', roles=['dbOwner'], mechanisms=['SCRAM-SHA-1']) client_context.create_user( 'testscram', 'sha256', 'pwd', roles=['dbOwner'], mechanisms=['SCRAM-SHA-256']) client_context.create_user( 'testscram', 'both', 'pwd', roles=['dbOwner'], mechanisms=['SCRAM-SHA-1', 'SCRAM-SHA-256']) client = rs_or_single_client_noauth( event_listeners=[self.listener]) self.assertTrue( client.testscram.authenticate('sha1', 'pwd')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( 'sha1', 'pwd', mechanism='SCRAM-SHA-1')) client.testscram.command('dbstats') client.testscram.logout() self.assertRaises( OperationFailure, client.testscram.authenticate, 'sha1', 'pwd', mechanism='SCRAM-SHA-256') self.assertTrue( client.testscram.authenticate('sha256', 'pwd')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( 'sha256', 'pwd', mechanism='SCRAM-SHA-256')) client.testscram.command('dbstats') client.testscram.logout() self.assertRaises( OperationFailure, client.testscram.authenticate, 'sha256', 'pwd', mechanism='SCRAM-SHA-1') self.listener.results.clear() self.assertTrue( client.testscram.authenticate('both', 'pwd')) started = self.listener.results['started'][0] self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256') client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( 'both', 'pwd', mechanism='SCRAM-SHA-256')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( 'both', 'pwd', mechanism='SCRAM-SHA-1')) client.testscram.command('dbstats') client.testscram.logout() self.assertRaises( OperationFailure, client.testscram.authenticate, 'not-a-user', 'pwd') if HAVE_STRINGPREP: # Test the use of SASLprep on passwords. For example, # saslprep(u'\u2136') becomes u'IV' and saslprep(u'I\u00ADX') # becomes u'IX'. SASLprep is only supported when the standard # library provides stringprep. client_context.create_user( 'testscram', u'\u2168', u'\u2163', roles=['dbOwner'], mechanisms=['SCRAM-SHA-256']) client_context.create_user( 'testscram', u'IX', u'IX', roles=['dbOwner'], mechanisms=['SCRAM-SHA-256']) self.assertTrue( client.testscram.authenticate(u'\u2168', u'\u2163')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( u'\u2168', u'\u2163', mechanism='SCRAM-SHA-256')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate(u'\u2168', u'IV')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate(u'IX', u'I\u00ADX')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate( u'IX', u'I\u00ADX', mechanism='SCRAM-SHA-256')) client.testscram.command('dbstats') client.testscram.logout() self.assertTrue( client.testscram.authenticate(u'IX', u'IX')) client.testscram.command('dbstats') client.testscram.logout() client = rs_or_single_client_noauth( u'mongodb://\u2168:\u2163@%s:%d/testscram' % (host, port)) client.testscram.command('dbstats') client = rs_or_single_client_noauth( u'mongodb://\u2168:IV@%s:%d/testscram' % (host, port)) client.testscram.command('dbstats') client = rs_or_single_client_noauth( u'mongodb://IX:I\u00ADX@%s:%d/testscram' % (host, port)) client.testscram.command('dbstats') client = rs_or_single_client_noauth( u'mongodb://IX:IX@%s:%d/testscram' % (host, port)) client.testscram.command('dbstats') self.listener.results.clear() client = rs_or_single_client_noauth( 'mongodb://both:pwd@%s:%d/testscram' % (host, port), event_listeners=[self.listener]) client.testscram.command('dbstats') if client_context.version.at_least(4, 4, -1): # Speculative authentication in 4.4+ sends saslStart with the # handshake. self.assertEqual(self.listener.results['started'], []) else: started = self.listener.results['started'][0] self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256') client = rs_or_single_client_noauth( 'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-1' % (host, port)) client.testscram.command('dbstats') client = rs_or_single_client_noauth( 'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-256' % (host, port)) client.testscram.command('dbstats') if client_context.is_rs: uri = ('mongodb://both:pwd@%s:%d/testscram' '?replicaSet=%s' % (host, port, client_context.replica_set_name)) client = single_client_noauth(uri) client.testscram.command('dbstats') db = client.get_database( 'testscram', read_preference=ReadPreference.SECONDARY) db.command('dbstats') def test_cache(self): client = single_client() # Force authentication. client.admin.command('ismaster') all_credentials = client._MongoClient__all_credentials credentials = all_credentials.get('admin') cache = credentials.cache self.assertIsNotNone(cache) data = cache.data self.assertIsNotNone(data) self.assertEqual(len(data), 4) ckey, skey, salt, iterations = data self.assertIsInstance(ckey, bytes) self.assertIsInstance(skey, bytes) self.assertIsInstance(salt, bytes) self.assertIsInstance(iterations, int) pool = next(iter(client._topology._servers.values()))._pool with pool.get_socket(all_credentials) as sock_info: authset = sock_info.authset cached = set(all_credentials.values()) self.assertEqual(len(cached), 1) self.assertFalse(authset - cached) self.assertFalse(cached - authset) sock_credentials = next(iter(authset)) sock_cache = sock_credentials.cache self.assertIsNotNone(sock_cache) self.assertEqual(sock_cache.data, data) def test_scram_threaded(self): coll = client_context.client.db.test coll.drop() coll.insert_one({'_id': 1}) # The first thread to call find() will authenticate coll = rs_or_single_client().db.test threads = [] for _ in range(4): threads.append(AutoAuthenticateThread(coll)) for thread in threads: thread.start() for thread in threads: thread.join() self.assertTrue(thread.success) class TestThreadedAuth(unittest.TestCase): @client_context.require_auth def test_db_authenticate_threaded(self): db = client_context.client.db coll = db.test coll.drop() coll.insert_one({'_id': 1}) client_context.create_user( 'db', 'user', 'pass', roles=['dbOwner']) self.addCleanup(db.command, 'dropUser', 'user') db = rs_or_single_client_noauth().db db.authenticate('user', 'pass') # No error. db.authenticate('user', 'pass') db = rs_or_single_client_noauth().db threads = [] for _ in range(4): threads.append(DBAuthenticateThread(db, 'user', 'pass')) for thread in threads: thread.start() for thread in threads: thread.join() self.assertTrue(thread.success) class TestAuthURIOptions(unittest.TestCase): @client_context.require_auth def setUp(self): client_context.create_user('admin', 'admin', 'pass') client_context.create_user( 'pymongo_test', 'user', 'pass', ['userAdmin', 'readWrite']) def tearDown(self): client_context.drop_user('pymongo_test', 'user') client_context.drop_user('admin', 'admin') def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port client = rs_or_single_client_noauth( 'mongodb://admin:pass@%s:%d' % (host, port)) self.assertTrue(client.admin.command('dbstats')) if client_context.is_rs: uri = ('mongodb://admin:pass@%s:%d/?replicaSet=%s' % ( host, port, client_context.replica_set_name)) client = single_client_noauth(uri) self.assertTrue(client.admin.command('dbstats')) db = client.get_database( 'admin', read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command('dbstats')) # Test explicit database uri = 'mongodb://user:pass@%s:%d/pymongo_test' % (host, port) client = rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) if client_context.is_rs: uri = ('mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s' % ( host, port, client_context.replica_set_name)) client = single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) db = client.get_database( 'pymongo_test', read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command('dbstats')) # Test authSource uri = ('mongodb://user:pass@%s:%d' '/pymongo_test2?authSource=pymongo_test' % (host, port)) client = rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) if client_context.is_rs: uri = ('mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=' '%s;authSource=pymongo_test' % ( host, port, client_context.replica_set_name)) client = single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, 'dbstats') self.assertTrue(client.pymongo_test.command('dbstats')) db = client.get_database( 'pymongo_test', read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command('dbstats')) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_auth_spec.py000066400000000000000000000101131374256237000176340ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the auth spec tests.""" import glob import json import os import sys sys.path[0:0] = [""] from pymongo import MongoClient from test import unittest _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'auth') class TestAuthSpec(unittest.TestCase): pass def create_test(test_case): def run_test(self): uri = test_case['uri'] valid = test_case['valid'] credential = test_case.get('credential') if not valid: self.assertRaises(Exception, MongoClient, uri, connect=False) else: client = MongoClient(uri, connect=False) credentials = client._MongoClient__options.credentials if credential is None: self.assertIsNone(credentials) else: self.assertIsNotNone(credentials) self.assertEqual(credentials.username, credential['username']) self.assertEqual(credentials.password, credential['password']) self.assertEqual(credentials.source, credential['source']) if credential['mechanism'] is not None: self.assertEqual( credentials.mechanism, credential['mechanism']) else: self.assertEqual(credentials.mechanism, 'DEFAULT') expected = credential['mechanism_properties'] if expected is not None: actual = credentials.mechanism_properties for key, val in expected.items(): if 'SERVICE_NAME' in expected: self.assertEqual( actual.service_name, expected['SERVICE_NAME']) elif 'CANONICALIZE_HOST_NAME' in expected: self.assertEqual( actual.canonicalize_host_name, expected['CANONICALIZE_HOST_NAME']) elif 'SERVICE_REALM' in expected: self.assertEqual( actual.service_realm, expected['SERVICE_REALM']) elif 'AWS_SESSION_TOKEN' in expected: self.assertEqual( actual.aws_session_token, expected['AWS_SESSION_TOKEN']) else: self.fail('Unhandled property: %s' % (key,)) else: if credential['mechanism'] == 'MONGODB-AWS': self.assertIsNone( credentials.mechanism_properties.aws_session_token) else: self.assertIsNone(credentials.mechanism_properties) return run_test def create_tests(): for filename in glob.glob(os.path.join(_TEST_PATH, '*.json')): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as auth_tests: test_cases = json.load(auth_tests)['tests'] for test_case in test_cases: if test_case.get('optional', False): continue test_method = create_test(test_case) name = str(test_case['description'].lower().replace(' ', '_')) setattr( TestAuthSpec, 'test_%s_%s' % (test_suffix, name), test_method) create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_binary.py000066400000000000000000000546421374256237000171640ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Binary wrapper.""" import array import base64 import copy import pickle import platform import sys import uuid sys.path[0:0] = [""] import bson from bson import decode, encode from bson.binary import * from bson.codec_options import CodecOptions from bson.py3compat import PY3 from bson.son import SON from pymongo.common import validate_uuid_representation from pymongo.mongo_client import MongoClient from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest from test.utils import ignore_deprecations class TestBinary(unittest.TestCase): @classmethod def setUpClass(cls): # Generated by the Java driver from_java = ( b'bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu' b'Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND' b'ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+' b'XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1' b'aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR' b'jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA' b'AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z' b'DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf' b'aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx' b'29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My' b'1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB' b'W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp' b'bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc' b'0MQAA') cls.java_data = base64.b64decode(from_java) # Generated by the .net driver from_csharp = ( b'ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl' b'iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2' b'ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V' b'pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl' b'AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A' b'ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z' b'oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU' b'zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn' b'dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA' b'CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT' b'QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP' b'MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00' b'ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=') cls.csharp_data = base64.b64decode(from_csharp) def test_binary(self): a_string = "hello world" a_binary = Binary(b"hello world") self.assertTrue(a_binary.startswith(b"hello")) self.assertTrue(a_binary.endswith(b"world")) self.assertTrue(isinstance(a_binary, Binary)) self.assertFalse(isinstance(a_string, Binary)) def test_exceptions(self): self.assertRaises(TypeError, Binary, None) self.assertRaises(TypeError, Binary, 5) self.assertRaises(TypeError, Binary, 10.2) self.assertRaises(TypeError, Binary, b"hello", None) self.assertRaises(TypeError, Binary, b"hello", "100") self.assertRaises(ValueError, Binary, b"hello", -1) self.assertRaises(ValueError, Binary, b"hello", 256) self.assertTrue(Binary(b"hello", 0)) self.assertTrue(Binary(b"hello", 255)) if platform.python_implementation() != "Jython": # Jython's memoryview accepts unicode strings... # https://bugs.jython.org/issue2784 self.assertRaises(TypeError, Binary, u"hello") def test_subtype(self): one = Binary(b"hello") self.assertEqual(one.subtype, 0) two = Binary(b"hello", 2) self.assertEqual(two.subtype, 2) three = Binary(b"hello", 100) self.assertEqual(three.subtype, 100) def test_equality(self): two = Binary(b"hello") three = Binary(b"hello", 100) self.assertNotEqual(two, three) self.assertEqual(three, Binary(b"hello", 100)) self.assertEqual(two, Binary(b"hello")) self.assertNotEqual(two, Binary(b"hello ")) self.assertNotEqual(b"hello", Binary(b"hello")) # Explicitly test inequality self.assertFalse(three != Binary(b"hello", 100)) self.assertFalse(two != Binary(b"hello")) def test_repr(self): one = Binary(b"hello world") self.assertEqual(repr(one), "Binary(%s, 0)" % (repr(b"hello world"),)) two = Binary(b"hello world", 2) self.assertEqual(repr(two), "Binary(%s, 2)" % (repr(b"hello world"),)) three = Binary(b"\x08\xFF") self.assertEqual(repr(three), "Binary(%s, 0)" % (repr(b"\x08\xFF"),)) four = Binary(b"\x08\xFF", 2) self.assertEqual(repr(four), "Binary(%s, 2)" % (repr(b"\x08\xFF"),)) five = Binary(b"test", 100) self.assertEqual(repr(five), "Binary(%s, 100)" % (repr(b"test"),)) def test_hash(self): one = Binary(b"hello world") two = Binary(b"hello world", 42) self.assertEqual(hash(Binary(b"hello world")), hash(one)) self.assertNotEqual(hash(one), hash(two)) self.assertEqual(hash(Binary(b"hello world", 42)), hash(two)) def test_uuid_subtype_4(self): """uuid_representation should be ignored when decoding subtype 4 for all UuidRepresentation values except UNSPECIFIED.""" expected_uuid = uuid.uuid4() doc = {"uuid": Binary(expected_uuid.bytes, 4)} encoded = encode(doc) for uuid_representation in (set(ALL_UUID_REPRESENTATIONS) - {UuidRepresentation.UNSPECIFIED}): options = CodecOptions(uuid_representation=uuid_representation) self.assertEqual(expected_uuid, decode(encoded, options)["uuid"]) def test_legacy_java_uuid(self): # Test decoding data = self.java_data docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) for d in docs: self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) # Test encoding encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=STANDARD)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) for doc in docs]) self.assertEqual(data, encoded) @client_context.require_connection def test_legacy_java_uuid_roundtrip(self): data = self.java_data docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) client_context.client.pymongo_test.drop_collection('java_uuid') db = client_context.client.pymongo_test coll = db.get_collection( 'java_uuid', CodecOptions(uuid_representation=JAVA_LEGACY)) coll.insert_many(docs) self.assertEqual(5, coll.count_documents({})) for d in coll.find(): self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) coll = db.get_collection( 'java_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): self.assertNotEqual(d['newguid'], d['newguidstring']) client_context.client.pymongo_test.drop_collection('java_uuid') def test_legacy_csharp_uuid(self): data = self.csharp_data # Test decoding docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) for d in docs: self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) for d in docs: self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) # Test encoding encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=STANDARD)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) for doc in docs]) self.assertNotEqual(data, encoded) encoded = b''.join([ encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) for doc in docs]) self.assertEqual(data, encoded) @client_context.require_connection def test_legacy_csharp_uuid_roundtrip(self): data = self.csharp_data docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) client_context.client.pymongo_test.drop_collection('csharp_uuid') db = client_context.client.pymongo_test coll = db.get_collection( 'csharp_uuid', CodecOptions(uuid_representation=CSHARP_LEGACY)) coll.insert_many(docs) self.assertEqual(5, coll.count_documents({})) for d in coll.find(): self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) coll = db.get_collection( 'csharp_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): self.assertNotEqual(d['newguid'], d['newguidstring']) client_context.client.pymongo_test.drop_collection('csharp_uuid') def test_uri_to_uuid(self): uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" client = MongoClient(uri, connect=False) self.assertEqual( client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) @client_context.require_connection @ignore_deprecations def test_uuid_queries(self): db = client_context.client.pymongo_test coll = db.test coll.drop() uu = uuid.uuid4() coll.insert_one({'uuid': Binary(uu.bytes, 3)}) self.assertEqual(1, coll.count_documents({})) # Test UUIDLegacy queries. coll = db.get_collection( "test", CodecOptions( uuid_representation=UuidRepresentation.STANDARD)) self.assertEqual(0, coll.find({'uuid': uu}).count()) cur = coll.find({'uuid': UUIDLegacy(uu)}) self.assertEqual(1, cur.count()) retrieved = next(cur) self.assertEqual(uu, retrieved['uuid']) # Test regular UUID queries (using subtype 4). coll.insert_one({'uuid': uu}) self.assertEqual(2, coll.count_documents({})) cur = coll.find({'uuid': uu}) self.assertEqual(1, cur.count()) retrieved = next(cur) self.assertEqual(uu, retrieved['uuid']) # Test both. predicate = {'uuid': {'$in': [uu, UUIDLegacy(uu)]}} self.assertEqual(2, coll.count_documents(predicate)) cur = coll.find(predicate) self.assertEqual(2, cur.count()) coll.drop() def test_pickle(self): b1 = Binary(b'123', 2) # For testing backwards compatibility with pre-2.4 pymongo if PY3: p = (b"\x80\x03cbson.binary\nBinary\nq\x00C\x03123q\x01\x85q" b"\x02\x81q\x03}q\x04X\x10\x00\x00\x00_Binary__subtypeq" b"\x05K\x02sb.") else: p = (b"ccopy_reg\n_reconstructor\np0\n(cbson.binary\nBinary\np1\nc" b"__builtin__\nstr\np2\nS'123'\np3\ntp4\nRp5\n(dp6\nS'_Binary" b"__subtype'\np7\nI2\nsb.") if not sys.version.startswith('3.0'): self.assertEqual(b1, pickle.loads(p)) for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.assertEqual(b1, pickle.loads(pickle.dumps(b1, proto))) uu = uuid.uuid4() uul = UUIDLegacy(uu) self.assertEqual(uul, copy.copy(uul)) self.assertEqual(uul, copy.deepcopy(uul)) for proto in range(pickle.HIGHEST_PROTOCOL + 1): self.assertEqual(uul, pickle.loads(pickle.dumps(uul, proto))) def test_buffer_protocol(self): b0 = Binary(b'123', 2) self.assertEqual(b0, Binary(memoryview(b'123'), 2)) self.assertEqual(b0, Binary(bytearray(b'123'), 2)) # mmap.mmap and array.array only expose the # buffer interface in python 3.x if PY3: # No mmap module in Jython import mmap with mmap.mmap(-1, len(b'123')) as mm: mm.write(b'123') mm.seek(0) self.assertEqual(b0, Binary(mm, 2)) self.assertEqual(b0, Binary(array.array('B', b'123'), 2)) class TestUuidSpecExplicitCoding(unittest.TestCase): @classmethod def setUpClass(cls): super(TestUuidSpecExplicitCoding, cls).setUpClass() cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF") @staticmethod def _hex_to_bytes(hexstring): if PY3: return bytes.fromhex(hexstring) return hexstring.decode("hex") # Explicit encoding prose test #1 def test_encoding_1(self): obj = Binary.from_uuid(self.uuid) expected_obj = Binary( self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) self.assertEqual(obj, expected_obj) def _test_encoding_w_uuid_rep( self, uuid_rep, expected_hexstring, expected_subtype): obj = Binary.from_uuid(self.uuid, uuid_rep) expected_obj = Binary( self._hex_to_bytes(expected_hexstring), expected_subtype) self.assertEqual(obj, expected_obj) # Explicit encoding prose test #2 def test_encoding_2(self): self._test_encoding_w_uuid_rep( UuidRepresentation.STANDARD, "00112233445566778899AABBCCDDEEFF", 4) # Explicit encoding prose test #3 def test_encoding_3(self): self._test_encoding_w_uuid_rep( UuidRepresentation.JAVA_LEGACY, "7766554433221100FFEEDDCCBBAA9988", 3) # Explicit encoding prose test #4 def test_encoding_4(self): self._test_encoding_w_uuid_rep( UuidRepresentation.CSHARP_LEGACY, "33221100554477668899AABBCCDDEEFF", 3) # Explicit encoding prose test #5 def test_encoding_5(self): self._test_encoding_w_uuid_rep( UuidRepresentation.PYTHON_LEGACY, "00112233445566778899AABBCCDDEEFF", 3) # Explicit encoding prose test #6 def test_encoding_6(self): with self.assertRaises(ValueError): Binary.from_uuid(self.uuid, UuidRepresentation.UNSPECIFIED) # Explicit decoding prose test #1 def test_decoding_1(self): obj = Binary( self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) # Case i: self.assertEqual(obj.as_uuid(), self.uuid) # Case ii: self.assertEqual(obj.as_uuid(UuidRepresentation.STANDARD), self.uuid) # Cases iii-vi: for uuid_rep in (UuidRepresentation.JAVA_LEGACY, UuidRepresentation.CSHARP_LEGACY, UuidRepresentation.PYTHON_LEGACY): with self.assertRaises(ValueError): obj.as_uuid(uuid_rep) def _test_decoding_legacy(self, hexstring, uuid_rep): obj = Binary(self._hex_to_bytes(hexstring), 3) # Case i: with self.assertRaises(ValueError): obj.as_uuid() # Cases ii-iii: for rep in (UuidRepresentation.STANDARD, UuidRepresentation.UNSPECIFIED): with self.assertRaises(ValueError): obj.as_uuid(rep) # Case iv: self.assertEqual(obj.as_uuid(uuid_rep), self.uuid) # Explicit decoding prose test #2 def test_decoding_2(self): self._test_decoding_legacy( "7766554433221100FFEEDDCCBBAA9988", UuidRepresentation.JAVA_LEGACY) # Explicit decoding prose test #3 def test_decoding_3(self): self._test_decoding_legacy( "33221100554477668899AABBCCDDEEFF", UuidRepresentation.CSHARP_LEGACY) # Explicit decoding prose test #4 def test_decoding_4(self): self._test_decoding_legacy( "00112233445566778899AABBCCDDEEFF", UuidRepresentation.PYTHON_LEGACY) class TestUuidSpecImplicitCoding(IntegrationTest): @classmethod def setUpClass(cls): super(TestUuidSpecImplicitCoding, cls).setUpClass() cls.uuid = uuid.UUID("00112233445566778899AABBCCDDEEFF") @staticmethod def _hex_to_bytes(hexstring): if PY3: return bytes.fromhex(hexstring) return hexstring.decode("hex") def _get_coll_w_uuid_rep(self, uuid_rep): codec_options = self.client.codec_options.with_options( uuid_representation=validate_uuid_representation(None, uuid_rep)) coll = self.db.get_collection( 'pymongo_test', codec_options=codec_options, write_concern=WriteConcern("majority")) return coll def _test_encoding(self, uuid_rep, expected_hexstring, expected_subtype): coll = self._get_coll_w_uuid_rep(uuid_rep) coll.delete_many({}) coll.insert_one({'_id': self.uuid}) self.assertTrue( coll.find_one({"_id": Binary( self._hex_to_bytes(expected_hexstring), expected_subtype)})) # Implicit encoding prose test #1 def test_encoding_1(self): self._test_encoding( "javaLegacy", "7766554433221100FFEEDDCCBBAA9988", 3) # Implicit encoding prose test #2 def test_encoding_2(self): self._test_encoding( "csharpLegacy", "33221100554477668899AABBCCDDEEFF", 3) # Implicit encoding prose test #3 def test_encoding_3(self): self._test_encoding( "pythonLegacy", "00112233445566778899AABBCCDDEEFF", 3) # Implicit encoding prose test #4 def test_encoding_4(self): self._test_encoding( "standard", "00112233445566778899AABBCCDDEEFF", 4) # Implicit encoding prose test #5 def test_encoding_5(self): with self.assertRaises(ValueError): self._test_encoding( "unspecifed", "dummy", -1) def _test_decoding(self, client_uuid_representation_string, legacy_field_uuid_representation, expected_standard_field_value, expected_legacy_field_value): coll = self._get_coll_w_uuid_rep(client_uuid_representation_string) coll.drop() standard_val = Binary.from_uuid(self.uuid, UuidRepresentation.STANDARD) legacy_val = Binary.from_uuid(self.uuid, legacy_field_uuid_representation) coll.insert_one({'standard': standard_val, 'legacy': legacy_val}) doc = coll.find_one() self.assertEqual(doc['standard'], expected_standard_field_value) self.assertEqual(doc['legacy'], expected_legacy_field_value) # Implicit decoding prose test #1 def test_decoding_1(self): # TODO: these assertions will change after PYTHON-2245. Specifically, # the 'standard' field will be decoded as a Binary subtype 4. binary_value = Binary.from_uuid( self.uuid, UuidRepresentation.PYTHON_LEGACY) self._test_decoding( "javaLegacy", UuidRepresentation.JAVA_LEGACY, self.uuid, self.uuid) self._test_decoding( "csharpLegacy", UuidRepresentation.CSHARP_LEGACY, self.uuid, self.uuid) self._test_decoding( "pythonLegacy", UuidRepresentation.PYTHON_LEGACY, self.uuid, self.uuid) # Implicit decoding pose test #2 def test_decoding_2(self): # TODO: these assertions will change after PYTHON-2245. Specifically, # the 'legacy' field will be decoded as a Binary subtype 3. binary_value = Binary.from_uuid( self.uuid, UuidRepresentation.PYTHON_LEGACY) self._test_decoding( "standard", UuidRepresentation.PYTHON_LEGACY, self.uuid, binary_value.as_uuid(UuidRepresentation.PYTHON_LEGACY)) # Implicit decoding pose test #3 def test_decoding_3(self): expected_standard_value = Binary.from_uuid( self.uuid, UuidRepresentation.STANDARD) for legacy_uuid_rep in (UuidRepresentation.PYTHON_LEGACY, UuidRepresentation.CSHARP_LEGACY, UuidRepresentation.JAVA_LEGACY): expected_legacy_value = Binary.from_uuid( self.uuid, legacy_uuid_rep) self._test_decoding( "unspecified", legacy_uuid_rep, expected_standard_value, expected_legacy_value) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_bson.py000066400000000000000000001277071374256237000166440ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the bson module.""" import collections import datetime import os import re import sys import tempfile import uuid sys.path[0:0] = [""] import bson from bson import (BSON, decode, decode_all, decode_file_iter, decode_iter, encode, EPOCH_AWARE, is_valid, Regex) from bson.binary import Binary, UUIDLegacy from bson.code import Code from bson.codec_options import CodecOptions from bson.int64 import Int64 from bson.objectid import ObjectId from bson.dbref import DBRef from bson.py3compat import abc, iteritems, PY3, StringIO, text_type from bson.son import SON from bson.timestamp import Timestamp from bson.errors import (InvalidBSON, InvalidDocument, InvalidStringData) from bson.max_key import MaxKey from bson.min_key import MinKey from bson.tz_util import (FixedOffset, utc) from test import qcheck, SkipTest, unittest from test.utils import ExceptionCatchingThread if PY3: long = int class NotADict(abc.MutableMapping): """Non-dict type that implements the mapping protocol.""" def __init__(self, initial=None): if not initial: self._dict = {} else: self._dict = initial def __iter__(self): return iter(self._dict) def __getitem__(self, item): return self._dict[item] def __delitem__(self, item): del self._dict[item] def __setitem__(self, item, value): self._dict[item] = value def __len__(self): return len(self._dict) def __eq__(self, other): if isinstance(other, abc.Mapping): return all(self.get(k) == other.get(k) for k in self) return NotImplemented def __repr__(self): return "NotADict(%s)" % repr(self._dict) class DSTAwareTimezone(datetime.tzinfo): def __init__(self, offset, name, dst_start_month, dst_end_month): self.__offset = offset self.__dst_start_month = dst_start_month self.__dst_end_month = dst_end_month self.__name = name def _is_dst(self, dt): return self.__dst_start_month <= dt.month <= self.__dst_end_month def utcoffset(self, dt): return datetime.timedelta(minutes=self.__offset) + self.dst(dt) def dst(self, dt): if self._is_dst(dt): return datetime.timedelta(hours=1) return datetime.timedelta(0) def tzname(self, dt): return self.__name class TestBSON(unittest.TestCase): def assertInvalid(self, data): self.assertRaises(InvalidBSON, decode, data) def check_encode_then_decode(self, doc_class=dict, decoder=decode, encoder=encode): # Work around http://bugs.jython.org/issue1728 if sys.platform.startswith('java'): doc_class = SON def helper(doc): self.assertEqual(doc, (decoder(encoder(doc_class(doc))))) self.assertEqual(doc, decoder(encoder(doc))) helper({}) helper({"test": u"hello"}) self.assertTrue(isinstance(decoder(encoder( {"hello": "world"}))["hello"], text_type)) helper({"mike": -10120}) helper({"long": Int64(10)}) helper({"really big long": 2147483648}) helper({u"hello": 0.0013109}) helper({"something": True}) helper({"false": False}) helper({"an array": [1, True, 3.8, u"world"]}) helper({"an object": doc_class({"test": u"something"})}) helper({"a binary": Binary(b"test", 100)}) helper({"a binary": Binary(b"test", 128)}) helper({"a binary": Binary(b"test", 254)}) helper({"another binary": Binary(b"test", 2)}) helper(SON([(u'test dst', datetime.datetime(1993, 4, 4, 2))])) helper(SON([(u'test negative dst', datetime.datetime(1, 1, 1, 1, 1, 1))])) helper({"big float": float(10000000000)}) helper({"ref": DBRef("coll", 5)}) helper({"ref": DBRef("coll", 5, foo="bar", bar=4)}) helper({"ref": DBRef("coll", 5, "foo")}) helper({"ref": DBRef("coll", 5, "foo", foo="bar")}) helper({"ref": Timestamp(1, 2)}) helper({"foo": MinKey()}) helper({"foo": MaxKey()}) helper({"$field": Code("function(){ return true; }")}) helper({"$field": Code("return function(){ return x; }", scope={'x': False})}) def encode_then_decode(doc): return doc_class(doc) == decoder(encode(doc), CodecOptions( document_class=doc_class)) qcheck.check_unittest(self, encode_then_decode, qcheck.gen_mongo_dict(3)) def test_encode_then_decode(self): self.check_encode_then_decode() def test_encode_then_decode_any_mapping(self): self.check_encode_then_decode(doc_class=NotADict) def test_encode_then_decode_legacy(self): self.check_encode_then_decode( encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:])) def test_encode_then_decode_any_mapping_legacy(self): self.check_encode_then_decode( doc_class=NotADict, encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:])) def test_encoding_defaultdict(self): dct = collections.defaultdict(dict, [('foo', 'bar')]) encode(dct) self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')])) def test_basic_validation(self): self.assertRaises(TypeError, is_valid, 100) self.assertRaises(TypeError, is_valid, u"test") self.assertRaises(TypeError, is_valid, 10.4) self.assertInvalid(b"test") # the simplest valid BSON document self.assertTrue(is_valid(b"\x05\x00\x00\x00\x00")) self.assertTrue(is_valid(BSON(b"\x05\x00\x00\x00\x00"))) # failure cases self.assertInvalid(b"\x04\x00\x00\x00\x00") self.assertInvalid(b"\x05\x00\x00\x00\x01") self.assertInvalid(b"\x05\x00\x00\x00") self.assertInvalid(b"\x05\x00\x00\x00\x00\x00") self.assertInvalid(b"\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12") self.assertInvalid(b"\x09\x00\x00\x00\x10a\x00\x05\x00") self.assertInvalid(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") self.assertInvalid(b"\x13\x00\x00\x00\x02foo\x00" b"\x04\x00\x00\x00bar\x00\x00") self.assertInvalid(b"\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" b"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00") self.assertInvalid(b"\x15\x00\x00\x00\x03foo\x00\x0c" b"\x00\x00\x00\x08bar\x00\x01\x00\x00") self.assertInvalid(b"\x1c\x00\x00\x00\x03foo\x00" b"\x12\x00\x00\x00\x02bar\x00" b"\x05\x00\x00\x00baz\x00\x00\x00") self.assertInvalid(b"\x10\x00\x00\x00\x02a\x00" b"\x04\x00\x00\x00abc\xff\x00") def test_bad_string_lengths(self): self.assertInvalid( b"\x0c\x00\x00\x00\x02\x00" b"\x00\x00\x00\x00\x00\x00") self.assertInvalid( b"\x12\x00\x00\x00\x02\x00" b"\xff\xff\xff\xfffoobar\x00\x00") self.assertInvalid( b"\x0c\x00\x00\x00\x0e\x00" b"\x00\x00\x00\x00\x00\x00") self.assertInvalid( b"\x12\x00\x00\x00\x0e\x00" b"\xff\xff\xff\xfffoobar\x00\x00") self.assertInvalid( b"\x18\x00\x00\x00\x0c\x00" b"\x00\x00\x00\x00\x00RY\xb5j" b"\xfa[\xd8A\xd6X]\x99\x00") self.assertInvalid( b"\x1e\x00\x00\x00\x0c\x00" b"\xff\xff\xff\xfffoobar\x00" b"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00") self.assertInvalid( b"\x0c\x00\x00\x00\r\x00" b"\x00\x00\x00\x00\x00\x00") self.assertInvalid( b"\x0c\x00\x00\x00\r\x00" b"\xff\xff\xff\xff\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x00\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" b"\x00\x00\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\xff\xff" b"\xff\xff\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" b"\x00\x00\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x00\x00\x00" b"\x00\x00\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\xff\xff\xff" b"\xff\x00\x00\x00") def test_random_data_is_not_bson(self): qcheck.check_unittest(self, qcheck.isnt(is_valid), qcheck.gen_string(qcheck.gen_range(0, 40))) def test_basic_decode(self): self.assertEqual({"test": u"hello world"}, decode(b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74\x00\x0C" b"\x00\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F" b"\x72\x6C\x64\x00\x00")) self.assertEqual([{"test": u"hello world"}, {}], decode_all(b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00")) self.assertEqual([{"test": u"hello world"}, {}], list(decode_iter( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00"))) self.assertEqual([{"test": u"hello world"}, {}], list(decode_file_iter(StringIO( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00")))) def test_decode_all_buffer_protocol(self): docs = [{'foo': 'bar'}, {}] bs = b"".join(map(encode, docs)) self.assertEqual(docs, decode_all(bytearray(bs))) self.assertEqual(docs, decode_all(memoryview(bs))) self.assertEqual(docs, decode_all(memoryview(b'1' + bs + b'1')[1:-1])) if PY3: import array import mmap self.assertEqual(docs, decode_all(array.array('B', bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) self.assertEqual(docs, decode_all(mm)) def test_decode_buffer_protocol(self): doc = {'foo': 'bar'} bs = encode(doc) self.assertEqual(doc, decode(bs)) self.assertEqual(doc, decode(bytearray(bs))) self.assertEqual(doc, decode(memoryview(bs))) self.assertEqual(doc, decode(memoryview(b'1' + bs + b'1')[1:-1])) if PY3: import array import mmap self.assertEqual(doc, decode(array.array('B', bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) self.assertEqual(doc, decode(mm)) def test_invalid_decodes(self): # Invalid object size (not enough bytes in document for even # an object size of first object. # NOTE: decode_all and decode_iter don't care, not sure if they should? self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(b"\x1B"))) bad_bsons = [ # An object size that's too small to even include the object size, # but is correctly encoded, along with a correct EOO (and no data). b"\x01\x00\x00\x00\x00", # One object, but with object size listed smaller than it is in the # data. (b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00"), # One object, missing the EOO at the end. (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00"), # One object, sized correctly, with a spot for an EOO, but the EOO # isn't 0x00. (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\xFF"), ] for i, data in enumerate(bad_bsons): msg = "bad_bson[{}]".format(i) with self.assertRaises(InvalidBSON, msg=msg): decode_all(data) with self.assertRaises(InvalidBSON, msg=msg): list(decode_iter(data)) with self.assertRaises(InvalidBSON, msg=msg): list(decode_file_iter(StringIO(data))) with tempfile.TemporaryFile() as scratch: scratch.write(data) scratch.seek(0, os.SEEK_SET) with self.assertRaises(InvalidBSON, msg=msg): list(decode_file_iter(scratch)) def test_data_timestamp(self): self.assertEqual({"test": Timestamp(4, 20)}, decode(b"\x13\x00\x00\x00\x11\x74\x65\x73\x74\x00\x14" b"\x00\x00\x00\x04\x00\x00\x00\x00")) def test_basic_encode(self): self.assertRaises(TypeError, encode, 100) self.assertRaises(TypeError, encode, "hello") self.assertRaises(TypeError, encode, None) self.assertRaises(TypeError, encode, []) self.assertEqual(encode({}), BSON(b"\x05\x00\x00\x00\x00")) self.assertEqual(encode({}), b"\x05\x00\x00\x00\x00") self.assertEqual(encode({"test": u"hello world"}), b"\x1B\x00\x00\x00\x02\x74\x65\x73\x74\x00\x0C\x00" b"\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F\x72\x6C" b"\x64\x00\x00") self.assertEqual(encode({u"mike": 100}), b"\x0F\x00\x00\x00\x10\x6D\x69\x6B\x65\x00\x64\x00" b"\x00\x00\x00") self.assertEqual(encode({"hello": 1.5}), b"\x14\x00\x00\x00\x01\x68\x65\x6C\x6C\x6F\x00\x00" b"\x00\x00\x00\x00\x00\xF8\x3F\x00") self.assertEqual(encode({"true": True}), b"\x0C\x00\x00\x00\x08\x74\x72\x75\x65\x00\x01\x00") self.assertEqual(encode({"false": False}), b"\x0D\x00\x00\x00\x08\x66\x61\x6C\x73\x65\x00\x00" b"\x00") self.assertEqual(encode({"empty": []}), b"\x11\x00\x00\x00\x04\x65\x6D\x70\x74\x79\x00\x05" b"\x00\x00\x00\x00\x00") self.assertEqual(encode({"none": {}}), b"\x10\x00\x00\x00\x03\x6E\x6F\x6E\x65\x00\x05\x00" b"\x00\x00\x00\x00") self.assertEqual(encode({"test": Binary(b"test", 0)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" b"\x00\x00\x00\x74\x65\x73\x74\x00") self.assertEqual(encode({"test": Binary(b"test", 2)}), b"\x18\x00\x00\x00\x05\x74\x65\x73\x74\x00\x08\x00" b"\x00\x00\x02\x04\x00\x00\x00\x74\x65\x73\x74\x00") self.assertEqual(encode({"test": Binary(b"test", 128)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" b"\x00\x00\x80\x74\x65\x73\x74\x00") self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") self.assertEqual(encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}), b"\x13\x00\x00\x00\x09\x64\x61\x74\x65\x00\x38\xBE" b"\x1C\xFF\x0F\x01\x00\x00\x00") self.assertEqual(encode({"regex": re.compile(b"a*b", re.IGNORECASE)}), b"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61" b"\x2A\x62\x00\x69\x00\x00") self.assertEqual(encode({"$where": Code("test")}), b"\x16\x00\x00\x00\r$where\x00\x05\x00\x00\x00test" b"\x00\x00") self.assertEqual(encode({"$field": Code("function(){ return true;}", scope=None)}), b"+\x00\x00\x00\r$field\x00\x1a\x00\x00\x00" b"function(){ return true;}\x00\x00") self.assertEqual(encode({"$field": Code("return function(){ return x; }", scope={'x': False})}), b"=\x00\x00\x00\x0f$field\x000\x00\x00\x00\x1f\x00" b"\x00\x00return function(){ return x; }\x00\t\x00" b"\x00\x00\x08x\x00\x00\x00\x00") unicode_empty_scope = Code(u"function(){ return 'héllo';}", {}) self.assertEqual(encode({'$field': unicode_empty_scope}), b"8\x00\x00\x00\x0f$field\x00+\x00\x00\x00\x1e\x00" b"\x00\x00function(){ return 'h\xc3\xa9llo';}\x00\x05" b"\x00\x00\x00\x00\x00") a = ObjectId(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B") self.assertEqual(encode({"oid": a}), b"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x00\x01\x02" b"\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00") self.assertEqual(encode({"ref": DBRef("coll", a)}), b"\x2F\x00\x00\x00\x03ref\x00\x25\x00\x00\x00\x02" b"$ref\x00\x05\x00\x00\x00coll\x00\x07$id\x00\x00" b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00" b"\x00") def test_unknown_type(self): # Repr value differs with major python version part = "type %r for fieldname 'foo'" % (b'\x14',) docs = [ b'\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00', (b'\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140' b'\x00\x01\x00\x00\x00\x00\x00'), (b' \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00' b'\x00\x14foo\x00\x01\x00\x00\x00\x00\x00\x00')] for bs in docs: try: decode(bs) except Exception as exc: self.assertTrue(isinstance(exc, InvalidBSON)) self.assertTrue(part in str(exc)) else: self.fail("Failed to raise an exception.") def test_dbpointer(self): # *Note* - DBPointer and DBRef are *not* the same thing. DBPointer # is a deprecated BSON type. DBRef is a convention that does not # exist in the BSON spec, meant to replace DBPointer. PyMongo does # not support creation of the DBPointer type, but will decode # DBPointer to DBRef. bs = (b"\x18\x00\x00\x00\x0c\x00\x01\x00\x00" b"\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00") self.assertEqual({'': DBRef('', ObjectId('5259b56afa5bd841d6585d99'))}, decode(bs)) def test_bad_dbref(self): ref_only = {'ref': {'$ref': 'collection'}} id_only = {'ref': {'$id': ObjectId()}} self.assertEqual(DBRef('collection', id=None), decode(encode(ref_only))['ref']) self.assertEqual(id_only, decode(encode(id_only))) def test_bytes_as_keys(self): doc = {b"foo": 'bar'} # Since `bytes` are stored as Binary you can't use them # as keys in python 3.x. Using binary data as a key makes # no sense in BSON anyway and little sense in python. if PY3: self.assertRaises(InvalidDocument, encode, doc) else: self.assertTrue(encode(doc)) def test_datetime_encode_decode(self): # Negative timestamps dt1 = datetime.datetime(1, 1, 1, 1, 1, 1, 111000) dt2 = decode(encode({"date": dt1}))["date"] self.assertEqual(dt1, dt2) dt1 = datetime.datetime(1959, 6, 25, 12, 16, 59, 999000) dt2 = decode(encode({"date": dt1}))["date"] self.assertEqual(dt1, dt2) # Positive timestamps dt1 = datetime.datetime(9999, 12, 31, 23, 59, 59, 999000) dt2 = decode(encode({"date": dt1}))["date"] self.assertEqual(dt1, dt2) dt1 = datetime.datetime(2011, 6, 14, 10, 47, 53, 444000) dt2 = decode(encode({"date": dt1}))["date"] self.assertEqual(dt1, dt2) def test_large_datetime_truncation(self): # Ensure that a large datetime is truncated correctly. dt1 = datetime.datetime(9999, 1, 1, 1, 1, 1, 999999) dt2 = decode(encode({"date": dt1}))["date"] self.assertEqual(dt2.microsecond, 999000) self.assertEqual(dt2.second, dt1.second) def test_aware_datetime(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), as_utc) after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[ "date"] self.assertEqual(utc, after.tzinfo) self.assertEqual(as_utc, after) def test_local_datetime(self): # Timezone -60 minutes of UTC, with DST between April and July. tz = DSTAwareTimezone(60, "sixty-minutes", 4, 7) # It's not DST. local = datetime.datetime(year=2025, month=12, hour=2, day=1, tzinfo=tz) options = CodecOptions(tz_aware=True, tzinfo=tz) # Encode with this timezone, then decode to UTC. encoded = encode({'date': local}, codec_options=options) self.assertEqual(local.replace(hour=1, tzinfo=None), decode(encoded)['date']) # It's DST. local = datetime.datetime(year=2025, month=4, hour=1, day=1, tzinfo=tz) encoded = encode({'date': local}, codec_options=options) self.assertEqual(local.replace(month=3, day=31, hour=23, tzinfo=None), decode(encoded)['date']) # Encode UTC, then decode in a different timezone. encoded = encode({'date': local.replace(tzinfo=utc)}) decoded = decode(encoded, options)['date'] self.assertEqual(local.replace(hour=3), decoded) self.assertEqual(tz, decoded.tzinfo) # Test round-tripping. self.assertEqual( local, decode(encode( {'date': local}, codec_options=options), options)['date']) # Test around the Unix Epoch. epochs = ( EPOCH_AWARE, EPOCH_AWARE.astimezone(FixedOffset(120, 'one twenty')), EPOCH_AWARE.astimezone(FixedOffset(-120, 'minus one twenty')) ) utc_co = CodecOptions(tz_aware=True) for epoch in epochs: doc = {'epoch': epoch} # We always retrieve datetimes in UTC unless told to do otherwise. self.assertEqual( EPOCH_AWARE, decode(encode(doc), codec_options=utc_co)['epoch']) # Round-trip the epoch. local_co = CodecOptions(tz_aware=True, tzinfo=epoch.tzinfo) self.assertEqual( epoch, decode(encode(doc), codec_options=local_co)['epoch']) def test_naive_decode(self): aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None) self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc) after = decode(encode({"date": aware}))["date"] self.assertEqual(None, after.tzinfo) self.assertEqual(naive_utc, after) def test_dst(self): d = {"x": datetime.datetime(1993, 4, 4, 2)} self.assertEqual(d, decode(encode(d))) def test_bad_encode(self): if not PY3: # Python3 treats this as a unicode string which won't raise # an exception. If we passed the string as bytes instead we # still wouldn't get an error since we store bytes as BSON # binary subtype 0. self.assertRaises(InvalidStringData, encode, {"lalala": '\xf4\xe0\xf0\xe1\xc0 Color Touch'}) # Work around what seems like a regression in python 3.5.0. # See http://bugs.python.org/issue25222 if sys.version_info[:2] < (3, 5): evil_list = {'a': []} evil_list['a'].append(evil_list) evil_dict = {} evil_dict['a'] = evil_dict for evil_data in [evil_dict, evil_list]: self.assertRaises(Exception, encode, evil_data) def test_overflow(self): self.assertTrue(encode({"x": long(9223372036854775807)})) self.assertRaises(OverflowError, encode, {"x": long(9223372036854775808)}) self.assertTrue(encode({"x": long(-9223372036854775808)})) self.assertRaises(OverflowError, encode, {"x": long(-9223372036854775809)}) def test_small_long_encode_decode(self): encoded1 = encode({'x': 256}) decoded1 = decode(encoded1)['x'] self.assertEqual(256, decoded1) self.assertEqual(type(256), type(decoded1)) encoded2 = encode({'x': Int64(256)}) decoded2 = decode(encoded2)['x'] expected = Int64(256) self.assertEqual(expected, decoded2) self.assertEqual(type(expected), type(decoded2)) self.assertNotEqual(type(decoded1), type(decoded2)) def test_tuple(self): self.assertEqual({"tuple": [1, 2]}, decode(encode({"tuple": (1, 2)}))) def test_uuid(self): id = uuid.uuid4() transformed_id = decode(encode({"id": id}))["id"] self.assertTrue(isinstance(transformed_id, uuid.UUID)) self.assertEqual(id, transformed_id) self.assertNotEqual(uuid.uuid4(), transformed_id) def test_uuid_legacy(self): id = uuid.uuid4() legacy = UUIDLegacy(id) self.assertEqual(3, legacy.subtype) transformed = decode(encode({"uuid": legacy}))["uuid"] self.assertTrue(isinstance(transformed, uuid.UUID)) self.assertEqual(id, transformed) self.assertNotEqual(UUIDLegacy(uuid.uuid4()), UUIDLegacy(transformed)) # The C extension was segfaulting on unicode RegExs, so we have this test # that doesn't really test anything but the lack of a segfault. def test_unicode_regex(self): regex = re.compile(u'revisi\xf3n') decode(encode({"regex": regex})) def test_non_string_keys(self): self.assertRaises(InvalidDocument, encode, {8.9: "test"}) def test_utf8(self): w = {u"aéあ": u"aéあ"} self.assertEqual(w, decode(encode(w))) # b'a\xe9' == u"aé".encode("iso-8859-1") iso8859_bytes = b'a\xe9' y = {"hello": iso8859_bytes} if PY3: # Stored as BSON binary subtype 0. out = decode(encode(y)) self.assertTrue(isinstance(out['hello'], bytes)) self.assertEqual(out['hello'], iso8859_bytes) else: # Python 2. try: encode(y) except InvalidStringData as e: self.assertTrue(repr(iso8859_bytes) in str(e)) # The next two tests only make sense in python 2.x since # you can't use `bytes` type as document keys in python 3.x. x = {u"aéあ".encode("utf-8"): u"aéあ".encode("utf-8")} self.assertEqual(w, decode(encode(x))) z = {iso8859_bytes: "hello"} self.assertRaises(InvalidStringData, encode, z) def test_null_character(self): doc = {"a": "\x00"} self.assertEqual(doc, decode(encode(doc))) # This test doesn't make much sense in Python2 # since {'a': '\x00'} == {'a': u'\x00'}. # Decoding here actually returns {'a': '\x00'} doc = {"a": u"\x00"} self.assertEqual(doc, decode(encode(doc))) self.assertRaises(InvalidDocument, encode, {b"\x00": "a"}) self.assertRaises(InvalidDocument, encode, {u"\x00": "a"}) self.assertRaises(InvalidDocument, encode, {"a": re.compile(b"ab\x00c")}) self.assertRaises(InvalidDocument, encode, {"a": re.compile(u"ab\x00c")}) def test_move_id(self): self.assertEqual(b"\x19\x00\x00\x00\x02_id\x00\x02\x00\x00\x00a\x00" b"\x02a\x00\x02\x00\x00\x00a\x00\x00", encode(SON([("a", "a"), ("_id", "a")]))) self.assertEqual(b"\x2c\x00\x00\x00" b"\x02_id\x00\x02\x00\x00\x00b\x00" b"\x03b\x00" b"\x19\x00\x00\x00\x02a\x00\x02\x00\x00\x00a\x00" b"\x02_id\x00\x02\x00\x00\x00a\x00\x00\x00", encode(SON([("b", SON([("a", "a"), ("_id", "a")])), ("_id", "b")]))) def test_dates(self): doc = {"early": datetime.datetime(1686, 5, 5), "late": datetime.datetime(2086, 5, 5)} try: self.assertEqual(doc, decode(encode(doc))) except ValueError: # Ignore ValueError when no C ext, since it's probably # a problem w/ 32-bit Python - we work around this in the # C ext, though. if bson.has_c(): raise def test_custom_class(self): self.assertIsInstance(decode(encode({})), dict) self.assertNotIsInstance(decode(encode({})), SON) self.assertIsInstance( decode(encode({}), CodecOptions(document_class=SON)), SON) self.assertEqual( 1, decode(encode({"x": 1}), CodecOptions(document_class=SON))["x"]) x = encode({"x": [{"y": 1}]}) self.assertIsInstance( decode(x, CodecOptions(document_class=SON))["x"][0], SON) def test_subclasses(self): # make sure we can serialize subclasses of native Python types. class _myint(int): pass class _myfloat(float): pass class _myunicode(text_type): pass d = {'a': _myint(42), 'b': _myfloat(63.9), 'c': _myunicode('hello world') } d2 = decode(encode(d)) for key, value in iteritems(d2): orig_value = d[key] orig_type = orig_value.__class__.__bases__[0] self.assertEqual(type(value), orig_type) self.assertEqual(value, orig_type(value)) def test_ordered_dict(self): try: from collections import OrderedDict except ImportError: raise SkipTest("No OrderedDict") d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)]) self.assertEqual( d, decode(encode(d), CodecOptions(document_class=OrderedDict))) def test_bson_regex(self): # Invalid Python regex, though valid PCRE. bson_re1 = Regex(r'[\w-\.]') self.assertEqual(r'[\w-\.]', bson_re1.pattern) self.assertEqual(0, bson_re1.flags) doc1 = {'r': bson_re1} doc1_bson = ( b'\x11\x00\x00\x00' # document length b'\x0br\x00[\\w-\\.]\x00\x00' # r: regex b'\x00') # document terminator self.assertEqual(doc1_bson, encode(doc1)) self.assertEqual(doc1, decode(doc1_bson)) # Valid Python regex, with flags. re2 = re.compile(u'.*', re.I | re.M | re.S | re.U | re.X) bson_re2 = Regex(u'.*', re.I | re.M | re.S | re.U | re.X) doc2_with_re = {'r': re2} doc2_with_bson_re = {'r': bson_re2} doc2_bson = ( b"\x11\x00\x00\x00" # document length b"\x0br\x00.*\x00imsux\x00" # r: regex b"\x00") # document terminator self.assertEqual(doc2_bson, encode(doc2_with_re)) self.assertEqual(doc2_bson, encode(doc2_with_bson_re)) self.assertEqual(re2.pattern, decode(doc2_bson)['r'].pattern) self.assertEqual(re2.flags, decode(doc2_bson)['r'].flags) def test_regex_from_native(self): self.assertEqual('.*', Regex.from_native(re.compile('.*')).pattern) self.assertEqual(0, Regex.from_native(re.compile(b'')).flags) regex = re.compile(b'', re.I | re.L | re.M | re.S | re.X) self.assertEqual( re.I | re.L | re.M | re.S | re.X, Regex.from_native(regex).flags) unicode_regex = re.compile('', re.U) self.assertEqual(re.U, Regex.from_native(unicode_regex).flags) def test_regex_hash(self): self.assertRaises(TypeError, hash, Regex('hello')) def test_regex_comparison(self): re1 = Regex('a') re2 = Regex('b') self.assertNotEqual(re1, re2) re1 = Regex('a', re.I) re2 = Regex('a', re.M) self.assertNotEqual(re1, re2) re1 = Regex('a', re.I) re2 = Regex('a', re.I) self.assertEqual(re1, re2) def test_exception_wrapping(self): # No matter what exception is raised while trying to decode BSON, # the final exception always matches InvalidBSON. # {'s': '\xff'}, will throw attempting to decode utf-8. bad_doc = b'\x0f\x00\x00\x00\x02s\x00\x03\x00\x00\x00\xff\x00\x00\x00' with self.assertRaises(InvalidBSON) as context: decode_all(bad_doc) self.assertIn("codec can't decode byte 0xff", str(context.exception)) def test_minkey_maxkey_comparison(self): # MinKey's <, <=, >, >=, !=, and ==. self.assertTrue(MinKey() < None) self.assertTrue(MinKey() < 1) self.assertTrue(MinKey() <= 1) self.assertTrue(MinKey() <= MinKey()) self.assertFalse(MinKey() > None) self.assertFalse(MinKey() > 1) self.assertFalse(MinKey() >= 1) self.assertTrue(MinKey() >= MinKey()) self.assertTrue(MinKey() != 1) self.assertFalse(MinKey() == 1) self.assertTrue(MinKey() == MinKey()) # MinKey compared to MaxKey. self.assertTrue(MinKey() < MaxKey()) self.assertTrue(MinKey() <= MaxKey()) self.assertFalse(MinKey() > MaxKey()) self.assertFalse(MinKey() >= MaxKey()) self.assertTrue(MinKey() != MaxKey()) self.assertFalse(MinKey() == MaxKey()) # MaxKey's <, <=, >, >=, !=, and ==. self.assertFalse(MaxKey() < None) self.assertFalse(MaxKey() < 1) self.assertFalse(MaxKey() <= 1) self.assertTrue(MaxKey() <= MaxKey()) self.assertTrue(MaxKey() > None) self.assertTrue(MaxKey() > 1) self.assertTrue(MaxKey() >= 1) self.assertTrue(MaxKey() >= MaxKey()) self.assertTrue(MaxKey() != 1) self.assertFalse(MaxKey() == 1) self.assertTrue(MaxKey() == MaxKey()) # MaxKey compared to MinKey. self.assertFalse(MaxKey() < MinKey()) self.assertFalse(MaxKey() <= MinKey()) self.assertTrue(MaxKey() > MinKey()) self.assertTrue(MaxKey() >= MinKey()) self.assertTrue(MaxKey() != MinKey()) self.assertFalse(MaxKey() == MinKey()) def test_minkey_maxkey_hash(self): self.assertEqual(hash(MaxKey()), hash(MaxKey())) self.assertEqual(hash(MinKey()), hash(MinKey())) self.assertNotEqual(hash(MaxKey()), hash(MinKey())) def test_timestamp_comparison(self): # Timestamp is initialized with time, inc. Time is the more # significant comparand. self.assertTrue(Timestamp(1, 0) < Timestamp(2, 17)) self.assertTrue(Timestamp(2, 0) > Timestamp(1, 0)) self.assertTrue(Timestamp(1, 7) <= Timestamp(2, 0)) self.assertTrue(Timestamp(2, 0) >= Timestamp(1, 1)) self.assertTrue(Timestamp(2, 0) <= Timestamp(2, 0)) self.assertTrue(Timestamp(2, 0) >= Timestamp(2, 0)) self.assertFalse(Timestamp(1, 0) > Timestamp(2, 0)) # Comparison by inc. self.assertTrue(Timestamp(1, 0) < Timestamp(1, 1)) self.assertTrue(Timestamp(1, 1) > Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 1)) self.assertFalse(Timestamp(1, 0) >= Timestamp(1, 1)) self.assertTrue(Timestamp(1, 0) >= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 1) >= Timestamp(1, 0)) self.assertFalse(Timestamp(1, 1) <= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0)) self.assertFalse(Timestamp(1, 0) > Timestamp(1, 0)) def test_timestamp_highorder_bits(self): doc = {'a': Timestamp(0xFFFFFFFF, 0xFFFFFFFF)} doc_bson = (b'\x10\x00\x00\x00' b'\x11a\x00\xff\xff\xff\xff\xff\xff\xff\xff' b'\x00') self.assertEqual(doc_bson, encode(doc)) self.assertEqual(doc, decode(doc_bson)) def test_bad_id_keys(self): self.assertRaises(InvalidDocument, encode, {"_id": {"$bad": 123}}, True) self.assertRaises(InvalidDocument, encode, {"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}, True) encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}) def test_bson_encode_thread_safe(self): def target(i): for j in range(1000): my_int = type('MyInt_%s_%s' % (i, j), (int,), {}) bson.encode({'my_int': my_int()}) threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() for t in threads: self.assertIsNone(t.exc) def test_raise_invalid_document(self): class Wrapper(object): def __init__(self, val): self.val = val def __repr__(self): return repr(self.val) self.assertEqual('1', repr(Wrapper(1))) with self.assertRaisesRegex( InvalidDocument, "cannot encode object: 1, of type: " + repr(Wrapper)): encode({'t': Wrapper(1)}) class TestCodecOptions(unittest.TestCase): def test_document_class(self): self.assertRaises(TypeError, CodecOptions, document_class=object) self.assertIs(SON, CodecOptions(document_class=SON).document_class) def test_tz_aware(self): self.assertRaises(TypeError, CodecOptions, tz_aware=1) self.assertFalse(CodecOptions().tz_aware) self.assertTrue(CodecOptions(tz_aware=True).tz_aware) def test_uuid_representation(self): self.assertRaises(ValueError, CodecOptions, uuid_representation=7) self.assertRaises(ValueError, CodecOptions, uuid_representation=2) def test_tzinfo(self): self.assertRaises(TypeError, CodecOptions, tzinfo='pacific') tz = FixedOffset(42, 'forty-two') self.assertRaises(ValueError, CodecOptions, tzinfo=tz) self.assertEqual(tz, CodecOptions(tz_aware=True, tzinfo=tz).tzinfo) def test_codec_options_repr(self): r = ("CodecOptions(document_class=dict, tz_aware=False, " "uuid_representation=UuidRepresentation.PYTHON_LEGACY, " "unicode_decode_error_handler='strict', " "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], " "fallback_encoder=None))") self.assertEqual(r, repr(CodecOptions())) def test_decode_all_defaults(self): # Test decode_all()'s default document_class is dict and tz_aware is # False. The default uuid_representation is PYTHON_LEGACY but this # decodes same as STANDARD, so all this test proves about UUID decoding # is that it's not CSHARP_LEGACY or JAVA_LEGACY. doc = {'sub_document': {}, 'uuid': uuid.uuid4(), 'dt': datetime.datetime.utcnow()} decoded = bson.decode_all(bson.encode(doc))[0] self.assertIsInstance(decoded['sub_document'], dict) self.assertEqual(decoded['uuid'], doc['uuid']) self.assertIsNone(decoded['dt'].tzinfo) def test_unicode_decode_error_handler(self): enc = encode({"keystr": "foobar"}) # Test handling of bad key value. invalid_key = enc[:7] + b'\xe9' + enc[8:] replaced_key = b'ke\xe9str'.decode('utf-8', 'replace') ignored_key = b'ke\xe9str'.decode('utf-8', 'ignore') dec = decode(invalid_key, CodecOptions(unicode_decode_error_handler="replace")) self.assertEqual(dec, {replaced_key: u"foobar"}) dec = decode(invalid_key, CodecOptions(unicode_decode_error_handler="ignore")) self.assertEqual(dec, {ignored_key: u"foobar"}) self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions( unicode_decode_error_handler="strict")) self.assertRaises(InvalidBSON, decode, invalid_key, CodecOptions()) self.assertRaises(InvalidBSON, decode, invalid_key) # Test handing of bad string value. invalid_val = BSON(enc[:18] + b'\xe9' + enc[19:]) replaced_val = b'fo\xe9bar'.decode('utf-8', 'replace') ignored_val = b'fo\xe9bar'.decode('utf-8', 'ignore') dec = decode(invalid_val, CodecOptions(unicode_decode_error_handler="replace")) self.assertEqual(dec, {u"keystr": replaced_val}) dec = decode(invalid_val, CodecOptions(unicode_decode_error_handler="ignore")) self.assertEqual(dec, {u"keystr": ignored_val}) self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions( unicode_decode_error_handler="strict")) self.assertRaises(InvalidBSON, decode, invalid_val, CodecOptions()) self.assertRaises(InvalidBSON, decode, invalid_val) # Test handing bad key + bad value. invalid_both = enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:] dec = decode(invalid_both, CodecOptions(unicode_decode_error_handler="replace")) self.assertEqual(dec, {replaced_key: replaced_val}) dec = decode(invalid_both, CodecOptions(unicode_decode_error_handler="ignore")) self.assertEqual(dec, {ignored_key: ignored_val}) self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions( unicode_decode_error_handler="strict")) self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions()) self.assertRaises(InvalidBSON, decode, invalid_both) # Test handling bad error mode. dec = decode(enc, CodecOptions(unicode_decode_error_handler="junk")) self.assertEqual(dec, {"keystr": "foobar"}) self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions( unicode_decode_error_handler="junk")) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_bson_corpus.py000066400000000000000000000214301374256237000202210ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the BSON corpus specification tests.""" import binascii import codecs import functools import glob import json import os import sys from decimal import DecimalException sys.path[0:0] = [""] from bson import decode, encode, json_util from bson.binary import STANDARD from bson.codec_options import CodecOptions from bson.decimal128 import Decimal128 from bson.dbref import DBRef from bson.errors import InvalidBSON, InvalidId from bson.json_util import JSONMode from bson.py3compat import text_type, b from bson.son import SON from test import unittest _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'bson_corpus') _TESTS_TO_SKIP = set([ # Python cannot decode dates after year 9999. 'Y10K', ]) _NON_PARSE_ERRORS = set([ # {"$date": } is our legacy format which we still need to parse. 'Bad $date (number, not string or hash)', # This variant of $numberLong may have been generated by an old version # of mongoexport. 'Bad $numberLong (number, not string)', ]) _DEPRECATED_BSON_TYPES = { # Symbol '0x0E': text_type, # Undefined '0x06': type(None), # DBPointer '0x0C': DBRef } # Need to set tz_aware=True in order to use "strict" dates in extended JSON. codec_options = CodecOptions(tz_aware=True, document_class=SON) # We normally encode UUID as binary subtype 0x03, # but we'll need to encode to subtype 0x04 for one of the tests. codec_options_uuid_04 = codec_options._replace(uuid_representation=STANDARD) json_options_uuid_04 = json_util.JSONOptions(json_mode=JSONMode.CANONICAL, uuid_representation=STANDARD) json_options_iso8601 = json_util.JSONOptions( datetime_representation=json_util.DatetimeRepresentation.ISO8601) to_extjson = functools.partial(json_util.dumps, json_options=json_util.CANONICAL_JSON_OPTIONS) to_extjson_uuid_04 = functools.partial(json_util.dumps, json_options=json_options_uuid_04) to_extjson_iso8601 = functools.partial(json_util.dumps, json_options=json_options_iso8601) to_relaxed_extjson = functools.partial( json_util.dumps, json_options=json_util.RELAXED_JSON_OPTIONS) to_bson_uuid_04 = functools.partial(encode, codec_options=codec_options_uuid_04) to_bson = functools.partial(encode, codec_options=codec_options) decode_bson = lambda bbytes: decode(bbytes, codec_options=codec_options) decode_extjson = functools.partial( json_util.loads, json_options=json_util.JSONOptions(json_mode=JSONMode.CANONICAL, document_class=SON)) loads = functools.partial(json.loads, object_pairs_hook=SON) class TestBSONCorpus(unittest.TestCase): def assertJsonEqual(self, first, second, msg=None): """Fail if the two json strings are unequal. Normalize json by parsing it with the built-in json library. This accounts for discrepancies in spacing. """ self.assertEqual(loads(first), loads(second), msg=msg) def create_test(case_spec): bson_type = case_spec['bson_type'] # Test key is absent when testing top-level documents. test_key = case_spec.get('test_key') deprecated = case_spec.get('deprecated') def run_test(self): for valid_case in case_spec.get('valid', []): description = valid_case['description'] if description in _TESTS_TO_SKIP: continue # Special case for testing encoding UUID as binary subtype 0x04. if description == 'subtype 0x04': encode_extjson = to_extjson_uuid_04 encode_bson = to_bson_uuid_04 else: encode_extjson = to_extjson encode_bson = to_bson cB = binascii.unhexlify(b(valid_case['canonical_bson'])) cEJ = valid_case['canonical_extjson'] rEJ = valid_case.get('relaxed_extjson') dEJ = valid_case.get('degenerate_extjson') lossy = valid_case.get('lossy') decoded_bson = decode_bson(cB) if not lossy: # Make sure we can parse the legacy (default) JSON format. legacy_json = json_util.dumps( decoded_bson, json_options=json_util.LEGACY_JSON_OPTIONS) self.assertEqual(decode_extjson(legacy_json), decoded_bson) if deprecated: if 'converted_bson' in valid_case: converted_bson = binascii.unhexlify( b(valid_case['converted_bson'])) self.assertEqual(encode_bson(decoded_bson), converted_bson) self.assertJsonEqual( encode_extjson(decode_bson(converted_bson)), valid_case['converted_extjson']) # Make sure we can decode the type. self.assertEqual(decoded_bson, decode_extjson(cEJ)) if test_key is not None: self.assertIsInstance(decoded_bson[test_key], _DEPRECATED_BSON_TYPES[bson_type]) continue # Jython can't handle NaN with a payload from # struct.(un)pack if endianness is specified in the format string. if not (sys.platform.startswith("java") and description == 'NaN with payload'): # Test round-tripping canonical bson. self.assertEqual(encode_bson(decoded_bson), cB) self.assertJsonEqual(encode_extjson(decoded_bson), cEJ) # Test round-tripping canonical extended json. decoded_json = decode_extjson(cEJ) self.assertJsonEqual(encode_extjson(decoded_json), cEJ) if not lossy: self.assertEqual(encode_bson(decoded_json), cB) # Test round-tripping degenerate bson. if 'degenerate_bson' in valid_case: dB = binascii.unhexlify(b(valid_case['degenerate_bson'])) self.assertEqual(encode_bson(decode_bson(dB)), cB) # Test round-tripping degenerate extended json. if dEJ is not None: decoded_json = decode_extjson(dEJ) self.assertJsonEqual(encode_extjson(decoded_json), cEJ) if not lossy: self.assertEqual(encode_bson(decoded_json), cB) # Test round-tripping relaxed extended json. if rEJ is not None: self.assertJsonEqual(to_relaxed_extjson(decoded_bson), rEJ) decoded_json = decode_extjson(rEJ) self.assertJsonEqual(to_relaxed_extjson(decoded_json), rEJ) for decode_error_case in case_spec.get('decodeErrors', []): with self.assertRaises(InvalidBSON): decode_bson( binascii.unhexlify(b(decode_error_case['bson']))) for parse_error_case in case_spec.get('parseErrors', []): if bson_type == '0x13': self.assertRaises( DecimalException, Decimal128, parse_error_case['string']) elif bson_type == '0x00': description = parse_error_case['description'] if description in _NON_PARSE_ERRORS: decode_extjson(parse_error_case['string']) else: try: decode_extjson(parse_error_case['string']) raise AssertionError('exception not raised for test ' 'case: ' + description) except (ValueError, KeyError, TypeError, InvalidId): pass else: raise AssertionError('cannot test parseErrors for type ' + bson_type) return run_test def create_tests(): for filename in glob.glob(os.path.join(_TEST_PATH, '*.json')): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with codecs.open(filename, encoding='utf-8') as bson_test_file: test_method = create_test(json.load(bson_test_file)) setattr(TestBSONCorpus, 'test_' + test_suffix, test_method) create_tests() if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_bulk.py000066400000000000000000000355121374256237000166300ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the bulk API.""" import sys sys.path[0:0] = [""] from bson.objectid import ObjectId from pymongo.operations import * from pymongo.errors import (ConfigurationError, InvalidOperation, OperationFailure) from pymongo.write_concern import WriteConcern from test import (client_context, unittest, IntegrationTest) from test.utils import (remove_all_users, rs_or_single_client_noauth) class BulkTestBase(IntegrationTest): @classmethod def setUpClass(cls): super(BulkTestBase, cls).setUpClass() cls.coll = cls.db.test ismaster = client_context.client.admin.command('ismaster') cls.has_write_commands = (ismaster.get("maxWireVersion", 0) > 1) def setUp(self): super(BulkTestBase, self).setUp() self.coll.drop() def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" for key, value in expected.items(): if key == 'nModified': if self.has_write_commands: self.assertEqual(value, actual['nModified']) else: # Legacy servers don't include nModified in the response. self.assertFalse('nModified' in actual) elif key == 'upserted': expected_upserts = value actual_upserts = actual['upserted'] self.assertEqual( len(expected_upserts), len(actual_upserts), 'Expected %d elements in "upserted", got %d' % ( len(expected_upserts), len(actual_upserts))) for e, a in zip(expected_upserts, actual_upserts): self.assertEqualUpsert(e, a) elif key == 'writeErrors': expected_errors = value actual_errors = actual['writeErrors'] self.assertEqual( len(expected_errors), len(actual_errors), 'Expected %d elements in "writeErrors", got %d' % ( len(expected_errors), len(actual_errors))) for e, a in zip(expected_errors, actual_errors): self.assertEqualWriteError(e, a) else: self.assertEqual( actual.get(key), value, '%r value of %r does not match expected %r' % (key, actual.get(key), value)) def assertEqualUpsert(self, expected, actual): """Compare bulk.execute()['upserts'] to expected value. Like: {'index': 0, '_id': ObjectId()} """ self.assertEqual(expected['index'], actual['index']) if expected['_id'] == '...': # Unspecified value. self.assertTrue('_id' in actual) else: self.assertEqual(expected['_id'], actual['_id']) def assertEqualWriteError(self, expected, actual): """Compare bulk.execute()['writeErrors'] to expected value. Like: {'index': 0, 'code': 123, 'errmsg': '...', 'op': { ... }} """ self.assertEqual(expected['index'], actual['index']) self.assertEqual(expected['code'], actual['code']) if expected['errmsg'] == '...': # Unspecified value. self.assertTrue('errmsg' in actual) else: self.assertEqual(expected['errmsg'], actual['errmsg']) expected_op = expected['op'].copy() actual_op = actual['op'].copy() if expected_op.get('_id') == '...': # Unspecified _id. self.assertTrue('_id' in actual_op) actual_op.pop('_id') expected_op.pop('_id') self.assertEqual(expected_op, actual_op) class TestBulk(BulkTestBase): def test_empty(self): self.assertRaises(InvalidOperation, self.coll.bulk_write, []) def test_insert(self): expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 1, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } result = self.coll.bulk_write([InsertOne({})]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.inserted_count) self.assertEqual(1, self.coll.count_documents({})) def _test_update_many(self, update): expected = { 'nMatched': 2, 'nModified': 2, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) result = self.coll.bulk_write([UpdateMany({}, update)]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(2, result.matched_count) self.assertTrue(result.modified_count in (2, None)) def test_update_many(self): self._test_update_many({'$set': {'foo': 'bar'}}) @client_context.require_version_min(4, 1, 11) def test_update_many_pipeline(self): self._test_update_many([{'$set': {'foo': 'bar'}}]) @client_context.require_version_max(3, 5, 5) def test_array_filters_unsupported(self): requests = [ UpdateMany( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]), UpdateOne( {}, {'$set': {"y.$[i].b": 2}}, array_filters=[{'i.b': 3}]) ] for bulk_op in requests: self.assertRaises( ConfigurationError, self.coll.bulk_write, [bulk_op]) def test_array_filters_validation(self): self.assertRaises(TypeError, UpdateMany, {}, {}, array_filters={}) self.assertRaises(TypeError, UpdateOne, {}, {}, array_filters={}) def test_array_filters_unacknowledged(self): coll = self.coll.with_options(write_concern=WriteConcern(w=0)) update_one = UpdateOne( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) update_many = UpdateMany( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) self.assertRaises(ConfigurationError, coll.bulk_write, [update_one]) self.assertRaises(ConfigurationError, coll.bulk_write, [update_many]) def _test_update_one(self, update): expected = { 'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) result = self.coll.bulk_write([UpdateOne({}, update)]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (1, None)) def test_update_one(self): self._test_update_one({'$set': {'foo': 'bar'}}) @client_context.require_version_min(4, 1, 11) def test_update_one_pipeline(self): self._test_update_one([{'$set': {'foo': 'bar'}}]) def test_replace_one(self): expected = { 'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) result = self.coll.bulk_write([ReplaceOne({}, {'foo': 'bar'})]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (1, None)) def test_remove(self): # Test removing all documents, ordered. expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 2, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) result = self.coll.bulk_write([DeleteMany({})]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(2, result.deleted_count) def test_remove_one(self): # Test removing one document, empty selector. self.coll.insert_many([{}, {}]) expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 1, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } result = self.coll.bulk_write([DeleteOne({})]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.deleted_count) self.assertEqual(self.coll.count_documents({}), 1) def test_upsert(self): expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 1, 'nInserted': 0, 'nRemoved': 0, 'upserted': [{'index': 0, '_id': '...'}] } result = self.coll.bulk_write([ReplaceOne({}, {'foo': 'bar'}, upsert=True)]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.upserted_count) self.assertEqual(1, len(result.upserted_ids)) self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId)) self.assertEqual(self.coll.count_documents({'foo': 'bar'}), 1) def test_numerous_inserts(self): # Ensure we don't exceed server's 1000-document batch size limit. n_docs = 2100 requests = [InsertOne({}) for _ in range(n_docs)] result = self.coll.bulk_write(requests, ordered=False) self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) # Same with ordered bulk. self.coll.drop() result = self.coll.bulk_write(requests) self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) @client_context.require_version_min(3, 6) def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {}) _16_MB = 16 * 1000 * 1000 # Generate a list of documents such that the first batched OP_MSG is # as close as possible to the 48MB limit. docs = [ {'_id': 1, 'l': 's' * _16_MB}, {'_id': 2, 'l': 's' * _16_MB}, {'_id': 3, 'l': 's' * (_16_MB - 10000)}, ] # Fill in the remaining ~10000 bytes with small documents. for i in range(4, 10000): docs.append({'_id': i}) result = self.coll.insert_many(docs) self.assertEqual(len(docs), len(result.inserted_ids)) def test_generator_insert(self): def gen(): yield {'a': 1, 'b': 1} yield {'a': 1, 'b': 2} yield {'a': 2, 'b': 3} yield {'a': 3, 'b': 5} yield {'a': 5, 'b': 8} result = self.coll.insert_many(gen()) self.assertEqual(5, len(result.inserted_ids)) def test_bulk_write_no_results(self): coll = self.coll.with_options(write_concern=WriteConcern(w=0)) result = coll.bulk_write([InsertOne({})]) self.assertFalse(result.acknowledged) self.assertRaises(InvalidOperation, lambda: result.inserted_count) self.assertRaises(InvalidOperation, lambda: result.matched_count) self.assertRaises(InvalidOperation, lambda: result.modified_count) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertRaises(InvalidOperation, lambda: result.upserted_count) self.assertRaises(InvalidOperation, lambda: result.upserted_ids) def test_bulk_write_invalid_arguments(self): # The requests argument must be a list. generator = (InsertOne({}) for _ in range(10)) with self.assertRaises(TypeError): self.coll.bulk_write(generator) # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): self.coll.bulk_write([{}]) class BulkAuthorizationTestBase(BulkTestBase): @classmethod @client_context.require_auth def setUpClass(cls): super(BulkAuthorizationTestBase, cls).setUpClass() def setUp(self): super(BulkAuthorizationTestBase, self).setUp() client_context.create_user( self.db.name, 'readonly', 'pw', ['read']) self.db.command( 'createRole', 'noremove', privileges=[{ 'actions': ['insert', 'update', 'find'], 'resource': {'db': 'pymongo_test', 'collection': 'test'} }], roles=[]) client_context.create_user(self.db.name, 'noremove', 'pw', ['noremove']) def tearDown(self): self.db.command('dropRole', 'noremove') remove_all_users(self.db) class TestBulkAuthorization(BulkAuthorizationTestBase): def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. cli = rs_or_single_client_noauth(username='readonly', password='pw', authSource='pymongo_test') coll = cli.pymongo_test.test coll.find_one() self.assertRaises(OperationFailure, coll.bulk_write, [InsertOne({'x': 1})]) def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. cli = rs_or_single_client_noauth(username='noremove', password='pw', authSource='pymongo_test') coll = cli.pymongo_test.test coll.find_one() requests = [ InsertOne({'x': 1}), ReplaceOne({'x': 2}, {'x': 2}, upsert=True), DeleteMany({}), # Prohibited. InsertOne({'x': 3}), # Never attempted. ] self.assertRaises(OperationFailure, coll.bulk_write, requests) self.assertEqual(set([1, 2]), set(self.coll.distinct('x'))) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_change_stream.py000066400000000000000000001515701374256237000204760ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the change_stream module.""" import random import os import re import sys import string import threading import time import uuid from contextlib import contextmanager from itertools import product sys.path[0:0] = [''] from bson import ObjectId, SON, Timestamp, encode, json_util from bson.binary import (ALL_UUID_REPRESENTATIONS, Binary, STANDARD, PYTHON_LEGACY) from bson.py3compat import iteritems from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from pymongo import MongoClient from pymongo.command_cursor import CommandCursor from pymongo.errors import (InvalidOperation, OperationFailure, ServerSelectionTimeoutError) from pymongo.message import _CursorAddress from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest from test.utils import ( EventListener, WhiteListEventListener, rs_or_single_client, wait_until) class TestChangeStreamBase(IntegrationTest): def change_stream_with_client(self, client, *args, **kwargs): """Create a change stream using the given client and return it.""" raise NotImplementedError def change_stream(self, *args, **kwargs): """Create a change stream using the default client and return it.""" return self.change_stream_with_client(self.client, *args, **kwargs) def client_with_listener(self, *commands): """Return a client with a WhiteListEventListener.""" listener = WhiteListEventListener(*commands) client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener def watched_collection(self, *args, **kwargs): """Return a collection that is watched by self.change_stream().""" # Construct a unique collection for each test. collname = '.'.join(self.id().rsplit('.', 2)[1:]) return self.db.get_collection(collname, *args, **kwargs) def generate_invalidate_event(self, change_stream): """Cause a change stream invalidate event.""" raise NotImplementedError def generate_unique_collnames(self, numcolls): """Generate numcolls collection names unique to a test.""" collnames = [] for idx in range(1, numcolls + 1): collnames.append(self.id() + '_' + str(idx)) return collnames def get_resume_token(self, invalidate=False): """Get a resume token to use for starting a change stream.""" # Ensure targeted collection exists before starting. coll = self.watched_collection(write_concern=WriteConcern('majority')) coll.insert_one({}) if invalidate: with self.change_stream( [{'$match': {'operationType': 'invalidate'}}]) as cs: if isinstance(cs._target, MongoClient): self.skipTest( "cluster-level change streams cannot be invalidated") self.generate_invalidate_event(cs) return cs.next()['_id'] else: with self.change_stream() as cs: coll.insert_one({'data': 1}) return cs.next()['_id'] def get_start_at_operation_time(self): """Get an operationTime. Advances the operation clock beyond the most recently returned timestamp.""" optime = self.client.admin.command("ping")["operationTime"] return Timestamp(optime.time, optime.inc + 1) def insert_one_and_check(self, change_stream, doc): """Insert a document and check that it shows up in the change stream.""" raise NotImplementedError def kill_change_stream_cursor(self, change_stream): """Cause a cursor not found error on the next getMore.""" cursor = change_stream._cursor address = _CursorAddress(cursor.address, cursor._CommandCursor__ns) client = self.watched_collection().database.client client._close_cursor_now(cursor.cursor_id, address) class APITestsMixin(object): def test_watch(self): with self.change_stream( [{'$project': {'foo': 0}}], full_document='updateLookup', max_await_time_ms=1000, batch_size=100) as change_stream: self.assertEqual([{'$project': {'foo': 0}}], change_stream._pipeline) self.assertEqual('updateLookup', change_stream._full_document) self.assertEqual(1000, change_stream._max_await_time_ms) self.assertEqual(100, change_stream._batch_size) self.assertIsInstance(change_stream._cursor, CommandCursor) self.assertEqual( 1000, change_stream._cursor._CommandCursor__max_await_time_ms) self.watched_collection( write_concern=WriteConcern("majority")).insert_one({}) _ = change_stream.next() resume_token = change_stream.resume_token with self.assertRaises(TypeError): self.change_stream(pipeline={}) with self.assertRaises(TypeError): self.change_stream(full_document={}) # No Error. with self.change_stream(resume_after=resume_token): pass def test_try_next(self): # ChangeStreams only read majority committed data so use w:majority. coll = self.watched_collection().with_options( write_concern=WriteConcern("majority")) coll.drop() coll.insert_one({}) self.addCleanup(coll.drop) with self.change_stream(max_await_time_ms=250) as stream: self.assertIsNone(stream.try_next()) # No changes initially. coll.insert_one({}) # Generate a change. # On sharded clusters, even majority-committed changes only show # up once an event that sorts after it shows up on the other # shard. So, we wait on try_next to eventually return changes. wait_until(lambda: stream.try_next() is not None, "get change from try_next") def test_try_next_runs_one_getmore(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command('ping') listener.results.clear() # ChangeStreams only read majority committed data so use w:majority. coll = self.watched_collection().with_options( write_concern=WriteConcern("majority")) coll.drop() # Create the watched collection before starting the change stream to # skip any "create" events. coll.insert_one({'_id': 1}) self.addCleanup(coll.drop) with self.change_stream_with_client( client, max_await_time_ms=250) as stream: self.assertEqual(listener.started_command_names(), ["aggregate"]) listener.results.clear() # Confirm that only a single getMore is run even when no documents # are returned. self.assertIsNone(stream.try_next()) self.assertEqual(listener.started_command_names(), ["getMore"]) listener.results.clear() self.assertIsNone(stream.try_next()) self.assertEqual(listener.started_command_names(), ["getMore"]) listener.results.clear() # Get at least one change before resuming. coll.insert_one({'_id': 2}) wait_until(lambda: stream.try_next() is not None, "get change from try_next") listener.results.clear() # Cause the next request to initiate the resume process. self.kill_change_stream_cursor(stream) listener.results.clear() # The sequence should be: # - getMore, fail # - resume with aggregate command # - no results, return immediately without another getMore self.assertIsNone(stream.try_next()) self.assertEqual( listener.started_command_names(), ["getMore", "aggregate"]) listener.results.clear() # Stream still works after a resume. coll.insert_one({'_id': 3}) wait_until(lambda: stream.try_next() is not None, "get change from try_next") self.assertEqual(set(listener.started_command_names()), set(["getMore"])) self.assertIsNone(stream.try_next()) def test_batch_size_is_honored(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command('ping') listener.results.clear() # ChangeStreams only read majority committed data so use w:majority. coll = self.watched_collection().with_options( write_concern=WriteConcern("majority")) coll.drop() # Create the watched collection before starting the change stream to # skip any "create" events. coll.insert_one({'_id': 1}) self.addCleanup(coll.drop) # Expected batchSize. expected = {'batchSize': 23} with self.change_stream_with_client( client, max_await_time_ms=250, batch_size=23) as stream: # Confirm that batchSize is honored for initial batch. cmd = listener.results['started'][0].command self.assertEqual(cmd['cursor'], expected) listener.results.clear() # Confirm that batchSize is honored by getMores. self.assertIsNone(stream.try_next()) cmd = listener.results['started'][0].command key = next(iter(expected)) self.assertEqual(expected[key], cmd[key]) # $changeStream.startAtOperationTime was added in 4.0.0. @client_context.require_version_min(4, 0, 0) def test_start_at_operation_time(self): optime = self.get_start_at_operation_time() coll = self.watched_collection( write_concern=WriteConcern("majority")) ndocs = 3 coll.insert_many([{"data": i} for i in range(ndocs)]) with self.change_stream(start_at_operation_time=optime) as cs: for i in range(ndocs): cs.next() def _test_full_pipeline(self, expected_cs_stage): client, listener = self.client_with_listener("aggregate") results = listener.results with self.change_stream_with_client( client, [{'$project': {'foo': 0}}]) as _: pass self.assertEqual(1, len(results['started'])) command = results['started'][0] self.assertEqual('aggregate', command.command_name) self.assertEqual([ {'$changeStream': expected_cs_stage}, {'$project': {'foo': 0}}], command.command['pipeline']) def test_full_pipeline(self): """$changeStream must be the first stage in a change stream pipeline sent to the server. """ self._test_full_pipeline({}) def test_iteration(self): with self.change_stream(batch_size=2) as change_stream: num_inserted = 10 self.watched_collection().insert_many( [{} for _ in range(num_inserted)]) inserts_received = 0 for change in change_stream: self.assertEqual(change['operationType'], 'insert') inserts_received += 1 if inserts_received == num_inserted: break self._test_invalidate_stops_iteration(change_stream) def _test_next_blocks(self, change_stream): inserted_doc = {'_id': ObjectId()} changes = [] t = threading.Thread( target=lambda: changes.append(change_stream.next())) t.start() # Sleep for a bit to prove that the call to next() blocks. time.sleep(1) self.assertTrue(t.is_alive()) self.assertFalse(changes) self.watched_collection().insert_one(inserted_doc) # Join with large timeout to give the server time to return the change, # in particular for shard clusters. t.join(30) self.assertFalse(t.is_alive()) self.assertEqual(1, len(changes)) self.assertEqual(changes[0]['operationType'], 'insert') self.assertEqual(changes[0]['fullDocument'], inserted_doc) def test_next_blocks(self): """Test that next blocks until a change is readable""" # Use a short await time to speed up the test. with self.change_stream(max_await_time_ms=250) as change_stream: self._test_next_blocks(change_stream) def test_aggregate_cursor_blocks(self): """Test that an aggregate cursor blocks until a change is readable.""" with self.watched_collection().aggregate( [{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream: self._test_next_blocks(change_stream) def test_concurrent_close(self): """Ensure a ChangeStream can be closed from another thread.""" # Use a short await time to speed up the test. with self.change_stream(max_await_time_ms=250) as change_stream: def iterate_cursor(): for _ in change_stream: pass t = threading.Thread(target=iterate_cursor) t.start() self.watched_collection().insert_one({}) time.sleep(1) change_stream.close() t.join(3) self.assertFalse(t.is_alive()) def test_unknown_full_document(self): """Must rely on the server to raise an error on unknown fullDocument. """ try: with self.change_stream(full_document='notValidatedByPyMongo'): pass except OperationFailure: pass def test_change_operations(self): """Test each operation type.""" expected_ns = {'db': self.watched_collection().database.name, 'coll': self.watched_collection().name} with self.change_stream() as change_stream: # Insert. inserted_doc = {'_id': ObjectId(), 'foo': 'bar'} self.watched_collection().insert_one(inserted_doc) change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['ns'], expected_ns) self.assertEqual(change['fullDocument'], inserted_doc) # Update. update_spec = {'$set': {'new': 1}, '$unset': {'foo': 1}} self.watched_collection().update_one(inserted_doc, update_spec) change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'update') self.assertEqual(change['ns'], expected_ns) self.assertNotIn('fullDocument', change) self.assertEqual({'updatedFields': {'new': 1}, 'removedFields': ['foo']}, change['updateDescription']) # Replace. self.watched_collection().replace_one({'new': 1}, {'foo': 'bar'}) change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'replace') self.assertEqual(change['ns'], expected_ns) self.assertEqual(change['fullDocument'], inserted_doc) # Delete. self.watched_collection().delete_one({'foo': 'bar'}) change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'delete') self.assertEqual(change['ns'], expected_ns) self.assertNotIn('fullDocument', change) # Invalidate. self._test_get_invalidate_event(change_stream) @client_context.require_version_min(4, 1, 1) def test_start_after(self): resume_token = self.get_resume_token(invalidate=True) # resume_after cannot resume after invalidate. with self.assertRaises(OperationFailure): self.change_stream(resume_after=resume_token) # start_after can resume after invalidate. with self.change_stream(start_after=resume_token) as change_stream: self.watched_collection().insert_one({'_id': 2}) change = change_stream.next() self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 2}) @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_with_changes(self): resume_token = self.get_resume_token(invalidate=True) with self.change_stream(start_after=resume_token, max_await_time_ms=250) as change_stream: self.watched_collection().insert_one({'_id': 2}) change = change_stream.next() self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 2}) self.assertIsNone(change_stream.try_next()) self.kill_change_stream_cursor(change_stream) self.watched_collection().insert_one({'_id': 3}) change = change_stream.next() self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 3}) @client_context.require_no_mongos # Remove after SERVER-41196 @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_without_changes(self): resume_token = self.get_resume_token(invalidate=True) with self.change_stream(start_after=resume_token, max_await_time_ms=250) as change_stream: self.assertIsNone(change_stream.try_next()) self.kill_change_stream_cursor(change_stream) self.watched_collection().insert_one({'_id': 2}) change = change_stream.next() self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['fullDocument'], {'_id': 2}) class ProseSpecTestsMixin(object): def _client_with_listener(self, *commands): listener = WhiteListEventListener(*commands) client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3): self.watched_collection().insert_many( [{"data": k} for k in range(batch_size)]) for _ in range(batch_size): change = next(change_stream) return change def _get_expected_resume_token_legacy(self, stream, listener, previous_change=None): """Predicts what the resume token should currently be for server versions that don't support postBatchResumeToken. Assumes the stream has never returned any changes if previous_change is None.""" if previous_change is None: agg_cmd = listener.results['started'][0] stage = agg_cmd.command["pipeline"][0]["$changeStream"] return stage.get("resumeAfter") or stage.get("startAfter") return previous_change['_id'] def _get_expected_resume_token(self, stream, listener, previous_change=None): """Predicts what the resume token should currently be for server versions that support postBatchResumeToken. Assumes the stream has never returned any changes if previous_change is None. Assumes listener is a WhiteListEventListener that listens for aggregate and getMore commands.""" if previous_change is None or stream._cursor._has_next(): token = self._get_expected_resume_token_legacy( stream, listener, previous_change) if token is not None: return token response = listener.results['succeeded'][-1].reply return response['cursor']['postBatchResumeToken'] def _test_raises_error_on_missing_id(self, expected_exception): """ChangeStream will raise an exception if the server response is missing the resume token. """ with self.change_stream([{'$project': {'_id': 0}}]) as change_stream: self.watched_collection().insert_one({}) with self.assertRaises(expected_exception): next(change_stream) # The cursor should now be closed. with self.assertRaises(StopIteration): next(change_stream) def _test_update_resume_token(self, expected_rt_getter): """ChangeStream must continuously track the last seen resumeToken.""" client, listener = self._client_with_listener("aggregate", "getMore") coll = self.watched_collection(write_concern=WriteConcern('majority')) with self.change_stream_with_client(client) as change_stream: self.assertEqual( change_stream.resume_token, expected_rt_getter(change_stream, listener)) for _ in range(3): coll.insert_one({}) change = next(change_stream) self.assertEqual( change_stream.resume_token, expected_rt_getter(change_stream, listener, change)) # Prose test no. 1 @client_context.require_version_min(4, 0, 7) def test_update_resume_token(self): self._test_update_resume_token(self._get_expected_resume_token) # Prose test no. 1 @client_context.require_version_max(4, 0, 7) def test_update_resume_token_legacy(self): self._test_update_resume_token(self._get_expected_resume_token_legacy) # Prose test no. 2 @client_context.require_version_max(4, 3, 3) # PYTHON-2120 @client_context.require_version_min(4, 1, 8) def test_raises_error_on_missing_id_418plus(self): # Server returns an error on 4.1.8+ self._test_raises_error_on_missing_id(OperationFailure) # Prose test no. 2 @client_context.require_version_max(4, 1, 8) def test_raises_error_on_missing_id_418minus(self): # PyMongo raises an error self._test_raises_error_on_missing_id(InvalidOperation) # Prose test no. 3 def test_resume_on_error(self): with self.change_stream() as change_stream: self.insert_one_and_check(change_stream, {'_id': 1}) # Cause a cursor not found error on the next getMore. self.kill_change_stream_cursor(change_stream) self.insert_one_and_check(change_stream, {'_id': 2}) # Prose test no. 4 @client_context.require_failCommand_fail_point def test_no_resume_attempt_if_aggregate_command_fails(self): # Set non-retryable error on aggregate command. fail_point = {'mode': {'times': 1}, 'data': {'errorCode': 2, 'failCommands': ['aggregate']}} client, listener = self._client_with_listener("aggregate", "getMore") with self.fail_point(fail_point): try: _ = self.change_stream_with_client(client) except OperationFailure: pass # Driver should have attempted aggregate command only once. self.assertEqual(len(listener.results['started']), 1) self.assertEqual(listener.results['started'][0].command_name, 'aggregate') # Prose test no. 5 - REMOVED # Prose test no. 6 - SKIPPED # Reason: readPreference is not configurable using the watch() helpers # so we can skip this test. Also, PyMongo performs server selection for # each operation which ensure compliance with this prose test. # Prose test no. 7 def test_initial_empty_batch(self): with self.change_stream() as change_stream: # The first batch should be empty. self.assertFalse(change_stream._cursor._has_next()) cursor_id = change_stream._cursor.cursor_id self.assertTrue(cursor_id) self.insert_one_and_check(change_stream, {}) # Make sure we're still using the same cursor. self.assertEqual(cursor_id, change_stream._cursor.cursor_id) # Prose test no. 8 def test_kill_cursors(self): def raise_error(): raise ServerSelectionTimeoutError('mock error') with self.change_stream() as change_stream: self.insert_one_and_check(change_stream, {'_id': 1}) # Cause a cursor not found error on the next getMore. cursor = change_stream._cursor self.kill_change_stream_cursor(change_stream) cursor.close = raise_error self.insert_one_and_check(change_stream, {'_id': 2}) # Prose test no. 9 @client_context.require_version_min(4, 0, 0) @client_context.require_version_max(4, 0, 7) def test_start_at_operation_time_caching(self): # Case 1: change stream not started with startAtOperationTime client, listener = self.client_with_listener("aggregate") with self.change_stream_with_client(client) as cs: self.kill_change_stream_cursor(cs) cs.try_next() cmd = listener.results['started'][-1].command self.assertIsNotNone(cmd["pipeline"][0]["$changeStream"].get( "startAtOperationTime")) # Case 2: change stream started with startAtOperationTime listener.results.clear() optime = self.get_start_at_operation_time() with self.change_stream_with_client( client, start_at_operation_time=optime) as cs: self.kill_change_stream_cursor(cs) cs.try_next() cmd = listener.results['started'][-1].command self.assertEqual(cmd["pipeline"][0]["$changeStream"].get( "startAtOperationTime"), optime, str([k.command for k in listener.results['started']])) # Prose test no. 10 - SKIPPED # This test is identical to prose test no. 3. # Prose test no. 11 @client_context.require_version_min(4, 0, 7) def test_resumetoken_empty_batch(self): client, listener = self._client_with_listener("getMore") with self.change_stream_with_client(client) as change_stream: self.assertIsNone(change_stream.try_next()) resume_token = change_stream.resume_token response = listener.results['succeeded'][0].reply self.assertEqual(resume_token, response["cursor"]["postBatchResumeToken"]) # Prose test no. 11 @client_context.require_version_min(4, 0, 7) def test_resumetoken_exhausted_batch(self): client, listener = self._client_with_listener("getMore") with self.change_stream_with_client(client) as change_stream: self._populate_and_exhaust_change_stream(change_stream) resume_token = change_stream.resume_token response = listener.results['succeeded'][-1].reply self.assertEqual(resume_token, response["cursor"]["postBatchResumeToken"]) # Prose test no. 12 @client_context.require_version_max(4, 0, 7) def test_resumetoken_empty_batch_legacy(self): resume_point = self.get_resume_token() # Empty resume token when neither resumeAfter or startAfter specified. with self.change_stream() as change_stream: change_stream.try_next() self.assertIsNone(change_stream.resume_token) # Resume token value is same as resumeAfter. with self.change_stream(resume_after=resume_point) as change_stream: change_stream.try_next() resume_token = change_stream.resume_token self.assertEqual(resume_token, resume_point) # Prose test no. 12 @client_context.require_version_max(4, 0, 7) def test_resumetoken_exhausted_batch_legacy(self): # Resume token is _id of last change. with self.change_stream() as change_stream: change = self._populate_and_exhaust_change_stream(change_stream) self.assertEqual(change_stream.resume_token, change["_id"]) resume_point = change['_id'] # Resume token is _id of last change even if resumeAfter is specified. with self.change_stream(resume_after=resume_point) as change_stream: change = self._populate_and_exhaust_change_stream(change_stream) self.assertEqual(change_stream.resume_token, change["_id"]) # Prose test no. 13 def test_resumetoken_partially_iterated_batch(self): # When batch has been iterated up to but not including the last element. # Resume token should be _id of previous change document. with self.change_stream() as change_stream: self.watched_collection( write_concern=WriteConcern('majority')).insert_many( [{"data": k} for k in range(3)]) for _ in range(2): change = next(change_stream) resume_token = change_stream.resume_token self.assertEqual(resume_token, change["_id"]) def _test_resumetoken_uniterated_nonempty_batch(self, resume_option): # When the batch is not empty and hasn't been iterated at all. # Resume token should be same as the resume option used. resume_point = self.get_resume_token() # Insert some documents so that firstBatch isn't empty. self.watched_collection( write_concern=WriteConcern("majority")).insert_many( [{'a': 1}, {'b': 2}, {'c': 3}]) # Resume token should be same as the resume option. with self.change_stream( **{resume_option: resume_point}) as change_stream: self.assertTrue(change_stream._cursor._has_next()) resume_token = change_stream.resume_token self.assertEqual(resume_token, resume_point) # Prose test no. 14 @client_context.require_no_mongos def test_resumetoken_uniterated_nonempty_batch_resumeafter(self): self._test_resumetoken_uniterated_nonempty_batch("resume_after") # Prose test no. 14 @client_context.require_no_mongos @client_context.require_version_min(4, 1, 1) def test_resumetoken_uniterated_nonempty_batch_startafter(self): self._test_resumetoken_uniterated_nonempty_batch("start_after") # Prose test no. 17 @client_context.require_version_min(4, 1, 1) def test_startafter_resume_uses_startafter_after_empty_getMore(self): # Resume should use startAfter after no changes have been returned. resume_point = self.get_resume_token() client, listener = self._client_with_listener("aggregate") with self.change_stream_with_client( client, start_after=resume_point) as change_stream: self.assertFalse(change_stream._cursor._has_next()) # No changes change_stream.try_next() # No changes self.kill_change_stream_cursor(change_stream) change_stream.try_next() # Resume attempt response = listener.results['started'][-1] self.assertIsNone( response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) self.assertIsNotNone( response.command["pipeline"][0]["$changeStream"].get("startAfter")) # Prose test no. 18 @client_context.require_version_min(4, 1, 1) def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self): # Resume should use resumeAfter after some changes have been returned. resume_point = self.get_resume_token() client, listener = self._client_with_listener("aggregate") with self.change_stream_with_client( client, start_after=resume_point) as change_stream: self.assertFalse(change_stream._cursor._has_next()) # No changes self.watched_collection().insert_one({}) next(change_stream) # Changes self.kill_change_stream_cursor(change_stream) change_stream.try_next() # Resume attempt response = listener.results['started'][-1] self.assertIsNotNone( response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) self.assertIsNone( response.command["pipeline"][0]["$changeStream"].get("startAfter")) class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestClusterChangeStream, cls).setUpClass() cls.dbs = [cls.db, cls.client.pymongo_test_2] @classmethod def tearDownClass(cls): for db in cls.dbs: cls.client.drop_database(db) super(TestClusterChangeStream, cls).tearDownClass() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) def generate_invalidate_event(self, change_stream): self.skipTest("cluster-level change streams cannot be invalidated") def _test_get_invalidate_event(self, change_stream): # Cluster-level change streams don't get invalidated. pass def _test_invalidate_stops_iteration(self, change_stream): # Cluster-level change streams don't get invalidated. pass def _insert_and_check(self, change_stream, db, collname, doc): coll = db[collname] coll.insert_one(doc) change = next(change_stream) self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['ns'], {'db': db.name, 'coll': collname}) self.assertEqual(change['fullDocument'], doc) def insert_one_and_check(self, change_stream, doc): db = random.choice(self.dbs) collname = self.id() self._insert_and_check(change_stream, db, collname, doc) def test_simple(self): collnames = self.generate_unique_collnames(3) with self.change_stream() as change_stream: for db, collname in product(self.dbs, collnames): self._insert_and_check( change_stream, db, collname, {'_id': collname} ) def test_aggregate_cursor_blocks(self): """Test that an aggregate cursor blocks until a change is readable.""" with self.client.admin.aggregate( [{'$changeStream': {'allChangesForCluster': True}}], maxAwaitTimeMS=250) as change_stream: self._test_next_blocks(change_stream) def test_full_pipeline(self): """$changeStream must be the first stage in a change stream pipeline sent to the server. """ self._test_full_pipeline({'allChangesForCluster': True}) class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestDatabaseChangeStream, cls).setUpClass() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) def generate_invalidate_event(self, change_stream): # Dropping the database invalidates the change stream. change_stream._client.drop_database(self.db.name) def _test_get_invalidate_event(self, change_stream): # Cache collection names. dropped_colls = self.db.list_collection_names() # Drop the watched database to get an invalidate event. self.generate_invalidate_event(change_stream) change = change_stream.next() # 4.1+ returns "drop" events for each collection in dropped database # and a "dropDatabase" event for the database itself. if change['operationType'] == 'drop': self.assertTrue(change['_id']) for _ in range(len(dropped_colls)): ns = change['ns'] self.assertEqual(ns['db'], change_stream._target.name) self.assertIn(ns['coll'], dropped_colls) change = change_stream.next() self.assertEqual(change['operationType'], 'dropDatabase') self.assertTrue(change['_id']) self.assertEqual(change['ns'], {'db': change_stream._target.name}) # Get next change. change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'invalidate') self.assertNotIn('ns', change) self.assertNotIn('fullDocument', change) # The ChangeStream should be dead. with self.assertRaises(StopIteration): change_stream.next() def _test_invalidate_stops_iteration(self, change_stream): # Drop the watched database to get an invalidate event. change_stream._client.drop_database(self.db.name) # Check drop and dropDatabase events. for change in change_stream: self.assertIn(change['operationType'], ( 'drop', 'dropDatabase', 'invalidate')) # Last change must be invalidate. self.assertEqual(change['operationType'], 'invalidate') # Change stream must not allow further iteration. with self.assertRaises(StopIteration): change_stream.next() with self.assertRaises(StopIteration): next(change_stream) def _insert_and_check(self, change_stream, collname, doc): coll = self.db[collname] coll.insert_one(doc) change = next(change_stream) self.assertEqual(change['operationType'], 'insert') self.assertEqual(change['ns'], {'db': self.db.name, 'coll': collname}) self.assertEqual(change['fullDocument'], doc) def insert_one_and_check(self, change_stream, doc): self._insert_and_check(change_stream, self.id(), doc) def test_simple(self): collnames = self.generate_unique_collnames(3) with self.change_stream() as change_stream: for collname in collnames: self._insert_and_check( change_stream, collname, {'_id': uuid.uuid4()}) def test_isolation(self): # Ensure inserts to other dbs don't show up in our ChangeStream. other_db = self.client.pymongo_test_temp self.assertNotEqual( other_db, self.db, msg="Isolation must be tested on separate DBs") collname = self.id() with self.change_stream() as change_stream: other_db[collname].insert_one({'_id': uuid.uuid4()}) self._insert_and_check( change_stream, collname, {'_id': uuid.uuid4()}) self.client.drop_database(other_db) class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): @classmethod @client_context.require_version_min(3, 5, 11) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestCollectionChangeStream, cls).setUpClass() def setUp(self): # Use a new collection for each test. self.watched_collection().drop() self.watched_collection().insert_one({}) def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].get_collection( self.watched_collection().name).watch(*args, **kwargs) def generate_invalidate_event(self, change_stream): # Dropping the collection invalidates the change stream. change_stream._target.drop() def _test_invalidate_stops_iteration(self, change_stream): self.generate_invalidate_event(change_stream) # Check drop and dropDatabase events. for change in change_stream: self.assertIn(change['operationType'], ('drop', 'invalidate')) # Last change must be invalidate. self.assertEqual(change['operationType'], 'invalidate') # Change stream must not allow further iteration. with self.assertRaises(StopIteration): change_stream.next() with self.assertRaises(StopIteration): next(change_stream) def _test_get_invalidate_event(self, change_stream): # Drop the watched database to get an invalidate event. change_stream._target.drop() change = change_stream.next() # 4.1+ returns a "drop" change document. if change['operationType'] == 'drop': self.assertTrue(change['_id']) self.assertEqual(change['ns'], { 'db': change_stream._target.database.name, 'coll': change_stream._target.name}) # Last change should be invalidate. change = change_stream.next() self.assertTrue(change['_id']) self.assertEqual(change['operationType'], 'invalidate') self.assertNotIn('ns', change) self.assertNotIn('fullDocument', change) # The ChangeStream should be dead. with self.assertRaises(StopIteration): change_stream.next() def insert_one_and_check(self, change_stream, doc): self.watched_collection().insert_one(doc) change = next(change_stream) self.assertEqual(change['operationType'], 'insert') self.assertEqual( change['ns'], {'db': self.watched_collection().database.name, 'coll': self.watched_collection().name}) self.assertEqual(change['fullDocument'], doc) def test_raw(self): """Test with RawBSONDocument.""" raw_coll = self.watched_collection( codec_options=DEFAULT_RAW_BSON_OPTIONS) with raw_coll.watch() as change_stream: raw_doc = RawBSONDocument(encode({'_id': 1})) self.watched_collection().insert_one(raw_doc) change = next(change_stream) self.assertIsInstance(change, RawBSONDocument) self.assertEqual(change['operationType'], 'insert') self.assertEqual( change['ns']['db'], self.watched_collection().database.name) self.assertEqual( change['ns']['coll'], self.watched_collection().name) self.assertEqual(change['fullDocument'], raw_doc) def test_uuid_representations(self): """Test with uuid document _ids and different uuid_representation.""" for uuid_representation in ALL_UUID_REPRESENTATIONS: for id_subtype in (STANDARD, PYTHON_LEGACY): options = self.watched_collection().codec_options.with_options( uuid_representation=uuid_representation) coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: coll.insert_one( {'_id': Binary(uuid.uuid4().bytes, id_subtype)}) _ = change_stream.next() resume_token = change_stream.resume_token # Should not error. coll.watch(resume_after=resume_token) def test_document_id_order(self): """Test with document _ids that need their order preserved.""" random_keys = random.sample(string.ascii_letters, len(string.ascii_letters)) random_doc = {'_id': SON([(key, key) for key in random_keys])} for document_class in (dict, SON, RawBSONDocument): options = self.watched_collection().codec_options.with_options( document_class=document_class) coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: coll.insert_one(random_doc) _ = change_stream.next() resume_token = change_stream.resume_token # The resume token is always a document. self.assertIsInstance(resume_token, document_class) # Should not error. coll.watch(resume_after=resume_token) coll.delete_many({}) def test_read_concern(self): """Test readConcern is not validated by the driver.""" # Read concern 'local' is not allowed for $changeStream. coll = self.watched_collection(read_concern=ReadConcern('local')) with self.assertRaises(OperationFailure): coll.watch() # Does not error. coll = self.watched_collection(read_concern=ReadConcern('majority')) with coll.watch(): pass class TestAllScenarios(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = WhiteListEventListener("aggregate", "getMore") cls.client = rs_or_single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): cls.client.close() def setUp(self): self.listener.results.clear() def setUpCluster(self, scenario_dict): assets = [(scenario_dict["database_name"], scenario_dict["collection_name"]), (scenario_dict.get("database2_name", "db2"), scenario_dict.get("collection2_name", "coll2"))] for db, coll in assets: self.client.drop_database(db) self.client[db].create_collection(coll) def setFailPoint(self, scenario_dict): fail_point = scenario_dict.get("failPoint") if fail_point is None: return elif not client_context.test_commands_enabled: self.skipTest("Test commands must be enabled") fail_cmd = SON([('configureFailPoint', 'failCommand')]) fail_cmd.update(fail_point) client_context.client.admin.command(fail_cmd) self.addCleanup( client_context.client.admin.command, 'configureFailPoint', fail_cmd['configureFailPoint'], mode='off') def assert_list_contents_are_subset(self, superlist, sublist): """Check that each element in sublist is a subset of the corresponding element in superlist.""" self.assertEqual(len(superlist), len(sublist)) for sup, sub in zip(superlist, sublist): if isinstance(sub, dict): self.assert_dict_is_subset(sup, sub) continue if isinstance(sub, (list, tuple)): self.assert_list_contents_are_subset(sup, sub) continue self.assertEqual(sup, sub) def assert_dict_is_subset(self, superdict, subdict): """Check that subdict is a subset of superdict.""" exempt_fields = ["documentKey", "_id", "getMore"] for key, value in iteritems(subdict): if key not in superdict: self.fail('Key %s not found in %s' % (key, superdict)) if isinstance(value, dict): self.assert_dict_is_subset(superdict[key], value) continue if isinstance(value, (list, tuple)): self.assert_list_contents_are_subset(superdict[key], value) continue if key in exempt_fields: # Only check for presence of these exempt fields, but not value. self.assertIn(key, superdict) else: self.assertEqual(superdict[key], value) def check_event(self, event, expectation_dict): if event is None: self.fail() for key, value in iteritems(expectation_dict): if isinstance(value, dict): self.assert_dict_is_subset(getattr(event, key), value) else: self.assertEqual(getattr(event, key), value) def tearDown(self): self.listener.results.clear() _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'change_streams' ) def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() def get_change_stream(client, scenario_def, test): # Get target namespace on which to instantiate change stream target = test["target"] if target == "collection": db = client.get_database(scenario_def["database_name"]) cs_target = db.get_collection(scenario_def["collection_name"]) elif target == "database": cs_target = client.get_database(scenario_def["database_name"]) elif target == "client": cs_target = client else: raise ValueError("Invalid target in spec") # Construct change stream kwargs dict cs_pipeline = test["changeStreamPipeline"] options = test["changeStreamOptions"] cs_options = {} for key, value in iteritems(options): cs_options[camel_to_snake(key)] = value # Create and return change stream return cs_target.watch(pipeline=cs_pipeline, **cs_options) def run_operation(client, operation): # Apply specified operations opname = camel_to_snake(operation["name"]) arguments = operation.get("arguments", {}) if opname == 'rename': # Special case for rename operation. arguments = {'new_name': arguments["to"]} cmd = getattr(client.get_database( operation["database"]).get_collection( operation["collection"]), opname ) return cmd(**arguments) def create_test(scenario_def, test): def run_scenario(self): # Set up self.setUpCluster(scenario_def) self.setFailPoint(test) is_error = test["result"].get("error", False) try: with get_change_stream( self.client, scenario_def, test ) as change_stream: for operation in test["operations"]: # Run specified operations run_operation(self.client, operation) num_expected_changes = len(test["result"].get("success", [])) changes = [ change_stream.next() for _ in range(num_expected_changes)] # Run a next() to induce an error if one is expected and # there are no changes. if is_error and not changes: change_stream.next() except OperationFailure as exc: if not is_error: raise expected_code = test["result"]["error"]["code"] self.assertEqual(exc.code, expected_code) else: # Check for expected output from change streams if test["result"].get("success"): for change, expected_changes in zip(changes, test["result"]["success"]): self.assert_dict_is_subset(change, expected_changes) self.assertEqual(len(changes), len(test["result"]["success"])) finally: # Check for expected events results = self.listener.results for idx, expectation in enumerate(test.get("expectations", [])): for event_type, event_desc in iteritems(expectation): results_key = event_type.split("_")[1] event = results[results_key][idx] if len(results[results_key]) > idx else None self.check_event(event, event_desc) return run_scenario def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json_util.loads(scenario_stream.read()) test_type = os.path.splitext(filename)[0] for test in scenario_def['tests']: new_test = create_test(scenario_def, test) new_test = client_context.require_no_mmap(new_test) if 'minServerVersion' in test: min_ver = tuple( int(elt) for elt in test['minServerVersion'].split('.')) new_test = client_context.require_version_min(*min_ver)( new_test) if 'maxServerVersion' in test: max_ver = tuple( int(elt) for elt in test['maxServerVersion'].split('.')) new_test = client_context.require_version_max(*max_ver)( new_test) topologies = test['topology'] new_test = client_context.require_cluster_type(topologies)( new_test) test_name = 'test_%s_%s_%s' % ( dirname, test_type.replace("-", "_"), str(test['description'].replace(" ", "_"))) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_client.py000066400000000000000000002364021374256237000171520ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the mongo_client module.""" import contextlib import copy import datetime import gc import os import signal import socket import struct import sys import time import threading import warnings sys.path[0:0] = [""] from bson import encode from bson.codec_options import CodecOptions, TypeEncoder, TypeRegistry from bson.py3compat import thread from bson.son import SON from bson.tz_util import utc import pymongo from pymongo import auth, message from pymongo.common import CONNECT_TIMEOUT, _UUID_REPRESENTATIONS from pymongo.command_cursor import CommandCursor from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD from pymongo.cursor import Cursor, CursorType from pymongo.database import Database from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure, InvalidName, InvalidURI, NetworkTimeout, OperationFailure, ServerSelectionTimeoutError, WriteConcernError) from pymongo.monitoring import (ServerHeartbeatListener, ServerHeartbeatStartedEvent) from pymongo.mongo_client import MongoClient from pymongo.monotonic import time as monotonic_time from pymongo.driver_info import DriverInfo from pymongo.pool import SocketInfo, _METADATA from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.server_type import SERVER_TYPE from pymongo.settings import TOPOLOGY_TYPE from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.write_concern import WriteConcern from test import (client_context, client_knobs, SkipTest, unittest, IntegrationTest, db_pwd, db_user, MockClientTest, HAVE_IPADDRESS) from test.pymongo_mocks import MockClient from test.utils import (assertRaisesExactly, connected, delay, FunctionCallRecorder, get_pool, gevent_monkey_patched, ignore_deprecations, is_greenthread_patched, lazy_client_trial, NTHREADS, one, remove_all_users, rs_client, rs_or_single_client, rs_or_single_client_noauth, server_is_master_with_slave, single_client, wait_until) class ClientUnitTest(unittest.TestCase): """MongoClient tests that don't require a server.""" @classmethod @client_context.require_connection def setUpClass(cls): cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @classmethod def tearDownClass(cls): cls.client.close() def test_keyword_arg_defaults(self): client = MongoClient(socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, waitQueueMultiple=None, replicaSet=None, read_preference=ReadPreference.PRIMARY, ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=0, # ssl.CERT_NONE ssl_ca_certs=None, connect=False, serverSelectionTimeoutMS=12000) options = client._MongoClient__options pool_opts = options.pool_options self.assertEqual(None, pool_opts.socket_timeout) # socket.Socket.settimeout takes a float in seconds self.assertEqual(20.0, pool_opts.connect_timeout) self.assertEqual(None, pool_opts.wait_queue_timeout) self.assertEqual(None, pool_opts.wait_queue_multiple) self.assertTrue(pool_opts.socket_keepalive) self.assertEqual(None, pool_opts.ssl_context) self.assertEqual(None, options.replica_set_name) self.assertEqual(ReadPreference.PRIMARY, client.read_preference) self.assertAlmostEqual(12, client.server_selection_timeout) def test_connect_timeout(self): client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) client = MongoClient( 'mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0', connect=False) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) def test_types(self): self.assertRaises(TypeError, MongoClient, 1) self.assertRaises(TypeError, MongoClient, 1.14) self.assertRaises(TypeError, MongoClient, "localhost", "27017") self.assertRaises(TypeError, MongoClient, "localhost", 1.14) self.assertRaises(TypeError, MongoClient, "localhost", []) self.assertRaises(ConfigurationError, MongoClient, []) def test_max_pool_size_zero(self): with self.assertRaises(ValueError): MongoClient(maxPoolSize=0) def test_get_db(self): def make_db(base, name): return base[name] self.assertRaises(InvalidName, make_db, self.client, "") self.assertRaises(InvalidName, make_db, self.client, "te$t") self.assertRaises(InvalidName, make_db, self.client, "te.t") self.assertRaises(InvalidName, make_db, self.client, "te\\t") self.assertRaises(InvalidName, make_db, self.client, "te/t") self.assertRaises(InvalidName, make_db, self.client, "te st") self.assertTrue(isinstance(self.client.test, Database)) self.assertEqual(self.client.test, self.client["test"]) self.assertEqual(self.client.test, Database(self.client, "test")) def test_get_database(self): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) db = self.client.get_database( 'foo', codec_options, ReadPreference.SECONDARY, write_concern) self.assertEqual('foo', db.name) self.assertEqual(codec_options, db.codec_options) self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) def test_getattr(self): self.assertTrue(isinstance(self.client['_does_not_exist'], Database)) with self.assertRaises(AttributeError) as context: self.client._does_not_exist # Message should be: # "AttributeError: MongoClient has no attribute '_does_not_exist'. To # access the _does_not_exist database, use client['_does_not_exist']". self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) def test_iteration(self): def iterate(): [a for a in self.client] self.assertRaises(TypeError, iterate) def test_get_default_database(self): c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) # Test that default doesn't override the URI value. self.assertEqual(Database(c, 'foo'), c.get_default_database('bar')) codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) db = c.get_default_database( None, codec_options, ReadPreference.SECONDARY, write_concern) self.assertEqual('foo', db.name) self.assertEqual(codec_options, db.codec_options) self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database('foo')) def test_get_default_database_error(self): # URI with no database. c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False) self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( client_context.host, client_context.port) c = rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_get_database_default(self): c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False) self.assertEqual(Database(c, 'foo'), c.get_database()) def test_get_database_default_error(self): # URI with no database. c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False) self.assertRaises(ConfigurationError, c.get_database) def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % ( client_context.host, client_context.port) c = rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_database()) def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): MongoClient('mongodb://host/?readpreferencetags=dc:east') with self.assertRaises(ConfigurationError): MongoClient('mongodb://host/?' 'readpreference=primary&readpreferencetags=dc:east') def test_read_preference(self): c = rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode) self.assertEqual(c.read_preference, ReadPreference.NEAREST) def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata['application'] = {'name': 'foobar'} client = MongoClient( "mongodb://foo:27017/?appname=foobar&connect=false") options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) client = MongoClient('foo', 27017, appname='foobar', connect=False) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) # No error MongoClient(appname='x' * 128) self.assertRaises(ValueError, MongoClient, appname='x' * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, 'Foo', 1, 'a') self.assertRaises(TypeError, MongoClient, driver=1) self.assertRaises(TypeError, MongoClient, driver='abc') self.assertRaises(TypeError, MongoClient, driver=('Foo', '1', 'a')) # Test appending to driver info. metadata['driver']['name'] = 'PyMongo|FooDriver' metadata['driver']['version'] = '%s|1.2.3' % ( _METADATA['driver']['version'],) client = MongoClient('foo', 27017, appname='foobar', driver=DriverInfo('FooDriver', '1.2.3', None), connect=False) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) metadata['platform'] = '%s|FooPlatform' % ( _METADATA['platform'],) client = MongoClient('foo', 27017, appname='foobar', driver=DriverInfo('FooDriver', '1.2.3', 'FooPlatform'), connect=False) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) def test_kwargs_codec_options(self): class MyFloatType(object): def __init__(self, x): self.__x = x @property def x(self): return self.__x class MyFloatAsIntEncoder(TypeEncoder): python_type = MyFloatType def transform_python(self, value): return int(value) # Ensure codec options are passed in correctly document_class = SON type_registry = TypeRegistry([MyFloatAsIntEncoder()]) tz_aware = True uuid_representation_label = 'javaLegacy' unicode_decode_error_handler = 'ignore' tzinfo = utc c = MongoClient( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, uuidrepresentation=uuid_representation_label, unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, connect=False ) self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label]) self.assertEqual( c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) def test_uri_codec_options(self): # Ensure codec options are passed in correctly uuid_representation_label = 'javaLegacy' unicode_decode_error_handler = 'ignore' uri = ("mongodb://%s:%d/foo?tz_aware=true&uuidrepresentation=" "%s&unicode_decode_error_handler=%s" % ( client_context.host, client_context.port, uuid_representation_label, unicode_decode_error_handler)) c = MongoClient(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label]) self.assertEqual( c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = ("mongodb://localhost/?ssl=true&replicaSet=name" "&readPreference=primary") c = MongoClient(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") clopts = c._MongoClient__options opts = clopts._options self.assertEqual(opts['ssl'], False) self.assertEqual(clopts.replica_set_name, "newname") self.assertEqual( clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) @unittest.skipUnless( _HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import resolver patched_resolver = FunctionCallRecorder(resolver.query) pymongo.srv_resolver.resolver.query = patched_resolver def reset_resolver(): pymongo.srv_resolver.resolver.query = resolver.query self.addCleanup(reset_resolver) # Setup. base_uri = "mongodb+srv://test5.test.build.10gen.cc" connectTimeoutMS = 5000 expected_kw_value = 5.0 uri_with_timeout = base_uri + "/?connectTimeoutMS=6000" expected_uri_value = 6.0 def test_scenario(args, kwargs, expected_value): patched_resolver.reset() MongoClient(*args, **kwargs) for _, kw in patched_resolver.call_list(): self.assertAlmostEqual(kw['lifetime'], expected_value) # No timeout specified. test_scenario((base_uri,), {}, CONNECT_TIMEOUT) # Timeout only specified in connection string. test_scenario((uri_with_timeout,), {}, expected_uri_value) # Timeout only specified in keyword arguments. kwarg = {'connectTimeoutMS': connectTimeoutMS} test_scenario((base_uri,), kwarg, expected_kw_value) # Timeout specified in both kwargs and connection string. test_scenario((uri_with_timeout,), kwarg, expected_kw_value) def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): MongoClient('mongodb://localhost/?ssl=true', tls=False, connect=False) # Matching SSL and TLS options should not cause errors. c = MongoClient('mongodb://localhost/?ssl=false', tls=False, connect=False) self.assertEqual(c._MongoClient__options._options['ssl'], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): MongoClient('mongodb://localhost/?tlsInsecure=true', connect=False, tlsAllowInvalidHostnames=True) # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): MongoClient('mongodb://localhost/?tlsInsecure=true', connect=False, ssl_cert_reqs=True) class TestClient(IntegrationTest): def test_max_idle_time_reaper(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove sockets when maxIdleTimeMS not set client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) self.assertTrue(sock_info in server._pool.sockets) client.close() # Assert reaper removes idle socket and replaces it with a new one client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, two # sockets could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.sockets), 1) wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket") client.close() # Assert reaper respects maxPoolSize when adding new sockets. client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two sockets from being created. self.assertEqual(1, len(server._pool.sockets)) wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket") client.close() # Assert reaper has removed idle socket and NOT replaced it client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info_one: pass # Assert that the pool does not close sockets prematurely. time.sleep(.300) with server._pool.get_socket({}) as sock_info_two: pass self.assertIs(sock_info_one, sock_info_two) wait_until( lambda: 0 == len(server._pool.sockets), "stale socket reaped and new one NOT added to the pool") client.close() def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=.1): client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) self.assertEqual(0, len(server._pool.sockets)) # Assert that pool started up at minPoolSize client = rs_or_single_client(minPoolSize=10) server = client._get_topology().select_server(any_server_selector) wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") # Assert that if a socket is closed, a new one takes its place with server._pool.get_socket({}) as sock_info: sock_info.close_socket(None) wait_until(lambda: 10 == len(server._pool.sockets), "a closed socket gets replaced from the pool") self.assertFalse(sock_info in server._pool.sockets) def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) # Sleep so that the socket becomes stale. with server._pool.get_socket({}) as new_sock_info: self.assertNotEqual(sock_info, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) self.assertFalse(sock_info in server._pool.sockets) self.assertTrue(new_sock_info in server._pool.sockets) # Test that sockets are reused if maxIdleTimeMS is not set. client = rs_or_single_client() server = client._get_topology().select_server(any_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) time.sleep(1) with server._pool.get_socket({}) as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(server._pool.sockets)) def test_constants(self): """This test uses MongoClient explicitly to make sure that host and port are not overloaded. """ host, port = client_context.host, client_context.port kwargs = client_context.default_client_options.copy() if client_context.auth_enabled: kwargs['username'] = db_user kwargs['password'] = db_pwd # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): connected(MongoClient(serverSelectionTimeoutMS=10, **kwargs)) # Override the defaults. No error. connected(MongoClient(host, port, **kwargs)) # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. connected(MongoClient(**kwargs)) def test_init_disconnected(self): host, port = client_context.host, client_context.port c = rs_or_single_client(connect=False) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) c = rs_or_single_client(connect=False) self.assertIsInstance(c.is_mongos, bool) c = rs_or_single_client(connect=False) self.assertIsInstance(c.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) self.assertIsInstance(c.max_bson_size, int) c = rs_or_single_client(connect=False) self.assertFalse(c.primary) self.assertFalse(c.secondaries) c = rs_or_single_client(connect=False) self.assertIsInstance(c.max_write_batch_size, int) if client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(c.address) else: self.assertEqual(c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_equality(self): c = connected(rs_or_single_client()) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) def test_host_w_port(self): with self.assertRaises(ValueError): connected(MongoClient("%s:1234567" % (client_context.host,), connectTimeoutMS=1, serverSelectionTimeoutMS=10)) def test_repr(self): # Used to test 'eval' below. import bson client = MongoClient( 'mongodb://localhost:27017,localhost:27018/?replicaSet=replset' '&connectTimeoutMS=12345&w=1&wtimeoutms=100', connect=False, document_class=SON) the_repr = repr(client) self.assertIn('MongoClient(host=', the_repr) self.assertIn( "document_class=bson.son.SON, " "tz_aware=False, " "connect=False, ", the_repr) self.assertIn("connecttimeoutms=12345", the_repr) self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) self.assertEqual(eval(the_repr), client) client = MongoClient("localhost:27017,localhost:27018", replicaSet='replset', connectTimeoutMS=12345, socketTimeoutMS=None, w=1, wtimeoutms=100, connect=False) the_repr = repr(client) self.assertIn('MongoClient(host=', the_repr) self.assertIn( "document_class=dict, " "tz_aware=False, " "connect=False, ", the_repr) self.assertIn("connecttimeoutms=12345", the_repr) self.assertIn("replicaset='replset'", the_repr) self.assertIn("sockettimeoutms=None", the_repr) self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) self.assertEqual(eval(the_repr), client) def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") def test_list_databases(self): cmd_docs = self.client.admin.command('listDatabases')['databases'] cursor = self.client.list_databases() self.assertIsInstance(cursor, CommandCursor) helper_docs = list(cursor) self.assertTrue(len(helper_docs) > 0) self.assertEqual(helper_docs, cmd_docs) for doc in helper_docs: self.assertIs(type(doc), dict) client = rs_or_single_client(document_class=SON) for doc in client.list_databases(): self.assertIs(type(doc), dict) if client_context.version.at_least(3, 4, 2): self.client.pymongo_test.test.insert_one({}) cursor = self.client.list_databases(filter={"name": "admin"}) docs = list(cursor) self.assertEqual(1, len(docs)) self.assertEqual(docs[0]["name"], "admin") if client_context.version.at_least(3, 4, 3): cursor = self.client.list_databases(nameOnly=True) for doc in cursor: self.assertEqual(["name"], list(doc)) def _test_list_names(self, meth): self.client.pymongo_test.test.insert_one({"dummy": u"object"}) self.client.pymongo_test_mike.test.insert_one({"dummy": u"object"}) cmd_docs = self.client.admin.command("listDatabases")["databases"] cmd_names = [doc["name"] for doc in cmd_docs] db_names = meth() self.assertTrue("pymongo_test" in db_names) self.assertTrue("pymongo_test_mike" in db_names) self.assertEqual(db_names, cmd_names) def test_list_database_names(self): self._test_list_names(self.client.list_database_names) def test_database_names(self): self._test_list_names(self.client.database_names) def test_drop_database(self): self.assertRaises(TypeError, self.client.drop_database, 5) self.assertRaises(TypeError, self.client.drop_database, None) self.client.pymongo_test.test.insert_one({"dummy": u"object"}) self.client.pymongo_test2.test.insert_one({"dummy": u"object"}) dbs = self.client.list_database_names() self.assertIn("pymongo_test", dbs) self.assertIn("pymongo_test2", dbs) self.client.drop_database("pymongo_test") if client_context.version.at_least(3, 3, 9) and client_context.is_rs: wc_client = rs_or_single_client(w=len(client_context.nodes) + 1) with self.assertRaises(WriteConcernError): wc_client.drop_database('pymongo_test2') self.client.drop_database(self.client.pymongo_test2) raise SkipTest("This test often fails due to SERVER-2329") dbs = self.client.list_database_names() self.assertNotIn("pymongo_test", dbs) self.assertNotIn("pymongo_test2", dbs) def test_close(self): coll = self.client.pymongo_test.bar self.client.close() self.client.close() coll.count_documents({}) self.client.close() self.client.close() coll.count_documents({}) def test_close_kills_cursors(self): if sys.platform.startswith('java'): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") # Kill any cursors possibly queued up by previous tests. gc.collect() self.client._process_periodic_tasks() # Add some test data. coll = self.client.pymongo_test.test_close_kills_cursors docs_inserted = 1000 coll.insert_many([{"i": i} for i in range(docs_inserted)]) # Open a cursor and leave it open on the server. cursor = coll.find().batch_size(10) self.assertTrue(bool(next(cursor))) self.assertLess(cursor.retrieved, docs_inserted) # Open a command cursor and leave it open on the server. cursor = coll.aggregate([], batchSize=10) self.assertTrue(bool(next(cursor))) del cursor # Required for PyPy, Jython and other Python implementations that # don't use reference counting garbage collection. gc.collect() # Close the client and ensure the topology is closed. self.assertTrue(self.client._topology._opened) self.client.close() self.assertFalse(self.client._topology._opened) # The killCursors task should not need to re-open the topology. self.client._process_periodic_tasks() self.assertFalse(self.client._topology._opened) def test_close_stops_kill_cursors_thread(self): client = rs_client() client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) # Closing the client should stop the thread. client.close() self.assertTrue(client._kill_cursors_executor._stopped) # Reusing the closed client should restart the thread. client.admin.command('isMaster') self.assertFalse(client._kill_cursors_executor._stopped) # Again, closing the client should stop the thread. client.close() self.assertTrue(client._kill_cursors_executor._stopped) def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. client = rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. kc_thread = client._kill_cursors_executor._thread self.assertFalse(kc_thread and kc_thread.is_alive()) # Using the client should open topology and start the thread. client.admin.command('isMaster') self.assertTrue(client._topology._opened) kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) # Tear down. client.close() def test_close_does_not_open_servers(self): client = rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) client.close() self.assertEqual(topology._servers, {}) def test_close_closes_sockets(self): client = rs_client() self.addCleanup(client.close) client.test.test.find_one() topology = client._topology client.close() for server in topology._servers.values(): self.assertFalse(server._pool.sockets) self.assertTrue(server._monitor._executor._stopped) self.assertTrue(server._monitor._rtt_monitor._executor._stopped) self.assertFalse(server._monitor._pool.sockets) self.assertFalse(server._monitor._rtt_monitor._pool.sockets) def test_bad_uri(self): with self.assertRaises(InvalidURI): MongoClient("http://localhost") @client_context.require_auth def test_auth_from_uri(self): host, port = client_context.host, client_context.port client_context.create_user("admin", "admin", "pass") self.addCleanup(client_context.drop_user, "admin", "admin") self.addCleanup(remove_all_users, self.client.pymongo_test) client_context.create_user( "pymongo_test", "user", "pass", roles=['userAdmin', 'readWrite']) with self.assertRaises(OperationFailure): connected(rs_or_single_client( "mongodb://a:b@%s:%d" % (host, port))) # No error. connected(rs_or_single_client_noauth( "mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): connected(rs_or_single_client(uri)) # No error. connected(rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port))) # Auth with lazy connection. rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False).pymongo_test.test.find_one() # Wrong password. bad_client = rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False) self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one) @client_context.require_auth def test_username_and_password(self): client_context.create_user("admin", "ad min", "pa/ss") self.addCleanup(client_context.drop_user, "admin", "ad min") c = rs_or_single_client(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) self.assertNotIn("ad min", str(c)) self.assertNotIn("pa/ss", repr(c)) self.assertNotIn("pa/ss", str(c)) # Auth succeeds. c.server_info() with self.assertRaises(OperationFailure): rs_or_single_client(username="ad min", password="foo").server_info() @client_context.require_auth @ignore_deprecations def test_multiple_logins(self): client_context.create_user( 'pymongo_test', 'user1', 'pass', roles=['readWrite']) client_context.create_user( 'pymongo_test', 'user2', 'pass', roles=['readWrite']) self.addCleanup(remove_all_users, self.client.pymongo_test) client = rs_or_single_client_noauth( "mongodb://user1:pass@%s:%d/pymongo_test" % ( client_context.host, client_context.port)) client.pymongo_test.test.find_one() with self.assertRaises(OperationFailure): # Can't log in to the same database with multiple users. client.pymongo_test.authenticate('user2', 'pass') client.pymongo_test.test.find_one() client.pymongo_test.logout() with self.assertRaises(OperationFailure): client.pymongo_test.test.find_one() client.pymongo_test.authenticate('user2', 'pass') client.pymongo_test.test.find_one() with self.assertRaises(OperationFailure): client.pymongo_test.authenticate('user1', 'pass') client.pymongo_test.test.find_one() @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): lazy_client = rs_or_single_client_noauth( "mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), connect=False) assertRaisesExactly( OperationFailure, lazy_client.test.collection.find_one) @client_context.require_no_tls def test_unix_socket(self): if not hasattr(socket, "AF_UNIX"): raise SkipTest("UNIX-sockets are not supported on this system") mongodb_socket = '/tmp/mongodb-%d.sock' % (client_context.port,) encoded_socket = ( '%2Ftmp%2F' + 'mongodb-%d.sock' % (client_context.port,)) if not os.access(mongodb_socket, os.R_OK): raise SkipTest("Socket file is not accessible") if client_context.auth_enabled: uri = "mongodb://%s:%s@%s" % (db_user, db_pwd, encoded_socket) else: uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. client = rs_or_single_client(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) self.assertTrue(mongodb_socket in repr(client)) # Confirm it fails with a missing socket. self.assertRaises( ConnectionFailure, connected, MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100)) def test_document_class(self): c = self.client db = c.pymongo_test db.test.insert_one({"x": 1}) self.assertEqual(dict, c.codec_options.document_class) self.assertTrue(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) c = rs_or_single_client(document_class=SON) db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) self.assertTrue(isinstance(db.test.find_one(), SON)) def test_timeouts(self): client = rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500) self.assertEqual(10.5, get_pool(client).opts.connect_timeout) self.assertEqual(10.5, get_pool(client).opts.socket_timeout) self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds) self.assertEqual(10500, client.max_idle_time_ms) self.assertEqual(10.5, client.server_selection_timeout) def test_socket_timeout_ms_validation(self): c = rs_or_single_client(socketTimeoutMS=10 * 1000) self.assertEqual(10, get_pool(c).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=None)) self.assertEqual(None, get_pool(c).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, get_pool(c).opts.socket_timeout) self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1) self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10) self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS='foo') def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 timeout = rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) no_timeout.pymongo_test.drop_collection("test") no_timeout.pymongo_test.test.insert_one({"x": 1}) # A $where clause that takes a second longer than the timeout where_func = delay(timeout_sec + 1) def get_x(db): doc = next(db.test.find().where(where_func)) return doc["x"] self.assertEqual(1, get_x(no_timeout.pymongo_test)) self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test) def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.server_selection_timeout) client = MongoClient(serverSelectionTimeoutMS=0, connect=False) self.assertAlmostEqual(0, client.server_selection_timeout) self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS="foo", connect=False) self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS=-1, connect=False) self.assertRaises(ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False) client = MongoClient( 'mongodb://localhost/?serverSelectionTimeoutMS=100', connect=False) self.assertAlmostEqual(0.1, client.server_selection_timeout) client = MongoClient( 'mongodb://localhost/?serverSelectionTimeoutMS=0', connect=False) self.assertAlmostEqual(0, client.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. client = MongoClient( 'mongodb://localhost/?serverSelectionTimeoutMS=-1', connect=False) self.assertAlmostEqual(30, client.server_selection_timeout) client = MongoClient( 'mongodb://localhost/?serverSelectionTimeoutMS=', connect=False) self.assertAlmostEqual(30, client.server_selection_timeout) def test_waitQueueTimeoutMS(self): client = rs_or_single_client(waitQueueTimeoutMS=2000) self.assertEqual(get_pool(client).opts.wait_queue_timeout, 2) def test_waitQueueMultiple(self): client = rs_or_single_client(maxPoolSize=3, waitQueueMultiple=2) pool = get_pool(client) self.assertEqual(pool.opts.wait_queue_multiple, 2) self.assertEqual(pool._socket_semaphore.waiter_semaphore.counter, 6) def test_socketKeepAlive(self): for socketKeepAlive in [True, False]: with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") client = rs_or_single_client(socketKeepAlive=socketKeepAlive) self.assertTrue(any("The socketKeepAlive option is deprecated" in str(k) for k in ctx)) pool = get_pool(client) self.assertEqual(socketKeepAlive, pool.opts.socket_keepalive) with pool.get_socket({}) as sock_info: keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertEqual(socketKeepAlive, bool(keepalive)) def test_tz_aware(self): self.assertRaises(ValueError, MongoClient, tz_aware='foo') aware = rs_or_single_client(tz_aware=True) naive = self.client aware.pymongo_test.drop_collection("test") now = datetime.datetime.utcnow() aware.pymongo_test.test.insert_one({"x": now}) self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo) self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo) self.assertEqual( aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None), naive.pymongo_test.test.find_one()["x"]) @client_context.require_ipv6 def test_ipv6(self): if client_context.tls: if not HAVE_IPADDRESS: raise SkipTest("Need the ipaddress module to test with SSL") if client_context.auth_enabled: auth_str = "%s:%s@" % (db_user, db_pwd) else: auth_str = "" uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port) if client_context.is_rs: uri += '/?replicaSet=' + client_context.replica_set_name client = rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": u"object"}) client.pymongo_test_bernie.test.insert_one({"dummy": u"object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) self.assertTrue("pymongo_test_bernie" in dbs) @ignore_deprecations @client_context.require_no_mongos def test_fsync_lock_unlock(self): if server_is_master_with_slave(client_context.client): raise SkipTest('SERVER-7714') self.assertFalse(self.client.is_locked) # async flushing not supported on windows... if sys.platform not in ('cygwin', 'win32'): # Work around async becoming a reserved keyword in Python 3.7 opts = {'async': True} self.client.fsync(**opts) self.assertFalse(self.client.is_locked) self.client.fsync(lock=True) self.assertTrue(self.client.is_locked) locked = True self.client.unlock() for _ in range(5): locked = self.client.is_locked if not locked: break time.sleep(1) self.assertFalse(locked) def test_deprecated_methods(self): with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) with self.assertRaisesRegex(DeprecationWarning, 'is_locked is deprecated'): _ = self.client.is_locked if not client_context.is_mongos: with self.assertRaisesRegex(DeprecationWarning, 'fsync is deprecated'): self.client.fsync(lock=True) with self.assertRaisesRegex(DeprecationWarning, 'unlock is deprecated'): self.client.unlock() def test_contextlib(self): client = rs_or_single_client() client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert_one({"foo": "bar"}) # The socket used for the previous commands has been returned to the # pool self.assertEqual(1, len(get_pool(client).sockets)) with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) self.assertEqual(1, len(get_pool(client).sockets)) self.assertEqual(0, len(get_pool(client).sockets)) with client as client: self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) self.assertEqual(0, len(get_pool(client).sockets)) def test_interrupt_signal(self): if sys.platform.startswith('java'): # We can't figure out how to raise an exception on a thread that's # blocked on a socket, whether that's the main thread or a worker, # without simply killing the whole thread in Jython. This suggests # PYTHON-294 can't actually occur in Jython. raise SkipTest("Can't test interrupts in Jython") if is_greenthread_patched(): raise SkipTest("Can't reliably test interrupts with green threads") # Test fix for PYTHON-294 -- make sure MongoClient closes its # socket if it gets an interrupt while waiting to recv() from it. db = self.client.pymongo_test # A $where clause which takes 1.5 sec to execute where = delay(1.5) # Need exactly 1 document so find() will execute its $where clause once db.drop_collection('foo') db.foo.insert_one({'_id': 1}) old_signal_handler = None try: # Platform-specific hacks for raising a KeyboardInterrupt on the # main thread while find() is in-progress: On Windows, SIGALRM is # unavailable so we use a second thread. In our Evergreen setup on # Linux, the thread technique causes an error in the test at # sock.recv(): TypeError: 'int' object is not callable # We don't know what causes this, so we hack around it. if sys.platform == 'win32': def interrupter(): # Raises KeyboardInterrupt in the main thread time.sleep(0.25) thread.interrupt_main() thread.start_new_thread(interrupter, ()) else: # Convert SIGALRM to SIGINT -- it's hard to schedule a SIGINT # for one second in the future, but easy to schedule SIGALRM. def sigalarm(num, frame): raise KeyboardInterrupt old_signal_handler = signal.signal(signal.SIGALRM, sigalarm) signal.alarm(1) raised = False try: # Will be interrupted by a KeyboardInterrupt. next(db.foo.find({'$where': where})) except KeyboardInterrupt: raised = True # Can't use self.assertRaises() because it doesn't catch system # exceptions self.assertTrue(raised, "Didn't raise expected KeyboardInterrupt") # Raises AssertionError due to PYTHON-294 -- Mongo's response to # the previous find() is still waiting to be read on the socket, # so the request id's don't match. self.assertEqual( {'_id': 1}, next(db.foo.find()) ) finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) def test_operation_failure(self): # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. client = single_client() client.pymongo_test.test.find_one() pool = get_pool(client) socket_count = len(pool.sockets) self.assertGreaterEqual(socket_count, 1) old_sock_info = next(iter(pool.sockets)) client.pymongo_test.test.drop() client.pymongo_test.test.insert_one({'_id': 'foo'}) self.assertRaises( OperationFailure, client.pymongo_test.test.insert_one, {'_id': 'foo'}) self.assertEqual(socket_count, len(pool.sockets)) new_sock_info = next(iter(pool.sockets)) self.assertEqual(old_sock_info, new_sock_info) def test_lazy_connect_w0(self): # Ensure that connect-on-demand works when the first operation is # an unacknowledged write. This exercises _writable_max_wire_version(). # Use a separate collection to avoid races where we're still # completing an operation on a collection while the next test begins. client_context.client.drop_database('test_lazy_connect_w0') self.addCleanup( client_context.client.drop_database, 'test_lazy_connect_w0') client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.insert_one({}) wait_until( lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document") client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.update_one({}, {'$set': {'x': 1}}) wait_until( lambda: client.test_lazy_connect_w0.test.find_one().get('x') == 1, "update one document") client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.delete_one({}) wait_until( lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document") @client_context.require_no_mongos def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. # Ensure a socket. connected(client) # Cause a network error. sock_info = one(pool.sockets) sock_info.sock.close() cursor = collection.find(cursor_type=CursorType.EXHAUST) with self.assertRaises(ConnectionFailure): next(cursor) self.assertTrue(sock_info.closed) # The semaphore was decremented despite the error. self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) @client_context.require_auth def test_auth_network_error(self): # Make sure there's no semaphore leak if we get a network error # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) # Simulate an authenticate() call on a different socket. credentials = auth._build_credentials_tuple( 'DEFAULT', 'admin', db_user, db_pwd, {}, None) c._cache_credentials('test', credentials, connect=False) # Cause a network error on the actual socket. pool = get_pool(c) socket_info = one(pool.sockets) socket_info.sock.close() # SocketInfo.check_auth logs in with the new credential, but gets a # socket.error. Should be reraised as AutoReconnect. self.assertRaises(AutoReconnect, c.test.collection.find_one) # No semaphore leak, the pool is allowed to make a new socket. c.test.collection.find_one() @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): client = single_client(replicaSet='anything', serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): client.test.test.find_one() @client_context.require_replica_set def test_stale_getmore(self): # A cursor is created, but its member goes down and is removed from # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): client = rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation_with_response( operation=message._GetMore('pymongo_test', 'collection', 101, 1234, client.codec_options, ReadPreference.PRIMARY, None, client, None, None), unpack_res=Cursor( client.pymongo_test.collection)._unpack_response, address=('not-a-member', 27017)) def test_heartbeat_frequency_ms(self): class HeartbeatStartedListener(ServerHeartbeatListener): def __init__(self): self.results = [] def started(self, event): self.results.append(event) def succeeded(self, event): pass def failed(self, event): pass old_init = ServerHeartbeatStartedEvent.__init__ heartbeat_times = [] def init(self, *args): old_init(self, *args) heartbeat_times.append(time.time()) try: ServerHeartbeatStartedEvent.__init__ = init listener = HeartbeatStartedListener() uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( client_context.host, client_context.port) client = single_client(uri, event_listeners=[listener]) wait_until(lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents") # Default heartbeatFrequencyMS is 10 sec. Check the interval was # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual( heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init def test_small_heartbeat_frequency_ms(self): uri = "mongodb://example/?heartbeatFrequencyMS=499" with self.assertRaises(ConfigurationError) as context: MongoClient(uri) self.assertIn('heartbeatFrequencyMS', str(context.exception)) def test_compression(self): def compression_settings(client): pool_options = client._MongoClient__options.pool_options return pool_options.compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zlib']) self.assertEqual(opts.zlib_compression_level, -1) if not _HAVE_SNAPPY: uri = "mongodb://localhost:27017/?compressors=snappy" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['snappy']) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['snappy', 'zlib']) if not _HAVE_ZSTD: uri = "mongodb://localhost:27017/?compressors=zstd" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zstd']) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ['zstd', 'zlib']) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): client = single_client(zlibcompressionlevel=level) # No error client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): client = rs_or_single_client(minPoolSize=10) self.addCleanup(client.close) client.admin.command('ping') pool = get_pool(client) generation = pool.generation # Continuously reset the pool. class ResetPoolThread(threading.Thread): def __init__(self, pool): super(ResetPoolThread, self).__init__() self.running = True self.pool = pool def stop(self): self.running = False def run(self): while self.running: self.pool.reset() time.sleep(0.001) t = ResetPoolThread(pool) t.start() # Ensure that update_pool completes without error even when the pool # is reset concurrently. try: while True: for _ in range(10): client._topology.update_pool( client._MongoClient__all_credentials) if generation != pool.generation: break finally: t.stop() t.join() client.admin.command('ping') def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 client = rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False) self.addCleanup(client.close) # Create a single connection in the pool. client.admin.command('ping') # Cause new connections stall for a few seconds. pool = get_pool(client) original_connect = pool.connect def stall_connect(*args, **kwargs): time.sleep(2) return original_connect(*args, **kwargs) pool.connect = stall_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, 'connect') # Wait for the background thread to start creating connections wait_until(lambda: len(pool.sockets) > 1, 'start creating connections') # Assert that application operations do not block. for _ in range(10): start = monotonic_time() client.admin.command('ping') total = monotonic_time() - start # Each ping command should not take more than 2 seconds self.assertLess(total, 2) @client_context.require_replica_set def test_direct_connection(self): # direct_connection=True should result in Single topology. client = rs_or_single_client(directConnection=True) client.admin.command('ping') self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) client.close() # direct_connection=False should result in RS topology. client = rs_or_single_client(directConnection=False) client.admin.command('ping') self.assertGreaterEqual(len(client.nodes), 1) self.assertIn(client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary]) client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): MongoClient(['host1', 'host2'], directConnection=True) class TestExhaustCursor(IntegrationTest): """Test that clients properly handle errors from exhaust cursors.""" def setUp(self): super(TestExhaustCursor, self).setUp() if client_context.is_mongos: raise SkipTest("mongos doesn't support exhaust, SERVER-2627") def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = connected(rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = get_pool(client) sock_info = one(pool.sockets) # This will cause OperationFailure in all mongo versions since # the value for $orderby must be a document. cursor = collection.find( SON([('$query', {}), ('$orderby', True)]), cursor_type=CursorType.EXHAUST) self.assertRaises(OperationFailure, cursor.next) self.assertFalse(sock_info.closed) # The socket was checked in and the semaphore was decremented. self.assertIn(sock_info, pool.sockets) self.assertTrue(pool._socket_semaphore.acquire(blocking=False)) def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) self.addCleanup(client_context.client.pymongo_test.test.drop) pool = get_pool(client) pool._check_interval_seconds = None # Never check. sock_info = one(pool.sockets) cursor = collection.find(cursor_type=CursorType.EXHAUST) # Initial query succeeds. cursor.next() # Cause a server error on getmore. def receive_message(request_id): # Discard the actual server response. SocketInfo.receive_message(sock_info, request_id) # responseFlags bit 1 is QueryFailure. msg = struct.pack('= count, 'find %s %s event(s)' % (count, event)) def check_out(self, op): """Run the 'checkOut' operation.""" label = op['label'] with self.pool.get_socket({}, checkout=True) as sock_info: if label: self.labels[label] = sock_info else: self.addCleanup(sock_info.close_socket, None) def check_in(self, op): """Run the 'checkIn' operation.""" label = op['connection'] sock_info = self.labels[label] self.pool.return_socket(sock_info) def clear(self, op): """Run the 'clear' operation.""" self.pool.reset() def close(self, op): """Run the 'close' operation.""" self.pool.close() def run_operation(self, op): """Run a single operation in a test.""" op_name = camel_to_snake(op['name']) thread = op['thread'] meth = getattr(self, op_name) if thread: self.targets[thread].schedule(lambda: meth(op)) else: meth(op) def run_operations(self, ops): """Run a test's operations.""" for op in ops: self._ops.append(op) self.run_operation(op) def check_object(self, actual, expected): """Assert that the actual object matches the expected object.""" self.assertEqual(type(actual), OBJECT_TYPES[expected['type']]) for attr, expected_val in expected.items(): if attr == 'type': continue c2s = camel_to_snake(attr) actual_val = getattr(actual, c2s) if expected_val == 42: self.assertIsNotNone(actual_val) else: self.assertEqual(actual_val, expected_val) def check_event(self, actual, expected): """Assert that the actual event matches the expected event.""" self.check_object(actual, expected) def actual_events(self, ignore): """Return all the non-ignored events.""" ignore = tuple(OBJECT_TYPES[name] for name in ignore) return [event for event in self.listener.events if not isinstance(event, ignore)] def check_events(self, events, ignore): """Check the events of a test.""" actual_events = self.actual_events(ignore) for actual, expected in zip(actual_events, events): self.check_event(actual, expected) if len(events) > len(actual_events): self.fail('missing events: %r' % (events[len(actual_events):],)) elif len(events) < len(actual_events): self.fail('extra events: %r' % (actual_events[len(events):],)) def check_error(self, actual, expected): message = expected.pop('message') self.check_object(actual, expected) self.assertIn(message, str(actual)) def run_scenario(self, scenario_def, test): """Run a CMAP spec test.""" self.assertEqual(scenario_def['version'], 1) self.assertEqual(scenario_def['style'], 'unit') self.listener = CMAPListener() self._ops = [] opts = test['poolOptions'].copy() opts['event_listeners'] = [self.listener] client = single_client(**opts) self.addCleanup(client.close) self.pool = get_pool(client) # Map of target names to Thread objects. self.targets = dict() # Map of label names to Connection objects self.labels = dict() def cleanup(): for t in self.targets.values(): t.stop() for t in self.targets.values(): t.join(5) for conn in self.labels.values(): conn.close_socket(None) self.addCleanup(cleanup) try: if test['error']: with self.assertRaises(PyMongoError) as ctx: self.run_operations(test['operations']) self.check_error(ctx.exception, test['error']) else: self.run_operations(test['operations']) self.check_events(test['events'], test['ignore']) except Exception: # Print the events after a test failure. print() print('Failed test: %r' % (test['description'],)) print('Operations:') for op in self._ops: print(op) print('Threads:') print(self.targets) print('Connections:') print(self.labels) print('Events:') for event in self.listener.events: print(event) raise POOL_OPTIONS = { 'maxPoolSize': 50, 'minPoolSize': 1, 'maxIdleTimeMS': 10000, 'waitQueueTimeoutMS': 10000 } # # Prose tests. Numbers correspond to the prose test number in the spec. # def test_1_client_connection_pool_options(self): client = rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): client = rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) client.admin.command('isMaster') # Discover at least one secondary. if client_context.has_secondaries: client.admin.command( 'isMaster', read_preference=ReadPreference.SECONDARY) pools = get_pools(client) pool_opts = pools[0].opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) for pool in pools[1:]: self.assertEqual(pool.opts, pool_opts) def test_3_uri_connection_pool_options(self): opts = '&'.join(['%s=%s' % (k, v) for k, v in self.POOL_OPTIONS.items()]) uri = 'mongodb://%s/?%s' % (client_context.pair, opts) client = rs_or_single_client(uri, **self.credentials) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() client = single_client(event_listeners=[listener]) self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. client.admin.command('isMaster') self.assertEqual( listener.event_count(ConnectionCheckOutStartedEvent), 1) self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) # Uses the existing connection. client.admin.command('isMaster') self.assertEqual( listener.event_count(ConnectionCheckOutStartedEvent), 2) self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) client.close() self.assertEqual(listener.event_count(PoolClearedEvent), 1) self.assertEqual(listener.event_count(ConnectionClosedEvent), 1) def test_5_check_out_fails_connection_error(self): listener = CMAPListener() client = single_client(event_listeners=[listener]) self.addCleanup(client.close) pool = get_pool(client) def mock_connect(*args, **kwargs): raise ConnectionFailure('connect failed') pool.connect = mock_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, 'connect') # Attempt to create a new connection. with self.assertRaisesRegex(ConnectionFailure, 'connect failed'): client.admin.command('isMaster') self.assertIsInstance(listener.events[0], PoolCreatedEvent) self.assertIsInstance(listener.events[1], ConnectionCheckOutStartedEvent) self.assertIsInstance(listener.events[2], ConnectionCheckOutFailedEvent) self.assertIsInstance(listener.events[3], PoolClearedEvent) failed_event = listener.events[2] self.assertEqual( failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) def test_5_check_out_fails_auth_error(self): listener = CMAPListener() client = single_client(event_listeners=[listener]) self.addCleanup(client.close) pool = get_pool(client) connect = pool.connect def mock_check_auth(self, *args, **kwargs): self.close_socket(ConnectionClosedReason.ERROR) raise ConnectionFailure('auth failed') def mock_connect(*args, **kwargs): sock_info = connect(*args, **kwargs) sock_info.check_auth = functools.partial(mock_check_auth, sock_info) # Un-patch to break the cyclic reference. self.addCleanup(delattr, sock_info, 'check_auth') return sock_info pool.connect = mock_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, 'connect') # Attempt to create a new connection. with self.assertRaisesRegex(ConnectionFailure, 'auth failed'): client.admin.command('isMaster') self.assertIsInstance(listener.events[0], PoolCreatedEvent) self.assertIsInstance(listener.events[1], ConnectionCheckOutStartedEvent) self.assertIsInstance(listener.events[2], ConnectionCreatedEvent) # Error happens here. self.assertIsInstance(listener.events[3], ConnectionClosedEvent) self.assertIsInstance(listener.events[4], ConnectionCheckOutFailedEvent) self.assertIsInstance(listener.events[5], PoolClearedEvent) failed_event = listener.events[4] self.assertEqual( failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) # # Extra non-spec tests # def assertRepr(self, obj): new_obj = eval(repr(obj)) self.assertEqual(type(new_obj), type(obj)) self.assertEqual(repr(new_obj), repr(obj)) def test_events_repr(self): host = ('localhost', 27017) self.assertRepr(ConnectionCheckedInEvent(host, 1)) self.assertRepr(ConnectionCheckedOutEvent(host, 1)) self.assertRepr(ConnectionCheckOutFailedEvent( host, ConnectionCheckOutFailedReason.POOL_CLOSED)) self.assertRepr(ConnectionClosedEvent( host, 1, ConnectionClosedReason.POOL_CLOSED)) self.assertRepr(ConnectionCreatedEvent(host, 1)) self.assertRepr(ConnectionReadyEvent(host, 1)) self.assertRepr(ConnectionCheckOutStartedEvent(host)) self.assertRepr(PoolCreatedEvent(host, {})) self.assertRepr(PoolClearedEvent(host)) self.assertRepr(PoolClosedEvent(host)) def create_test(scenario_def, test, name): def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario class CMAPTestCreator(TestCreator): def tests(self, scenario_def): """Extract the tests from a spec file. CMAP tests do not have a 'tests' field. The whole file represents a single test case. """ return [scenario_def] test_creator = CMAPTestCreator(create_test, TestCMAP, TestCMAP.TEST_PATH) test_creator.create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_code.py000066400000000000000000000073761374256237000166140ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Code wrapper.""" import sys sys.path[0:0] = [""] from bson.code import Code from test import unittest class TestCode(unittest.TestCase): def test_types(self): self.assertRaises(TypeError, Code, 5) self.assertRaises(TypeError, Code, None) self.assertRaises(TypeError, Code, "aoeu", 5) self.assertRaises(TypeError, Code, u"aoeu", 5) self.assertTrue(Code("aoeu")) self.assertTrue(Code(u"aoeu")) self.assertTrue(Code("aoeu", {})) self.assertTrue(Code(u"aoeu", {})) def test_read_only(self): c = Code("blah") def set_c(): c.scope = 5 self.assertRaises(AttributeError, set_c) def test_code(self): a_string = "hello world" a_code = Code("hello world") self.assertTrue(a_code.startswith("hello")) self.assertTrue(a_code.endswith("world")) self.assertTrue(isinstance(a_code, Code)) self.assertFalse(isinstance(a_string, Code)) self.assertIsNone(a_code.scope) with_scope = Code('hello world', {'my_var': 5}) self.assertEqual({'my_var': 5}, with_scope.scope) empty_scope = Code('hello world', {}) self.assertEqual({}, empty_scope.scope) another_scope = Code(with_scope, {'new_var': 42}) self.assertEqual(str(with_scope), str(another_scope)) self.assertEqual({'new_var': 42, 'my_var': 5}, another_scope.scope) # No error. Code(u'héllø world¡') def test_repr(self): c = Code("hello world", {}) self.assertEqual(repr(c), "Code('hello world', {})") c.scope["foo"] = "bar" self.assertEqual(repr(c), "Code('hello world', {'foo': 'bar'})") c = Code("hello world", {"blah": 3}) self.assertEqual(repr(c), "Code('hello world', {'blah': 3})") c = Code("\x08\xFF") self.assertEqual(repr(c), "Code(%s, None)" % (repr("\x08\xFF"),)) def test_equality(self): b = Code("hello") c = Code("hello", {"foo": 5}) self.assertNotEqual(b, c) self.assertEqual(c, Code("hello", {"foo": 5})) self.assertNotEqual(c, Code("hello", {"foo": 6})) self.assertEqual(b, Code("hello")) self.assertEqual(b, Code("hello", None)) self.assertNotEqual(b, Code("hello ")) self.assertNotEqual("hello", Code("hello")) # Explicitly test inequality self.assertFalse(c != Code("hello", {"foo": 5})) self.assertFalse(b != Code("hello")) self.assertFalse(b != Code("hello", None)) def test_hash(self): self.assertRaises(TypeError, hash, Code("hello world")) def test_scope_preserved(self): a = Code("hello") b = Code("hello", {"foo": 5}) self.assertEqual(a, Code(a)) self.assertEqual(b, Code(b)) self.assertNotEqual(a, Code(b)) self.assertNotEqual(b, Code(a)) def test_scope_kwargs(self): self.assertEqual({"a": 1}, Code("", a=1).scope) self.assertEqual({"a": 1}, Code("", {"a": 2}, a=1).scope) self.assertEqual({"a": 1, "b": 2, "c": 3}, Code("", {"b": 2}, a=1, c=3).scope) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_collation.py000066400000000000000000000404151374256237000176550ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the collation module.""" import functools import warnings from pymongo.collation import ( Collation, CollationCaseFirst, CollationStrength, CollationAlternate, CollationMaxVariable) from pymongo.errors import ConfigurationError from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, ReplaceOne, UpdateMany, UpdateOne) from pymongo.write_concern import WriteConcern from test import unittest, client_context from test.utils import EventListener, ignore_deprecations, rs_or_single_client class TestCollationObject(unittest.TestCase): def test_constructor(self): self.assertRaises(TypeError, Collation, locale=42) # Fill in a locale to test the other options. _Collation = functools.partial(Collation, 'en_US') # No error. _Collation(caseFirst=CollationCaseFirst.UPPER) self.assertRaises(TypeError, _Collation, caseLevel='true') self.assertRaises(ValueError, _Collation, strength='six') self.assertRaises(TypeError, _Collation, numericOrdering='true') self.assertRaises(TypeError, _Collation, alternate=5) self.assertRaises(TypeError, _Collation, maxVariable=2) self.assertRaises(TypeError, _Collation, normalization='false') self.assertRaises(TypeError, _Collation, backwards='true') # No errors. Collation('en_US', future_option='bar', another_option=42) collation = Collation( 'en_US', caseLevel=True, caseFirst=CollationCaseFirst.UPPER, strength=CollationStrength.QUATERNARY, numericOrdering=True, alternate=CollationAlternate.SHIFTED, maxVariable=CollationMaxVariable.SPACE, normalization=True, backwards=True) self.assertEqual({ 'locale': 'en_US', 'caseLevel': True, 'caseFirst': 'upper', 'strength': 4, 'numericOrdering': True, 'alternate': 'shifted', 'maxVariable': 'space', 'normalization': True, 'backwards': True }, collation.document) self.assertEqual({ 'locale': 'en_US', 'backwards': True }, Collation('en_US', backwards=True).document) def raisesConfigurationErrorForOldMongoDB(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): if client_context.version.at_least(3, 3, 9): return func(self, *args, **kwargs) else: with self.assertRaises(ConfigurationError): return func(self, *args, **kwargs) return wrapper class TestCollation(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = EventListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation('en_US') cls.warn_context = warnings.catch_warnings() cls.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) @classmethod def tearDownClass(cls): cls.warn_context.__exit__() cls.warn_context = None cls.client.close() def tearDown(self): self.listener.results.clear() def last_command_started(self): return self.listener.results['started'][-1].command def assertCollationInLastCommand(self): self.assertEqual( self.collation.document, self.last_command_started()['collation']) @raisesConfigurationErrorForOldMongoDB def test_create_collection(self): self.db.test.drop() self.db.create_collection('test', collation=self.collation) self.assertCollationInLastCommand() # Test passing collation as a dict as well. self.db.test.drop() self.listener.results.clear() self.db.create_collection('test', collation=self.collation.document) self.assertCollationInLastCommand() def test_index_model(self): model = IndexModel([('a', 1), ('b', -1)], collation=self.collation) self.assertEqual(self.collation.document, model.document['collation']) @raisesConfigurationErrorForOldMongoDB def test_create_index(self): self.db.test.create_index('foo', collation=self.collation) ci_cmd = self.listener.results['started'][0].command self.assertEqual( self.collation.document, ci_cmd['indexes'][0]['collation']) @raisesConfigurationErrorForOldMongoDB def test_ensure_index(self): self.db.test.ensure_index('foo', collation=self.collation) ci_cmd = self.listener.results['started'][0].command self.assertEqual( self.collation.document, ci_cmd['indexes'][0]['collation']) @raisesConfigurationErrorForOldMongoDB def test_aggregate(self): self.db.test.aggregate([{'$group': {'_id': 42}}], collation=self.collation) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB @ignore_deprecations def test_count(self): self.db.test.count(collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() self.db.test.find(collation=self.collation).count() self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_count_documents(self): self.db.test.count_documents({}, collation=self.collation) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_distinct(self): self.db.test.distinct('foo', collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() self.db.test.find(collation=self.collation).distinct('foo') self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_find_command(self): self.db.test.insert_one({'is this thing on?': True}) self.listener.results.clear() next(self.db.test.find(collation=self.collation)) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_explain_command(self): self.listener.results.clear() self.db.test.find(collation=self.collation).explain() # The collation should be part of the explained command. self.assertEqual( self.collation.document, self.last_command_started()['explain']['collation']) @raisesConfigurationErrorForOldMongoDB @client_context.require_version_max(4, 1, 0, -1) def test_group(self): self.db.test.group('foo', {'foo': {'$gt': 42}}, {}, 'function(a, b) { return a; }', collation=self.collation) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_map_reduce(self): self.db.test.map_reduce('function() {}', 'function() {}', 'output', collation=self.collation) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_delete(self): self.db.test.delete_one({'foo': 42}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['deletes'][0]['collation']) self.listener.results.clear() self.db.test.delete_many({'foo': 42}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['deletes'][0]['collation']) self.listener.results.clear() self.db.test.remove({'foo': 42}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['deletes'][0]['collation']) @raisesConfigurationErrorForOldMongoDB def test_update(self): self.db.test.update({'foo': 42}, {'$set': {'foo': 'bar'}}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['updates'][0]['collation']) self.listener.results.clear() self.db.test.save({'_id': 12345}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['updates'][0]['collation']) self.listener.results.clear() self.db.test.replace_one({'foo': 42}, {'foo': 43}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['updates'][0]['collation']) self.listener.results.clear() self.db.test.update_one({'foo': 42}, {'$set': {'foo': 43}}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['updates'][0]['collation']) self.listener.results.clear() self.db.test.update_many({'foo': 42}, {'$set': {'foo': 43}}, collation=self.collation) command = self.listener.results['started'][0].command self.assertEqual( self.collation.document, command['updates'][0]['collation']) @raisesConfigurationErrorForOldMongoDB def test_find_and(self): self.db.test.find_and_modify({'foo': 42}, {'$set': {'foo': 43}}, collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() self.db.test.find_one_and_delete({'foo': 42}, collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() self.db.test.find_one_and_update({'foo': 42}, {'$set': {'foo': 43}}, collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() self.db.test.find_one_and_replace({'foo': 42}, {'foo': 43}, collation=self.collation) self.assertCollationInLastCommand() @raisesConfigurationErrorForOldMongoDB def test_bulk_write(self): self.db.test.collection.bulk_write([ DeleteOne({'noCollation': 42}), DeleteMany({'noCollation': 42}), DeleteOne({'foo': 42}, collation=self.collation), DeleteMany({'foo': 42}, collation=self.collation), ReplaceOne({'noCollation': 24}, {'bar': 42}), UpdateOne({'noCollation': 84}, {'$set': {'bar': 10}}, upsert=True), UpdateMany({'noCollation': 45}, {'$set': {'bar': 42}}), ReplaceOne({'foo': 24}, {'foo': 42}, collation=self.collation), UpdateOne({'foo': 84}, {'$set': {'foo': 10}}, upsert=True, collation=self.collation), UpdateMany({'foo': 45}, {'$set': {'foo': 42}}, collation=self.collation) ]) delete_cmd = self.listener.results['started'][0].command update_cmd = self.listener.results['started'][1].command def check_ops(ops): for op in ops: if 'noCollation' in op['q']: self.assertNotIn('collation', op) else: self.assertEqual(self.collation.document, op['collation']) check_ops(delete_cmd['deletes']) check_ops(update_cmd['updates']) @raisesConfigurationErrorForOldMongoDB def test_bulk(self): bulk = self.db.test.initialize_ordered_bulk_op() bulk.find({'noCollation': 42}).remove_one() bulk.find({'noCollation': 42}).remove() bulk.find({'foo': 42}, collation=self.collation).remove_one() bulk.find({'foo': 42}, collation=self.collation).remove() bulk.find({'noCollation': 24}).replace_one({'bar': 42}) bulk.find({'noCollation': 84}).upsert().update_one( {'$set': {'foo': 10}}) bulk.find({'noCollation': 45}).update({'$set': {'bar': 42}}) bulk.find({'foo': 24}, collation=self.collation).replace_one( {'foo': 42}) bulk.find({'foo': 84}, collation=self.collation).upsert().update_one( {'$set': {'foo': 10}}) bulk.find({'foo': 45}, collation=self.collation).update({ '$set': {'foo': 42}}) bulk.execute() delete_cmd = self.listener.results['started'][0].command update_cmd = self.listener.results['started'][1].command def check_ops(ops): for op in ops: if 'noCollation' in op['q']: self.assertNotIn('collation', op) else: self.assertEqual(self.collation.document, op['collation']) check_ops(delete_cmd['deletes']) check_ops(update_cmd['updates']) @client_context.require_version_max(3, 3, 8) def test_mixed_bulk_collation(self): bulk = self.db.test.initialize_unordered_bulk_op() bulk.find({'foo': 42}).upsert().update_one( {'$set': {'bar': 10}}) bulk.find({'foo': 43}, collation=self.collation).remove_one() with self.assertRaises(ConfigurationError): bulk.execute() self.assertIsNone(self.db.test.find_one({'foo': 42})) @raisesConfigurationErrorForOldMongoDB def test_indexes_same_keys_different_collations(self): self.db.test.drop() usa_collation = Collation('en_US') ja_collation = Collation('ja') self.db.test.create_indexes([ IndexModel('fieldname', collation=usa_collation), IndexModel('fieldname', name='japanese_version', collation=ja_collation), IndexModel('fieldname', name='simple') ]) indexes = self.db.test.index_information() self.assertEqual(usa_collation.document['locale'], indexes['fieldname_1']['collation']['locale']) self.assertEqual(ja_collation.document['locale'], indexes['japanese_version']['collation']['locale']) self.assertNotIn('collation', indexes['simple']) self.db.test.drop_index('fieldname_1') indexes = self.db.test.index_information() self.assertIn('japanese_version', indexes) self.assertIn('simple', indexes) self.assertNotIn('fieldname', indexes) def test_unacknowledged_write(self): unacknowledged = WriteConcern(w=0) collection = self.db.get_collection( 'test', write_concern=unacknowledged) with self.assertRaises(ConfigurationError): collection.update_one( {'hello': 'world'}, {'$set': {'hello': 'moon'}}, collation=self.collation) bulk = collection.initialize_ordered_bulk_op() bulk.find({'hello': 'world'}, collation=self.collation).update_one( {'$set': {'hello': 'moon'}}) with self.assertRaises(ConfigurationError): bulk.execute() update_one = UpdateOne({'hello': 'world'}, {'$set': {'hello': 'moon'}}, collation=self.collation) with self.assertRaises(ConfigurationError): collection.bulk_write([update_one]) @raisesConfigurationErrorForOldMongoDB def test_cursor_collation(self): self.db.test.insert_one({'hello': 'world'}) next(self.db.test.find().collation(self.collation)) self.assertCollationInLastCommand() pymongo-3.11.0/test/test_collection.py000066400000000000000000002776771374256237000200520ustar00rootroot00000000000000# -*- coding: utf-8 -*- # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the collection module.""" import contextlib import re import sys import threading from codecs import utf_8_decode from collections import defaultdict sys.path[0:0] = [""] from bson import encode from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.code import Code from bson.codec_options import CodecOptions from bson.objectid import ObjectId from bson.py3compat import itervalues from bson.son import SON from pymongo import (ASCENDING, DESCENDING, GEO2D, GEOHAYSTACK, GEOSPHERE, HASHED, TEXT) from pymongo.bulk import BulkWriteError from pymongo.collection import Collection, ReturnDocument from pymongo.command_cursor import CommandCursor from pymongo.cursor import CursorType from pymongo.errors import (ConfigurationError, DocumentTooLarge, DuplicateKeyError, ExecutionTimeout, InvalidDocument, InvalidName, InvalidOperation, OperationFailure, WriteConcernError) from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.mongo_client import MongoClient from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference from pymongo.results import (InsertOneResult, InsertManyResult, UpdateResult, DeleteResult) from pymongo.write_concern import WriteConcern from test import client_context, unittest from test.test_client import IntegrationTest from test.utils import (get_pool, ignore_deprecations, is_mongos, rs_or_single_client, single_client, wait_until, EventListener, IMPOSSIBLE_WRITE_CONCERN) class TestCollectionNoConnect(unittest.TestCase): """Test Collection features on a client that does not connect. """ @classmethod def setUpClass(cls): cls.db = MongoClient(connect=False).pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) def make_col(base, name): return base[name] self.assertRaises(InvalidName, make_col, self.db, "") self.assertRaises(InvalidName, make_col, self.db, "te$t") self.assertRaises(InvalidName, make_col, self.db, ".test") self.assertRaises(InvalidName, make_col, self.db, "test.") self.assertRaises(InvalidName, make_col, self.db, "tes..t") self.assertRaises(InvalidName, make_col, self.db.test, "") self.assertRaises(InvalidName, make_col, self.db.test, "te$t") self.assertRaises(InvalidName, make_col, self.db.test, ".test") self.assertRaises(InvalidName, make_col, self.db.test, "test.") self.assertRaises(InvalidName, make_col, self.db.test, "tes..t") self.assertRaises(InvalidName, make_col, self.db.test, "tes\x00t") def test_getattr(self): coll = self.db.test self.assertTrue(isinstance(coll['_does_not_exist'], Collection)) with self.assertRaises(AttributeError) as context: coll._does_not_exist # Message should be: # "AttributeError: Collection has no attribute '_does_not_exist'. To # access the test._does_not_exist collection, use # database['test._does_not_exist']." self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) coll2 = coll.with_options(write_concern=WriteConcern(w=0)) self.assertEqual(coll2.write_concern, WriteConcern(w=0)) self.assertNotEqual(coll.write_concern, coll2.write_concern) coll3 = coll2.subcoll self.assertEqual(coll2.write_concern, coll3.write_concern) coll4 = coll2["subcoll"] self.assertEqual(coll2.write_concern, coll4.write_concern) def test_iteration(self): self.assertRaises(TypeError, next, self.db) class TestCollection(IntegrationTest): @classmethod def setUpClass(cls): super(TestCollection, cls).setUpClass() cls.w = client_context.w @classmethod def tearDownClass(cls): cls.db.drop_collection("test_large_limit") def setUp(self): self.db.test.drop() def tearDown(self): self.db.test.drop() @contextlib.contextmanager def write_concern_collection(self): if client_context.version.at_least(3, 3, 9) and client_context.is_rs: with self.assertRaises(WriteConcernError): # Unsatisfiable write concern. yield Collection( self.db, 'test', write_concern=WriteConcern(w=len(client_context.nodes) + 1)) else: yield self.db.test def test_equality(self): self.assertTrue(isinstance(self.db.test, Collection)) self.assertEqual(self.db.test, self.db["test"]) self.assertEqual(self.db.test, Collection(self.db, "test")) self.assertEqual(self.db.test.mike, self.db["test.mike"]) self.assertEqual(self.db.test["mike"], self.db["test.mike"]) @client_context.require_version_min(3, 3, 9) def test_create(self): # No Exception. db = client_context.client.pymongo_test db.create_test_no_wc.drop() wait_until( lambda: 'create_test_no_wc' not in db.list_collection_names(), 'drop create_test_no_wc collection') Collection(db, name='create_test_no_wc', create=True) wait_until( lambda: 'create_test_no_wc' in db.list_collection_names(), 'create create_test_no_wc collection') # SERVER-33317 if (not client_context.is_mongos or not client_context.version.at_least(3, 7, 0)): with self.assertRaises(OperationFailure): Collection( db, name='create-test-wc', write_concern=IMPOSSIBLE_WRITE_CONCERN, create=True) def test_drop_nonexistent_collection(self): self.db.drop_collection('test') self.assertFalse('test' in self.db.list_collection_names()) # No exception self.db.drop_collection('test') def test_create_indexes(self): db = self.db self.assertRaises(TypeError, db.test.create_indexes, 'foo') self.assertRaises(TypeError, db.test.create_indexes, ['foo']) self.assertRaises(TypeError, IndexModel, 5) self.assertRaises(ValueError, IndexModel, []) db.test.drop_indexes() db.test.insert_one({}) self.assertEqual(len(db.test.index_information()), 1) db.test.create_indexes([IndexModel("hello")]) db.test.create_indexes([IndexModel([("hello", DESCENDING), ("world", ASCENDING)])]) # Tuple instead of list. db.test.create_indexes([IndexModel((("world", ASCENDING),))]) self.assertEqual(len(db.test.index_information()), 4) db.test.drop_indexes() names = db.test.create_indexes([IndexModel([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world")]) self.assertEqual(names, ["hello_world"]) db.test.drop_indexes() self.assertEqual(len(db.test.index_information()), 1) db.test.create_indexes([IndexModel("hello")]) self.assertTrue("hello_1" in db.test.index_information()) db.test.drop_indexes() self.assertEqual(len(db.test.index_information()), 1) names = db.test.create_indexes([IndexModel([("hello", DESCENDING), ("world", ASCENDING)]), IndexModel("hello")]) info = db.test.index_information() for name in names: self.assertTrue(name in info) db.test.drop() db.test.insert_one({'a': 1}) db.test.insert_one({'a': 1}) self.assertRaises( DuplicateKeyError, db.test.create_indexes, [IndexModel('a', unique=True)]) with self.write_concern_collection() as coll: coll.create_indexes([IndexModel('hello')]) @client_context.require_version_max(4, 3, -1) def test_create_indexes_commitQuorum_requires_44(self): db = self.db with self.assertRaisesRegex( ConfigurationError, 'Must be connected to MongoDB 4\.4\+ to use the commitQuorum ' 'option for createIndexes'): db.coll.create_indexes([IndexModel('a')], commitQuorum="majority") @client_context.require_no_standalone @client_context.require_version_min(4, 4, -1) def test_create_indexes_commitQuorum(self): self.db.coll.create_indexes([IndexModel('a')], commitQuorum="majority") def test_create_index(self): db = self.db self.assertRaises(TypeError, db.test.create_index, 5) self.assertRaises(TypeError, db.test.create_index, {"hello": 1}) self.assertRaises(ValueError, db.test.create_index, []) db.test.drop_indexes() db.test.insert_one({}) self.assertEqual(len(db.test.index_information()), 1) db.test.create_index("hello") db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)]) # Tuple instead of list. db.test.create_index((("world", ASCENDING),)) self.assertEqual(len(db.test.index_information()), 4) db.test.drop_indexes() ix = db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world") self.assertEqual(ix, "hello_world") db.test.drop_indexes() self.assertEqual(len(db.test.index_information()), 1) db.test.create_index("hello") self.assertTrue("hello_1" in db.test.index_information()) db.test.drop_indexes() self.assertEqual(len(db.test.index_information()), 1) db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)]) self.assertTrue("hello_-1_world_1" in db.test.index_information()) db.test.drop() db.test.insert_one({'a': 1}) db.test.insert_one({'a': 1}) self.assertRaises( DuplicateKeyError, db.test.create_index, 'a', unique=True) with self.write_concern_collection() as coll: coll.create_index([('hello', DESCENDING)]) def test_drop_index(self): db = self.db db.test.drop_indexes() db.test.create_index("hello") name = db.test.create_index("goodbye") self.assertEqual(len(db.test.index_information()), 3) self.assertEqual(name, "goodbye_1") db.test.drop_index(name) # Drop it again. with self.assertRaises(OperationFailure): db.test.drop_index(name) self.assertEqual(len(db.test.index_information()), 2) self.assertTrue("hello_1" in db.test.index_information()) db.test.drop_indexes() db.test.create_index("hello") name = db.test.create_index("goodbye") self.assertEqual(len(db.test.index_information()), 3) self.assertEqual(name, "goodbye_1") db.test.drop_index([("goodbye", ASCENDING)]) self.assertEqual(len(db.test.index_information()), 2) self.assertTrue("hello_1" in db.test.index_information()) with self.write_concern_collection() as coll: coll.drop_index('hello_1') @client_context.require_no_mongos @client_context.require_test_commands def test_index_management_max_time_ms(self): if (client_context.version[:2] == (3, 4) and client_context.version[2] < 4): raise unittest.SkipTest("SERVER-27711") coll = self.db.test self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: self.assertRaises( ExecutionTimeout, coll.create_index, "foo", maxTimeMS=1) self.assertRaises( ExecutionTimeout, coll.create_indexes, [IndexModel("foo")], maxTimeMS=1) self.assertRaises( ExecutionTimeout, coll.drop_index, "foo", maxTimeMS=1) self.assertRaises( ExecutionTimeout, coll.drop_indexes, maxTimeMS=1) self.assertRaises( ExecutionTimeout, coll.reindex, maxTimeMS=1) finally: self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_reindex(self): if not client_context.supports_reindex: raise unittest.SkipTest( "reindex is no longer supported by mongos 4.1+") db = self.db db.drop_collection("test") db.test.insert_one({"foo": "bar", "who": "what", "when": "how"}) db.test.create_index("foo") db.test.create_index("who") db.test.create_index("when") info = db.test.index_information() def check_result(result): self.assertEqual(4, result['nIndexes']) indexes = result['indexes'] names = [idx['name'] for idx in indexes] for name in names: self.assertTrue(name in info) for key in info: self.assertTrue(key in names) reindexed = db.test.reindex() if 'raw' in reindexed: # mongos for result in itervalues(reindexed['raw']): check_result(result) else: check_result(reindexed) coll = Collection( self.db, 'test', write_concern=WriteConcern(w=100)) # No error since writeConcern is not sent. coll.reindex() def test_list_indexes(self): db = self.db db.test.drop() db.test.insert_one({}) # create collection def map_indexes(indexes): return dict([(index["name"], index) for index in indexes]) indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 1) self.assertTrue("_id_" in map_indexes(indexes)) db.test.create_index("hello") indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 2) self.assertEqual(map_indexes(indexes)["hello_1"]["key"], SON([("hello", ASCENDING)])) db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 3) index_map = map_indexes(indexes) self.assertEqual(index_map["hello_-1_world_1"]["key"], SON([("hello", DESCENDING), ("world", ASCENDING)])) self.assertEqual(True, index_map["hello_-1_world_1"]["unique"]) # List indexes on a collection that does not exist. indexes = list(db.does_not_exist.list_indexes()) self.assertEqual(len(indexes), 0) # List indexes on a database that does not exist. indexes = list(self.client.db_does_not_exist.coll.list_indexes()) self.assertEqual(len(indexes), 0) def test_index_info(self): db = self.db db.test.drop() db.test.insert_one({}) # create collection self.assertEqual(len(db.test.index_information()), 1) self.assertTrue("_id_" in db.test.index_information()) db.test.create_index("hello") self.assertEqual(len(db.test.index_information()), 2) self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) self.assertEqual(len(db.test.index_information()), 3) self.assertEqual([("hello", DESCENDING), ("world", ASCENDING)], db.test.index_information()["hello_-1_world_1"]["key"] ) self.assertEqual( True, db.test.index_information()["hello_-1_world_1"]["unique"]) def test_index_geo2d(self): db = self.db db.test.drop_indexes() self.assertEqual('loc_2d', db.test.create_index([("loc", GEO2D)])) index_info = db.test.index_information()['loc_2d'] self.assertEqual([('loc', '2d')], index_info['key']) @client_context.require_no_mongos def test_index_haystack(self): db = self.db db.test.drop() _id = db.test.insert_one({ "pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant" }).inserted_id db.test.insert_one({ "pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant" }) db.test.insert_one({ "pos": {"long": 59.1, "lat": 87.2}, "type": "office" }) db.test.create_index( [("pos", GEOHAYSTACK), ("type", ASCENDING)], bucketSize=1 ) results = db.command(SON([ ("geoSearch", "test"), ("near", [33, 33]), ("maxDistance", 6), ("search", {"type": "restaurant"}), ("limit", 30), ]))['results'] self.assertEqual(2, len(results)) self.assertEqual({ "_id": _id, "pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant" }, results[0]) @client_context.require_no_mongos def test_index_text(self): db = self.db db.test.drop_indexes() self.assertEqual("t_text", db.test.create_index([("t", TEXT)])) index_info = db.test.index_information()["t_text"] self.assertTrue("weights" in index_info) db.test.insert_many([ {'t': 'spam eggs and spam'}, {'t': 'spam'}, {'t': 'egg sausage and bacon'}]) # MongoDB 2.6 text search. Create 'score' field in projection. cursor = db.test.find( {'$text': {'$search': 'spam'}}, {'score': {'$meta': 'textScore'}}) # Sort by 'score' field. cursor.sort([('score', {'$meta': 'textScore'})]) results = list(cursor) self.assertTrue(results[0]['score'] >= results[1]['score']) db.test.drop_indexes() def test_index_2dsphere(self): db = self.db db.test.drop_indexes() self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) for dummy, info in db.test.index_information().items(): field, idx_type = info['key'][0] if field == 'geo' and idx_type == '2dsphere': break else: self.fail("2dsphere index not found.") poly = {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} query = {"geo": {"$within": {"$geometry": poly}}} # This query will error without a 2dsphere index. db.test.find(query) db.test.drop_indexes() def test_index_hashed(self): db = self.db db.test.drop_indexes() self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) for dummy, info in db.test.index_information().items(): field, idx_type = info['key'][0] if field == 'a' and idx_type == 'hashed': break else: self.fail("hashed index not found.") db.test.drop_indexes() def test_index_sparse(self): db = self.db db.test.drop_indexes() db.test.create_index([('key', ASCENDING)], sparse=True) self.assertTrue(db.test.index_information()['key_1']['sparse']) def test_index_background(self): db = self.db db.test.drop_indexes() db.test.create_index([('keya', ASCENDING)]) db.test.create_index([('keyb', ASCENDING)], background=False) db.test.create_index([('keyc', ASCENDING)], background=True) self.assertFalse('background' in db.test.index_information()['keya_1']) self.assertFalse(db.test.index_information()['keyb_1']['background']) self.assertTrue(db.test.index_information()['keyc_1']['background']) def _drop_dups_setup(self, db): db.drop_collection('test') db.test.insert_one({'i': 1}) db.test.insert_one({'i': 2}) db.test.insert_one({'i': 2}) # duplicate db.test.insert_one({'i': 3}) @client_context.require_version_max(2, 6) def test_index_drop_dups(self): # Try dropping duplicates db = self.db self._drop_dups_setup(db) # No error, just drop the duplicate db.test.create_index([('i', ASCENDING)], unique=True, dropDups=True) # Duplicate was dropped self.assertEqual(3, db.test.count_documents({})) # Index was created, plus the index on _id self.assertEqual(2, len(db.test.index_information())) def test_index_dont_drop_dups(self): # Try *not* dropping duplicates db = self.db self._drop_dups_setup(db) # There's a duplicate def test_create(): db.test.create_index( [('i', ASCENDING)], unique=True, dropDups=False ) self.assertRaises(DuplicateKeyError, test_create) # Duplicate wasn't dropped self.assertEqual(4, db.test.count_documents({})) # Index wasn't created, only the default index on _id self.assertEqual(1, len(db.test.index_information())) # Get the plan dynamically because the explain format will change. def get_plan_stage(self, root, stage): if root.get('stage') == stage: return root elif "inputStage" in root: return self.get_plan_stage(root['inputStage'], stage) elif "inputStages" in root: for i in root['inputStages']: stage = self.get_plan_stage(i, stage) if stage: return stage elif "shards" in root: for i in root['shards']: stage = self.get_plan_stage(i['winningPlan'], stage) if stage: return stage return {} @client_context.require_version_min(3, 1, 9, -1) def test_index_filter(self): db = self.db db.drop_collection("test") # Test bad filter spec on create. self.assertRaises(OperationFailure, db.test.create_index, "x", partialFilterExpression=5) self.assertRaises(OperationFailure, db.test.create_index, "x", partialFilterExpression={"x": {"$asdasd": 3}}) self.assertRaises(OperationFailure, db.test.create_index, "x", partialFilterExpression={"$and": 5}) self.assertEqual("x_1", db.test.create_index( [('x', ASCENDING)], partialFilterExpression={"a": {"$lte": 1.5}})) db.test.insert_one({"x": 5, "a": 2}) db.test.insert_one({"x": 6, "a": 1}) # Operations that use the partial index. explain = db.test.find({"x": 6, "a": 1}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'IXSCAN') self.assertEqual("x_1", stage.get('indexName')) self.assertTrue(stage.get('isPartial')) explain = db.test.find({"x": {"$gt": 1}, "a": 1}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'IXSCAN') self.assertEqual("x_1", stage.get('indexName')) self.assertTrue(stage.get('isPartial')) explain = db.test.find({"x": 6, "a": {"$lte": 1}}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'IXSCAN') self.assertEqual("x_1", stage.get('indexName')) self.assertTrue(stage.get('isPartial')) # Operations that do not use the partial index. explain = db.test.find({"x": 6, "a": {"$lte": 1.6}}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'COLLSCAN') self.assertNotEqual({}, stage) explain = db.test.find({"x": 6}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'COLLSCAN') self.assertNotEqual({}, stage) # Test drop_indexes. db.test.drop_index("x_1") explain = db.test.find({"x": 6, "a": 1}).explain() stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], 'COLLSCAN') self.assertNotEqual({}, stage) def test_field_selection(self): db = self.db db.drop_collection("test") doc = {"a": 1, "b": 5, "c": {"d": 5, "e": 10}} db.test.insert_one(doc) # Test field inclusion doc = next(db.test.find({}, ["_id"])) self.assertEqual(list(doc), ["_id"]) doc = next(db.test.find({}, ["a"])) l = list(doc) l.sort() self.assertEqual(l, ["_id", "a"]) doc = next(db.test.find({}, ["b"])) l = list(doc) l.sort() self.assertEqual(l, ["_id", "b"]) doc = next(db.test.find({}, ["c"])) l = list(doc) l.sort() self.assertEqual(l, ["_id", "c"]) doc = next(db.test.find({}, ["a"])) self.assertEqual(doc["a"], 1) doc = next(db.test.find({}, ["b"])) self.assertEqual(doc["b"], 5) doc = next(db.test.find({}, ["c"])) self.assertEqual(doc["c"], {"d": 5, "e": 10}) # Test inclusion of fields with dots doc = next(db.test.find({}, ["c.d"])) self.assertEqual(doc["c"], {"d": 5}) doc = next(db.test.find({}, ["c.e"])) self.assertEqual(doc["c"], {"e": 10}) doc = next(db.test.find({}, ["b", "c.e"])) self.assertEqual(doc["c"], {"e": 10}) doc = next(db.test.find({}, ["b", "c.e"])) l = list(doc) l.sort() self.assertEqual(l, ["_id", "b", "c"]) doc = next(db.test.find({}, ["b", "c.e"])) self.assertEqual(doc["b"], 5) # Test field exclusion doc = next(db.test.find({}, {"a": False, "b": 0})) l = list(doc) l.sort() self.assertEqual(l, ["_id", "c"]) doc = next(db.test.find({}, {"_id": False})) l = list(doc) self.assertFalse("_id" in l) def test_options(self): db = self.db db.drop_collection("test") db.create_collection("test", capped=True, size=4096) result = db.test.options() # mongos 2.2.x adds an $auth field when auth is enabled. result.pop('$auth', None) self.assertEqual(result, {"capped": True, 'size': 4096}) db.drop_collection("test") def test_insert_one(self): db = self.db db.test.drop() document = {"_id": 1000} result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertTrue(isinstance(result.inserted_id, int)) self.assertEqual(document["_id"], result.inserted_id) self.assertTrue(result.acknowledged) self.assertIsNotNone(db.test.find_one({"_id": document["_id"]})) self.assertEqual(1, db.test.count_documents({})) document = {"foo": "bar"} result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertTrue(isinstance(result.inserted_id, ObjectId)) self.assertEqual(document["_id"], result.inserted_id) self.assertTrue(result.acknowledged) self.assertIsNotNone(db.test.find_one({"_id": document["_id"]})) self.assertEqual(2, db.test.count_documents({})) db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertTrue(isinstance(result.inserted_id, ObjectId)) self.assertEqual(document["_id"], result.inserted_id) self.assertFalse(result.acknowledged) # The insert failed duplicate key... wait_until(lambda: 2 == db.test.count_documents({}), 'forcing duplicate key error') document = RawBSONDocument( encode({'_id': ObjectId(), 'foo': 'bar'})) result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(result.inserted_id, None) def test_insert_many(self): db = self.db db.test.drop() docs = [{} for _ in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(isinstance(result.inserted_ids, list)) self.assertEqual(5, len(result.inserted_ids)) for doc in docs: _id = doc["_id"] self.assertTrue(isinstance(_id, ObjectId)) self.assertTrue(_id in result.inserted_ids) self.assertEqual(1, db.test.count_documents({'_id': _id})) self.assertTrue(result.acknowledged) docs = [{"_id": i} for i in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(isinstance(result.inserted_ids, list)) self.assertEqual(5, len(result.inserted_ids)) for doc in docs: _id = doc["_id"] self.assertTrue(isinstance(_id, int)) self.assertTrue(_id in result.inserted_ids) self.assertEqual(1, db.test.count_documents({"_id": _id})) self.assertTrue(result.acknowledged) docs = [RawBSONDocument(encode({"_id": i + 5})) for i in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(isinstance(result.inserted_ids, list)) self.assertEqual([], result.inserted_ids) db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) docs = [{} for _ in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertFalse(result.acknowledged) self.assertEqual(20, db.test.count_documents({})) def test_delete_one(self): self.db.test.drop() self.db.test.insert_one({"x": 1}) self.db.test.insert_one({"y": 1}) self.db.test.insert_one({"z": 1}) result = self.db.test.delete_one({"x": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertEqual(1, result.deleted_count) self.assertTrue(result.acknowledged) self.assertEqual(2, self.db.test.count_documents({})) result = self.db.test.delete_one({"y": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertEqual(1, result.deleted_count) self.assertTrue(result.acknowledged) self.assertEqual(1, self.db.test.count_documents({})) db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) result = db.test.delete_one({"z": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) wait_until(lambda: 0 == db.test.count_documents({}), 'delete 1 documents') def test_delete_many(self): self.db.test.drop() self.db.test.insert_one({"x": 1}) self.db.test.insert_one({"x": 1}) self.db.test.insert_one({"y": 1}) self.db.test.insert_one({"y": 1}) result = self.db.test.delete_many({"x": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertEqual(2, result.deleted_count) self.assertTrue(result.acknowledged) self.assertEqual(0, self.db.test.count_documents({"x": 1})) db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) result = db.test.delete_many({"y": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) wait_until( lambda: 0 == db.test.count_documents({}), 'delete 2 documents') def test_command_document_too_large(self): large = '*' * (self.client.max_bson_size + _COMMAND_OVERHEAD) coll = self.db.test self.assertRaises( DocumentTooLarge, coll.insert_one, {'data': large}) # update_one and update_many are the same self.assertRaises( DocumentTooLarge, coll.replace_one, {}, {'data': large}) self.assertRaises( DocumentTooLarge, coll.delete_one, {'data': large}) @client_context.require_version_min(3, 1, 9, -1) def test_insert_bypass_document_validation(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$exists": True}}) db_w0 = self.db.client.get_database( self.db.name, write_concern=WriteConcern(w=0)) # Test insert_one self.assertRaises(OperationFailure, db.test.insert_one, {"_id": 1, "x": 100}) result = db.test.insert_one({"_id": 1, "x": 100}, bypass_document_validation=True) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(1, result.inserted_id) result = db.test.insert_one({"_id":2, "a":0}) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(2, result.inserted_id) if client_context.version < (3, 6): # Uses OP_INSERT which does not support bypass_document_validation. self.assertRaises(OperationFailure, db_w0.test.insert_one, {"y": 1}, bypass_document_validation=True) else: db_w0.test.insert_one({"y": 1}, bypass_document_validation=True) wait_until(lambda: db_w0.test.find_one({"y": 1}), "find w:0 inserted document") # Test insert_many docs = [{"_id": i, "x": 100 - i} for i in range(3, 100)] self.assertRaises(OperationFailure, db.test.insert_many, docs) result = db.test.insert_many(docs, bypass_document_validation=True) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(97, len(result.inserted_ids)) for doc in docs: _id = doc["_id"] self.assertTrue(isinstance(_id, int)) self.assertTrue(_id in result.inserted_ids) self.assertEqual(1, db.test.count_documents({"x": doc["x"]})) self.assertTrue(result.acknowledged) docs = [{"_id": i, "a": 200 - i} for i in range(100, 200)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(97, len(result.inserted_ids)) for doc in docs: _id = doc["_id"] self.assertTrue(isinstance(_id, int)) self.assertTrue(_id in result.inserted_ids) self.assertEqual(1, db.test.count_documents({"a": doc["a"]})) self.assertTrue(result.acknowledged) self.assertRaises(OperationFailure, db_w0.test.insert_many, [{"x": 1}, {"x": 2}], bypass_document_validation=True) @client_context.require_version_min(3, 1, 9, -1) def test_replace_bypass_document_validation(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$exists": True}}) db_w0 = self.db.client.get_database( self.db.name, write_concern=WriteConcern(w=0)) # Test replace_one db.test.insert_one({"a": 101}) self.assertRaises(OperationFailure, db.test.replace_one, {"a": 101}, {"y": 1}) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(1, db.test.count_documents({"a": 101})) db.test.replace_one({"a": 101}, {"y": 1}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"a": 101})) self.assertEqual(1, db.test.count_documents({"y": 1})) db.test.replace_one({"y": 1}, {"a": 102}) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(0, db.test.count_documents({"a": 101})) self.assertEqual(1, db.test.count_documents({"a": 102})) db.test.insert_one({"y": 1}, bypass_document_validation=True) self.assertRaises(OperationFailure, db.test.replace_one, {"y": 1}, {"x": 101}) self.assertEqual(0, db.test.count_documents({"x": 101})) self.assertEqual(1, db.test.count_documents({"y": 1})) db.test.replace_one({"y": 1}, {"x": 101}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(1, db.test.count_documents({"x": 101})) db.test.replace_one({"x": 101}, {"a": 103}, bypass_document_validation=False) self.assertEqual(0, db.test.count_documents({"x": 101})) self.assertEqual(1, db.test.count_documents({"a": 103})) db.test.insert_one({"y": 1}, bypass_document_validation=True) if client_context.version < (3, 6): # Uses OP_UPDATE which does not support bypass_document_validation. self.assertRaises(OperationFailure, db_w0.test.replace_one, {"y": 1}, {"x": 1}, bypass_document_validation=True) else: db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") @client_context.require_version_min(3, 1, 9, -1) def test_update_bypass_document_validation(self): db = self.db db.test.drop() db.test.insert_one({"z": 5}) db.command(SON([("collMod", "test"), ("validator", {"z": {"$gte": 0}})])) db_w0 = self.db.client.get_database( self.db.name, write_concern=WriteConcern(w=0)) # Test update_one self.assertRaises(OperationFailure, db.test.update_one, {"z": 5}, {"$inc": {"z": -10}}) self.assertEqual(0, db.test.count_documents({"z": -5})) self.assertEqual(1, db.test.count_documents({"z": 5})) db.test.update_one({"z": 5}, {"$inc": {"z": -10}}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"z": 5})) self.assertEqual(1, db.test.count_documents({"z": -5})) db.test.update_one({"z": -5}, {"$inc": {"z": 6}}, bypass_document_validation=False) self.assertEqual(1, db.test.count_documents({"z": 1})) self.assertEqual(0, db.test.count_documents({"z": -5})) db.test.insert_one({"z": -10}, bypass_document_validation=True) self.assertRaises(OperationFailure, db.test.update_one, {"z": -10}, {"$inc": {"z": 1}}) self.assertEqual(0, db.test.count_documents({"z": -9})) self.assertEqual(1, db.test.count_documents({"z": -10})) db.test.update_one({"z": -10}, {"$inc": {"z": 1}}, bypass_document_validation=True) self.assertEqual(1, db.test.count_documents({"z": -9})) self.assertEqual(0, db.test.count_documents({"z": -10})) db.test.update_one({"z": -9}, {"$inc": {"z": 9}}, bypass_document_validation=False) self.assertEqual(0, db.test.count_documents({"z": -9})) self.assertEqual(1, db.test.count_documents({"z": 0})) db.test.insert_one({"y": 1, "x": 0}, bypass_document_validation=True) if client_context.version < (3, 6): # Uses OP_UPDATE which does not support bypass_document_validation. self.assertRaises(OperationFailure, db_w0.test.update_one, {"y": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) else: db_w0.test.update_one({"y": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) wait_until(lambda: db_w0.test.find_one({"y": 1, "x": 1}), "find w:0 updated document") # Test update_many db.test.insert_many([{"z": i} for i in range(3, 101)]) db.test.insert_one({"y": 0}, bypass_document_validation=True) self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": -100}}) self.assertEqual(100, db.test.count_documents({"z": {"$gte": 0}})) self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) self.assertEqual(0, db.test.count_documents({"y": 0, "z": -100})) db.test.update_many({"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) db.test.update_many({"z": {"$gt": -50}}, {"$inc": {"z": 100}}, bypass_document_validation=False) self.assertEqual(50, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(50, db.test.count_documents({"z": {"$lt": 0}})) db.test.insert_many([{"z": -i} for i in range(50)], bypass_document_validation=True) self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": 1}}) self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) self.assertEqual(50, db.test.count_documents({"z": {"$gt": 1}})) db.test.update_many({"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(150, db.test.count_documents({"z": {"$lte": 0}})) db.test.update_many({"z": {"$lte": 0}}, {"$inc": {"z": 100}}, bypass_document_validation=False) self.assertEqual(150, db.test.count_documents({"z": {"$gte": 0}})) self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) if client_context.version < (3, 6): # Uses OP_UPDATE which does not support bypass_document_validation. self.assertRaises(OperationFailure, db_w0.test.update_many, {"m": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) else: db_w0.test.update_many({"m": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) wait_until( lambda: db_w0.test.count_documents({"m": 1, "x": 1}) == 2, "find w:0 updated documents") @client_context.require_version_min(3, 1, 9, -1) def test_bypass_document_validation_bulk_write(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$gte": 0}}) db_w0 = self.db.client.get_database( self.db.name, write_concern=WriteConcern(w=0)) ops = [InsertOne({"a": -10}), InsertOne({"a": -11}), InsertOne({"a": -12}), UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), UpdateMany({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), ReplaceOne({"a": {"$lte": -10}}, {"a": -1})] db.test.bulk_write(ops, bypass_document_validation=True) self.assertEqual(3, db.test.count_documents({})) self.assertEqual(1, db.test.count_documents({"a": -11})) self.assertEqual(1, db.test.count_documents({"a": -1})) self.assertEqual(1, db.test.count_documents({"a": -9})) # Assert that the operations would fail without bypass_doc_val for op in ops: self.assertRaises(BulkWriteError, db.test.bulk_write, [op]) self.assertRaises(OperationFailure, db_w0.test.bulk_write, ops, bypass_document_validation=True) def test_find_by_default_dct(self): db = self.db db.test.insert_one({'foo': 'bar'}) dct = defaultdict(dict, [('foo', 'bar')]) self.assertIsNotNone(db.test.find_one(dct)) self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')])) def test_find_w_fields(self): db = self.db db.test.delete_many({}) db.test.insert_one({"x": 1, "mike": "awesome", "extra thing": "abcdefghijklmnopqrstuvwxyz"}) self.assertEqual(1, db.test.count_documents({})) doc = next(db.test.find({})) self.assertTrue("x" in doc) doc = next(db.test.find({})) self.assertTrue("mike" in doc) doc = next(db.test.find({})) self.assertTrue("extra thing" in doc) doc = next(db.test.find({}, ["x", "mike"])) self.assertTrue("x" in doc) doc = next(db.test.find({}, ["x", "mike"])) self.assertTrue("mike" in doc) doc = next(db.test.find({}, ["x", "mike"])) self.assertFalse("extra thing" in doc) doc = next(db.test.find({}, ["mike"])) self.assertFalse("x" in doc) doc = next(db.test.find({}, ["mike"])) self.assertTrue("mike" in doc) doc = next(db.test.find({}, ["mike"])) self.assertFalse("extra thing" in doc) def test_fields_specifier_as_dict(self): db = self.db db.test.delete_many({}) db.test.insert_one({"x": [1, 2, 3], "mike": "awesome"}) self.assertEqual([1, 2, 3], db.test.find_one()["x"]) self.assertEqual([2, 3], db.test.find_one( projection={"x": {"$slice": -2}})["x"]) self.assertTrue("x" not in db.test.find_one(projection={"x": 0})) self.assertTrue("mike" in db.test.find_one(projection={"x": 0})) def test_find_w_regex(self): db = self.db db.test.delete_many({}) db.test.insert_one({"x": "hello_world"}) db.test.insert_one({"x": "hello_mike"}) db.test.insert_one({"x": "hello_mikey"}) db.test.insert_one({"x": "hello_test"}) self.assertEqual(len(list(db.test.find())), 4) self.assertEqual(len(list(db.test.find({"x": re.compile("^hello.*")}))), 4) self.assertEqual(len(list(db.test.find({"x": re.compile("ello")}))), 4) self.assertEqual(len(list(db.test.find({"x": re.compile("^hello$")}))), 0) self.assertEqual(len(list(db.test.find({"x": re.compile("^hello_mi.*$")}))), 2) def test_id_can_be_anything(self): db = self.db db.test.delete_many({}) auto_id = {"hello": "world"} db.test.insert_one(auto_id) self.assertTrue(isinstance(auto_id["_id"], ObjectId)) numeric = {"_id": 240, "hello": "world"} db.test.insert_one(numeric) self.assertEqual(numeric["_id"], 240) obj = {"_id": numeric, "hello": "world"} db.test.insert_one(obj) self.assertEqual(obj["_id"], numeric) for x in db.test.find(): self.assertEqual(x["hello"], u"world") self.assertTrue("_id" in x) def test_invalid_key_names(self): db = self.db db.test.drop() db.test.insert_one({"hello": "world"}) db.test.insert_one({"hello": {"hello": "world"}}) self.assertRaises(InvalidDocument, db.test.insert_one, {"$hello": "world"}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hello": {"$hello": "world"}}) db.test.insert_one({"he$llo": "world"}) db.test.insert_one({"hello": {"hello$": "world"}}) self.assertRaises(InvalidDocument, db.test.insert_one, {".hello": "world"}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hello": {".hello": "world"}}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hello.": "world"}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hello": {"hello.": "world"}}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hel.lo": "world"}) self.assertRaises(InvalidDocument, db.test.insert_one, {"hello": {"hel.lo": "world"}}) def test_unique_index(self): db = self.db db.drop_collection("test") db.test.create_index("hello") # No error. db.test.insert_one({"hello": "world"}) db.test.insert_one({"hello": "world"}) db.drop_collection("test") db.test.create_index("hello", unique=True) with self.assertRaises(DuplicateKeyError): db.test.insert_one({"hello": "world"}) db.test.insert_one({"hello": "world"}) def test_duplicate_key_error(self): db = self.db db.drop_collection("test") db.test.create_index("x", unique=True) db.test.insert_one({"_id": 1, "x": 1}) with self.assertRaises(DuplicateKeyError) as context: db.test.insert_one({"x": 1}) self.assertIsNotNone(context.exception.details) with self.assertRaises(DuplicateKeyError) as context: db.test.insert_one({"x": 1}) self.assertIsNotNone(context.exception.details) self.assertEqual(1, db.test.count_documents({})) def test_write_error_text_handling(self): db = self.db db.drop_collection("test") db.test.create_index("text", unique=True) # Test workaround for SERVER-24007 data = (b'a\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83') text = utf_8_decode(data, None, True) db.test.insert_one({"text": text}) # Should raise DuplicateKeyError, not InvalidBSON self.assertRaises(DuplicateKeyError, db.test.insert_one, {"text": text}) self.assertRaises(DuplicateKeyError, db.test.insert, {"text": text}) self.assertRaises(DuplicateKeyError, db.test.insert, [{"text": text}]) self.assertRaises(DuplicateKeyError, db.test.replace_one, {"_id": ObjectId()}, {"text": text}, upsert=True) self.assertRaises(DuplicateKeyError, db.test.update, {"_id": ObjectId()}, {"text": text}, upsert=True) # Should raise BulkWriteError, not InvalidBSON self.assertRaises(BulkWriteError, db.test.insert_many, [{"text": text}]) def test_write_error_unicode(self): coll = self.db.test self.addCleanup(coll.drop) coll.create_index('a', unique=True) coll.insert_one({'a': u'unicode \U0001f40d'}) with self.assertRaisesRegex( DuplicateKeyError, 'E11000 duplicate key error') as ctx: coll.insert_one({'a': u'unicode \U0001f40d'}) # Once more for good measure. self.assertIn('E11000 duplicate key error', str(ctx.exception)) if sys.version_info[0] == 2: # Test unicode(error) conversion. self.assertIn('E11000 duplicate key error', unicode(ctx.exception)) def test_wtimeout(self): # Ensure setting wtimeout doesn't disable write concern altogether. # See SERVER-12596. collection = self.db.test collection.drop() collection.insert_one({'_id': 1}) coll = collection.with_options( write_concern=WriteConcern(w=1, wtimeout=1000)) self.assertRaises(DuplicateKeyError, coll.insert_one, {'_id': 1}) coll = collection.with_options( write_concern=WriteConcern(wtimeout=1000)) self.assertRaises(DuplicateKeyError, coll.insert_one, {'_id': 1}) def test_error_code(self): try: self.db.test.update_many({}, {"$thismodifierdoesntexist": 1}) except OperationFailure as exc: self.assertTrue(exc.code in (9, 10147, 16840, 17009)) # Just check that we set the error document. Fields # vary by MongoDB version. self.assertTrue(exc.details is not None) else: self.fail("OperationFailure was not raised") def test_index_on_subfield(self): db = self.db db.drop_collection("test") db.test.insert_one({"hello": {"a": 4, "b": 5}}) db.test.insert_one({"hello": {"a": 7, "b": 2}}) db.test.insert_one({"hello": {"a": 4, "b": 10}}) db.drop_collection("test") db.test.create_index("hello.a", unique=True) db.test.insert_one({"hello": {"a": 4, "b": 5}}) db.test.insert_one({"hello": {"a": 7, "b": 2}}) self.assertRaises(DuplicateKeyError, db.test.insert_one, {"hello": {"a": 4, "b": 10}}) def test_replace_one(self): db = self.db db.drop_collection("test") self.assertRaises(ValueError, lambda: db.test.replace_one({}, {"$set": {"x": 1}})) id1 = db.test.insert_one({"x": 1}).inserted_id result = db.test.replace_one({"x": 1}, {"y": 1}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 1})) self.assertEqual(0, db.test.count_documents({"x": 1})) self.assertEqual(db.test.find_one(id1)["y"], 1) replacement = RawBSONDocument(encode({"_id": id1, "z": 1})) result = db.test.replace_one({"y": 1}, replacement, True) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"z": 1})) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(db.test.find_one(id1)["z"], 1) result = db.test.replace_one({"x": 2}, {"y": 2}, True) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(0, result.matched_count) self.assertTrue(result.modified_count in (None, 0)) self.assertTrue(isinstance(result.upserted_id, ObjectId)) self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 2})) db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.replace_one({"x": 0}, {"y": 0}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) self.assertRaises(InvalidOperation, lambda: result.modified_count) self.assertRaises(InvalidOperation, lambda: result.upserted_id) self.assertFalse(result.acknowledged) def test_update_one(self): db = self.db db.drop_collection("test") self.assertRaises(ValueError, lambda: db.test.update_one({}, {"x": 1})) id1 = db.test.insert_one({"x": 5}).inserted_id result = db.test.update_one({}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(db.test.find_one(id1)["x"], 6) id2 = db.test.insert_one({"x": 1}).inserted_id result = db.test.update_one({"x": 6}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(db.test.find_one(id1)["x"], 7) self.assertEqual(db.test.find_one(id2)["x"], 1) result = db.test.update_one({"x": 2}, {"$set": {"y": 1}}, True) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(0, result.matched_count) self.assertTrue(result.modified_count in (None, 0)) self.assertTrue(isinstance(result.upserted_id, ObjectId)) self.assertTrue(result.acknowledged) db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) self.assertRaises(InvalidOperation, lambda: result.modified_count) self.assertRaises(InvalidOperation, lambda: result.upserted_id) self.assertFalse(result.acknowledged) def test_update_many(self): db = self.db db.drop_collection("test") self.assertRaises(ValueError, lambda: db.test.update_many({}, {"x": 1})) db.test.insert_one({"x": 4, "y": 3}) db.test.insert_one({"x": 5, "y": 5}) db.test.insert_one({"x": 4, "y": 4}) result = db.test.update_many({"x": 4}, {"$set": {"y": 5}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(2, result.matched_count) self.assertTrue(result.modified_count in (None, 2)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(3, db.test.count_documents({"y": 5})) result = db.test.update_many({"x": 5}, {"$set": {"y": 6}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (None, 1)) self.assertIsNone(result.upserted_id) self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 6})) result = db.test.update_many({"x": 2}, {"$set": {"y": 1}}, True) self.assertTrue(isinstance(result, UpdateResult)) self.assertEqual(0, result.matched_count) self.assertTrue(result.modified_count in (None, 0)) self.assertTrue(isinstance(result.upserted_id, ObjectId)) self.assertTrue(result.acknowledged) db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.update_many({"x": 0}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) self.assertRaises(InvalidOperation, lambda: result.modified_count) self.assertRaises(InvalidOperation, lambda: result.upserted_id) self.assertFalse(result.acknowledged) # MongoDB >= 3.5.8 allows dotted fields in updates @client_context.require_version_max(3, 5, 7) def test_update_with_invalid_keys(self): self.db.drop_collection("test") self.assertTrue(self.db.test.insert_one({"hello": "world"})) doc = self.db.test.find_one() doc['a.b'] = 'c' # Replace self.assertRaises(OperationFailure, self.db.test.replace_one, {"hello": "world"}, doc) # Upsert self.assertRaises(OperationFailure, self.db.test.replace_one, {"foo": "bar"}, doc, upsert=True) # Check that the last two ops didn't actually modify anything self.assertTrue('a.b' not in self.db.test.find_one()) def test_update_check_keys(self): self.db.drop_collection("test") self.assertTrue(self.db.test.insert_one({"hello": "world"})) # Modify shouldn't check keys... self.assertTrue(self.db.test.update_one({"hello": "world"}, {"$set": {"foo.bar": "baz"}}, upsert=True)) # I know this seems like testing the server but I'd like to be notified # by CI if the server's behavior changes here. doc = SON([("$set", {"foo.bar": "bim"}), ("hello", "world")]) self.assertRaises(OperationFailure, self.db.test.update_one, {"hello": "world"}, doc, upsert=True) # This is going to cause keys to be checked and raise InvalidDocument. # That's OK assuming the server's behavior in the previous assert # doesn't change. If the behavior changes checking the first key for # '$' in update won't be good enough anymore. doc = SON([("hello", "world"), ("$set", {"foo.bar": "bim"})]) self.assertRaises(OperationFailure, self.db.test.replace_one, {"hello": "world"}, doc, upsert=True) # Replace with empty document self.assertNotEqual(0, self.db.test.replace_one( {"hello": "world"}, {}).matched_count) def test_acknowledged_delete(self): db = self.db db.drop_collection("test") db.create_collection("test", capped=True, size=1000) db.test.insert_one({"x": 1}) self.assertEqual(1, db.test.count_documents({})) # Can't remove from capped collection. self.assertRaises(OperationFailure, db.test.delete_one, {"x": 1}) db.drop_collection("test") db.test.insert_one({"x": 1}) db.test.insert_one({"x": 1}) self.assertEqual(2, db.test.delete_many({}).deleted_count) self.assertEqual(0, db.test.delete_many({}).deleted_count) def test_manual_last_error(self): coll = self.db.get_collection("test", write_concern=WriteConcern(w=0)) coll.insert_one({"x": 1}) self.db.command("getlasterror", w=1, wtimeout=1) @ignore_deprecations def test_count(self): db = self.db db.drop_collection("test") self.assertEqual(db.test.count(), 0) db.test.insert_many([{}, {}]) self.assertEqual(db.test.count(), 2) db.test.insert_many([{'foo': 'bar'}, {'foo': 'baz'}]) self.assertEqual(db.test.find({'foo': 'bar'}).count(), 1) self.assertEqual(db.test.count({'foo': 'bar'}), 1) self.assertEqual(db.test.find({'foo': re.compile(r'ba.*')}).count(), 2) self.assertEqual( db.test.count({'foo': re.compile(r'ba.*')}), 2) def test_count_documents(self): db = self.db db.drop_collection("test") self.addCleanup(db.drop_collection, "test") self.assertEqual(db.test.count_documents({}), 0) db.wrong.insert_many([{}, {}]) self.assertEqual(db.test.count_documents({}), 0) db.test.insert_many([{}, {}]) self.assertEqual(db.test.count_documents({}), 2) db.test.insert_many([{'foo': 'bar'}, {'foo': 'baz'}]) self.assertEqual(db.test.count_documents({'foo': 'bar'}), 1) self.assertEqual( db.test.count_documents({'foo': re.compile(r'ba.*')}), 2) def test_estimated_document_count(self): db = self.db db.drop_collection("test") self.addCleanup(db.drop_collection, "test") self.assertEqual(db.test.estimated_document_count(), 0) db.wrong.insert_many([{}, {}]) self.assertEqual(db.test.estimated_document_count(), 0) db.test.insert_many([{}, {}]) self.assertEqual(db.test.estimated_document_count(), 2) def test_aggregate(self): db = self.db db.drop_collection("test") db.test.insert_one({'foo': [1, 2]}) self.assertRaises(TypeError, db.test.aggregate, "wow") pipeline = {"$project": {"_id": False, "foo": True}} # MongoDB 3.5.1+ requires either the 'cursor' or 'explain' options. if client_context.version.at_least(3, 5, 1): result = db.test.aggregate([pipeline]) else: result = db.test.aggregate([pipeline], useCursor=False) self.assertTrue(isinstance(result, CommandCursor)) self.assertEqual([{'foo': [1, 2]}], list(result)) # Test write concern. with self.write_concern_collection() as coll: coll.aggregate([{'$out': 'output-collection'}]) def test_aggregate_raw_bson(self): db = self.db db.drop_collection("test") db.test.insert_one({'foo': [1, 2]}) self.assertRaises(TypeError, db.test.aggregate, "wow") pipeline = {"$project": {"_id": False, "foo": True}} coll = db.get_collection( 'test', codec_options=CodecOptions(document_class=RawBSONDocument)) # MongoDB 3.5.1+ requires either the 'cursor' or 'explain' options. if client_context.version.at_least(3, 5, 1): result = coll.aggregate([pipeline]) else: result = coll.aggregate([pipeline], useCursor=False) self.assertTrue(isinstance(result, CommandCursor)) first_result = next(result) self.assertIsInstance(first_result, RawBSONDocument) self.assertEqual([1, 2], list(first_result['foo'])) def test_aggregation_cursor_validation(self): db = self.db projection = {'$project': {'_id': '$_id'}} cursor = db.test.aggregate([projection], cursor={}) self.assertTrue(isinstance(cursor, CommandCursor)) cursor = db.test.aggregate([projection], useCursor=True) self.assertTrue(isinstance(cursor, CommandCursor)) def test_aggregation_cursor(self): db = self.db if client_context.has_secondaries: # Test that getMore messages are sent to the right server. db = self.client.get_database( db.name, read_preference=ReadPreference.SECONDARY, write_concern=WriteConcern(w=self.w)) for collection_size in (10, 1000): db.drop_collection("test") db.test.insert_many([{'_id': i} for i in range(collection_size)]) expected_sum = sum(range(collection_size)) # Use batchSize to ensure multiple getMore messages cursor = db.test.aggregate( [{'$project': {'_id': '$_id'}}], batchSize=5) self.assertEqual( expected_sum, sum(doc['_id'] for doc in cursor)) # Test that batchSize is handled properly. cursor = db.test.aggregate([], batchSize=5) self.assertEqual(5, len(cursor._CommandCursor__data)) # Force a getMore cursor._CommandCursor__data.clear() next(cursor) # batchSize - 1 self.assertEqual(4, len(cursor._CommandCursor__data)) # Exhaust the cursor. There shouldn't be any errors. for doc in cursor: pass def test_aggregation_cursor_alive(self): self.db.test.delete_many({}) self.db.test.insert_many([{} for _ in range(3)]) self.addCleanup(self.db.test.delete_many, {}) cursor = self.db.test.aggregate(pipeline=[], cursor={'batchSize': 2}) n = 0 while True: cursor.next() n += 1 if 3 == n: self.assertFalse(cursor.alive) break self.assertTrue(cursor.alive) @client_context.require_no_mongos @client_context.require_version_max(4, 1, 0) @ignore_deprecations def test_parallel_scan(self): db = self.db db.drop_collection("test") if client_context.has_secondaries: # Test that getMore messages are sent to the right server. db = self.client.get_database( db.name, read_preference=ReadPreference.SECONDARY, write_concern=WriteConcern(w=self.w)) coll = db.test coll.insert_many([{'_id': i} for i in range(8000)]) docs = [] threads = [threading.Thread(target=docs.extend, args=(cursor,)) for cursor in coll.parallel_scan(3)] for t in threads: t.start() for t in threads: t.join() self.assertEqual( set(range(8000)), set(doc['_id'] for doc in docs)) @client_context.require_no_mongos @client_context.require_version_min(3, 3, 10) @client_context.require_version_max(4, 1, 0) @client_context.require_test_commands @ignore_deprecations def test_parallel_scan_max_time_ms(self): self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: self.assertRaises(ExecutionTimeout, self.db.test.parallel_scan, 3, maxTimeMS=1) finally: self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_large_limit(self): db = self.db db.drop_collection("test_large_limit") db.test_large_limit.create_index([('x', 1)]) my_str = "mongomongo" * 1000 db.test_large_limit.insert_many( {"x": i, "y": my_str} for i in range(2000)) i = 0 y = 0 for doc in db.test_large_limit.find(limit=1900).sort([('x', 1)]): i += 1 y += doc["x"] self.assertEqual(1900, i) self.assertEqual((1900 * 1899) / 2, y) def test_find_kwargs(self): db = self.db db.drop_collection("test") db.test.insert_many({"x": i} for i in range(10)) self.assertEqual(10, db.test.count_documents({})) total = 0 for x in db.test.find({}, skip=4, limit=2): total += x["x"] self.assertEqual(9, total) def test_rename(self): db = self.db db.drop_collection("test") db.drop_collection("foo") self.assertRaises(TypeError, db.test.rename, 5) self.assertRaises(InvalidName, db.test.rename, "") self.assertRaises(InvalidName, db.test.rename, "te$t") self.assertRaises(InvalidName, db.test.rename, ".test") self.assertRaises(InvalidName, db.test.rename, "test.") self.assertRaises(InvalidName, db.test.rename, "tes..t") self.assertEqual(0, db.test.count_documents({})) self.assertEqual(0, db.foo.count_documents({})) db.test.insert_many({"x": i} for i in range(10)) self.assertEqual(10, db.test.count_documents({})) db.test.rename("foo") self.assertEqual(0, db.test.count_documents({})) self.assertEqual(10, db.foo.count_documents({})) x = 0 for doc in db.foo.find(): self.assertEqual(x, doc["x"]) x += 1 db.test.insert_one({}) self.assertRaises(OperationFailure, db.foo.rename, "test") db.foo.rename("test", dropTarget=True) with self.write_concern_collection() as coll: coll.rename('foo') def test_find_one(self): db = self.db db.drop_collection("test") _id = db.test.insert_one({"hello": "world", "foo": "bar"}).inserted_id self.assertEqual("world", db.test.find_one()["hello"]) self.assertEqual(db.test.find_one(_id), db.test.find_one()) self.assertEqual(db.test.find_one(None), db.test.find_one()) self.assertEqual(db.test.find_one({}), db.test.find_one()) self.assertEqual(db.test.find_one({"hello": "world"}), db.test.find_one()) self.assertTrue("hello" in db.test.find_one(projection=["hello"])) self.assertTrue("hello" not in db.test.find_one(projection=["foo"])) self.assertTrue("hello" in db.test.find_one(projection=("hello",))) self.assertTrue("hello" not in db.test.find_one(projection=("foo",))) self.assertTrue("hello" in db.test.find_one(projection=set(["hello"]))) self.assertTrue("hello" not in db.test.find_one(projection=set(["foo"]))) self.assertTrue("hello" in db.test.find_one(projection=frozenset(["hello"]))) self.assertTrue("hello" not in db.test.find_one(projection=frozenset(["foo"]))) self.assertEqual(["_id"], list(db.test.find_one(projection=[]))) self.assertEqual(None, db.test.find_one({"hello": "foo"})) self.assertEqual(None, db.test.find_one(ObjectId())) def test_find_one_non_objectid(self): db = self.db db.drop_collection("test") db.test.insert_one({"_id": 5}) self.assertTrue(db.test.find_one(5)) self.assertFalse(db.test.find_one(6)) def test_find_one_with_find_args(self): db = self.db db.drop_collection("test") db.test.insert_many([{"x": i} for i in range(1, 4)]) self.assertEqual(1, db.test.find_one()["x"]) self.assertEqual(2, db.test.find_one(skip=1, limit=2)["x"]) def test_find_with_sort(self): db = self.db db.drop_collection("test") db.test.insert_many([{"x": 2}, {"x": 1}, {"x": 3}]) self.assertEqual(2, db.test.find_one()["x"]) self.assertEqual(1, db.test.find_one(sort=[("x", 1)])["x"]) self.assertEqual(3, db.test.find_one(sort=[("x", -1)])["x"]) def to_list(things): return [thing["x"] for thing in things] self.assertEqual([2, 1, 3], to_list(db.test.find())) self.assertEqual([1, 2, 3], to_list(db.test.find(sort=[("x", 1)]))) self.assertEqual([3, 2, 1], to_list(db.test.find(sort=[("x", -1)]))) self.assertRaises(TypeError, db.test.find, sort=5) self.assertRaises(TypeError, db.test.find, sort="hello") self.assertRaises(ValueError, db.test.find, sort=["hello", 1]) # TODO doesn't actually test functionality, just that it doesn't blow up def test_cursor_timeout(self): list(self.db.test.find(no_cursor_timeout=True)) list(self.db.test.find(no_cursor_timeout=False)) def test_exhaust(self): if is_mongos(self.db.client): self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST) return # Limit is incompatible with exhaust. self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST, limit=5) cur = self.db.test.find(cursor_type=CursorType.EXHAUST) self.assertRaises(InvalidOperation, cur.limit, 5) cur = self.db.test.find(limit=5) self.assertRaises(InvalidOperation, cur.add_option, 64) cur = self.db.test.find() cur.add_option(64) self.assertRaises(InvalidOperation, cur.limit, 5) self.db.drop_collection("test") # Insert enough documents to require more than one batch self.db.test.insert_many([{'i': i} for i in range(150)]) client = rs_or_single_client(maxPoolSize=1) socks = get_pool(client).sockets # Make sure the socket is returned after exhaustion. cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST) next(cur) self.assertEqual(0, len(socks)) for _ in cur: pass self.assertEqual(1, len(socks)) # Same as previous but don't call next() for _ in client[self.db.name].test.find(cursor_type=CursorType.EXHAUST): pass self.assertEqual(1, len(socks)) # If the Cursor instance is discarded before being # completely iterated we have to close and # discard the socket. cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST) next(cur) self.assertEqual(0, len(socks)) if sys.platform.startswith('java') or 'PyPy' in sys.version: # Don't wait for GC or use gc.collect(), it's unreliable. cur.close() cur = None # The socket should be discarded. self.assertEqual(0, len(socks)) def test_distinct(self): self.db.drop_collection("test") test = self.db.test test.insert_many([{"a": 1}, {"a": 2}, {"a": 2}, {"a": 2}, {"a": 3}]) distinct = test.distinct("a") distinct.sort() self.assertEqual([1, 2, 3], distinct) distinct = test.find({'a': {'$gt': 1}}).distinct("a") distinct.sort() self.assertEqual([2, 3], distinct) distinct = test.distinct('a', {'a': {'$gt': 1}}) distinct.sort() self.assertEqual([2, 3], distinct) self.db.drop_collection("test") test.insert_one({"a": {"b": "a"}, "c": 12}) test.insert_one({"a": {"b": "b"}, "c": 12}) test.insert_one({"a": {"b": "c"}, "c": 12}) test.insert_one({"a": {"b": "c"}, "c": 12}) distinct = test.distinct("a.b") distinct.sort() self.assertEqual(["a", "b", "c"], distinct) def test_query_on_query_field(self): self.db.drop_collection("test") self.db.test.insert_one({"query": "foo"}) self.db.test.insert_one({"bar": "foo"}) self.assertEqual(1, self.db.test.count_documents({"query": {"$ne": None}})) self.assertEqual(1, len(list(self.db.test.find({"query": {"$ne": None}}))) ) def test_min_query(self): self.db.drop_collection("test") self.db.test.insert_many([{"x": 1}, {"x": 2}]) self.db.test.create_index("x") cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}}) if client_context.requires_hint_with_min_max_queries: cursor = cursor.hint("x_1") docs = list(cursor) self.assertEqual(1, len(docs)) self.assertEqual(2, docs[0]["x"]) def test_numerous_inserts(self): # Ensure we don't exceed server's 1000-document batch size limit. self.db.test.drop() n_docs = 2100 self.db.test.insert_many([{} for _ in range(n_docs)]) self.assertEqual(n_docs, self.db.test.count_documents({})) self.db.test.drop() def test_map_reduce(self): db = self.db db.drop_collection("test") db.test.insert_one({"id": 1, "tags": ["dog", "cat"]}) db.test.insert_one({"id": 2, "tags": ["cat"]}) db.test.insert_one({"id": 3, "tags": ["mouse", "cat", "dog"]}) db.test.insert_one({"id": 4, "tags": []}) map = Code("function () {" " this.tags.forEach(function(z) {" " emit(z, 1);" " });" "}") reduce = Code("function (key, values) {" " var total = 0;" " for (var i = 0; i < values.length; i++) {" " total += values[i];" " }" " return total;" "}") result = db.test.map_reduce(map, reduce, out='mrunittests') self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) db.test.insert_one({"id": 5, "tags": ["hampster"]}) result = db.test.map_reduce(map, reduce, out='mrunittests') self.assertEqual(1, result.find_one({"_id": "hampster"})["value"]) db.test.delete_one({"id": 5}) result = db.test.map_reduce(map, reduce, out={'merge': 'mrunittests'}) self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(1, result.find_one({"_id": "hampster"})["value"]) result = db.test.map_reduce(map, reduce, out={'reduce': 'mrunittests'}) self.assertEqual(6, result.find_one({"_id": "cat"})["value"]) self.assertEqual(4, result.find_one({"_id": "dog"})["value"]) self.assertEqual(2, result.find_one({"_id": "mouse"})["value"]) self.assertEqual(1, result.find_one({"_id": "hampster"})["value"]) result = db.test.map_reduce( map, reduce, out={'replace': 'mrunittests'} ) self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) # Create the output database. db.client.mrtestdb.mrunittests.insert_one({}) result = db.test.map_reduce(map, reduce, out=SON([('replace', 'mrunittests'), ('db', 'mrtestdb') ])) self.assertEqual(3, result.find_one({"_id": "cat"})["value"]) self.assertEqual(2, result.find_one({"_id": "dog"})["value"]) self.assertEqual(1, result.find_one({"_id": "mouse"})["value"]) self.client.drop_database('mrtestdb') full_result = db.test.map_reduce(map, reduce, out='mrunittests', full_response=True) self.assertEqual('mrunittests', full_result["result"]) if client_context.version < (4, 3): self.assertEqual(6, full_result["counts"]["emit"]) result = db.test.map_reduce(map, reduce, out='mrunittests', limit=2) self.assertEqual(2, result.find_one({"_id": "cat"})["value"]) self.assertEqual(1, result.find_one({"_id": "dog"})["value"]) self.assertEqual(None, result.find_one({"_id": "mouse"})) result = db.test.map_reduce(map, reduce, out={'inline': 1}) self.assertTrue(isinstance(result, dict)) self.assertTrue('results' in result) self.assertTrue(result['results'][1]["_id"] in ("cat", "dog", "mouse")) result = db.test.inline_map_reduce(map, reduce) self.assertTrue(isinstance(result, list)) self.assertEqual(3, len(result)) self.assertTrue(result[1]["_id"] in ("cat", "dog", "mouse")) full_result = db.test.inline_map_reduce(map, reduce, full_response=True) self.assertEqual(3, len(full_result["results"])) if client_context.version < (4, 3): self.assertEqual(6, full_result["counts"]["emit"]) with self.write_concern_collection() as coll: coll.map_reduce(map, reduce, 'output') def test_messages_with_unicode_collection_names(self): db = self.db db[u"Employés"].insert_one({"x": 1}) db[u"Employés"].replace_one({"x": 1}, {"x": 2}) db[u"Employés"].delete_many({}) db[u"Employés"].find_one() list(db[u"Employés"].find()) def test_drop_indexes_non_existent(self): self.db.drop_collection("test") self.db.test.drop_indexes() # This is really a bson test but easier to just reproduce it here... # (Shame on me) def test_bad_encode(self): c = self.db.test c.drop() self.assertRaises(InvalidDocument, c.insert_one, {"x": c}) class BadGetAttr(dict): def __getattr__(self, name): pass bad = BadGetAttr([('foo', 'bar')]) c.insert_one({'bad': bad}) self.assertEqual('bar', c.find_one()['bad']['foo']) @client_context.require_version_max(3, 5, 5) def test_array_filters_unsupported(self): c = self.db.test with self.assertRaises(ConfigurationError): c.update_one( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) with self.assertRaises(ConfigurationError): c.update_many( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) with self.assertRaises(ConfigurationError): c.find_one_and_update( {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) def test_array_filters_validation(self): # array_filters must be a list. c = self.db.test with self.assertRaises(TypeError): c.update_one({}, {'$set': {'a': 1}}, array_filters={}) with self.assertRaises(TypeError): c.update_many({}, {'$set': {'a': 1}}, array_filters={}) with self.assertRaises(TypeError): c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) def test_array_filters_unacknowledged(self): c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0)) with self.assertRaises(ConfigurationError): c_w0.update_one({}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) with self.assertRaises(ConfigurationError): c_w0.update_many({}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) with self.assertRaises(ConfigurationError): c_w0.find_one_and_update({}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) def test_find_one_and(self): c = self.db.test c.drop() c.insert_one({'_id': 1, 'i': 1}) self.assertEqual({'_id': 1, 'i': 1}, c.find_one_and_update({'_id': 1}, {'$inc': {'i': 1}})) self.assertEqual({'_id': 1, 'i': 3}, c.find_one_and_update( {'_id': 1}, {'$inc': {'i': 1}}, return_document=ReturnDocument.AFTER)) self.assertEqual({'_id': 1, 'i': 3}, c.find_one_and_delete({'_id': 1})) self.assertEqual(None, c.find_one({'_id': 1})) self.assertEqual(None, c.find_one_and_update({'_id': 1}, {'$inc': {'i': 1}})) self.assertEqual({'_id': 1, 'i': 1}, c.find_one_and_update( {'_id': 1}, {'$inc': {'i': 1}}, return_document=ReturnDocument.AFTER, upsert=True)) self.assertEqual({'_id': 1, 'i': 2}, c.find_one_and_update( {'_id': 1}, {'$inc': {'i': 1}}, return_document=ReturnDocument.AFTER)) self.assertEqual({'_id': 1, 'i': 3}, c.find_one_and_replace( {'_id': 1}, {'i': 3, 'j': 1}, projection=['i'], return_document=ReturnDocument.AFTER)) self.assertEqual({'i': 4}, c.find_one_and_update( {'_id': 1}, {'$inc': {'i': 1}}, projection={'i': 1, '_id': 0}, return_document=ReturnDocument.AFTER)) c.drop() for j in range(5): c.insert_one({'j': j, 'i': 0}) sort = [('j', DESCENDING)] self.assertEqual(4, c.find_one_and_update({}, {'$inc': {'i': 1}}, sort=sort)['j']) def test_find_one_and_write_concern(self): listener = EventListener() db = single_client(event_listeners=[listener])[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection( 'test', write_concern=WriteConcern(w=0)) # default WriteConcern. c_default = db.get_collection('test', write_concern=WriteConcern()) results = listener.results # Authenticate the client and throw out auth commands from the listener. db.command('ismaster') results.clear() if client_context.version.at_least(3, 1, 9, -1): c_w0.find_and_modify( {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertEqual( {'w': 0}, results['started'][0].command['writeConcern']) results.clear() c_w0.find_one_and_update( {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertEqual( {'w': 0}, results['started'][0].command['writeConcern']) results.clear() c_w0.find_one_and_replace({'_id': 1}, {'foo': 'bar'}) self.assertEqual( {'w': 0}, results['started'][0].command['writeConcern']) results.clear() c_w0.find_one_and_delete({'_id': 1}) self.assertEqual( {'w': 0}, results['started'][0].command['writeConcern']) results.clear() # Test write concern errors. if client_context.is_rs: c_wc_error = db.get_collection( 'test', write_concern=WriteConcern( w=len(client_context.nodes) + 1)) self.assertRaises( WriteConcernError, c_wc_error.find_and_modify, {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_update, {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_replace, {'w': 0}, results['started'][0].command['writeConcern']) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_delete, {'w': 0}, results['started'][0].command['writeConcern']) results.clear() else: c_w0.find_and_modify( {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_w0.find_one_and_update( {'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_w0.find_one_and_replace({'_id': 1}, {'foo': 'bar'}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_w0.find_one_and_delete({'_id': 1}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_default.find_and_modify({'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_default.find_one_and_update({'_id': 1}, {'$set': {'foo': 'bar'}}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_default.find_one_and_replace({'_id': 1}, {'foo': 'bar'}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() c_default.find_one_and_delete({'_id': 1}) self.assertNotIn('writeConcern', results['started'][0].command) results.clear() def test_find_with_nested(self): c = self.db.test c.drop() c.insert_many([{'i': i} for i in range(5)]) # [0, 1, 2, 3, 4] self.assertEqual( [2], [i['i'] for i in c.find({ '$and': [ { # This clause gives us [1,2,4] '$or': [ {'i': {'$lte': 2}}, {'i': {'$gt': 3}}, ], }, { # This clause gives us [2,3] '$or': [ {'i': 2}, {'i': 3}, ] }, ] })] ) self.assertEqual( [0, 1, 2], [i['i'] for i in c.find({ '$or': [ { # This clause gives us [2] '$and': [ {'i': {'$gte': 2}}, {'i': {'$lt': 3}}, ], }, { # This clause gives us [0,1] '$and': [ {'i': {'$gt': -100}}, {'i': {'$lt': 2}}, ] }, ] })] ) def test_find_regex(self): c = self.db.test c.drop() c.insert_one({'r': re.compile('.*')}) self.assertTrue(isinstance(c.find_one()['r'], Regex)) for doc in c.find(): self.assertTrue(isinstance(doc['r'], Regex)) def test_find_command_generation(self): cmd = _gen_find_command('coll', {'$query': {'foo': 1}, '$dumb': 2}, None, 0, 0, 0, None, DEFAULT_READ_CONCERN, None, None) self.assertEqual( cmd.to_dict(), SON([('find', 'coll'), ('$dumb', 2), ('filter', {'foo': 1})]).to_dict()) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_command_monitoring_spec.py000066400000000000000000000234651374256237000225740ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the command monitoring spec tests.""" import os import re import sys sys.path[0:0] = [""] import pymongo from bson import json_util from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern from test import unittest, client_context from test.utils import single_client, wait_until, EventListener, parse_read_preference # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'command_monitoring') def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() class TestAllScenarios(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = EventListener() cls.client = single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): cls.client.close() def tearDown(self): self.listener.results.clear() def format_actual_results(results): started = results['started'] succeeded = results['succeeded'] failed = results['failed'] msg = "\nStarted: %r" % (started[0].command if len(started) else None,) msg += "\nSucceeded: %r" % (succeeded[0].reply if len(succeeded) else None,) msg += "\nFailed: %r" % (failed[0].failure if len(failed) else None,) return msg def create_test(scenario_def, test): def run_scenario(self): dbname = scenario_def['database_name'] collname = scenario_def['collection_name'] coll = self.client[dbname][collname] coll.drop() coll.insert_many(scenario_def['data']) self.listener.results.clear() name = camel_to_snake(test['operation']['name']) if 'read_preference' in test['operation']: coll = coll.with_options(read_preference=parse_read_preference( test['operation']['read_preference'])) if 'collectionOptions' in test['operation']: colloptions = test['operation']['collectionOptions'] if 'writeConcern' in colloptions: concern = colloptions['writeConcern'] coll = coll.with_options( write_concern=WriteConcern(**concern)) test_args = test['operation']['arguments'] if 'options' in test_args: options = test_args.pop('options') test_args.update(options) args = {} for arg in test_args: args[camel_to_snake(arg)] = test_args[arg] if name == 'bulk_write': bulk_args = [] for request in args['requests']: opname = request['name'] klass = opname[0:1].upper() + opname[1:] arg = getattr(pymongo, klass)(**request['arguments']) bulk_args.append(arg) try: coll.bulk_write(bulk_args, args.get('ordered', True)) except OperationFailure: pass elif name == 'find': if 'sort' in args: args['sort'] = list(args['sort'].items()) for arg in 'skip', 'limit': if arg in args: args[arg] = int(args[arg]) try: # Iterate the cursor. tuple(coll.find(**args)) except OperationFailure: pass # Wait for the killCursors thread to run if necessary. if 'limit' in args and client_context.version[:2] < (3, 1): self.client._kill_cursors_executor.wake() started = self.listener.results['started'] succeeded = self.listener.results['succeeded'] wait_until( lambda: started[-1].command_name == 'killCursors', "publish a start event for killCursors.") wait_until( lambda: succeeded[-1].command_name == 'killCursors', "publish a succeeded event for killCursors.") else: try: getattr(coll, name)(**args) except OperationFailure: pass res = self.listener.results for expectation in test['expectations']: event_type = next(iter(expectation)) if event_type == "command_started_event": event = res['started'][0] if len(res['started']) else None if event is not None: # The tests substitute 42 for any number other than 0. if (event.command_name == 'getMore' and event.command['getMore']): event.command['getMore'] = 42 elif event.command_name == 'killCursors': event.command['cursors'] = [42] elif event_type == "command_succeeded_event": event = ( res['succeeded'].pop(0) if len(res['succeeded']) else None) if event is not None: reply = event.reply # The tests substitute 42 for any number other than 0, # and "" for any error message. if 'writeErrors' in reply: for doc in reply['writeErrors']: # Remove any new fields the server adds. The tests # only have index, code, and errmsg. diff = set(doc) - set(['index', 'code', 'errmsg']) for field in diff: doc.pop(field) doc['code'] = 42 doc['errmsg'] = "" elif 'cursor' in reply: if reply['cursor']['id']: reply['cursor']['id'] = 42 elif event.command_name == 'killCursors': # Make the tests continue to pass when the killCursors # command is actually in use. if 'cursorsKilled' in reply: reply.pop('cursorsKilled') reply['cursorsUnknown'] = [42] # Found succeeded event. Pop related started event. res['started'].pop(0) elif event_type == "command_failed_event": event = res['failed'].pop(0) if len(res['failed']) else None if event is not None: # Found failed event. Pop related started event. res['started'].pop(0) else: self.fail("Unknown event type") if event is None: event_name = event_type.split('_')[1] self.fail( "Expected %s event for %s command. Actual " "results:%s" % ( event_name, expectation[event_type]['command_name'], format_actual_results(res))) for attr, expected in expectation[event_type].items(): if 'options' in expected: options = expected.pop('options') expected.update(options) actual = getattr(event, attr) if isinstance(expected, dict): for key, val in expected.items(): self.assertEqual(val, actual[key]) else: self.assertEqual(actual, expected) return run_scenario def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json_util.loads(scenario_stream.read()) assert bool(scenario_def.get('tests')), "tests cannot be empty" # Construct test from scenario. for test in scenario_def['tests']: new_test = create_test(scenario_def, test) if "ignore_if_server_version_greater_than" in test: version = test["ignore_if_server_version_greater_than"] ver = tuple(int(elt) for elt in version.split('.')) new_test = client_context.require_version_max(*ver)( new_test) if "ignore_if_server_version_less_than" in test: version = test["ignore_if_server_version_less_than"] ver = tuple(int(elt) for elt in version.split('.')) new_test = client_context.require_version_min(*ver)( new_test) if "ignore_if_topology_type" in test: types = set(test["ignore_if_topology_type"]) if "sharded" in types: new_test = client_context.require_no_mongos(None)( new_test) test_name = 'test_%s_%s_%s' % ( dirname, os.path.splitext(filename)[0], str(test['description'].replace(" ", "_"))) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_common.py000066400000000000000000000176301374256237000171640ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the pymongo common module.""" import sys import uuid sys.path[0:0] = [""] from bson.binary import UUIDLegacy, PYTHON_LEGACY, STANDARD from bson.code import Code from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest from test.utils import connected, rs_or_single_client, single_client @client_context.require_connection def setUpModule(): pass class TestCommon(IntegrationTest): def test_uuid_representation(self): coll = self.db.uuid coll.drop() # Test property self.assertEqual(PYTHON_LEGACY, coll.codec_options.uuid_representation) # Test basic query uu = uuid.uuid4() # Insert as binary subtype 3 coll.insert_one({'uu': uu}) self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual(STANDARD, coll.codec_options.uuid_representation) self.assertEqual(None, coll.find_one({'uu': uu})) self.assertEqual(uu, coll.find_one({'uu': UUIDLegacy(uu)})['uu']) # Test count_documents self.assertEqual(0, coll.count_documents({'uu': uu})) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) self.assertEqual(1, coll.count_documents({'uu': uu})) # Test delete coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) coll.delete_one({'uu': uu}) self.assertEqual(1, coll.count_documents({})) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) coll.delete_one({'uu': uu}) self.assertEqual(0, coll.count_documents({})) # Test update_one coll.insert_one({'_id': uu, 'i': 1}) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) coll.update_one({'_id': uu}, {'$set': {'i': 2}}) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) self.assertEqual(1, coll.find_one({'_id': uu})['i']) coll.update_one({'_id': uu}, {'$set': {'i': 2}}) self.assertEqual(2, coll.find_one({'_id': uu})['i']) # Test Cursor.distinct self.assertEqual([2], coll.find({'_id': uu}).distinct('i')) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual([], coll.find({'_id': uu}).distinct('i')) # Test findAndModify self.assertEqual(None, coll.find_one_and_update({'_id': uu}, {'$set': {'i': 5}})) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) self.assertEqual(2, coll.find_one_and_update({'_id': uu}, {'$set': {'i': 5}})['i']) self.assertEqual(5, coll.find_one({'_id': uu})['i']) # Test command self.assertEqual(5, self.db.command('findAndModify', 'uuid', update={'$set': {'i': 6}}, query={'_id': uu})['value']['i']) self.assertEqual(6, self.db.command( 'findAndModify', 'uuid', update={'$set': {'i': 7}}, query={'_id': UUIDLegacy(uu)})['value']['i']) # Test (inline)_map_reduce coll.drop() coll.insert_one({"_id": uu, "x": 1, "tags": ["dog", "cat"]}) coll.insert_one({"_id": uuid.uuid4(), "x": 3, "tags": ["mouse", "cat", "dog"]}) map = Code("function () {" " this.tags.forEach(function(z) {" " emit(z, 1);" " });" "}") reduce = Code("function (key, values) {" " var total = 0;" " for (var i = 0; i < values.length; i++) {" " total += values[i];" " }" " return total;" "}") coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) q = {"_id": uu} result = coll.inline_map_reduce(map, reduce, query=q) self.assertEqual([], result) result = coll.map_reduce(map, reduce, "results", query=q) self.assertEqual(0, self.db.results.count_documents({})) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) q = {"_id": uu} result = coll.inline_map_reduce(map, reduce, query=q) self.assertEqual(2, len(result)) result = coll.map_reduce(map, reduce, "results", query=q) self.assertEqual(2, self.db.results.count_documents({})) self.db.drop_collection("result") coll.drop() def test_write_concern(self): c = rs_or_single_client(connect=False) self.assertEqual(WriteConcern(), c.write_concern) c = rs_or_single_client(connect=False, w=2, wtimeout=1000) wc = WriteConcern(w=2, wtimeout=1000) self.assertEqual(wc, c.write_concern) # Can we override back to the server default? db = c.get_database('pymongo_test', write_concern=WriteConcern()) self.assertEqual(db.write_concern, WriteConcern()) db = c.pymongo_test self.assertEqual(wc, db.write_concern) coll = db.test self.assertEqual(wc, coll.write_concern) cwc = WriteConcern(j=True) coll = db.get_collection('test', write_concern=cwc) self.assertEqual(cwc, coll.write_concern) self.assertEqual(wc, db.write_concern) def test_mongo_client(self): pair = client_context.pair m = rs_or_single_client(w=0) coll = m.pymongo_test.write_concern_test coll.drop() doc = {"_id": ObjectId()} coll.insert_one(doc) self.assertTrue(coll.insert_one(doc)) coll = coll.with_options(write_concern=WriteConcern(w=1)) self.assertRaises(OperationFailure, coll.insert_one, doc) m = rs_or_single_client() coll = m.pymongo_test.write_concern_test new_coll = coll.with_options(write_concern=WriteConcern(w=0)) self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) m = rs_or_single_client("mongodb://%s/" % (pair,), replicaSet=client_context.replica_set_name) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) m = rs_or_single_client("mongodb://%s/?w=0" % (pair,), replicaSet=client_context.replica_set_name) coll = m.pymongo_test.write_concern_test coll.insert_one(doc) # Equality tests direct = connected(single_client(w=0)) direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,), **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_connections_survive_primary_stepdown_spec.py000066400000000000000000000130041374256237000264700ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test compliance with the connections survive primary step down spec.""" import sys sys.path[0:0] = [""] from bson import SON from pymongo import monitoring from pymongo.errors import NotMasterError from pymongo.write_concern import WriteConcern from test import (client_context, unittest, IntegrationTest) from test.utils import (CMAPListener, ensure_all_connected, repl_set_step_down, rs_or_single_client) class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @client_context.require_replica_set def setUpClass(cls): super(TestConnectionsSurvivePrimaryStepDown, cls).setUpClass() cls.listener = CMAPListener() cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for sockets that # survive a replica set election. ensure_all_connected(cls.client) cls.listener.reset() cls.db = cls.client.get_database( "step-down", write_concern=WriteConcern("majority")) cls.coll = cls.db.get_collection( "step-down", write_concern=WriteConcern("majority")) @classmethod def tearDownClass(cls): cls.client.close() def setUp(self): # Note that all ops use same write-concern as self.db (majority). self.db.drop_collection("step-down") self.db.create_collection("step-down") self.listener.reset() def set_fail_point(self, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) self.client.admin.command(cmd) def verify_pool_cleared(self): self.assertEqual( self.listener.event_count(monitoring.PoolClearedEvent), 1) def verify_pool_not_cleared(self): self.assertEqual( self.listener.event_count(monitoring.PoolClearedEvent), 0) @client_context.require_version_min(4, 2, -1) def test_get_more_iteration(self): # Insert 5 documents with WC majority. self.coll.insert_many([{'data': k} for k in range(5)]) # Start a find operation and retrieve first batch of results. batch_size = 2 cursor = self.coll.find(batch_size=batch_size) for _ in range(batch_size): cursor.next() # Force step-down the primary. repl_set_step_down(self.client, replSetStepDown=5, force=True) # Get next batch of results. for _ in range(batch_size): cursor.next() # Verify pool not cleared. self.verify_pool_not_cleared() # Attempt insertion to mark server description as stale and prevent a # notMaster error on the subsequent operation. try: self.coll.insert_one({}) except NotMasterError: pass # Next insert should succeed on the new primary without clearing pool. self.coll.insert_one({}) self.verify_pool_not_cleared() def run_scenario(self, error_code, retry, pool_status_checker): # Set fail point. self.set_fail_point({"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": error_code}}) self.addCleanup(self.set_fail_point, {"mode": "off"}) # Insert record and verify failure. with self.assertRaises(NotMasterError) as exc: self.coll.insert_one({"test": 1}) self.assertEqual(exc.exception.details['code'], error_code) # Retry before CMAPListener assertion if retry_before=True. if retry: self.coll.insert_one({"test": 1}) # Verify pool cleared/not cleared. pool_status_checker() # Always retry here to ensure discovery of new primary. self.coll.insert_one({"test": 1}) @client_context.require_version_min(4, 2, -1) @client_context.require_test_commands def test_not_master_keep_connection_pool(self): self.run_scenario(10107, True, self.verify_pool_not_cleared) @client_context.require_version_min(4, 0, 0) @client_context.require_version_max(4, 1, 0, -1) @client_context.require_test_commands def test_not_master_reset_connection_pool(self): self.run_scenario(10107, False, self.verify_pool_cleared) @client_context.require_version_min(4, 0, 0) @client_context.require_test_commands def test_shutdown_in_progress(self): self.run_scenario(91, False, self.verify_pool_cleared) @client_context.require_version_min(4, 0, 0) @client_context.require_test_commands def test_interrupted_at_shutdown(self): self.run_scenario(11600, False, self.verify_pool_cleared) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_crud_v1.py000066400000000000000000000221161374256237000172320ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the collection module.""" import json import os import re import sys sys.path[0:0] = [""] from bson.py3compat import iteritems from pymongo import operations, WriteConcern from pymongo.command_cursor import CommandCursor from pymongo.cursor import Cursor from pymongo.errors import PyMongoError from pymongo.read_concern import ReadConcern from pymongo.results import _WriteResult, BulkWriteResult from pymongo.operations import (InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany) from test import unittest, client_context, IntegrationTest from test.utils import (camel_to_snake, camel_to_upper_camel, camel_to_snake_args, drop_collections, TestCreator) # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'crud', 'v1') class TestAllScenarios(IntegrationTest): pass def check_result(self, expected_result, result): if isinstance(result, _WriteResult): for res in expected_result: prop = camel_to_snake(res) msg = "%s : %r != %r" % (prop, expected_result, result) # SPEC-869: Only BulkWriteResult has upserted_count. if (prop == "upserted_count" and not isinstance(result, BulkWriteResult)): if result.upserted_id is not None: upserted_count = 1 else: upserted_count = 0 self.assertEqual(upserted_count, expected_result[res], msg) elif prop == "inserted_ids": # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): self.assertEqual(len(expected_result[res]), result.inserted_count) else: # InsertManyResult may be compared to [id1] from the # crud spec or {"0": id1} from the retryable write spec. ids = expected_result[res] if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] self.assertEqual(ids, result.inserted_ids, msg) elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] self.assertEqual(expected_ids, result.upserted_ids, msg) else: self.assertEqual( getattr(result, prop), expected_result[res], msg) else: self.assertEqual(result, expected_result) def run_operation(collection, test): # Convert command from CamelCase to pymongo.collection method. operation = camel_to_snake(test['operation']['name']) cmd = getattr(collection, operation) # Convert arguments to snake_case and handle special cases. arguments = test['operation']['arguments'] options = arguments.pop("options", {}) for option_name in options: arguments[camel_to_snake(option_name)] = options[option_name] if operation == "bulk_write": # Parse each request into a bulk write model. requests = [] for request in arguments["requests"]: bulk_model = camel_to_upper_camel(request["name"]) bulk_class = getattr(operations, bulk_model) bulk_arguments = camel_to_snake_args(request["arguments"]) requests.append(bulk_class(**bulk_arguments)) arguments["requests"] = requests else: for arg_name in list(arguments): c2s = camel_to_snake(arg_name) # PyMongo accepts sort as list of tuples. if arg_name == "sort": sort_dict = arguments[arg_name] arguments[arg_name] = list(iteritems(sort_dict)) # Named "key" instead not fieldName. if arg_name == "fieldName": arguments["key"] = arguments.pop(arg_name) # Aggregate uses "batchSize", while find uses batch_size. elif arg_name == "batchSize" and operation == "aggregate": continue # Requires boolean returnDocument. elif arg_name == "returnDocument": arguments[c2s] = arguments.pop(arg_name) == "After" else: arguments[c2s] = arguments.pop(arg_name) result = cmd(**arguments) if operation == "aggregate": if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: out = collection.database[arguments["pipeline"][-1]["$out"]] result = out.find() if isinstance(result, Cursor) or isinstance(result, CommandCursor): return list(result) return result def create_test(scenario_def, test, name): def run_scenario(self): # Cleanup state and load data (if provided). drop_collections(self.db) data = scenario_def.get('data') if data: self.db.test.with_options( write_concern=WriteConcern(w="majority")).insert_many( scenario_def['data']) # Run operations and check results or errors. expected_result = test.get('outcome', {}).get('result') expected_error = test.get('outcome', {}).get('error') if expected_error is True: with self.assertRaises(PyMongoError): run_operation(self.db.test, test) else: result = run_operation(self.db.test, test) check_result(self, expected_result, result) # Assert final state is expected. expected_c = test['outcome'].get('collection') if expected_c is not None: expected_name = expected_c.get('name') if expected_name is not None: db_coll = self.db[expected_name] else: db_coll = self.db.test db_coll = db_coll.with_options( read_concern=ReadConcern(level="local")) self.assertEqual(list(db_coll.find()), expected_c['data']) return run_scenario test_creator = TestCreator(create_test, TestAllScenarios, _TEST_PATH) test_creator.create_tests() class TestWriteOpsComparison(unittest.TestCase): def test_InsertOneEquals(self): self.assertEqual(InsertOne({'foo': 42}), InsertOne({'foo': 42})) def test_InsertOneNotEquals(self): self.assertNotEqual(InsertOne({'foo': 42}), InsertOne({'foo': 23})) def test_DeleteOneEquals(self): self.assertEqual(DeleteOne({'foo': 42}), DeleteOne({'foo': 42})) def test_DeleteOneNotEquals(self): self.assertNotEqual(DeleteOne({'foo': 42}), DeleteOne({'foo': 23})) def test_DeleteManyEquals(self): self.assertEqual(DeleteMany({'foo': 42}), DeleteMany({'foo': 42})) def test_DeleteManyNotEquals(self): self.assertNotEqual(DeleteMany({'foo': 42}), DeleteMany({'foo': 23})) def test_DeleteOneNotEqualsDeleteMany(self): self.assertNotEqual(DeleteOne({'foo': 42}), DeleteMany({'foo': 42})) def test_ReplaceOneEquals(self): self.assertEqual(ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False), ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False)) def test_ReplaceOneNotEquals(self): self.assertNotEqual(ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False), ReplaceOne({'foo': 42}, {'bar': 42}, upsert=True)) def test_UpdateOneEquals(self): self.assertEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), UpdateOne({'foo': 42}, {'$set': {'bar': 42}})) def test_UpdateOneNotEquals(self): self.assertNotEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), UpdateOne({'foo': 42}, {'$set': {'bar': 23}})) def test_UpdateManyEquals(self): self.assertEqual(UpdateMany({'foo': 42}, {'$set': {'bar': 42}}), UpdateMany({'foo': 42}, {'$set': {'bar': 42}})) def test_UpdateManyNotEquals(self): self.assertNotEqual(UpdateMany({'foo': 42}, {'$set': {'bar': 42}}), UpdateMany({'foo': 42}, {'$set': {'bar': 23}})) def test_UpdateOneNotEqualsUpdateMany(self): self.assertNotEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), UpdateMany({'foo': 42}, {'$set': {'bar': 42}})) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_crud_v2.py000066400000000000000000000043601374256237000172340ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the collection module.""" import os import sys sys.path[0:0] = [""] from test import unittest from test.utils import TestCreator from test.utils_spec_runner import SpecRunner # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'crud', 'v2') # Default test database and collection names. TEST_DB = 'testdb' TEST_COLLECTION = 'testcollection' class TestSpec(SpecRunner): def get_scenario_db_name(self, scenario_def): """Crud spec says database_name is optional.""" return scenario_def.get('database_name', TEST_DB) def get_scenario_coll_name(self, scenario_def): """Crud spec says collection_name is optional.""" return scenario_def.get('collection_name', TEST_COLLECTION) def get_object_name(self, op): """Crud spec says object is optional and defaults to 'collection'.""" return op.get('object', 'collection') def get_outcome_coll_name(self, outcome, collection): """Crud spec says outcome has an optional 'collection.name'.""" return outcome['collection'].get('name', collection.name) def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" # PYTHON-1935 Only create the collection if there is data to insert. if scenario_def['data']: super(TestSpec, self).setup_scenario(scenario_def) def create_test(scenario_def, test, name): def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestSpec, _TEST_PATH) test_creator.create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_cursor.py000066400000000000000000001731261374256237000172140ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the cursor module.""" import copy import gc import itertools import random import re import sys import time import threading import warnings sys.path[0:0] = [""] from bson import decode_all from bson.code import Code from bson.py3compat import PY3 from bson.son import SON from pymongo import (ASCENDING, DESCENDING, ALL, OFF) from pymongo.collation import Collation from pymongo.cursor import Cursor, CursorType from pymongo.errors import (ConfigurationError, ExecutionTimeout, InvalidOperation, OperationFailure) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from test import (client_context, unittest, IntegrationTest) from test.utils import (EventListener, ignore_deprecations, rs_or_single_client, WhiteListEventListener) if PY3: long = int class TestCursor(IntegrationTest): def test_deepcopy_cursor_littered_with_regexes(self): cursor = self.db.test.find({ "x": re.compile("^hmmm.*"), "y": [re.compile("^hmm.*")], "z": {"a": [re.compile("^hm.*")]}, re.compile("^key.*"): {"a": [re.compile("^hm.*")]}}) cursor2 = copy.deepcopy(cursor) self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) def test_add_remove_option(self): cursor = self.db.test.find() self.assertEqual(0, cursor._Cursor__query_flags) cursor.add_option(2) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE) self.assertEqual(2, cursor2._Cursor__query_flags) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.add_option(32) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor2._Cursor__query_flags) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.add_option(128) cursor2 = self.db.test.find( cursor_type=CursorType.TAILABLE_AWAIT).add_option(128) self.assertEqual(162, cursor2._Cursor__query_flags) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertEqual(162, cursor._Cursor__query_flags) cursor.add_option(128) self.assertEqual(162, cursor._Cursor__query_flags) cursor.remove_option(128) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor2._Cursor__query_flags) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(32) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE) self.assertEqual(2, cursor2._Cursor__query_flags) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertEqual(2, cursor._Cursor__query_flags) cursor.remove_option(32) self.assertEqual(2, cursor._Cursor__query_flags) # Timeout cursor = self.db.test.find(no_cursor_timeout=True) self.assertEqual(16, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(16) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(16) self.assertEqual(0, cursor._Cursor__query_flags) # Tailable / Await data cursor = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(34) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(32) self.assertEqual(2, cursor._Cursor__query_flags) # Partial cursor = self.db.test.find(allow_partial_results=True) self.assertEqual(128, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(128) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(128) self.assertEqual(0, cursor._Cursor__query_flags) def test_add_remove_option_exhaust(self): # Exhaust - which mongos doesn't support if client_context.is_mongos: with self.assertRaises(InvalidOperation): self.db.test.find(cursor_type=CursorType.EXHAUST) else: cursor = self.db.test.find(cursor_type=CursorType.EXHAUST) self.assertEqual(64, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(64) self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertTrue(cursor._Cursor__exhaust) cursor.remove_option(64) self.assertEqual(0, cursor._Cursor__query_flags) self.assertFalse(cursor._Cursor__exhaust) def test_allow_disk_use(self): db = self.db db.pymongo_test.drop() coll = db.pymongo_test self.assertRaises(TypeError, coll.find().allow_disk_use, 'baz') cursor = coll.find().allow_disk_use(True) self.assertEqual(True, cursor._Cursor__allow_disk_use) cursor = coll.find().allow_disk_use(False) self.assertEqual(False, cursor._Cursor__allow_disk_use) def test_max_time_ms(self): db = self.db db.pymongo_test.drop() coll = db.pymongo_test self.assertRaises(TypeError, coll.find().max_time_ms, 'foo') coll.insert_one({"amalia": 1}) coll.insert_one({"amalia": 2}) coll.find().max_time_ms(None) coll.find().max_time_ms(long(1)) cursor = coll.find().max_time_ms(999) self.assertEqual(999, cursor._Cursor__max_time_ms) cursor = coll.find().max_time_ms(10).max_time_ms(1000) self.assertEqual(1000, cursor._Cursor__max_time_ms) cursor = coll.find().max_time_ms(999) c2 = cursor.clone() self.assertEqual(999, c2._Cursor__max_time_ms) self.assertTrue("$maxTimeMS" in cursor._Cursor__query_spec()) self.assertTrue("$maxTimeMS" in c2._Cursor__query_spec()) self.assertTrue(coll.find_one(max_time_ms=1000)) client = self.client if (not client_context.is_mongos and client_context.test_commands_enabled): # Cursor parses server timeout error in response to initial query. client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: cursor = coll.find().max_time_ms(1) try: next(cursor) except ExecutionTimeout: pass else: self.fail("ExecutionTimeout not raised") self.assertRaises(ExecutionTimeout, coll.find_one, max_time_ms=1) finally: client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") @client_context.require_version_min(3, 1, 9, -1) def test_max_await_time_ms(self): db = self.db db.pymongo_test.drop() coll = db.create_collection("pymongo_test", capped=True, size=4096) self.assertRaises(TypeError, coll.find().max_await_time_ms, 'foo') coll.insert_one({"amalia": 1}) coll.insert_one({"amalia": 2}) coll.find().max_await_time_ms(None) coll.find().max_await_time_ms(long(1)) # When cursor is not tailable_await cursor = coll.find() self.assertEqual(None, cursor._Cursor__max_await_time_ms) cursor = coll.find().max_await_time_ms(99) self.assertEqual(None, cursor._Cursor__max_await_time_ms) # If cursor is tailable_await and timeout is unset cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(None, cursor._Cursor__max_await_time_ms) # If cursor is tailable_await and timeout is set cursor = coll.find( cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99) self.assertEqual(99, cursor._Cursor__max_await_time_ms) cursor = coll.find( cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms( 10).max_await_time_ms(90) self.assertEqual(90, cursor._Cursor__max_await_time_ms) listener = WhiteListEventListener('find', 'getMore') coll = rs_or_single_client( event_listeners=[listener])[self.db.name].pymongo_test results = listener.results # Tailable_await defaults. list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT)) # find self.assertFalse('maxTimeMS' in results['started'][0].command) # getMore self.assertFalse('maxTimeMS' in results['started'][1].command) results.clear() # Tailable_await with max_await_time_ms set. list(coll.find( cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)) # find self.assertEqual('find', results['started'][0].command_name) self.assertFalse('maxTimeMS' in results['started'][0].command) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertTrue('maxTimeMS' in results['started'][1].command) self.assertEqual(99, results['started'][1].command['maxTimeMS']) results.clear() # Tailable_await with max_time_ms list(coll.find( cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(99)) # find self.assertEqual('find', results['started'][0].command_name) self.assertTrue('maxTimeMS' in results['started'][0].command) self.assertEqual(99, results['started'][0].command['maxTimeMS']) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertFalse('maxTimeMS' in results['started'][1].command) results.clear() # Tailable_await with both max_time_ms and max_await_time_ms list(coll.find( cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms( 99).max_await_time_ms(99)) # find self.assertEqual('find', results['started'][0].command_name) self.assertTrue('maxTimeMS' in results['started'][0].command) self.assertEqual(99, results['started'][0].command['maxTimeMS']) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertTrue('maxTimeMS' in results['started'][1].command) self.assertEqual(99, results['started'][1].command['maxTimeMS']) results.clear() # Non tailable_await with max_await_time_ms list(coll.find(batch_size=1).max_await_time_ms(99)) # find self.assertEqual('find', results['started'][0].command_name) self.assertFalse('maxTimeMS' in results['started'][0].command) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertFalse('maxTimeMS' in results['started'][1].command) results.clear() # Non tailable_await with max_time_ms list(coll.find(batch_size=1).max_time_ms(99)) # find self.assertEqual('find', results['started'][0].command_name) self.assertTrue('maxTimeMS' in results['started'][0].command) self.assertEqual(99, results['started'][0].command['maxTimeMS']) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertFalse('maxTimeMS' in results['started'][1].command) # Non tailable_await with both max_time_ms and max_await_time_ms list(coll.find(batch_size=1).max_time_ms(99).max_await_time_ms(88)) # find self.assertEqual('find', results['started'][0].command_name) self.assertTrue('maxTimeMS' in results['started'][0].command) self.assertEqual(99, results['started'][0].command['maxTimeMS']) # getMore self.assertEqual('getMore', results['started'][1].command_name) self.assertFalse('maxTimeMS' in results['started'][1].command) @client_context.require_test_commands @client_context.require_no_mongos def test_max_time_ms_getmore(self): # Test that Cursor handles server timeout error in response to getmore. coll = self.db.pymongo_test coll.insert_many([{} for _ in range(200)]) cursor = coll.find().max_time_ms(100) # Send initial query before turning on failpoint. next(cursor) self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: try: # Iterate up to first getmore. list(cursor) except ExecutionTimeout: pass else: self.fail("ExecutionTimeout not raised") finally: self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_explain(self): a = self.db.test.find() a.explain() for _ in a: break b = a.explain() # "cursor" pre MongoDB 2.7.6, "executionStats" post self.assertTrue("cursor" in b or "executionStats" in b) def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = WhiteListEventListener("explain") client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client.pymongo_test.test.with_options( read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) started = listener.results['started'] self.assertEqual(len(started), 1) self.assertNotIn("readConcern", started[0].command) def test_hint(self): db = self.db self.assertRaises(TypeError, db.test.find().hint, 5.5) db.test.drop() db.test.insert_many([{"num": i, "foo": i} for i in range(100)]) self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}) .hint([("num", ASCENDING)]).explain) self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}) .hint([("foo", ASCENDING)]).explain) spec = [("num", DESCENDING)] index = db.test.create_index(spec) first = next(db.test.find()) self.assertEqual(0, first.get('num')) first = next(db.test.find().hint(spec)) self.assertEqual(99, first.get('num')) self.assertRaises(OperationFailure, db.test.find({"num": 17, "foo": 17}) .hint([("foo", ASCENDING)]).explain) a = db.test.find({"num": 17}) a.hint(spec) for _ in a: break self.assertRaises(InvalidOperation, a.hint, spec) def test_hint_by_name(self): db = self.db db.test.drop() db.test.insert_many([{"i": i} for i in range(100)]) db.test.create_index([('i', DESCENDING)], name='fooindex') first = next(db.test.find()) self.assertEqual(0, first.get('i')) first = next(db.test.find().hint('fooindex')) self.assertEqual(99, first.get('i')) def test_limit(self): db = self.db self.assertRaises(TypeError, db.test.find().limit, None) self.assertRaises(TypeError, db.test.find().limit, "hello") self.assertRaises(TypeError, db.test.find().limit, 5.5) self.assertTrue(db.test.find().limit(long(5))) db.test.drop() db.test.insert_many([{"x": i} for i in range(100)]) count = 0 for _ in db.test.find(): count += 1 self.assertEqual(count, 100) count = 0 for _ in db.test.find().limit(20): count += 1 self.assertEqual(count, 20) count = 0 for _ in db.test.find().limit(99): count += 1 self.assertEqual(count, 99) count = 0 for _ in db.test.find().limit(1): count += 1 self.assertEqual(count, 1) count = 0 for _ in db.test.find().limit(0): count += 1 self.assertEqual(count, 100) count = 0 for _ in db.test.find().limit(0).limit(50).limit(10): count += 1 self.assertEqual(count, 10) a = db.test.find() a.limit(10) for _ in a: break self.assertRaises(InvalidOperation, a.limit, 5) @ignore_deprecations # Ignore max without hint. def test_max(self): db = self.db db.test.drop() j_index = [("j", ASCENDING)] db.test.create_index(j_index) db.test.insert_many([{"j": j, "k": j} for j in range(10)]) def find(max_spec, expected_index): cursor = db.test.find().max(max_spec) if client_context.requires_hint_with_min_max_queries: cursor = cursor.hint(expected_index) return cursor cursor = find([("j", 3)], j_index) self.assertEqual(len(list(cursor)), 3) # Tuple. cursor = find((("j", 3),), j_index) self.assertEqual(len(list(cursor)), 3) # Compound index. index_keys = [("j", ASCENDING), ("k", ASCENDING)] db.test.create_index(index_keys) cursor = find([("j", 3), ("k", 3)], index_keys) self.assertEqual(len(list(cursor)), 3) # Wrong order. cursor = find([("k", 3), ("j", 3)], index_keys) self.assertRaises(OperationFailure, list, cursor) # No such index. cursor = find([("k", 3)], "k") self.assertRaises(OperationFailure, list, cursor) self.assertRaises(TypeError, db.test.find().max, 10) self.assertRaises(TypeError, db.test.find().max, {"j": 10}) @ignore_deprecations # Ignore min without hint. def test_min(self): db = self.db db.test.drop() j_index = [("j", ASCENDING)] db.test.create_index(j_index) db.test.insert_many([{"j": j, "k": j} for j in range(10)]) def find(min_spec, expected_index): cursor = db.test.find().min(min_spec) if client_context.requires_hint_with_min_max_queries: cursor = cursor.hint(expected_index) return cursor cursor = find([("j", 3)], j_index) self.assertEqual(len(list(cursor)), 7) # Tuple. cursor = find((("j", 3),), j_index) self.assertEqual(len(list(cursor)), 7) # Compound index. index_keys = [("j", ASCENDING), ("k", ASCENDING)] db.test.create_index(index_keys) cursor = find([("j", 3), ("k", 3)], index_keys) self.assertEqual(len(list(cursor)), 7) # Wrong order. cursor = find([("k", 3), ("j", 3)], index_keys) self.assertRaises(OperationFailure, list, cursor) # No such index. cursor = find([("k", 3)], "k") self.assertRaises(OperationFailure, list, cursor) self.assertRaises(TypeError, db.test.find().min, 10) self.assertRaises(TypeError, db.test.find().min, {"j": 10}) @client_context.require_version_max(4, 1, -1) def test_min_max_without_hint(self): coll = self.db.test j_index = [("j", ASCENDING)] coll.create_index(j_index) with warnings.catch_warnings(record=True) as warns: warnings.simplefilter("default", DeprecationWarning) list(coll.find().min([("j", 3)])) self.assertIn('using a min/max query operator', str(warns[0])) # Ensure the warning is raised with the proper stack level. del warns[:] list(coll.find().min([("j", 3)])) self.assertIn('using a min/max query operator', str(warns[0])) del warns[:] list(coll.find().max([("j", 3)])) self.assertIn('using a min/max query operator', str(warns[0])) def test_batch_size(self): db = self.db db.test.drop() db.test.insert_many([{"x": x} for x in range(200)]) self.assertRaises(TypeError, db.test.find().batch_size, None) self.assertRaises(TypeError, db.test.find().batch_size, "hello") self.assertRaises(TypeError, db.test.find().batch_size, 5.5) self.assertRaises(ValueError, db.test.find().batch_size, -1) self.assertTrue(db.test.find().batch_size(long(5))) a = db.test.find() for _ in a: break self.assertRaises(InvalidOperation, a.batch_size, 5) def cursor_count(cursor, expected_count): count = 0 for _ in cursor: count += 1 self.assertEqual(expected_count, count) cursor_count(db.test.find().batch_size(0), 200) cursor_count(db.test.find().batch_size(1), 200) cursor_count(db.test.find().batch_size(2), 200) cursor_count(db.test.find().batch_size(5), 200) cursor_count(db.test.find().batch_size(100), 200) cursor_count(db.test.find().batch_size(500), 200) cursor_count(db.test.find().batch_size(0).limit(1), 1) cursor_count(db.test.find().batch_size(1).limit(1), 1) cursor_count(db.test.find().batch_size(2).limit(1), 1) cursor_count(db.test.find().batch_size(5).limit(1), 1) cursor_count(db.test.find().batch_size(100).limit(1), 1) cursor_count(db.test.find().batch_size(500).limit(1), 1) cursor_count(db.test.find().batch_size(0).limit(10), 10) cursor_count(db.test.find().batch_size(1).limit(10), 10) cursor_count(db.test.find().batch_size(2).limit(10), 10) cursor_count(db.test.find().batch_size(5).limit(10), 10) cursor_count(db.test.find().batch_size(100).limit(10), 10) cursor_count(db.test.find().batch_size(500).limit(10), 10) cur = db.test.find().batch_size(1) next(cur) if client_context.version.at_least(3, 1, 9): # find command batchSize should be 1 self.assertEqual(0, len(cur._Cursor__data)) else: # OP_QUERY ntoreturn should be 2 self.assertEqual(1, len(cur._Cursor__data)) next(cur) self.assertEqual(0, len(cur._Cursor__data)) next(cur) self.assertEqual(0, len(cur._Cursor__data)) next(cur) self.assertEqual(0, len(cur._Cursor__data)) def test_limit_and_batch_size(self): db = self.db db.test.drop() db.test.insert_many([{"x": x} for x in range(500)]) curs = db.test.find().limit(0).batch_size(10) next(curs) self.assertEqual(10, curs._Cursor__retrieved) curs = db.test.find(limit=0, batch_size=10) next(curs) self.assertEqual(10, curs._Cursor__retrieved) curs = db.test.find().limit(-2).batch_size(0) next(curs) self.assertEqual(2, curs._Cursor__retrieved) curs = db.test.find(limit=-2, batch_size=0) next(curs) self.assertEqual(2, curs._Cursor__retrieved) curs = db.test.find().limit(-4).batch_size(5) next(curs) self.assertEqual(4, curs._Cursor__retrieved) curs = db.test.find(limit=-4, batch_size=5) next(curs) self.assertEqual(4, curs._Cursor__retrieved) curs = db.test.find().limit(50).batch_size(500) next(curs) self.assertEqual(50, curs._Cursor__retrieved) curs = db.test.find(limit=50, batch_size=500) next(curs) self.assertEqual(50, curs._Cursor__retrieved) curs = db.test.find().batch_size(500) next(curs) self.assertEqual(500, curs._Cursor__retrieved) curs = db.test.find(batch_size=500) next(curs) self.assertEqual(500, curs._Cursor__retrieved) curs = db.test.find().limit(50) next(curs) self.assertEqual(50, curs._Cursor__retrieved) curs = db.test.find(limit=50) next(curs) self.assertEqual(50, curs._Cursor__retrieved) # these two might be shaky, as the default # is set by the server. as of 2.0.0-rc0, 101 # or 1MB (whichever is smaller) is default # for queries without ntoreturn curs = db.test.find() next(curs) self.assertEqual(101, curs._Cursor__retrieved) curs = db.test.find().limit(0).batch_size(0) next(curs) self.assertEqual(101, curs._Cursor__retrieved) curs = db.test.find(limit=0, batch_size=0) next(curs) self.assertEqual(101, curs._Cursor__retrieved) def test_skip(self): db = self.db self.assertRaises(TypeError, db.test.find().skip, None) self.assertRaises(TypeError, db.test.find().skip, "hello") self.assertRaises(TypeError, db.test.find().skip, 5.5) self.assertRaises(ValueError, db.test.find().skip, -5) self.assertTrue(db.test.find().skip(long(5))) db.drop_collection("test") db.test.insert_many([{"x": i} for i in range(100)]) for i in db.test.find(): self.assertEqual(i["x"], 0) break for i in db.test.find().skip(20): self.assertEqual(i["x"], 20) break for i in db.test.find().skip(99): self.assertEqual(i["x"], 99) break for i in db.test.find().skip(1): self.assertEqual(i["x"], 1) break for i in db.test.find().skip(0): self.assertEqual(i["x"], 0) break for i in db.test.find().skip(0).skip(50).skip(10): self.assertEqual(i["x"], 10) break for i in db.test.find().skip(1000): self.fail() a = db.test.find() a.skip(10) for _ in a: break self.assertRaises(InvalidOperation, a.skip, 5) def test_sort(self): db = self.db self.assertRaises(TypeError, db.test.find().sort, 5) self.assertRaises(ValueError, db.test.find().sort, []) self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING) self.assertRaises(TypeError, db.test.find().sort, [("hello", DESCENDING)], DESCENDING) db.test.drop() unsort = list(range(10)) random.shuffle(unsort) db.test.insert_many([{"x": i} for i in unsort]) asc = [i["x"] for i in db.test.find().sort("x", ASCENDING)] self.assertEqual(asc, list(range(10))) asc = [i["x"] for i in db.test.find().sort("x")] self.assertEqual(asc, list(range(10))) asc = [i["x"] for i in db.test.find().sort([("x", ASCENDING)])] self.assertEqual(asc, list(range(10))) expect = list(reversed(range(10))) desc = [i["x"] for i in db.test.find().sort("x", DESCENDING)] self.assertEqual(desc, expect) desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])] self.assertEqual(desc, expect) desc = [i["x"] for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)] self.assertEqual(desc, expect) expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)] shuffled = list(expected) random.shuffle(shuffled) db.test.drop() for (a, b) in shuffled: db.test.insert_one({"a": a, "b": b}) result = [(i["a"], i["b"]) for i in db.test.find().sort([("b", DESCENDING), ("a", ASCENDING)])] self.assertEqual(result, expected) a = db.test.find() a.sort("x", ASCENDING) for _ in a: break self.assertRaises(InvalidOperation, a.sort, "x", ASCENDING) @ignore_deprecations def test_count(self): db = self.db db.test.drop() self.assertEqual(0, db.test.find().count()) db.test.insert_many([{"x": i} for i in range(10)]) self.assertEqual(10, db.test.find().count()) self.assertTrue(isinstance(db.test.find().count(), int)) self.assertEqual(10, db.test.find().limit(5).count()) self.assertEqual(10, db.test.find().skip(5).count()) self.assertEqual(1, db.test.find({"x": 1}).count()) self.assertEqual(5, db.test.find({"x": {"$lt": 5}}).count()) a = db.test.find() b = a.count() for _ in a: break self.assertEqual(b, a.count()) self.assertEqual(0, db.test.acollectionthatdoesntexist.find().count()) @ignore_deprecations def test_count_with_hint(self): collection = self.db.test collection.drop() collection.insert_many([{'i': 1}, {'i': 2}]) self.assertEqual(2, collection.find().count()) collection.create_index([('i', 1)]) self.assertEqual(1, collection.find({'i': 1}).hint("_id_").count()) self.assertEqual(2, collection.find().hint("_id_").count()) self.assertRaises(OperationFailure, collection.find({'i': 1}).hint("BAD HINT").count) # Create a sparse index which should have no entries. collection.create_index([('x', 1)], sparse=True) self.assertEqual(0, collection.find({'i': 1}).hint("x_1").count()) self.assertEqual( 0, collection.find({'i': 1}).hint([("x", 1)]).count()) if client_context.version.at_least(3, 3, 2): self.assertEqual(0, collection.find().hint("x_1").count()) self.assertEqual(0, collection.find().hint([("x", 1)]).count()) else: self.assertEqual(2, collection.find().hint("x_1").count()) self.assertEqual(2, collection.find().hint([("x", 1)]).count()) @ignore_deprecations def test_where(self): db = self.db db.test.drop() a = db.test.find() self.assertRaises(TypeError, a.where, 5) self.assertRaises(TypeError, a.where, None) self.assertRaises(TypeError, a.where, {}) db.test.insert_many([{"x": i} for i in range(10)]) self.assertEqual(3, len(list(db.test.find().where('this.x < 3')))) self.assertEqual(3, len(list(db.test.find().where(Code('this.x < 3'))))) code_with_scope = Code('this.x < i', {"i": 3}) if client_context.version.at_least(4, 3, 3): # MongoDB 4.4 removed support for Code with scope. with self.assertRaises(OperationFailure): list(db.test.find().where(code_with_scope)) code_with_empty_scope = Code('this.x < 3', {}) with self.assertRaises(OperationFailure): list(db.test.find().where(code_with_empty_scope)) else: self.assertEqual( 3, len(list(db.test.find().where(code_with_scope)))) self.assertEqual(10, len(list(db.test.find()))) self.assertEqual(3, db.test.find().where('this.x < 3').count()) self.assertEqual(10, db.test.find().count()) self.assertEqual(3, db.test.find().where(u'this.x < 3').count()) self.assertEqual([0, 1, 2], [a["x"] for a in db.test.find().where('this.x < 3')]) self.assertEqual([], [a["x"] for a in db.test.find({"x": 5}).where('this.x < 3')]) self.assertEqual([5], [a["x"] for a in db.test.find({"x": 5}).where('this.x > 3')]) cursor = db.test.find().where('this.x < 3').where('this.x > 7') self.assertEqual([8, 9], [a["x"] for a in cursor]) a = db.test.find() b = a.where('this.x > 3') for _ in a: break self.assertRaises(InvalidOperation, a.where, 'this.x < 3') def test_rewind(self): self.db.test.insert_many([{"x": i} for i in range(1, 4)]) cursor = self.db.test.find().limit(2) count = 0 for _ in cursor: count += 1 self.assertEqual(2, count) count = 0 for _ in cursor: count += 1 self.assertEqual(0, count) cursor.rewind() count = 0 for _ in cursor: count += 1 self.assertEqual(2, count) cursor.rewind() count = 0 for _ in cursor: break cursor.rewind() for _ in cursor: count += 1 self.assertEqual(2, count) self.assertEqual(cursor, cursor.rewind()) # manipulate, oplog_reply, and snapshot are all deprecated. @ignore_deprecations def test_clone(self): self.db.test.insert_many([{"x": i} for i in range(1, 4)]) cursor = self.db.test.find().limit(2) count = 0 for _ in cursor: count += 1 self.assertEqual(2, count) count = 0 for _ in cursor: count += 1 self.assertEqual(0, count) cursor = cursor.clone() cursor2 = cursor.clone() count = 0 for _ in cursor: count += 1 self.assertEqual(2, count) for _ in cursor2: count += 1 self.assertEqual(4, count) cursor.rewind() count = 0 for _ in cursor: break cursor = cursor.clone() for _ in cursor: count += 1 self.assertEqual(2, count) self.assertNotEqual(cursor, cursor.clone()) # Just test attributes cursor = self.db.test.find({"x": re.compile("^hello.*")}, projection={'_id': False}, skip=1, no_cursor_timeout=True, cursor_type=CursorType.TAILABLE_AWAIT, sort=[("x", 1)], allow_partial_results=True, oplog_replay=True, batch_size=123, manipulate=False, collation={'locale': 'en_US'}, hint=[("_id", 1)], max_scan=100, max_time_ms=1000, return_key=True, show_record_id=True, snapshot=True, allow_disk_use=True).limit(2) cursor.min([('a', 1)]).max([('b', 3)]) cursor.add_option(128) cursor.comment('hi!') # Every attribute should be the same. cursor2 = cursor.clone() self.assertEqual(cursor.__dict__, cursor2.__dict__) # Shallow copies can so can mutate cursor2 = copy.copy(cursor) cursor2._Cursor__projection['cursor2'] = False self.assertTrue('cursor2' in cursor._Cursor__projection) # Deepcopies and shouldn't mutate cursor3 = copy.deepcopy(cursor) cursor3._Cursor__projection['cursor3'] = False self.assertFalse('cursor3' in cursor._Cursor__projection) cursor4 = cursor.clone() cursor4._Cursor__projection['cursor4'] = False self.assertFalse('cursor4' in cursor._Cursor__projection) # Test memo when deepcopying queries query = {"hello": "world"} query["reflexive"] = query cursor = self.db.test.find(query) cursor2 = copy.deepcopy(cursor) self.assertNotEqual(id(cursor._Cursor__spec), id(cursor2._Cursor__spec)) self.assertEqual(id(cursor2._Cursor__spec['reflexive']), id(cursor2._Cursor__spec)) self.assertEqual(len(cursor2._Cursor__spec), 2) # Ensure hints are cloned as the correct type cursor = self.db.test.find().hint([('z', 1), ("a", 1)]) cursor2 = copy.deepcopy(cursor) self.assertTrue(isinstance(cursor2._Cursor__hint, SON)) self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint) def test_clone_empty(self): self.db.test.delete_many({}) self.db.test.insert_many([{"x": i} for i in range(1, 4)]) cursor = self.db.test.find()[2:2] cursor2 = cursor.clone() self.assertRaises(StopIteration, cursor.next) self.assertRaises(StopIteration, cursor2.next) @ignore_deprecations def test_count_with_fields(self): self.db.test.drop() self.db.test.insert_one({"x": 1}) self.assertEqual(1, self.db.test.find({}, ["a"]).count()) def test_bad_getitem(self): self.assertRaises(TypeError, lambda x: self.db.test.find()[x], "hello") self.assertRaises(TypeError, lambda x: self.db.test.find()[x], 5.5) self.assertRaises(TypeError, lambda x: self.db.test.find()[x], None) def test_getitem_slice_index(self): self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) count = itertools.count self.assertRaises(IndexError, lambda: self.db.test.find()[-1:]) self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2]) for a, b in zip(count(0), self.db.test.find()): self.assertEqual(a, b['i']) self.assertEqual(100, len(list(self.db.test.find()[0:]))) for a, b in zip(count(0), self.db.test.find()[0:]): self.assertEqual(a, b['i']) self.assertEqual(80, len(list(self.db.test.find()[20:]))) for a, b in zip(count(20), self.db.test.find()[20:]): self.assertEqual(a, b['i']) for a, b in zip(count(99), self.db.test.find()[99:]): self.assertEqual(a, b['i']) for i in self.db.test.find()[1000:]: self.fail() self.assertEqual(5, len(list(self.db.test.find()[20:25]))) self.assertEqual(5, len(list( self.db.test.find()[long(20):long(25)]))) for a, b in zip(count(20), self.db.test.find()[20:25]): self.assertEqual(a, b['i']) self.assertEqual(80, len(list(self.db.test.find()[40:45][20:]))) for a, b in zip(count(20), self.db.test.find()[40:45][20:]): self.assertEqual(a, b['i']) self.assertEqual(80, len(list(self.db.test.find()[40:45].limit(0).skip(20)) ) ) for a, b in zip(count(20), self.db.test.find()[40:45].limit(0).skip(20)): self.assertEqual(a, b['i']) self.assertEqual(80, len(list(self.db.test.find().limit(10).skip(40)[20:])) ) for a, b in zip(count(20), self.db.test.find().limit(10).skip(40)[20:]): self.assertEqual(a, b['i']) self.assertEqual(1, len(list(self.db.test.find()[:1]))) self.assertEqual(5, len(list(self.db.test.find()[:5]))) self.assertEqual(1, len(list(self.db.test.find()[99:100]))) self.assertEqual(1, len(list(self.db.test.find()[99:1000]))) self.assertEqual(0, len(list(self.db.test.find()[10:10]))) self.assertEqual(0, len(list(self.db.test.find()[:0]))) self.assertEqual(80, len(list(self.db.test.find()[10:10].limit(0).skip(20)) ) ) self.assertRaises(IndexError, lambda: self.db.test.find()[10:8]) def test_getitem_numeric_index(self): self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) self.assertEqual(0, self.db.test.find()[0]['i']) self.assertEqual(50, self.db.test.find()[50]['i']) self.assertEqual(50, self.db.test.find().skip(50)[0]['i']) self.assertEqual(50, self.db.test.find().skip(49)[1]['i']) self.assertEqual(50, self.db.test.find()[long(50)]['i']) self.assertEqual(99, self.db.test.find()[99]['i']) self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1) self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100) self.assertRaises(IndexError, lambda x: self.db.test.find().skip(50)[x], 50) @ignore_deprecations def test_count_with_limit_and_skip(self): self.assertRaises(TypeError, self.db.test.find().count, "foo") def check_len(cursor, length): self.assertEqual(len(list(cursor)), cursor.count(True)) self.assertEqual(length, cursor.count(True)) self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) check_len(self.db.test.find(), 100) check_len(self.db.test.find().limit(10), 10) check_len(self.db.test.find().limit(110), 100) check_len(self.db.test.find().skip(10), 90) check_len(self.db.test.find().skip(110), 0) check_len(self.db.test.find().limit(10).skip(10), 10) check_len(self.db.test.find()[10:20], 10) check_len(self.db.test.find().limit(10).skip(95), 5) check_len(self.db.test.find()[95:105], 5) def test_len(self): self.assertRaises(TypeError, len, self.db.test.find()) def test_properties(self): self.assertEqual(self.db.test, self.db.test.find().collection) def set_coll(): self.db.test.find().collection = "hello" self.assertRaises(AttributeError, set_coll) def test_get_more(self): db = self.db db.drop_collection("test") db.test.insert_many([{'i': i} for i in range(10)]) self.assertEqual(10, len(list(db.test.find().batch_size(5)))) def test_tailable(self): db = self.db db.drop_collection("test") db.create_collection("test", capped=True, size=1000, max=3) self.addCleanup(db.drop_collection, "test") cursor = db.test.find(cursor_type=CursorType.TAILABLE) db.test.insert_one({"x": 1}) count = 0 for doc in cursor: count += 1 self.assertEqual(1, doc["x"]) self.assertEqual(1, count) db.test.insert_one({"x": 2}) count = 0 for doc in cursor: count += 1 self.assertEqual(2, doc["x"]) self.assertEqual(1, count) db.test.insert_one({"x": 3}) count = 0 for doc in cursor: count += 1 self.assertEqual(3, doc["x"]) self.assertEqual(1, count) # Capped rollover - the collection can never # have more than 3 documents. Just make sure # this doesn't raise... db.test.insert_many([{"x": i} for i in range(4, 7)]) self.assertEqual(0, len(list(cursor))) # and that the cursor doesn't think it's still alive. self.assertFalse(cursor.alive) self.assertEqual(3, db.test.count_documents({})) # __getitem__(index) for cursor in (db.test.find(cursor_type=CursorType.TAILABLE), db.test.find(cursor_type=CursorType.TAILABLE_AWAIT)): self.assertEqual(4, cursor[0]["x"]) self.assertEqual(5, cursor[1]["x"]) self.assertEqual(6, cursor[2]["x"]) cursor.rewind() self.assertEqual([4], [doc["x"] for doc in cursor[0:1]]) cursor.rewind() self.assertEqual([5], [doc["x"] for doc in cursor[1:2]]) cursor.rewind() self.assertEqual([6], [doc["x"] for doc in cursor[2:3]]) cursor.rewind() self.assertEqual([4, 5], [doc["x"] for doc in cursor[0:2]]) cursor.rewind() self.assertEqual([5, 6], [doc["x"] for doc in cursor[1:3]]) cursor.rewind() self.assertEqual([4, 5, 6], [doc["x"] for doc in cursor[0:3]]) def test_concurrent_close(self): """Ensure a tailable can be closed from another thread.""" db = self.db db.drop_collection("test") db.create_collection("test", capped=True, size=1000, max=3) self.addCleanup(db.drop_collection, "test") cursor = db.test.find(cursor_type=CursorType.TAILABLE) def iterate_cursor(): while cursor.alive: for doc in cursor: pass t = threading.Thread(target=iterate_cursor) t.start() time.sleep(1) cursor.close() self.assertFalse(cursor.alive) t.join(3) self.assertFalse(t.is_alive()) def test_distinct(self): self.db.drop_collection("test") self.db.test.insert_many( [{"a": 1}, {"a": 2}, {"a": 2}, {"a": 2}, {"a": 3}]) distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a") distinct.sort() self.assertEqual([1, 2], distinct) self.db.drop_collection("test") self.db.test.insert_one({"a": {"b": "a"}, "c": 12}) self.db.test.insert_one({"a": {"b": "b"}, "c": 8}) self.db.test.insert_one({"a": {"b": "c"}, "c": 12}) self.db.test.insert_one({"a": {"b": "c"}, "c": 8}) distinct = self.db.test.find({"c": 8}).distinct("a.b") distinct.sort() self.assertEqual(["b", "c"], distinct) @client_context.require_version_max(4, 1, 0, -1) def test_max_scan(self): self.db.drop_collection("test") self.db.test.insert_many([{} for _ in range(100)]) self.assertEqual(100, len(list(self.db.test.find()))) self.assertEqual(50, len(list(self.db.test.find().max_scan(50)))) self.assertEqual(50, len(list(self.db.test.find() .max_scan(90).max_scan(50)))) def test_with_statement(self): self.db.drop_collection("test") self.db.test.insert_many([{} for _ in range(100)]) c1 = self.db.test.find() with self.db.test.find() as c2: self.assertTrue(c2.alive) self.assertFalse(c2.alive) with self.db.test.find() as c2: self.assertEqual(100, len(list(c2))) self.assertFalse(c2.alive) self.assertTrue(c1.alive) @client_context.require_no_mongos @ignore_deprecations def test_comment(self): # MongoDB 3.1.5 changed the ns for commands. regex = {'$regex': r'pymongo_test.(\$cmd|test)'} if client_context.version.at_least(3, 5, 8, -1): query_key = "command.comment" elif client_context.version.at_least(3, 1, 8, -1): query_key = "query.comment" else: query_key = "query.$comment" self.client.drop_database(self.db) self.db.set_profiling_level(ALL) try: list(self.db.test.find().comment('foo')) op = self.db.system.profile.find({'ns': 'pymongo_test.test', 'op': 'query', query_key: 'foo'}) self.assertEqual(op.count(), 1) self.db.test.find().comment('foo').count() op = self.db.system.profile.find({'ns': regex, 'op': 'command', 'command.count': 'test', 'command.comment': 'foo'}) self.assertEqual(op.count(), 1) self.db.test.find().comment('foo').distinct('type') op = self.db.system.profile.find({'ns': regex, 'op': 'command', 'command.distinct': 'test', 'command.comment': 'foo'}) self.assertEqual(op.count(), 1) finally: self.db.set_profiling_level(OFF) self.db.system.profile.drop() self.db.test.insert_many([{}, {}]) cursor = self.db.test.find() next(cursor) self.assertRaises(InvalidOperation, cursor.comment, 'hello') def test_modifiers(self): c = self.db.test # "modifiers" is deprecated. with ignore_deprecations(): cur = c.find() self.assertTrue('$query' not in cur._Cursor__query_spec()) cur = c.find().comment("testing").max_time_ms(500) self.assertTrue('$query' in cur._Cursor__query_spec()) self.assertEqual(cur._Cursor__query_spec()["$comment"], "testing") self.assertEqual(cur._Cursor__query_spec()["$maxTimeMS"], 500) cur = c.find( modifiers={"$maxTimeMS": 500, "$comment": "testing"}) self.assertTrue('$query' in cur._Cursor__query_spec()) self.assertEqual(cur._Cursor__query_spec()["$comment"], "testing") self.assertEqual(cur._Cursor__query_spec()["$maxTimeMS"], 500) # Keyword arg overwrites modifier. # If we remove the "modifiers" arg, delete this test after checking # that TestCommandMonitoring.test_find_options covers all cases. cur = c.find(comment="hi", modifiers={"$comment": "bye"}) self.assertEqual(cur._Cursor__query_spec()["$comment"], "hi") cur = c.find(max_scan=1, modifiers={"$maxScan": 2}) self.assertEqual(cur._Cursor__query_spec()["$maxScan"], 1) cur = c.find(max_time_ms=1, modifiers={"$maxTimeMS": 2}) self.assertEqual(cur._Cursor__query_spec()["$maxTimeMS"], 1) cur = c.find(min=1, modifiers={"$min": 2}) self.assertEqual(cur._Cursor__query_spec()["$min"], 1) cur = c.find(max=1, modifiers={"$max": 2}) self.assertEqual(cur._Cursor__query_spec()["$max"], 1) cur = c.find(return_key=True, modifiers={"$returnKey": False}) self.assertEqual(cur._Cursor__query_spec()["$returnKey"], True) cur = c.find(hint=[("a", 1)], modifiers={"$hint": {"b": "1"}}) self.assertEqual(cur._Cursor__query_spec()["$hint"], {"a": 1}) # The arg is named show_record_id after the "find" command arg, the # modifier is named $showDiskLoc for the OP_QUERY modifier. It's # stored as $showDiskLoc then upgraded to showRecordId if we send a # "find" command. cur = c.find(show_record_id=True, modifiers={"$showDiskLoc": False}) self.assertEqual(cur._Cursor__query_spec()["$showDiskLoc"], True) if not client_context.version.at_least(3, 7, 3): cur = c.find(snapshot=True, modifiers={"$snapshot": False}) self.assertEqual(cur._Cursor__query_spec()["$snapshot"], True) def test_alive(self): self.db.test.delete_many({}) self.db.test.insert_many([{} for _ in range(3)]) self.addCleanup(self.db.test.delete_many, {}) cursor = self.db.test.find().batch_size(2) n = 0 while True: cursor.next() n += 1 if 3 == n: self.assertFalse(cursor.alive) break self.assertTrue(cursor.alive) def test_close_kills_cursor_synchronously(self): # Kill any cursors possibly queued up by previous tests. gc.collect() self.client._process_periodic_tasks() listener = WhiteListEventListener("killCursors") results = listener.results client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors # Add some test data. docs_inserted = 1000 coll.insert_many([{"i": i} for i in range(docs_inserted)]) results.clear() # Close a cursor while it's still open on the server. cursor = coll.find().batch_size(10) self.assertTrue(bool(next(cursor))) self.assertLess(cursor.retrieved, docs_inserted) cursor.close() def assertCursorKilled(): self.assertEqual(1, len(results["started"])) self.assertEqual("killCursors", results["started"][0].command_name) self.assertEqual(1, len(results["succeeded"])) self.assertEqual("killCursors", results["succeeded"][0].command_name) assertCursorKilled() results.clear() # Close a command cursor while it's still open on the server. cursor = coll.aggregate([], batchSize=10) self.assertTrue(bool(next(cursor))) cursor.close() # The cursor should be killed if it had a non-zero id. if cursor.cursor_id: assertCursorKilled() else: self.assertEqual(0, len(results["started"])) def test_delete_not_initialized(self): # Creating a cursor with invalid arguments will not run __init__ # but will still call __del__, eg test.find(invalidKwarg=1). cursor = Cursor.__new__(Cursor) # Skip calling __init__ cursor.__del__() # no error @client_context.require_version_min(3, 6) def test_getMore_does_not_send_readPreference(self): listener = WhiteListEventListener('find', 'getMore') client = rs_or_single_client( event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test coll.delete_many({}) coll.insert_many([{} for _ in range(5)]) self.addCleanup(coll.drop) list(coll.find(batch_size=3)) started = listener.results['started'] self.assertEqual(2, len(started)) self.assertEqual('find', started[0].command_name) self.assertIn('$readPreference', started[0].command) self.assertEqual('getMore', started[1].command_name) self.assertNotIn('$readPreference', started[1].command) class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): c = self.db.test c.drop() docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] c.insert_many(docs) batches = list(c.find_raw_batches().sort('_id')) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) def test_manipulate(self): c = self.db.test with self.assertRaises(InvalidOperation): c.find_raw_batches(manipulate=True) def test_explain(self): c = self.db.test c.insert_one({}) explanation = c.find_raw_batches().explain() self.assertIsInstance(explanation, dict) def test_clone(self): cursor = self.db.test.find_raw_batches() # Copy of a RawBatchCursor is also a RawBatchCursor, not a Cursor. self.assertIsInstance(next(cursor.clone()), bytes) self.assertIsInstance(next(copy.copy(cursor)), bytes) @client_context.require_no_mongos def test_exhaust(self): c = self.db.test c.drop() c.insert_many({'_id': i} for i in range(200)) result = b''.join(c.find_raw_batches(cursor_type=CursorType.EXHAUST)) self.assertEqual([{'_id': i} for i in range(200)], decode_all(result)) def test_server_error(self): with self.assertRaises(OperationFailure) as exc: next(self.db.test.find_raw_batches({'x': {'$bad': 1}})) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) def test_get_item(self): with self.assertRaises(InvalidOperation): self.db.test.find_raw_batches()[0] @client_context.require_version_min(3, 4) def test_collation(self): next(self.db.test.find_raw_batches(collation=Collation('en_US'))) @client_context.require_version_max(3, 2) def test_collation_error(self): with self.assertRaises(ConfigurationError): next(self.db.test.find_raw_batches(collation=Collation('en_US'))) @client_context.require_version_min(3, 2) def test_read_concern(self): c = self.db.get_collection("test", read_concern=ReadConcern("majority")) next(c.find_raw_batches()) @client_context.require_version_max(3, 1) def test_read_concern_error(self): c = self.db.get_collection("test", read_concern=ReadConcern("majority")) with self.assertRaises(ConfigurationError): next(c.find_raw_batches()) def test_monitoring(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{'_id': i} for i in range(10)]) listener.results.clear() cursor = c.find_raw_batches(batch_size=4) # First raw batch of 4 documents. next(cursor) started = listener.results['started'][0] succeeded = listener.results['succeeded'][0] self.assertEqual(0, len(listener.results['failed'])) self.assertEqual('find', started.command_name) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('find', succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") # The batch is a list of one raw bytes object. self.assertEqual(len(csr["firstBatch"]), 1) self.assertEqual(decode_all(csr["firstBatch"][0]), [{'_id': i} for i in range(0, 4)]) listener.results.clear() # Next raw batch of 4 documents. next(cursor) try: results = listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertEqual('getMore', started.command_name) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('getMore', succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(len(csr["nextBatch"]), 1) self.assertEqual(decode_all(csr["nextBatch"][0]), [{'_id': i} for i in range(4, 8)]) finally: # Finish the cursor. tuple(cursor) class TestRawBatchCommandCursor(IntegrationTest): @classmethod def setUpClass(cls): super(TestRawBatchCommandCursor, cls).setUpClass() def test_aggregate_raw(self): c = self.db.test c.drop() docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] c.insert_many(docs) batches = list(c.aggregate_raw_batches([{'$sort': {'_id': 1}}])) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) def test_server_error(self): c = self.db.test c.drop() docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] c.insert_many(docs) c.insert_one({'_id': 10, 'x': 'not a number'}) with self.assertRaises(OperationFailure) as exc: list(self.db.test.aggregate_raw_batches([{ '$sort': {'_id': 1}, }, { '$project': {'x': {'$multiply': [2, '$x']}} }], batchSize=4)) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) def test_get_item(self): with self.assertRaises(InvalidOperation): self.db.test.aggregate_raw_batches([])[0] @client_context.require_version_min(3, 4) def test_collation(self): next(self.db.test.aggregate_raw_batches([], collation=Collation('en_US'))) @client_context.require_version_max(3, 2) def test_collation_error(self): with self.assertRaises(ConfigurationError): next(self.db.test.aggregate_raw_batches([], collation=Collation('en_US'))) def test_monitoring(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{'_id': i} for i in range(10)]) listener.results.clear() cursor = c.aggregate_raw_batches([{'$sort': {'_id': 1}}], batchSize=4) # Start cursor, no initial batch. started = listener.results['started'][0] succeeded = listener.results['succeeded'][0] self.assertEqual(0, len(listener.results['failed'])) self.assertEqual('aggregate', started.command_name) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('aggregate', succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") # First batch is empty. self.assertEqual(len(csr["firstBatch"]), 0) listener.results.clear() # Batches of 4 documents. n = 0 for batch in cursor: results = listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertEqual('getMore', started.command_name) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('getMore', succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(len(csr["nextBatch"]), 1) self.assertEqual(csr["nextBatch"][0], batch) self.assertEqual(decode_all(batch), [{'_id': i} for i in range(n, min(n + 4, 10))]) n += 4 listener.results.clear() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_cursor_manager.py000066400000000000000000000057341374256237000207050ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the cursor_manager module.""" import sys import warnings sys.path[0:0] = [""] from pymongo.cursor_manager import CursorManager from pymongo.errors import CursorNotFound from pymongo.message import _CursorAddress from test import (client_context, client_knobs, unittest, IntegrationTest, SkipTest) from test.utils import rs_or_single_client, wait_until class TestCursorManager(IntegrationTest): @classmethod def setUpClass(cls): super(TestCursorManager, cls).setUpClass() cls.warn_context = warnings.catch_warnings() cls.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) cls.collection = cls.db.test cls.collection.drop() # Ensure two batches. cls.collection.insert_many([{'_id': i} for i in range(200)]) @classmethod def tearDownClass(cls): cls.warn_context.__exit__() cls.warn_context = None cls.collection.drop() def test_cursor_manager_validation(self): with self.assertRaises(TypeError): client_context.client.set_cursor_manager(1) def test_cursor_manager(self): self.close_was_called = False test_case = self class CM(CursorManager): def __init__(self, client): super(CM, self).__init__(client) def close(self, cursor_id, address): test_case.close_was_called = True super(CM, self).close(cursor_id, address) with client_knobs(kill_cursor_frequency=0.01): client = rs_or_single_client(maxPoolSize=1) client.set_cursor_manager(CM) # Create a cursor on the same client so we're certain the getMore # is sent after the killCursors message. cursor = client.pymongo_test.test.find().batch_size(1) next(cursor) client.close_cursor( cursor.cursor_id, _CursorAddress(self.client.address, self.collection.full_name)) def raises_cursor_not_found(): try: next(cursor) return False except CursorNotFound: return True wait_until(raises_cursor_not_found, 'close cursor') self.assertTrue(self.close_was_called) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_custom_types.py000066400000000000000000001033571374256237000204340ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test support for callbacks to encode/decode custom types.""" import datetime import sys import tempfile from collections import OrderedDict from decimal import Decimal from random import random sys.path[0:0] = [""] from bson import (Decimal128, decode, decode_all, decode_file_iter, decode_iter, encode, RE_TYPE, _BUILT_IN_TYPES, _dict_to_bson, _bson_to_dict) from bson.code import Code from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder, TypeEncoder, TypeRegistry) from bson.errors import InvalidDocument from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument from bson.py3compat import text_type from gridfs import GridIn, GridOut from pymongo.collection import ReturnDocument from pymongo.errors import DuplicateKeyError from pymongo.message import _CursorAddress from test import client_context, unittest from test.test_client import IntegrationTest from test.utils import ignore_deprecations, rs_client class DecimalEncoder(TypeEncoder): @property def python_type(self): return Decimal def transform_python(self, value): return Decimal128(value) class DecimalDecoder(TypeDecoder): @property def bson_type(self): return Decimal128 def transform_bson(self, value): return value.to_decimal() class DecimalCodec(DecimalDecoder, DecimalEncoder): pass DECIMAL_CODECOPTS = CodecOptions( type_registry=TypeRegistry([DecimalCodec()])) class UndecipherableInt64Type(object): def __init__(self, value): self.value = value def __eq__(self, other): if isinstance(other, type(self)): return self.value == other.value # Does not compare equal to integers. return False class UndecipherableIntDecoder(TypeDecoder): bson_type = Int64 def transform_bson(self, value): return UndecipherableInt64Type(value) class UndecipherableIntEncoder(TypeEncoder): python_type = UndecipherableInt64Type def transform_python(self, value): return Int64(value.value) UNINT_DECODER_CODECOPTS = CodecOptions( type_registry=TypeRegistry([UndecipherableIntDecoder(), ])) UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry( [UndecipherableIntDecoder(), UndecipherableIntEncoder()])) class UppercaseTextDecoder(TypeDecoder): bson_type = text_type def transform_bson(self, value): return value.upper() UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry( [UppercaseTextDecoder(),])) def type_obfuscating_decoder_factory(rt_type): class ResumeTokenToNanDecoder(TypeDecoder): bson_type = rt_type def transform_bson(self, value): return "NaN" return ResumeTokenToNanDecoder class CustomBSONTypeTests(object): def roundtrip(self, doc): bsonbytes = encode(doc, codec_options=self.codecopts) rt_document = decode(bsonbytes, codec_options=self.codecopts) self.assertEqual(doc, rt_document) def test_encode_decode_roundtrip(self): self.roundtrip({'average': Decimal('56.47')}) self.roundtrip({'average': {'b': Decimal('56.47')}}) self.roundtrip({'average': [Decimal('56.47')]}) self.roundtrip({'average': [[Decimal('56.47')]]}) self.roundtrip({'average': [{'b': Decimal('56.47')}]}) def test_decode_all(self): documents = [] for dec in range(3): documents.append({'average': Decimal('56.4%s' % (dec,))}) bsonstream = bytes() for doc in documents: bsonstream += encode(doc, codec_options=self.codecopts) self.assertEqual( decode_all(bsonstream, self.codecopts), documents) def test__bson_to_dict(self): document = {'average': Decimal('56.47')} rawbytes = encode(document, codec_options=self.codecopts) decoded_document = _bson_to_dict(rawbytes, self.codecopts) self.assertEqual(document, decoded_document) def test__dict_to_bson(self): document = {'average': Decimal('56.47')} rawbytes = encode(document, codec_options=self.codecopts) encoded_document = _dict_to_bson(document, False, self.codecopts) self.assertEqual(encoded_document, rawbytes) def _generate_multidocument_bson_stream(self): inp_num = [str(random() * 100)[:4] for _ in range(10)] docs = [{'n': Decimal128(dec)} for dec in inp_num] edocs = [{'n': Decimal(dec)} for dec in inp_num] bsonstream = b"" for doc in docs: bsonstream += encode(doc) return edocs, bsonstream def test_decode_iter(self): expected, bson_data = self._generate_multidocument_bson_stream() for expected_doc, decoded_doc in zip( expected, decode_iter(bson_data, self.codecopts)): self.assertEqual(expected_doc, decoded_doc) def test_decode_file_iter(self): expected, bson_data = self._generate_multidocument_bson_stream() fileobj = tempfile.TemporaryFile() fileobj.write(bson_data) fileobj.seek(0) for expected_doc, decoded_doc in zip( expected, decode_file_iter(fileobj, self.codecopts)): self.assertEqual(expected_doc, decoded_doc) fileobj.close() class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): cls.codecopts = DECIMAL_CODECOPTS class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): codec_options = CodecOptions( type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder()))) cls.codecopts = codec_options class TestBSONFallbackEncoder(unittest.TestCase): def _get_codec_options(self, fallback_encoder): type_registry = TypeRegistry(fallback_encoder=fallback_encoder) return CodecOptions(type_registry=type_registry) def test_simple(self): codecopts = self._get_codec_options(lambda x: Decimal128(x)) document = {'average': Decimal('56.47')} bsonbytes = encode(document, codec_options=codecopts) exp_document = {'average': Decimal128('56.47')} exp_bsonbytes = encode(exp_document) self.assertEqual(bsonbytes, exp_bsonbytes) def test_erroring_fallback_encoder(self): codecopts = self._get_codec_options(lambda _: 1/0) # fallback converter should not be invoked when encoding known types. encode( {'a': 1, 'b': Decimal128('1.01'), 'c': {'arr': ['abc', 3.678]}}, codec_options=codecopts) # expect an error when encoding a custom type. document = {'average': Decimal('56.47')} with self.assertRaises(ZeroDivisionError): encode(document, codec_options=codecopts) def test_noop_fallback_encoder(self): codecopts = self._get_codec_options(lambda x: x) document = {'average': Decimal('56.47')} with self.assertRaises(InvalidDocument): encode(document, codec_options=codecopts) def test_type_unencodable_by_fallback_encoder(self): def fallback_encoder(value): try: return Decimal128(value) except: raise TypeError("cannot encode type %s" % (type(value))) codecopts = self._get_codec_options(fallback_encoder) document = {'average': Decimal} with self.assertRaises(TypeError): encode(document, codec_options=codecopts) class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): msg = "Can't instantiate abstract class .* with abstract methods .*" def run_test(base, attrs, fail): codec = type('testcodec', (base,), attrs) if fail: with self.assertRaisesRegex(TypeError, msg): codec() else: codec() class MyType(object): pass run_test(TypeEncoder, {'python_type': MyType,}, fail=True) run_test(TypeEncoder, {'transform_python': lambda s, x: x}, fail=True) run_test(TypeEncoder, {'transform_python': lambda s, x: x, 'python_type': MyType}, fail=False) run_test(TypeDecoder, {'bson_type': Decimal128, }, fail=True) run_test(TypeDecoder, {'transform_bson': lambda s, x: x}, fail=True) run_test(TypeDecoder, {'transform_bson': lambda s, x: x, 'bson_type': Decimal128}, fail=False) run_test(TypeCodec, {'bson_type': Decimal128, 'python_type': MyType}, fail=True) run_test(TypeCodec, {'transform_bson': lambda s, x: x, 'transform_python': lambda s, x: x}, fail=True) run_test(TypeCodec, {'python_type': MyType, 'transform_python': lambda s, x: x, 'transform_bson': lambda s, x: x, 'bson_type': Decimal128}, fail=False) def test_type_checks(self): self.assertTrue(issubclass(TypeCodec, TypeEncoder)) self.assertTrue(issubclass(TypeCodec, TypeDecoder)) self.assertFalse(issubclass(TypeDecoder, TypeEncoder)) self.assertFalse(issubclass(TypeEncoder, TypeDecoder)) class TestBSONCustomTypeEncoderAndFallbackEncoderTandem(unittest.TestCase): @classmethod def setUpClass(cls): class TypeA(object): def __init__(self, x): self.value = x class TypeB(object): def __init__(self, x): self.value = x # transforms A, and only A into B def fallback_encoder_A2B(value): assert isinstance(value, TypeA) return TypeB(value.value) # transforms A, and only A into something encodable def fallback_encoder_A2BSON(value): assert isinstance(value, TypeA) return value.value # transforms B into something encodable class B2BSON(TypeEncoder): python_type = TypeB def transform_python(self, value): return value.value # transforms A into B # technically, this isn't a proper type encoder as the output is not # BSON-encodable. class A2B(TypeEncoder): python_type = TypeA def transform_python(self, value): return TypeB(value.value) # transforms B into A # technically, this isn't a proper type encoder as the output is not # BSON-encodable. class B2A(TypeEncoder): python_type = TypeB def transform_python(self, value): return TypeA(value.value) cls.TypeA = TypeA cls.TypeB = TypeB cls.fallback_encoder_A2B = staticmethod(fallback_encoder_A2B) cls.fallback_encoder_A2BSON = staticmethod(fallback_encoder_A2BSON) cls.B2BSON = B2BSON cls.B2A = B2A cls.A2B = A2B def test_encode_fallback_then_custom(self): codecopts = CodecOptions(type_registry=TypeRegistry( [self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B)) testdoc = {'x': self.TypeA(123)} expected_bytes = encode({'x': 123}) self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) def test_encode_custom_then_fallback(self): codecopts = CodecOptions(type_registry=TypeRegistry( [self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON)) testdoc = {'x': self.TypeB(123)} expected_bytes = encode({'x': 123}) self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) def test_chaining_encoders_fails(self): codecopts = CodecOptions(type_registry=TypeRegistry( [self.A2B(), self.B2BSON()])) with self.assertRaises(InvalidDocument): encode({'x': self.TypeA(123)}, codec_options=codecopts) def test_infinite_loop_exceeds_max_recursion_depth(self): codecopts = CodecOptions(type_registry=TypeRegistry( [self.B2A()], fallback_encoder=self.fallback_encoder_A2B)) # Raises max recursion depth exceeded error with self.assertRaises(RuntimeError): encode({'x': self.TypeA(100)}, codec_options=codecopts) class TestTypeRegistry(unittest.TestCase): @classmethod def setUpClass(cls): class MyIntType(object): def __init__(self, x): assert isinstance(x, int) self.x = x class MyStrType(object): def __init__(self, x): assert isinstance(x, str) self.x = x class MyIntCodec(TypeCodec): @property def python_type(self): return MyIntType @property def bson_type(self): return int def transform_python(self, value): return value.x def transform_bson(self, value): return MyIntType(value) class MyStrCodec(TypeCodec): @property def python_type(self): return MyStrType @property def bson_type(self): return str def transform_python(self, value): return value.x def transform_bson(self, value): return MyStrType(value) def fallback_encoder(value): return value cls.types = (MyIntType, MyStrType) cls.codecs = (MyIntCodec, MyStrCodec) cls.fallback_encoder = fallback_encoder def test_simple(self): codec_instances = [codec() for codec in self.codecs] def assert_proper_initialization(type_registry, codec_instances): self.assertEqual(type_registry._encoder_map, { self.types[0]: codec_instances[0].transform_python, self.types[1]: codec_instances[1].transform_python}) self.assertEqual(type_registry._decoder_map, { int: codec_instances[0].transform_bson, str: codec_instances[1].transform_bson}) self.assertEqual( type_registry._fallback_encoder, self.fallback_encoder) type_registry = TypeRegistry(codec_instances, self.fallback_encoder) assert_proper_initialization(type_registry, codec_instances) type_registry = TypeRegistry( fallback_encoder=self.fallback_encoder, type_codecs=codec_instances) assert_proper_initialization(type_registry, codec_instances) # Ensure codec list held by the type registry doesn't change if we # mutate the initial list. codec_instances_copy = list(codec_instances) codec_instances.pop(0) self.assertListEqual( type_registry._TypeRegistry__type_codecs, codec_instances_copy) def test_simple_separate_codecs(self): class MyIntEncoder(TypeEncoder): python_type = self.types[0] def transform_python(self, value): return value.x class MyIntDecoder(TypeDecoder): bson_type = int def transform_bson(self, value): return self.types[0](value) codec_instances = [MyIntDecoder(), MyIntEncoder()] type_registry = TypeRegistry(codec_instances) self.assertEqual( type_registry._encoder_map, {MyIntEncoder.python_type: codec_instances[1].transform_python}) self.assertEqual( type_registry._decoder_map, {MyIntDecoder.bson_type: codec_instances[0].transform_bson}) def test_initialize_fail(self): err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, " "or TypeCodec, got .* instead") with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry(self.codecs) with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([type('AnyType', (object,), {})()]) err_msg = "fallback_encoder %r is not a callable" % (True,) with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([], True) err_msg = "fallback_encoder %r is not a callable" % ('hello',) with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry(fallback_encoder='hello') def test_type_registry_repr(self): codec_instances = [codec() for codec in self.codecs] type_registry = TypeRegistry(codec_instances) r = ("TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % ( codec_instances, None)) self.assertEqual(r, repr(type_registry)) def test_type_registry_eq(self): codec_instances = [codec() for codec in self.codecs] self.assertEqual( TypeRegistry(codec_instances), TypeRegistry(codec_instances)) codec_instances_2 = [codec() for codec in self.codecs] self.assertNotEqual( TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) def test_builtin_types_override_fails(self): def run_test(base, attrs): msg = (r"TypeEncoders cannot change how built-in types " r"are encoded \(encoder .* transforms type .*\)") for pytype in _BUILT_IN_TYPES: attrs.update({'python_type': pytype, 'transform_python': lambda x: x}) codec = type('testcodec', (base, ), attrs) codec_instance = codec() with self.assertRaisesRegex(TypeError, msg): TypeRegistry([codec_instance,]) # Test only some subtypes as not all can be subclassed. if pytype in [bool, type(None), RE_TYPE,]: continue class MyType(pytype): pass attrs.update({'python_type': MyType, 'transform_python': lambda x: x}) codec = type('testcodec', (base, ), attrs) codec_instance = codec() with self.assertRaisesRegex(TypeError, msg): TypeRegistry([codec_instance,]) run_test(TypeEncoder, {}) run_test(TypeCodec, {'bson_type': Decimal128, 'transform_bson': lambda x: x}) class TestCollectionWCustomType(IntegrationTest): def setUp(self): self.db.test.drop() def tearDown(self): self.db.test.drop() def test_command_errors_w_custom_type_decoder(self): db = self.db test_doc = {'_id': 1, 'data': 'a'} test = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) result = test.insert_one(test_doc) self.assertEqual(result.inserted_id, test_doc['_id']) with self.assertRaises(DuplicateKeyError): test.insert_one(test_doc) def test_find_w_custom_type_decoder(self): db = self.db input_docs = [ {'x': Int64(k)} for k in [1, 2, 3]] for doc in input_docs: db.test.insert_one(doc) test = db.get_collection( 'test', codec_options=UNINT_DECODER_CODECOPTS) for doc in test.find({}, batch_size=1): self.assertIsInstance(doc['x'], UndecipherableInt64Type) def test_find_w_custom_type_decoder_and_document_class(self): def run_test(doc_cls): db = self.db input_docs = [ {'x': Int64(k)} for k in [1, 2, 3]] for doc in input_docs: db.test.insert_one(doc) test = db.get_collection('test', codec_options=CodecOptions( type_registry=TypeRegistry([UndecipherableIntDecoder()]), document_class=doc_cls)) for doc in test.find({}, batch_size=1): self.assertIsInstance(doc, doc_cls) self.assertIsInstance(doc['x'], UndecipherableInt64Type) for doc_cls in [RawBSONDocument, OrderedDict]: run_test(doc_cls) @client_context.require_version_max(4, 1, 0, -1) def test_group_w_custom_type(self): db = self.db test = db.get_collection('test', codec_options=UNINT_CODECOPTS) test.insert_many([ {'sku': 'a', 'qty': UndecipherableInt64Type(2)}, {'sku': 'b', 'qty': UndecipherableInt64Type(5)}, {'sku': 'a', 'qty': UndecipherableInt64Type(1)}]) self.assertEqual([{'sku': 'b', 'qty': UndecipherableInt64Type(5)},], test.group(["sku", "qty"], {"sku": "b"}, {}, "function (obj, prev) { }")) def test_aggregate_w_custom_type_decoder(self): db = self.db db.test.insert_many([ {'status': 'in progress', 'qty': Int64(1)}, {'status': 'complete', 'qty': Int64(10)}, {'status': 'in progress', 'qty': Int64(1)}, {'status': 'complete', 'qty': Int64(10)}, {'status': 'in progress', 'qty': Int64(1)},]) test = db.get_collection( 'test', codec_options=UNINT_DECODER_CODECOPTS) pipeline = [ {'$match': {'status': 'complete'}}, {'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},] result = test.aggregate(pipeline) res = list(result)[0] self.assertEqual(res['_id'], 'complete') self.assertIsInstance(res['total_qty'], UndecipherableInt64Type) self.assertEqual(res['total_qty'].value, 20) def test_distinct_w_custom_type(self): self.db.drop_collection("test") test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS) values = [ UndecipherableInt64Type(1), UndecipherableInt64Type(2), UndecipherableInt64Type(3), {"b": UndecipherableInt64Type(3)}] test.insert_many({"a": val} for val in values) self.assertEqual(values, test.distinct("a")) def test_map_reduce_w_custom_type(self): test = self.db.get_collection( 'test', codec_options=UPPERSTR_DECODER_CODECOPTS) test.insert_many([ {'_id': 1, 'sku': 'abcd', 'qty': 1}, {'_id': 2, 'sku': 'abcd', 'qty': 2}, {'_id': 3, 'sku': 'abcd', 'qty': 3}]) map = Code("function () {" " emit(this.sku, this.qty);" "}") reduce = Code("function (key, values) {" " return Array.sum(values);" "}") result = test.map_reduce(map, reduce, out={'inline': 1}) self.assertTrue(isinstance(result, dict)) self.assertTrue('results' in result) self.assertEqual(result['results'][0], {'_id': 'ABCD', 'value': 6}) result = test.inline_map_reduce(map, reduce) self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) self.assertEqual(result[0]["_id"], 'ABCD') full_result = test.inline_map_reduce(map, reduce, full_response=True) result = full_result['results'] self.assertTrue(isinstance(result, list)) self.assertEqual(1, len(result)) self.assertEqual(result[0]["_id"], 'ABCD') def test_find_one_and__w_custom_type_decoder(self): db = self.db c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) c.insert_one({'_id': 1, 'x': Int64(1)}) doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}}, return_document=ReturnDocument.AFTER) self.assertEqual(doc['_id'], 1) self.assertIsInstance(doc['x'], UndecipherableInt64Type) self.assertEqual(doc['x'].value, 2) doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True}, return_document=ReturnDocument.AFTER) self.assertEqual(doc['_id'], 1) self.assertIsInstance(doc['x'], UndecipherableInt64Type) self.assertEqual(doc['x'].value, 3) self.assertEqual(doc['y'], True) doc = c.find_one_and_delete({'y': True}) self.assertEqual(doc['_id'], 1) self.assertIsInstance(doc['x'], UndecipherableInt64Type) self.assertEqual(doc['x'].value, 3) self.assertIsNone(c.find_one()) @ignore_deprecations def test_find_and_modify_w_custom_type_decoder(self): db = self.db c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) c.insert_one({'_id': 1, 'x': Int64(1)}) doc = c.find_and_modify({'_id': 1}, {'$inc': {'x': Int64(10)}}) self.assertEqual(doc['_id'], 1) self.assertIsInstance(doc['x'], UndecipherableInt64Type) self.assertEqual(doc['x'].value, 1) doc = c.find_one() self.assertEqual(doc['_id'], 1) self.assertIsInstance(doc['x'], UndecipherableInt64Type) self.assertEqual(doc['x'].value, 11) class TestGridFileCustomType(IntegrationTest): def setUp(self): self.db.drop_collection('fs.files') self.db.drop_collection('fs.chunks') def test_grid_out_custom_opts(self): db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS) one = GridIn(db.fs, _id=5, filename="my_file", contentType="text/html", chunkSize=1000, aliases=["foo"], metadata={"foo": 'red', "bar": 'blue'}, bar=3, baz="hello") one.write(b"hello world") one.close() two = GridOut(db.fs, 5) self.assertEqual("my_file", two.name) self.assertEqual("my_file", two.filename) self.assertEqual(5, two._id) self.assertEqual(11, two.length) self.assertEqual("text/html", two.content_type) self.assertEqual(1000, two.chunk_size) self.assertTrue(isinstance(two.upload_date, datetime.datetime)) self.assertEqual(["foo"], two.aliases) self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata) self.assertEqual(3, two.bar) self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5) for attr in ["_id", "name", "content_type", "length", "chunk_size", "upload_date", "aliases", "metadata", "md5"]: self.assertRaises(AttributeError, setattr, two, attr, 5) class ChangeStreamsWCustomTypesTestMixin(object): def change_stream(self, *args, **kwargs): return self.watched_target.watch(*args, **kwargs) def insert_and_check(self, change_stream, insert_doc, expected_doc): self.input_target.insert_one(insert_doc) change = next(change_stream) self.assertEqual(change['fullDocument'], expected_doc) def kill_change_stream_cursor(self, change_stream): # Cause a cursor not found error on the next getMore. cursor = change_stream._cursor address = _CursorAddress(cursor.address, cursor._CommandCursor__ns) client = self.input_target.database.client client._close_cursor_now(cursor.cursor_id, address) def test_simple(self): codecopts = CodecOptions(type_registry=TypeRegistry([ UndecipherableIntEncoder(), UppercaseTextDecoder()])) self.create_targets(codec_options=codecopts) input_docs = [ {'_id': UndecipherableInt64Type(1), 'data': 'hello'}, {'_id': 2, 'data': 'world'}, {'_id': UndecipherableInt64Type(3), 'data': '!'},] expected_docs = [ {'_id': 1, 'data': 'HELLO'}, {'_id': 2, 'data': 'WORLD'}, {'_id': 3, 'data': '!'},] change_stream = self.change_stream() self.insert_and_check(change_stream, input_docs[0], expected_docs[0]) self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, input_docs[1], expected_docs[1]) self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) def test_custom_type_in_pipeline(self): codecopts = CodecOptions(type_registry=TypeRegistry([ UndecipherableIntEncoder(), UppercaseTextDecoder()])) self.create_targets(codec_options=codecopts) input_docs = [ {'_id': UndecipherableInt64Type(1), 'data': 'hello'}, {'_id': 2, 'data': 'world'}, {'_id': UndecipherableInt64Type(3), 'data': '!'}] expected_docs = [ {'_id': 2, 'data': 'WORLD'}, {'_id': 3, 'data': '!'}] # UndecipherableInt64Type should be encoded with the TypeRegistry. change_stream = self.change_stream( [{'$match': {'documentKey._id': { '$gte': UndecipherableInt64Type(2)}}}]) self.input_target.insert_one(input_docs[0]) self.insert_and_check(change_stream, input_docs[1], expected_docs[0]) self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, input_docs[2], expected_docs[1]) def test_break_resume_token(self): # Get one document from a change stream to determine resumeToken type. self.create_targets() change_stream = self.change_stream() self.input_target.insert_one({"data": "test"}) change = next(change_stream) resume_token_decoder = type_obfuscating_decoder_factory( type(change['_id']['_data'])) # Custom-decoding the resumeToken type breaks resume tokens. codecopts = CodecOptions(type_registry=TypeRegistry([ resume_token_decoder(), UndecipherableIntEncoder()])) # Re-create targets, change stream and proceed. self.create_targets(codec_options=codecopts) docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}] change_stream = self.change_stream() self.insert_and_check(change_stream, docs[0], docs[0]) self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, docs[1], docs[1]) self.kill_change_stream_cursor(change_stream) self.insert_and_check(change_stream, docs[2], docs[2]) def test_document_class(self): def run_test(doc_cls): codecopts = CodecOptions(type_registry=TypeRegistry([ UppercaseTextDecoder(), UndecipherableIntEncoder()]), document_class=doc_cls) self.create_targets(codec_options=codecopts) change_stream = self.change_stream() doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'} self.input_target.insert_one(doc) change = next(change_stream) self.assertIsInstance(change, doc_cls) self.assertEqual(change['fullDocument']['a'], 101) self.assertEqual(change['fullDocument']['b'], 'XYZ') for doc_cls in [OrderedDict, RawBSONDocument]: run_test(doc_cls) class TestCollectionChangeStreamsWCustomTypes( IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_version_min(3, 6, 0) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestCollectionChangeStreamsWCustomTypes, cls).setUpClass() def tearDown(self): self.input_target.drop() def create_targets(self, *args, **kwargs): self.watched_target = self.db.get_collection( 'test', *args, **kwargs) self.input_target = self.watched_target # Ensure the collection exists and is empty. self.input_target.insert_one({}) self.input_target.delete_many({}) class TestDatabaseChangeStreamsWCustomTypes( IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestDatabaseChangeStreamsWCustomTypes, cls).setUpClass() def tearDown(self): self.input_target.drop() self.client.drop_database(self.watched_target) def create_targets(self, *args, **kwargs): self.watched_target = self.client.get_database( self.db.name, *args, **kwargs) self.input_target = self.watched_target.test # Insert a record to ensure db, coll are created. self.input_target.insert_one({'data': 'dummy'}) class TestClusterChangeStreamsWCustomTypes( IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_no_mmap @client_context.require_no_standalone def setUpClass(cls): super(TestClusterChangeStreamsWCustomTypes, cls).setUpClass() def tearDown(self): self.input_target.drop() self.client.drop_database(self.db) def create_targets(self, *args, **kwargs): codec_options = kwargs.pop('codec_options', None) if codec_options: kwargs['type_registry'] = codec_options.type_registry kwargs['document_class'] = codec_options.document_class self.watched_target = rs_client(*args, **kwargs) self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. self.input_target.insert_one({'data': 'dummy'}) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_database.py000066400000000000000000001271531374256237000174420ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the database module.""" import datetime import re import sys import warnings sys.path[0:0] = [""] from bson.code import Code from bson.codec_options import CodecOptions from bson.int64 import Int64 from bson.regex import Regex from bson.dbref import DBRef from bson.objectid import ObjectId from bson.py3compat import string_type, text_type, PY3 from bson.son import SON from pymongo import (ALL, auth, OFF, SLOW_ONLY, helpers) from pymongo.collection import Collection from pymongo.database import Database from pymongo.errors import (CollectionInvalid, ConfigurationError, ExecutionTimeout, InvalidName, OperationFailure, WriteConcernError) from pymongo.mongo_client import MongoClient from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP from pymongo.write_concern import WriteConcern from test import (client_context, SkipTest, unittest, IntegrationTest) from test.utils import (EventListener, ignore_deprecations, remove_all_users, rs_or_single_client_noauth, rs_or_single_client, server_started_with_auth, wait_until, IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener) from test.test_custom_types import DECIMAL_CODECOPTS if PY3: long = int class TestDatabaseNoConnect(unittest.TestCase): """Test Database features on a client that does not connect. """ @classmethod def setUpClass(cls): cls.client = MongoClient(connect=False) def test_name(self): self.assertRaises(TypeError, Database, self.client, 4) self.assertRaises(InvalidName, Database, self.client, "my db") self.assertRaises(InvalidName, Database, self.client, 'my"db') self.assertRaises(InvalidName, Database, self.client, "my\x00db") self.assertRaises(InvalidName, Database, self.client, u"my\u0000db") self.assertEqual("name", Database(self.client, "name").name) def test_get_collection(self): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) read_concern = ReadConcern('majority') coll = self.client.pymongo_test.get_collection( 'foo', codec_options, ReadPreference.SECONDARY, write_concern, read_concern) self.assertEqual('foo', coll.name) self.assertEqual(codec_options, coll.codec_options) self.assertEqual(ReadPreference.SECONDARY, coll.read_preference) self.assertEqual(write_concern, coll.write_concern) self.assertEqual(read_concern, coll.read_concern) def test_getattr(self): db = self.client.pymongo_test self.assertTrue(isinstance(db['_does_not_exist'], Collection)) with self.assertRaises(AttributeError) as context: db._does_not_exist # Message should be: "AttributeError: Database has no attribute # '_does_not_exist'. To access the _does_not_exist collection, # use database['_does_not_exist']". self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) def test_iteration(self): self.assertRaises(TypeError, next, self.client.pymongo_test) class TestDatabase(IntegrationTest): def test_equality(self): self.assertNotEqual(Database(self.client, "test"), Database(self.client, "mike")) self.assertEqual(Database(self.client, "test"), Database(self.client, "test")) # Explicitly test inequality self.assertFalse(Database(self.client, "test") != Database(self.client, "test")) def test_get_coll(self): db = Database(self.client, "pymongo_test") self.assertEqual(db.test, db["test"]) self.assertEqual(db.test, Collection(db, "test")) self.assertNotEqual(db.test, Collection(db, "mike")) self.assertEqual(db.test.mike, db["test.mike"]) def test_repr(self): self.assertEqual(repr(Database(self.client, "pymongo_test")), "Database(%r, %s)" % (self.client, repr(u"pymongo_test"))) def test_create_collection(self): db = Database(self.client, "pymongo_test") db.test.insert_one({"hello": "world"}) self.assertRaises(CollectionInvalid, db.create_collection, "test") db.drop_collection("test") self.assertRaises(TypeError, db.create_collection, 5) self.assertRaises(TypeError, db.create_collection, None) self.assertRaises(InvalidName, db.create_collection, "coll..ection") test = db.create_collection("test") self.assertTrue(u"test" in db.list_collection_names()) test.insert_one({"hello": u"world"}) self.assertEqual(db.test.find_one()["hello"], "world") db.drop_collection("test.foo") db.create_collection("test.foo") self.assertTrue(u"test.foo" in db.list_collection_names()) self.assertRaises(CollectionInvalid, db.create_collection, "test.foo") def _test_collection_names(self, meth, **no_system_kwargs): db = Database(self.client, "pymongo_test") db.test.insert_one({"dummy": u"object"}) db.test.mike.insert_one({"dummy": u"object"}) colls = getattr(db, meth)() self.assertTrue("test" in colls) self.assertTrue("test.mike" in colls) for coll in colls: self.assertTrue("$" not in coll) db.systemcoll.test.insert_one({}) no_system_collections = getattr(db, meth)(**no_system_kwargs) for coll in no_system_collections: self.assertTrue(not coll.startswith("system.")) self.assertIn("systemcoll.test", no_system_collections) # Force more than one batch. db = self.client.many_collections for i in range(101): db["coll" + str(i)].insert_one({}) # No Error try: getattr(db, meth)() finally: self.client.drop_database("many_collections") def test_collection_names(self): self._test_collection_names( 'collection_names', include_system_collections=False) def test_list_collection_names(self): self._test_collection_names( 'list_collection_names', filter={ "name": {"$regex": r"^(?!system\.)"}}) def test_list_collection_names_filter(self): listener = OvertCommandListener() results = listener.results client = rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.capped.drop() db.create_collection("capped", capped=True, size=4096) db.capped.insert_one({}) db.non_capped.insert_one({}) self.addCleanup(client.drop_database, db.name) # Should not send nameOnly. for filter in ({'options.capped': True}, {'options.capped': True, 'name': 'capped'}): results.clear() names = db.list_collection_names(filter=filter) self.assertEqual(names, ["capped"]) self.assertNotIn("nameOnly", results["started"][0].command) # Should send nameOnly (except on 2.6). for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}): results.clear() names = db.list_collection_names(filter=filter) self.assertIn("capped", names) self.assertIn("non_capped", names) command = results["started"][0].command if client_context.version >= (3, 0): self.assertIn("nameOnly", command) self.assertTrue(command["nameOnly"]) else: self.assertNotIn("nameOnly", command) def test_list_collections(self): self.client.drop_database("pymongo_test") db = Database(self.client, "pymongo_test") db.test.insert_one({"dummy": u"object"}) db.test.mike.insert_one({"dummy": u"object"}) results = db.list_collections() colls = [result["name"] for result in results] # All the collections present. self.assertTrue("test" in colls) self.assertTrue("test.mike" in colls) # No collection containing a '$'. for coll in colls: self.assertTrue("$" not in coll) # Duplicate check. coll_cnt = {} for coll in colls: try: # Found duplicate. coll_cnt[coll] += 1 self.assertTrue(False) except KeyError: coll_cnt[coll] = 1 coll_cnt = {} # Checking if is there any collection which don't exists. if (len(set(colls) - set(["test","test.mike"])) == 0 or len(set(colls) - set(["test","test.mike","system.indexes"])) == 0): self.assertTrue(True) else: self.assertTrue(False) colls = db.list_collections(filter={"name": {"$regex": "^test$"}}) self.assertEqual(1, len(list(colls))) colls = db.list_collections(filter={"name": {"$regex": "^test.mike$"}}) self.assertEqual(1, len(list(colls))) db.drop_collection("test") db.create_collection("test", capped=True, size=4096) results = db.list_collections(filter={'options.capped': True}) colls = [result["name"] for result in results] # Checking only capped collections are present self.assertTrue("test" in colls) self.assertFalse("test.mike" in colls) # No collection containing a '$'. for coll in colls: self.assertTrue("$" not in coll) # Duplicate check. coll_cnt = {} for coll in colls: try: # Found duplicate. coll_cnt[coll] += 1 self.assertTrue(False) except KeyError: coll_cnt[coll] = 1 coll_cnt = {} # Checking if is there any collection which don't exists. if (len(set(colls) - set(["test"])) == 0 or len(set(colls) - set(["test","system.indexes"])) == 0): self.assertTrue(True) else: self.assertTrue(False) self.client.drop_database("pymongo_test") def test_collection_names_single_socket(self): # Test that Database.collection_names only requires one socket. client = rs_or_single_client(maxPoolSize=1) client.drop_database('test_collection_names_single_socket') db = client.test_collection_names_single_socket for i in range(200): db.create_collection(str(i)) db.list_collection_names() # Must not hang. client.drop_database('test_collection_names_single_socket') def test_drop_collection(self): db = Database(self.client, "pymongo_test") self.assertRaises(TypeError, db.drop_collection, 5) self.assertRaises(TypeError, db.drop_collection, None) db.test.insert_one({"dummy": u"object"}) self.assertTrue("test" in db.list_collection_names()) db.drop_collection("test") self.assertFalse("test" in db.list_collection_names()) db.test.insert_one({"dummy": u"object"}) self.assertTrue("test" in db.list_collection_names()) db.drop_collection(u"test") self.assertFalse("test" in db.list_collection_names()) db.test.insert_one({"dummy": u"object"}) self.assertTrue("test" in db.list_collection_names()) db.drop_collection(db.test) self.assertFalse("test" in db.list_collection_names()) db.test.insert_one({"dummy": u"object"}) self.assertTrue("test" in db.list_collection_names()) db.test.drop() self.assertFalse("test" in db.list_collection_names()) db.test.drop() db.drop_collection(db.test.doesnotexist) if client_context.version.at_least(3, 3, 9) and client_context.is_rs: db_wc = Database(self.client, 'pymongo_test', write_concern=IMPOSSIBLE_WRITE_CONCERN) with self.assertRaises(WriteConcernError): db_wc.drop_collection('test') def test_validate_collection(self): db = self.client.pymongo_test self.assertRaises(TypeError, db.validate_collection, 5) self.assertRaises(TypeError, db.validate_collection, None) db.test.insert_one({"dummy": u"object"}) self.assertRaises(OperationFailure, db.validate_collection, "test.doesnotexist") self.assertRaises(OperationFailure, db.validate_collection, db.test.doesnotexist) self.assertTrue(db.validate_collection("test")) self.assertTrue(db.validate_collection(db.test)) self.assertTrue(db.validate_collection(db.test, full=True)) self.assertTrue(db.validate_collection(db.test, scandata=True)) self.assertTrue(db.validate_collection(db.test, scandata=True, full=True)) self.assertTrue(db.validate_collection(db.test, True, True)) @client_context.require_version_min(4, 3, 3) def test_validate_collection_background(self): db = self.client.pymongo_test db.test.insert_one({"dummy": u"object"}) coll = db.test self.assertTrue(db.validate_collection(coll, background=False)) # The inMemory storage engine does not support background=True. if client_context.storage_engine != 'inMemory': self.assertTrue(db.validate_collection(coll, background=True)) self.assertTrue( db.validate_collection(coll, scandata=True, background=True)) # The server does not support background=True with full=True. # Assert that we actually send the background option by checking # that this combination fails. with self.assertRaises(OperationFailure): db.validate_collection(coll, full=True, background=True) @client_context.require_no_mongos def test_profiling_levels(self): db = self.client.pymongo_test self.assertEqual(db.profiling_level(), OFF) # default self.assertRaises(ValueError, db.set_profiling_level, 5.5) self.assertRaises(ValueError, db.set_profiling_level, None) self.assertRaises(ValueError, db.set_profiling_level, -1) self.assertRaises(TypeError, db.set_profiling_level, SLOW_ONLY, 5.5) self.assertRaises(TypeError, db.set_profiling_level, SLOW_ONLY, '1') db.set_profiling_level(SLOW_ONLY) self.assertEqual(db.profiling_level(), SLOW_ONLY) db.set_profiling_level(ALL) self.assertEqual(db.profiling_level(), ALL) db.set_profiling_level(OFF) self.assertEqual(db.profiling_level(), OFF) db.set_profiling_level(SLOW_ONLY, 50) self.assertEqual(50, db.command("profile", -1)['slowms']) db.set_profiling_level(ALL, -1) self.assertEqual(-1, db.command("profile", -1)['slowms']) db.set_profiling_level(OFF, 100) # back to default self.assertEqual(100, db.command("profile", -1)['slowms']) @client_context.require_no_mongos def test_profiling_info(self): db = self.client.pymongo_test db.system.profile.drop() db.set_profiling_level(ALL) db.test.find_one() db.set_profiling_level(OFF) info = db.profiling_info() self.assertTrue(isinstance(info, list)) # Check if we're going to fail because of SERVER-4754, in which # profiling info isn't collected if mongod was started with --auth if server_started_with_auth(self.client): raise SkipTest( "We need SERVER-4754 fixed for the rest of this test to pass" ) self.assertTrue(len(info) >= 1) # These basically clue us in to server changes. self.assertTrue(isinstance(info[0]['responseLength'], int)) self.assertTrue(isinstance(info[0]['millis'], int)) self.assertTrue(isinstance(info[0]['client'], string_type)) self.assertTrue(isinstance(info[0]['user'], string_type)) self.assertTrue(isinstance(info[0]['ns'], string_type)) self.assertTrue(isinstance(info[0]['op'], string_type)) self.assertTrue(isinstance(info[0]["ts"], datetime.datetime)) @client_context.require_no_mongos @ignore_deprecations def test_errors(self): # We must call getlasterror, etc. on same socket as last operation. db = rs_or_single_client(maxPoolSize=1).pymongo_test db.reset_error_history() self.assertEqual(None, db.error()) if client_context.supports_getpreverror: self.assertEqual(None, db.previous_error()) db.test.insert_one({"_id": 1}) unacked = db.test.with_options(write_concern=WriteConcern(w=0)) unacked.insert_one({"_id": 1}) self.assertTrue(db.error()) if client_context.supports_getpreverror: self.assertTrue(db.previous_error()) unacked.insert_one({"_id": 1}) self.assertTrue(db.error()) if client_context.supports_getpreverror: prev_error = db.previous_error() self.assertEqual(prev_error["nPrev"], 1) del prev_error["nPrev"] prev_error.pop("lastOp", None) error = db.error() error.pop("lastOp", None) # getLastError includes "connectionId" in recent # server versions, getPrevError does not. error.pop("connectionId", None) self.assertEqualReply(error, prev_error) db.test.find_one() self.assertEqual(None, db.error()) if client_context.supports_getpreverror: self.assertTrue(db.previous_error()) self.assertEqual(db.previous_error()["nPrev"], 2) db.reset_error_history() self.assertEqual(None, db.error()) if client_context.supports_getpreverror: self.assertEqual(None, db.previous_error()) def test_command(self): self.maxDiff = None db = self.client.admin first = db.command("buildinfo") second = db.command({"buildinfo": 1}) third = db.command("buildinfo", 1) self.assertEqualReply(first, second) self.assertEqualReply(second, third) # We use 'aggregate' as our example command, since it's an easy way to # retrieve a BSON regex from a collection using a command. But until # MongoDB 2.3.2, aggregation turned regexes into strings: SERVER-6470. # Note: MongoDB 3.5.2 requires the 'cursor' or 'explain' option for # aggregate. @client_context.require_version_max(3, 5, 0) def test_command_with_regex(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({'r': re.compile('.*')}) db.test.insert_one({'r': Regex('.*')}) result = db.command('aggregate', 'test', pipeline=[]) for doc in result['result']: self.assertTrue(isinstance(doc['r'], Regex)) def test_password_digest(self): self.assertRaises(TypeError, auth._password_digest, 5) self.assertRaises(TypeError, auth._password_digest, True) self.assertRaises(TypeError, auth._password_digest, None) self.assertTrue(isinstance(auth._password_digest("mike", "password"), text_type)) self.assertEqual(auth._password_digest("mike", "password"), u"cd7e45b3b2767dc2fa9b6b548457ed00") self.assertEqual(auth._password_digest("mike", "password"), auth._password_digest(u"mike", u"password")) self.assertEqual(auth._password_digest("Gustave", u"Dor\xe9"), u"81e0e2364499209f466e75926a162d73") @client_context.require_auth def test_authenticate_add_remove_user(self): # "self.client" is logged in as root. auth_db = self.client.pymongo_test def check_auth(username, password): c = rs_or_single_client_noauth( username=username, password=password, authSource="pymongo_test") c.pymongo_test.collection.find_one() # Configuration errors self.assertRaises(ValueError, auth_db.add_user, "user", '') self.assertRaises(TypeError, auth_db.add_user, "user", 'password', 15) self.assertRaises(TypeError, auth_db.add_user, "user", 'password', 'True') self.assertRaises(ConfigurationError, auth_db.add_user, "user", 'password', True, roles=['read']) with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) self.assertRaises(DeprecationWarning, auth_db.add_user, "user", "password") self.assertRaises(DeprecationWarning, auth_db.add_user, "user", "password", True) with ignore_deprecations(): self.assertRaises(ConfigurationError, auth_db.add_user, "user", "password", digestPassword=True) # Add / authenticate / remove auth_db.add_user("mike", "password", roles=["read"]) self.addCleanup(remove_all_users, auth_db) self.assertRaises(TypeError, check_auth, 5, "password") self.assertRaises(TypeError, check_auth, "mike", 5) self.assertRaises(OperationFailure, check_auth, "mike", "not a real password") self.assertRaises(OperationFailure, check_auth, "faker", "password") check_auth("mike", "password") if not client_context.version.at_least(3, 7, 2) or HAVE_STRINGPREP: # Unicode name and password. check_auth(u"mike", u"password") auth_db.remove_user("mike") self.assertRaises( OperationFailure, check_auth, "mike", "password") # Add / authenticate / change password self.assertRaises( OperationFailure, check_auth, "Gustave", u"Dor\xe9") auth_db.add_user("Gustave", u"Dor\xe9", roles=["read"]) check_auth("Gustave", u"Dor\xe9") # Change password. auth_db.add_user("Gustave", "password", roles=["read"]) self.assertRaises( OperationFailure, check_auth, "Gustave", u"Dor\xe9") check_auth("Gustave", u"password") @client_context.require_auth @ignore_deprecations def test_make_user_readonly(self): # "self.client" is logged in as root. auth_db = self.client.pymongo_test # Make a read-write user. auth_db.add_user('jesse', 'pw') self.addCleanup(remove_all_users, auth_db) # Check that we're read-write by default. c = rs_or_single_client_noauth(username='jesse', password='pw', authSource='pymongo_test') c.pymongo_test.collection.insert_one({}) # Make the user read-only. auth_db.add_user('jesse', 'pw', read_only=True) c = rs_or_single_client_noauth(username='jesse', password='pw', authSource='pymongo_test') self.assertRaises(OperationFailure, c.pymongo_test.collection.insert_one, {}) @client_context.require_auth @ignore_deprecations def test_default_roles(self): # "self.client" is logged in as root. auth_admin = self.client.admin auth_admin.add_user('test_default_roles', 'pass') self.addCleanup(client_context.drop_user, 'admin', 'test_default_roles') info = auth_admin.command( 'usersInfo', 'test_default_roles')['users'][0] self.assertEqual("root", info['roles'][0]['role']) # Read only "admin" user auth_admin.add_user('ro-admin', 'pass', read_only=True) self.addCleanup(client_context.drop_user, 'admin', 'ro-admin') info = auth_admin.command('usersInfo', 'ro-admin')['users'][0] self.assertEqual("readAnyDatabase", info['roles'][0]['role']) # "Non-admin" user auth_db = self.client.pymongo_test auth_db.add_user('user', 'pass') self.addCleanup(remove_all_users, auth_db) info = auth_db.command('usersInfo', 'user')['users'][0] self.assertEqual("dbOwner", info['roles'][0]['role']) # Read only "Non-admin" user auth_db.add_user('ro-user', 'pass', read_only=True) info = auth_db.command('usersInfo', 'ro-user')['users'][0] self.assertEqual("read", info['roles'][0]['role']) @client_context.require_auth @ignore_deprecations def test_new_user_cmds(self): # "self.client" is logged in as root. auth_db = self.client.pymongo_test auth_db.add_user("amalia", "password", roles=["userAdmin"]) self.addCleanup(client_context.drop_user, "pymongo_test", "amalia") db = rs_or_single_client_noauth(username="amalia", password="password", authSource="pymongo_test").pymongo_test # This tests the ability to update user attributes. db.add_user("amalia", "new_password", customData={"secret": "koalas"}) user_info = db.command("usersInfo", "amalia") self.assertTrue(user_info["users"]) amalia_user = user_info["users"][0] self.assertEqual(amalia_user["user"], "amalia") self.assertEqual(amalia_user["customData"], {"secret": "koalas"}) @client_context.require_auth @ignore_deprecations def test_authenticate_multiple(self): # "self.client" is logged in as root. self.client.drop_database("pymongo_test") self.client.drop_database("pymongo_test1") admin_db_auth = self.client.admin users_db_auth = self.client.pymongo_test admin_db_auth.add_user( 'ro-admin', 'pass', roles=["userAdmin", "readAnyDatabase"]) self.addCleanup(client_context.drop_user, 'admin', 'ro-admin') users_db_auth.add_user( 'user', 'pass', roles=["userAdmin", "readWrite"]) self.addCleanup(remove_all_users, users_db_auth) # Non-root client. listener = EventListener() client = rs_or_single_client_noauth(event_listeners=[listener]) admin_db = client.admin users_db = client.pymongo_test other_db = client.pymongo_test1 self.assertRaises(OperationFailure, users_db.test.find_one) self.assertEqual(listener.started_command_names(), ['find']) listener.reset() # Regular user should be able to query its own db, but # no other. users_db.authenticate('user', 'pass') if client_context.version.at_least(3, 0): self.assertEqual(listener.started_command_names()[0], 'saslStart') else: self.assertEqual(listener.started_command_names()[0], 'getnonce') self.assertEqual(0, users_db.test.count_documents({})) self.assertRaises(OperationFailure, other_db.test.find_one) listener.reset() # Admin read-only user should be able to query any db, # but not write. admin_db.authenticate('ro-admin', 'pass') if client_context.version.at_least(3, 0): self.assertEqual(listener.started_command_names()[0], 'saslStart') else: self.assertEqual(listener.started_command_names()[0], 'getnonce') self.assertEqual(None, other_db.test.find_one()) self.assertRaises(OperationFailure, other_db.test.insert_one, {}) # Close all sockets. client.close() listener.reset() # We should still be able to write to the regular user's db. self.assertTrue(users_db.test.delete_many({})) names = listener.started_command_names() if client_context.version.at_least(4, 4, -1): # No speculation with multiple users (but we do skipEmptyExchange). self.assertEqual( names, ['saslStart', 'saslContinue', 'saslStart', 'saslContinue', 'delete']) elif client_context.version.at_least(3, 0): self.assertEqual( names, ['saslStart', 'saslContinue', 'saslContinue', 'saslStart', 'saslContinue', 'saslContinue', 'delete']) else: self.assertEqual( names, ['getnonce', 'authenticate', 'getnonce', 'authenticate', 'delete']) # And read from other dbs... self.assertEqual(0, other_db.test.count_documents({})) # But still not write to other dbs. self.assertRaises(OperationFailure, other_db.test.insert_one, {}) def test_id_ordering(self): # PyMongo attempts to have _id show up first # when you iterate key/value pairs in a document. # This isn't reliable since python dicts don't # guarantee any particular order. This will never # work right in Jython or any Python or environment # with hash randomization enabled (e.g. tox). db = self.client.pymongo_test db.test.drop() db.test.insert_one(SON([("hello", "world"), ("_id", 5)])) db = self.client.get_database( "pymongo_test", codec_options=CodecOptions(document_class=SON)) cursor = db.test.find() for x in cursor: for (k, v) in x.items(): self.assertEqual(k, "_id") break def test_deref(self): db = self.client.pymongo_test db.test.drop() self.assertRaises(TypeError, db.dereference, 5) self.assertRaises(TypeError, db.dereference, "hello") self.assertRaises(TypeError, db.dereference, None) self.assertEqual(None, db.dereference(DBRef("test", ObjectId()))) obj = {"x": True} key = db.test.insert_one(obj).inserted_id self.assertEqual(obj, db.dereference(DBRef("test", key))) self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test"))) self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo")) self.assertEqual(None, db.dereference(DBRef("test", 4))) obj = {"_id": 4} db.test.insert_one(obj) self.assertEqual(obj, db.dereference(DBRef("test", 4))) def test_deref_kwargs(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({"_id": 4, "foo": "bar"}) db = self.client.get_database( "pymongo_test", codec_options=CodecOptions(document_class=SON)) self.assertEqual(SON([("foo", "bar")]), db.dereference(DBRef("test", 4), projection={"_id": False})) @client_context.require_no_auth @client_context.require_version_max(4, 1, 0) def test_eval(self): db = self.client.pymongo_test db.test.drop() with ignore_deprecations(): self.assertRaises(TypeError, db.eval, None) self.assertRaises(TypeError, db.eval, 5) self.assertRaises(TypeError, db.eval, []) self.assertEqual(3, db.eval("function (x) {return x;}", 3)) self.assertEqual(3, db.eval(u"function (x) {return x;}", 3)) self.assertEqual(None, db.eval("function (x) {db.test.save({y:x});}", 5)) self.assertEqual(db.test.find_one()["y"], 5) self.assertEqual(5, db.eval("function (x, y) {return x + y;}", 2, 3)) self.assertEqual(5, db.eval("function () {return 5;}")) self.assertEqual(5, db.eval("2 + 3;")) self.assertEqual(5, db.eval(Code("2 + 3;"))) self.assertRaises(OperationFailure, db.eval, Code("return i;")) self.assertEqual(2, db.eval(Code("return i;", {"i": 2}))) self.assertEqual(5, db.eval(Code("i + 3;", {"i": 2}))) self.assertRaises(OperationFailure, db.eval, "5 ++ 5;") # TODO some of these tests belong in the collection level testing. def test_insert_find_one(self): db = self.client.pymongo_test db.test.drop() a_doc = SON({"hello": u"world"}) a_key = db.test.insert_one(a_doc).inserted_id self.assertTrue(isinstance(a_doc["_id"], ObjectId)) self.assertEqual(a_doc["_id"], a_key) self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]})) self.assertEqual(a_doc, db.test.find_one(a_key)) self.assertEqual(None, db.test.find_one(ObjectId())) self.assertEqual(a_doc, db.test.find_one({"hello": u"world"})) self.assertEqual(None, db.test.find_one({"hello": u"test"})) b = db.test.find_one() b["hello"] = u"mike" db.test.replace_one({"_id": b["_id"]}, b) self.assertNotEqual(a_doc, db.test.find_one(a_key)) self.assertEqual(b, db.test.find_one(a_key)) self.assertEqual(b, db.test.find_one()) count = 0 for _ in db.test.find(): count += 1 self.assertEqual(count, 1) def test_long(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({"x": long(9223372036854775807)}) retrieved = db.test.find_one()['x'] self.assertEqual(Int64(9223372036854775807), retrieved) self.assertIsInstance(retrieved, Int64) db.test.delete_many({}) db.test.insert_one({"x": Int64(1)}) retrieved = db.test.find_one()['x'] self.assertEqual(Int64(1), retrieved) self.assertIsInstance(retrieved, Int64) def test_delete(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({"x": 1}) db.test.insert_one({"x": 2}) db.test.insert_one({"x": 3}) length = 0 for _ in db.test.find(): length += 1 self.assertEqual(length, 3) db.test.delete_one({"x": 1}) length = 0 for _ in db.test.find(): length += 1 self.assertEqual(length, 2) db.test.delete_one(db.test.find_one()) db.test.delete_one(db.test.find_one()) self.assertEqual(db.test.find_one(), None) db.test.insert_one({"x": 1}) db.test.insert_one({"x": 2}) db.test.insert_one({"x": 3}) self.assertTrue(db.test.find_one({"x": 2})) db.test.delete_one({"x": 2}) self.assertFalse(db.test.find_one({"x": 2})) self.assertTrue(db.test.find_one()) db.test.delete_many({}) self.assertFalse(db.test.find_one()) @client_context.require_no_auth @client_context.require_version_max(4, 1, 0) def test_system_js(self): db = self.client.pymongo_test db.system.js.delete_many({}) self.assertEqual(0, db.system.js.count_documents({})) db.system_js.add = "function(a, b) { return a + b; }" self.assertEqual('add', db.system.js.find_one()['_id']) self.assertEqual(1, db.system.js.count_documents({})) self.assertEqual(6, db.system_js.add(1, 5)) del db.system_js.add self.assertEqual(0, db.system.js.count_documents({})) db.system_js['add'] = "function(a, b) { return a + b; }" self.assertEqual('add', db.system.js.find_one()['_id']) self.assertEqual(1, db.system.js.count_documents({})) self.assertEqual(6, db.system_js['add'](1, 5)) del db.system_js['add'] self.assertEqual(0, db.system.js.count_documents({})) self.assertRaises(OperationFailure, db.system_js.add, 1, 5) # TODO right now CodeWScope doesn't work w/ system js # db.system_js.scope = Code("return hello;", {"hello": 8}) # self.assertEqual(8, db.system_js.scope()) self.assertRaises(OperationFailure, db.system_js.non_existant) def test_system_js_list(self): db = self.client.pymongo_test db.system.js.delete_many({}) self.assertEqual([], db.system_js.list()) db.system_js.foo = "function() { return 'blah'; }" self.assertEqual(["foo"], db.system_js.list()) db.system_js.bar = "function() { return 'baz'; }" self.assertEqual(set(["foo", "bar"]), set(db.system_js.list())) del db.system_js.foo self.assertEqual(["bar"], db.system_js.list()) def test_command_response_without_ok(self): # Sometimes (SERVER-10891) the server's response to a badly-formatted # command document will have no 'ok' field. We should raise # OperationFailure instead of KeyError. self.assertRaises(OperationFailure, helpers._check_command_response, {}, None) try: helpers._check_command_response({'$err': 'foo'}, None) except OperationFailure as e: self.assertEqual(e.args[0], "foo, full error: {'$err': 'foo'}") else: self.fail("_check_command_response didn't raise OperationFailure") def test_mongos_response(self): error_document = { 'ok': 0, 'errmsg': 'outer', 'raw': {'shard0/host0,host1': {'ok': 0, 'errmsg': 'inner'}}} with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) self.assertIn('inner', str(context.exception)) # If a shard has no primary and you run a command like dbstats, which # cannot be run on a secondary, mongos's response includes empty "raw" # errors. See SERVER-15428. error_document = { 'ok': 0, 'errmsg': 'outer', 'raw': {'shard0/host0,host1': {}}} with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) self.assertIn('outer', str(context.exception)) # Raw error has ok: 0 but no errmsg. Not a known case, but test it. error_document = { 'ok': 0, 'errmsg': 'outer', 'raw': {'shard0/host0,host1': {'ok': 0}}} with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) self.assertIn('outer', str(context.exception)) @client_context.require_test_commands @client_context.require_no_mongos def test_command_max_time_ms(self): self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: db = self.client.pymongo_test db.command('count', 'test') self.assertRaises(ExecutionTimeout, db.command, 'count', 'test', maxTimeMS=1) pipeline = [{'$project': {'name': 1, 'count': 1}}] # Database command helper. db.command('aggregate', 'test', pipeline=pipeline, cursor={}) self.assertRaises(ExecutionTimeout, db.command, 'aggregate', 'test', pipeline=pipeline, cursor={}, maxTimeMS=1) # Collection helper. db.test.aggregate(pipeline=pipeline) self.assertRaises(ExecutionTimeout, db.test.aggregate, pipeline, maxTimeMS=1) finally: self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_with_options(self): codec_options = DECIMAL_CODECOPTS read_preference = ReadPreference.SECONDARY_PREFERRED write_concern = WriteConcern(j=True) read_concern = ReadConcern(level="majority") # List of all options to compare. allopts = ['name', 'client', 'codec_options', 'read_preference', 'write_concern', 'read_concern'] db1 = self.client.get_database( 'with_options_test', codec_options=codec_options, read_preference=read_preference, write_concern=write_concern, read_concern=read_concern) # Case 1: swap no options db2 = db1.with_options() for opt in allopts: self.assertEqual(getattr(db1, opt), getattr(db2, opt)) # Case 2: swap all options newopts = {'codec_options': CodecOptions(), 'read_preference': ReadPreference.PRIMARY, 'write_concern': WriteConcern(w=1), 'read_concern': ReadConcern(level="local")} db2 = db1.with_options(**newopts) for opt in newopts: self.assertEqual( getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) def test_current_op_codec_options(self): class MySON(SON): pass opts = CodecOptions(document_class=MySON) db = self.client.get_database("pymongo_test", codec_options=opts) current_op = db.current_op(True) self.assertTrue(current_op['inprog']) self.assertIsInstance(current_op, MySON) class TestDatabaseAggregation(IntegrationTest): def setUp(self): self.pipeline = [{"$listLocalSessions": {}}, {"$limit": 1}, {"$addFields": {"dummy": "dummy field"}}, {"$project": {"_id": 0, "dummy": 1}}] self.result = {"dummy": "dummy field"} self.admin = self.client.admin @client_context.require_version_min(3, 6, 0) def test_database_aggregation(self): with self.admin.aggregate(self.pipeline) as cursor: result = next(cursor) self.assertEqual(result, self.result) @client_context.require_version_min(3, 6, 0) @client_context.require_no_mongos def test_database_aggregation_fake_cursor(self): coll_name = "test_output" if client_context.version < (4, 3): db_name = "admin" write_stage = {"$out": coll_name} else: # SERVER-43287 disallows writing with $out to the admin db, use # $merge instead. db_name = "pymongo_test" write_stage = { "$merge": {"into": {"db": db_name, "coll": coll_name}}} output_coll = self.client[db_name][coll_name] output_coll.drop() self.addCleanup(output_coll.drop) admin = self.admin.with_options(write_concern=WriteConcern(w=0)) pipeline = self.pipeline[:] pipeline.append(write_stage) with admin.aggregate(pipeline) as cursor: with self.assertRaises(StopIteration): next(cursor) result = wait_until(output_coll.find_one, "read unacknowledged write") self.assertEqual(result["dummy"], self.result["dummy"]) @client_context.require_version_max(3, 6, 0, -1) def test_database_aggregation_unsupported(self): err_msg = r"Database.aggregate\(\) is only supported on MongoDB 3.6\+." with self.assertRaisesRegex(ConfigurationError, err_msg): with self.admin.aggregate(self.pipeline) as _: pass if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_dbref.py000066400000000000000000000123061374256237000167510ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the dbref module.""" import pickle import sys sys.path[0:0] = [""] from bson.dbref import DBRef from bson.objectid import ObjectId from test import unittest from copy import deepcopy class TestDBRef(unittest.TestCase): def test_creation(self): a = ObjectId() self.assertRaises(TypeError, DBRef) self.assertRaises(TypeError, DBRef, "coll") self.assertRaises(TypeError, DBRef, 4, a) self.assertRaises(TypeError, DBRef, 1.5, a) self.assertRaises(TypeError, DBRef, a, a) self.assertRaises(TypeError, DBRef, None, a) self.assertRaises(TypeError, DBRef, "coll", a, 5) self.assertTrue(DBRef("coll", a)) self.assertTrue(DBRef(u"coll", a)) self.assertTrue(DBRef(u"coll", 5)) self.assertTrue(DBRef(u"coll", 5, "database")) def test_read_only(self): a = DBRef("coll", ObjectId()) def foo(): a.collection = "blah" def bar(): a.id = "aoeu" self.assertEqual("coll", a.collection) a.id self.assertEqual(None, a.database) self.assertRaises(AttributeError, foo) self.assertRaises(AttributeError, bar) def test_repr(self): self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))), "DBRef('coll', ObjectId('1234567890abcdef12345678'))") self.assertEqual(repr(DBRef(u"coll", ObjectId("1234567890abcdef12345678"))), "DBRef(%s, ObjectId('1234567890abcdef12345678'))" % (repr(u'coll'),) ) self.assertEqual(repr(DBRef("coll", 5, foo="bar")), "DBRef('coll', 5, foo='bar')") self.assertEqual(repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")), "DBRef('coll', ObjectId('1234567890abcdef12345678'), " "'foo')") def test_equality(self): obj_id = ObjectId("1234567890abcdef12345678") self.assertEqual(DBRef('foo', 5), DBRef('foo', 5)) self.assertEqual(DBRef("coll", obj_id), DBRef(u"coll", obj_id)) self.assertNotEqual(DBRef("coll", obj_id), DBRef(u"coll", obj_id, "foo")) self.assertNotEqual(DBRef("coll", obj_id), DBRef("col", obj_id)) self.assertNotEqual(DBRef("coll", obj_id), DBRef("coll", ObjectId(b"123456789011"))) self.assertNotEqual(DBRef("coll", obj_id), 4) self.assertEqual(DBRef("coll", obj_id, "foo"), DBRef(u"coll", obj_id, "foo")) self.assertNotEqual(DBRef("coll", obj_id, "foo"), DBRef(u"coll", obj_id, "bar")) # Explicitly test inequality self.assertFalse(DBRef('foo', 5) != DBRef('foo', 5)) self.assertFalse(DBRef("coll", obj_id) != DBRef(u"coll", obj_id)) self.assertFalse(DBRef("coll", obj_id, "foo") != DBRef(u"coll", obj_id, "foo")) def test_kwargs(self): self.assertEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="bar")) self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5)) self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="baz")) self.assertEqual("bar", DBRef("coll", 5, foo="bar").foo) self.assertRaises(AttributeError, getattr, DBRef("coll", 5, foo="bar"), "bar") def test_deepcopy(self): a = DBRef('coll', 'asdf', 'db', x=[1]) b = deepcopy(a) self.assertEqual(a, b) self.assertNotEqual(id(a), id(b.x)) self.assertEqual(a.x, b.x) self.assertNotEqual(id(a.x), id(b.x)) b.x[0] = 2 self.assertEqual(a.x, [1]) self.assertEqual(b.x, [2]) def test_pickling(self): dbr = DBRef('coll', 5, foo='bar') for protocol in [0, 1, 2, -1]: pkl = pickle.dumps(dbr, protocol=protocol) dbr2 = pickle.loads(pkl) self.assertEqual(dbr, dbr2) def test_dbref_hash(self): dbref_1a = DBRef('collection', 'id', 'database') dbref_1b = DBRef('collection', 'id', 'database') self.assertEqual(hash(dbref_1a), hash(dbref_1b)) dbref_2a = DBRef('collection', 'id', 'database', custom='custom') dbref_2b = DBRef('collection', 'id', 'database', custom='custom') self.assertEqual(hash(dbref_2a), hash(dbref_2b)) self.assertNotEqual(hash(dbref_1a), hash(dbref_2a)) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_decimal128.py000066400000000000000000000055351374256237000175260ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for Decimal128.""" import codecs import glob import json import os.path import pickle import sys from binascii import unhexlify from decimal import Decimal, DecimalException sys.path[0:0] = [""] from bson import BSON from bson.decimal128 import Decimal128, create_decimal128_context from bson.json_util import dumps, loads from bson.py3compat import b from test import client_context, unittest class TestDecimal128(unittest.TestCase): def test_round_trip(self): if not client_context.version.at_least(3, 3, 6): raise unittest.SkipTest( 'Round trip test requires MongoDB >= 3.3.6') coll = client_context.client.pymongo_test.test coll.drop() dec128 = Decimal128.from_bid( b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0') coll.insert_one({'dec128': dec128}) doc = coll.find_one({'dec128': dec128}) self.assertIsNotNone(doc) self.assertEqual(doc['dec128'], dec128) def test_pickle(self): dec128 = Decimal128.from_bid( b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0') for protocol in range(pickle.HIGHEST_PROTOCOL + 1): pkl = pickle.dumps(dec128, protocol=protocol) self.assertEqual(dec128, pickle.loads(pkl)) def test_special(self): dnan = Decimal('NaN') dnnan = Decimal('-NaN') dsnan = Decimal('sNaN') dnsnan = Decimal('-sNaN') dnan128 = Decimal128(dnan) dnnan128 = Decimal128(dnnan) dsnan128 = Decimal128(dsnan) dnsnan128 = Decimal128(dnsnan) # Due to the comparison rules for decimal.Decimal we have to # compare strings. self.assertEqual(str(dnan), str(dnan128.to_decimal())) self.assertEqual(str(dnnan), str(dnnan128.to_decimal())) self.assertEqual(str(dsnan), str(dsnan128.to_decimal())) self.assertEqual(str(dnsnan), str(dnsnan128.to_decimal())) def test_decimal128_context(self): ctx = create_decimal128_context() self.assertEqual("NaN", str(ctx.copy().create_decimal(".13.1"))) self.assertEqual("Infinity", str(ctx.copy().create_decimal("1E6145"))) self.assertEqual("0E-6176", str(ctx.copy().create_decimal("1E-6177"))) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_discovery_and_monitoring.py000066400000000000000000000340761374256237000227750ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the topology module.""" import os import sys import threading import time sys.path[0:0] = [""] from bson import json_util, Timestamp from pymongo import (common, monitoring) from pymongo.errors import (AutoReconnect, ConfigurationError, NetworkTimeout, NotMasterError, OperationFailure) from pymongo.helpers import _check_command_response from pymongo.ismaster import IsMaster from pymongo.server_description import ServerDescription, SERVER_TYPE from pymongo.settings import TopologySettings from pymongo.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.uri_parser import parse_uri from test import unittest, IntegrationTest from test.utils import (assertion_context, client_context, Barrier, get_pool, server_name_to_type, rs_or_single_client, TestCreator, wait_until) from test.utils_spec_runner import SpecRunner, SpecRunnerThread # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'discovery_and_monitoring') class MockMonitor(object): def __init__(self, server_description, topology, pool, topology_settings): self._server_description = server_description def cancel_check(self): pass def open(self): pass def close(self): pass def join(self): pass def request_check(self): pass def create_mock_topology(uri, monitor_class=MockMonitor): parsed_uri = parse_uri(uri) replica_set_name = None direct_connection = None if 'replicaset' in parsed_uri['options']: replica_set_name = parsed_uri['options']['replicaset'] if 'directConnection' in parsed_uri['options']: direct_connection = parsed_uri['options']['directConnection'] topology_settings = TopologySettings( parsed_uri['nodelist'], replica_set_name=replica_set_name, monitor_class=monitor_class, direct_connection=direct_connection) c = Topology(topology_settings) c.open() return c def got_ismaster(topology, server_address, ismaster_response): server_description = ServerDescription( server_address, IsMaster(ismaster_response), 0) topology.on_change(server_description) def got_app_error(topology, app_error): server_address = common.partition_node(app_error['address']) server = topology.get_server_by_address(server_address) error_type = app_error['type'] generation = app_error.get('generation', server.pool.generation) when = app_error['when'] max_wire_version = app_error['maxWireVersion'] # XXX: We could get better test coverage by mocking the errors on the # Pool/SocketInfo. try: if error_type == 'command': _check_command_response(app_error['response'], max_wire_version) elif error_type == 'network': raise AutoReconnect('mock non-timeout network error') elif error_type == 'timeout': raise NetworkTimeout('mock network timeout error') else: raise AssertionError('unknown error type: %s' % (error_type,)) assert False except (AutoReconnect, NotMasterError, OperationFailure) as e: if when == 'beforeHandshakeCompletes': completed_handshake = False elif when == 'afterHandshakeCompletes': completed_handshake = True else: assert False, 'Unknown when field %s' % (when,) topology.handle_error( server_address, _ErrorContext(e, max_wire_version, generation, completed_handshake)) def get_type(topology, hostname): description = topology.get_server_by_address((hostname, 27017)).description return description.server_type class TestAllScenarios(unittest.TestCase): pass def topology_type_name(topology_type): return TOPOLOGY_TYPE._fields[topology_type] def server_type_name(server_type): return SERVER_TYPE._fields[server_type] def check_outcome(self, topology, outcome): expected_servers = outcome['servers'] # Check weak equality before proceeding. self.assertEqual( len(topology.description.server_descriptions()), len(expected_servers)) if outcome.get('compatible') is False: with self.assertRaises(ConfigurationError): topology.description.check_compatible() else: # No error. topology.description.check_compatible() # Since lengths are equal, every actual server must have a corresponding # expected server. for expected_server_address, expected_server in expected_servers.items(): node = common.partition_node(expected_server_address) self.assertTrue(topology.has_server(node)) actual_server = topology.get_server_by_address(node) actual_server_description = actual_server.description expected_server_type = server_name_to_type(expected_server['type']) self.assertEqual( server_type_name(expected_server_type), server_type_name(actual_server_description.server_type)) self.assertEqual( expected_server.get('setName'), actual_server_description.replica_set_name) self.assertEqual( expected_server.get('setVersion'), actual_server_description.set_version) self.assertEqual( expected_server.get('electionId'), actual_server_description.election_id) self.assertEqual( expected_server.get('topologyVersion'), actual_server_description.topology_version) expected_pool = expected_server.get('pool') if expected_pool: self.assertEqual( expected_pool.get('generation'), actual_server.pool.generation) self.assertEqual(outcome['setName'], topology.description.replica_set_name) self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'), topology.description.logical_session_timeout_minutes) expected_topology_type = getattr(TOPOLOGY_TYPE, outcome['topologyType']) self.assertEqual(topology_type_name(expected_topology_type), topology_type_name(topology.description.topology_type)) def create_test(scenario_def): def run_scenario(self): c = create_mock_topology(scenario_def['uri']) for i, phase in enumerate(scenario_def['phases']): # Including the phase description makes failures easier to debug. description = phase.get('description', str(i)) with assertion_context('phase: %s' % (description,)): for response in phase.get('responses', []): got_ismaster( c, common.partition_node(response[0]), response[1]) for app_error in phase.get('applicationErrors', []): got_app_error(c, app_error) check_outcome(self, c, phase['outcome']) return run_scenario def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json_util.loads(scenario_stream.read()) # Construct test from scenario. new_test = create_test(scenario_def) test_name = 'test_%s_%s' % ( dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() class TestClusterTimeComparison(unittest.TestCase): def test_cluster_time_comparison(self): t = create_mock_topology('mongodb://host') def send_cluster_time(time, inc, should_update): old = t.max_cluster_time() new = {'clusterTime': Timestamp(time, inc)} got_ismaster(t, ('host', 27017), {'ok': 1, 'minWireVersion': 0, 'maxWireVersion': 6, '$clusterTime': new}) actual = t.max_cluster_time() if should_update: self.assertEqual(actual, new) else: self.assertEqual(actual, old) send_cluster_time(0, 1, True) send_cluster_time(2, 2, True) send_cluster_time(2, 1, False) send_cluster_time(1, 3, False) send_cluster_time(2, 3, True) class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = Barrier(N_THREADS, timeout=30) client = rs_or_single_client(minPoolSize=N_THREADS) self.addCleanup(client.close) # Wait for initial discovery. client.admin.command('ping') pool = get_pool(client) starting_generation = pool.generation wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets') def mock_command(*args, **kwargs): # Synchronize all threads to ensure they use the same generation. barrier.wait() raise AutoReconnect('mock SocketInfo.command error') for sock in pool.sockets: sock.command = mock_command def insert_command(i): try: client.test.command('insert', 'test', documents=[{'i': i}]) except AutoReconnect as exc: pass threads = [] for i in range(N_THREADS): threads.append(threading.Thread(target=insert_command, args=(i,))) for t in threads: t.start() for t in threads: t.join() # Expect a single pool reset for the network error self.assertEqual(starting_generation+1, pool.generation) # Server should be selectable. client.admin.command('ping') class TestIntegration(SpecRunner): # Location of JSON test specifications. TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'discovery_and_monitoring_integration') def _event_count(self, event): if event == 'ServerMarkedUnknownEvent': def marked_unknown(e): return (isinstance(e, monitoring.ServerDescriptionChangedEvent) and not e.new_description.is_server_type_known) return len(self.server_listener.matching(marked_unknown)) # Only support CMAP events for now. self.assertTrue(event.startswith('Pool') or event.startswith('Conn')) event_type = getattr(monitoring, event) return self.pool_listener.event_count(event_type) def assert_event_count(self, event, count): """Run the assertEventCount test operation. Assert the given event was published exactly `count` times. """ self.assertEqual(self._event_count(event), count, 'expected %s not %r' % (count, event)) def wait_for_event(self, event, count): """Run the waitForEvent test operation. Wait for a number of events to be published, or fail. """ wait_until(lambda: self._event_count(event) >= count, 'find %s %s event(s)' % (count, event)) def configure_fail_point(self, fail_point): """Run the configureFailPoint test operation. """ self.set_fail_point(fail_point) self.addCleanup(self.set_fail_point, { 'configureFailPoint': fail_point['configureFailPoint'], 'mode': 'off'}) def run_admin_command(self, command, **kwargs): """Run the runAdminCommand test operation. """ self.client.admin.command(command, **kwargs) def record_primary(self): """Run the recordPrimary test operation. """ self._previous_primary = self.scenario_client.primary def wait_for_primary_change(self, timeout_ms): """Run the waitForPrimaryChange test operation. """ def primary_changed(): primary = self.scenario_client.primary if primary is None: return False return primary != self._previous_primary timeout = timeout_ms/1000.0 wait_until(primary_changed, 'change primary', timeout=timeout) def wait(self, ms): """Run the "wait" test operation. """ time.sleep(ms/1000.0) def start_thread(self, name): """Run the 'startThread' thread operation.""" thread = SpecRunnerThread(name) thread.start() self.targets[name] = thread def run_on_thread(self, sessions, collection, name, operation): """Run the 'runOnThread' operation.""" thread = self.targets[name] thread.schedule(lambda: self._run_op( sessions, collection, operation, False)) def wait_for_thread(self, name): """Run the 'waitForThread' operation.""" thread = self.targets[name] thread.stop() thread.join() if thread.exc: raise thread.exc def create_spec_test(scenario_def, test, name): @client_context.require_test_commands def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_spec_test, TestIntegration, TestIntegration.TEST_PATH) test_creator.create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_dns.py000066400000000000000000000076371374256237000164660ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the SRV support tests.""" import glob import json import os import sys sys.path[0:0] = [""] from pymongo.common import validate_read_preference_tags from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.errors import ConfigurationError from pymongo.mongo_client import MongoClient from pymongo.uri_parser import parse_uri, split_hosts from test import client_context, unittest from test.utils import wait_until _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'srv_seedlist') class TestDNS(unittest.TestCase): pass def create_test(test_case): @client_context.require_replica_set @client_context.require_tls def run_test(self): if not _HAVE_DNSPYTHON: raise unittest.SkipTest("DNS tests require the dnspython module") uri = test_case['uri'] seeds = test_case['seeds'] hosts = test_case['hosts'] options = test_case.get('options') if seeds: seeds = split_hosts(','.join(seeds)) if hosts: hosts = frozenset(split_hosts(','.join(hosts))) if seeds: result = parse_uri(uri, validate=True) self.assertEqual(sorted(result['nodelist']), sorted(seeds)) if options: opts = result['options'] if 'readpreferencetags' in opts: rpts = validate_read_preference_tags( 'readPreferenceTags', opts.pop('readpreferencetags')) opts['readPreferenceTags'] = rpts self.assertEqual(result['options'], options) hostname = next(iter(client_context.client.nodes))[0] # The replica set members must be configured as 'localhost'. if hostname == 'localhost': copts = client_context.default_client_options.copy() if client_context.tls is True: # Our test certs don't support the SRV hosts used in these tests. copts['ssl_match_hostname'] = False client = MongoClient(uri, **copts) # Force server selection client.admin.command('ismaster') wait_until( lambda: hosts == client.nodes, 'match test hosts to client nodes') else: try: parse_uri(uri) except (ConfigurationError, ValueError): pass else: self.fail("failed to raise an exception") return run_test def create_tests(): for filename in glob.glob(os.path.join(_TEST_PATH, '*.json')): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as dns_test_file: test_method = create_test(json.load(dns_test_file)) setattr(TestDNS, 'test_' + test_suffix, test_method) create_tests() class TestParsingErrors(unittest.TestCase): @unittest.skipUnless(_HAVE_DNSPYTHON, "DNS tests require the dnspython module") def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb", MongoClient, "mongodb+srv://mongodb") self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb.com", MongoClient, "mongodb+srv://mongodb.com") if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_encryption.py000066400000000000000000001266671374256237000201010ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test client side encryption spec.""" import base64 import copy import os import traceback import socket import sys import uuid sys.path[0:0] = [""] from bson import encode, json_util from bson.binary import (Binary, JAVA_LEGACY, STANDARD, UUID_SUBTYPE) from bson.codec_options import CodecOptions from bson.errors import BSONError from bson.json_util import JSONOptions from bson.son import SON from pymongo.cursor import CursorType from pymongo.encryption import (Algorithm, ClientEncryption) from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT from pymongo.errors import (BulkWriteError, ConfigurationError, EncryptionError, InvalidOperation, OperationFailure, WriteError) from pymongo.mongo_client import MongoClient from pymongo.operations import InsertOne from pymongo.write_concern import WriteConcern from test import unittest, IntegrationTest, PyMongoTestCase, client_context from test.utils import (TestCreator, camel_to_snake_args, OvertCommandListener, rs_or_single_client, wait_until) from test.utils_spec_runner import SpecRunner def get_client_opts(client): return client._MongoClient__options KMS_PROVIDERS = {'local': {'key': b'\x00'*96}} class TestAutoEncryptionOpts(PyMongoTestCase): @unittest.skipIf(_HAVE_PYMONGOCRYPT, 'pymongocrypt is installed') def test_init_requires_pymongocrypt(self): with self.assertRaises(ConfigurationError): AutoEncryptionOpts({}, 'keyvault.datakeys') @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') def test_init(self): opts = AutoEncryptionOpts({}, 'keyvault.datakeys') self.assertEqual(opts._kms_providers, {}) self.assertEqual(opts._key_vault_namespace, 'keyvault.datakeys') self.assertEqual(opts._key_vault_client, None) self.assertEqual(opts._schema_map, None) self.assertEqual(opts._bypass_auto_encryption, False) self.assertEqual(opts._mongocryptd_uri, 'mongodb://localhost:27020') self.assertEqual(opts._mongocryptd_bypass_spawn, False) self.assertEqual(opts._mongocryptd_spawn_path, 'mongocryptd') self.assertEqual( opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=60']) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') def test_init_spawn_args(self): # User can override idleShutdownTimeoutSecs opts = AutoEncryptionOpts( {}, 'keyvault.datakeys', mongocryptd_spawn_args=['--idleShutdownTimeoutSecs=88']) self.assertEqual( opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=88']) # idleShutdownTimeoutSecs is added by default opts = AutoEncryptionOpts( {}, 'keyvault.datakeys', mongocryptd_spawn_args=[]) self.assertEqual( opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=60']) # Also added when other options are given opts = AutoEncryptionOpts( {}, 'keyvault.datakeys', mongocryptd_spawn_args=['--quiet', '--port=27020']) self.assertEqual( opts._mongocryptd_spawn_args, ['--quiet', '--port=27020', '--idleShutdownTimeoutSecs=60']) class TestClientOptions(PyMongoTestCase): def test_default(self): client = MongoClient(connect=False) self.addCleanup(client.close) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) client = MongoClient(auto_encryption_opts=None, connect=False) self.addCleanup(client.close) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') client = MongoClient(auto_encryption_opts=opts, connect=False) self.addCleanup(client.close) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') @client_context.require_version_min(4, 2, -1) def setUpClass(cls): super(EncryptionIntegrationTest, cls).setUpClass() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, 6) def assertBinaryUUID(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) # Location of JSON test files. BASE = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption') SPEC_PATH = os.path.join(BASE, 'spec') OPTS = CodecOptions(uuid_representation=STANDARD) # Use SON to preserve the order of fields while parsing json. Use tz_aware # =False to match how CodecOptions decodes dates. JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD, tz_aware=False) def read(*paths): with open(os.path.join(BASE, *paths)) as fp: return fp.read() def json_data(*paths): return json_util.loads(read(*paths), json_options=JSON_OPTS) def bson_data(*paths): return encode(json_data(*paths), codec_options=OPTS) class TestClientSimple(EncryptionIntegrationTest): def _test_auto_encrypt(self, opts): client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) # Create the encrypted field's data key. key_vault = create_key_vault( self.client.keyvault.datakeys, json_data('custom', 'key-document-local.json')) self.addCleanup(key_vault.drop) # Collection.insert_one/insert_many auto encrypts. docs = [{'_id': 0, 'ssn': '000'}, {'_id': 1, 'ssn': '111'}, {'_id': 2, 'ssn': '222'}, {'_id': 3, 'ssn': '333'}, {'_id': 4, 'ssn': '444'}, {'_id': 5, 'ssn': '555'}] encrypted_coll = client.pymongo_test.test encrypted_coll.insert_one(docs[0]) encrypted_coll.insert_many(docs[1:3]) unack = encrypted_coll.with_options(write_concern=WriteConcern(w=0)) unack.insert_one(docs[3]) unack.insert_many(docs[4:], ordered=False) wait_until(lambda: self.db.test.count_documents({}) == len(docs), 'insert documents with w=0') # Database.command auto decrypts. res = client.pymongo_test.command( 'find', 'test', filter={'ssn': '000'}) decrypted_docs = res['cursor']['firstBatch'] self.assertEqual(decrypted_docs, [{'_id': 0, 'ssn': '000'}]) # Collection.find auto decrypts. decrypted_docs = list(encrypted_coll.find()) self.assertEqual(decrypted_docs, docs) # Collection.find auto decrypts getMores. decrypted_docs = list(encrypted_coll.find(batch_size=1)) self.assertEqual(decrypted_docs, docs) # Collection.aggregate auto decrypts. decrypted_docs = list(encrypted_coll.aggregate([])) self.assertEqual(decrypted_docs, docs) # Collection.aggregate auto decrypts getMores. decrypted_docs = list(encrypted_coll.aggregate([], batchSize=1)) self.assertEqual(decrypted_docs, docs) # Collection.distinct auto decrypts. decrypted_ssns = encrypted_coll.distinct('ssn') self.assertEqual(set(decrypted_ssns), set(d['ssn'] for d in docs)) # Make sure the field is actually encrypted. for encrypted_doc in self.db.test.find(): self.assertIsInstance(encrypted_doc['_id'], int) self.assertEncrypted(encrypted_doc['ssn']) # Attempt to encrypt an unencodable object. with self.assertRaises(BSONError): encrypted_coll.insert_one({'unencodeable': object()}) def test_auto_encrypt(self): # Configure the encrypted field via jsonSchema. json_schema = json_data('custom', 'schema.json') create_with_schema(self.db.test, json_schema) self.addCleanup(self.db.test.drop) opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') self._test_auto_encrypt(opts) def test_auto_encrypt_local_schema_map(self): # Configure the encrypted field via the local schema_map option. schemas = {'pymongo_test.test': json_data('custom', 'schema.json')} opts = AutoEncryptionOpts( KMS_PROVIDERS, 'keyvault.datakeys', schema_map=schemas) self._test_auto_encrypt(opts) def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) client.admin.command('isMaster') client.close() with self.assertRaisesRegex(InvalidOperation, 'Cannot use MongoClient after close'): client.admin.command('isMaster') class TestClientMaxWireVersion(IntegrationTest): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') def setUpClass(cls): super(TestClientMaxWireVersion, cls).setUpClass() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) msg = 'Auto-encryption requires a minimum MongoDB version of 4.2' with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.insert_one({}) with self.assertRaisesRegex(ConfigurationError, msg): client.admin.command('isMaster') with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.find_one({}) with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.bulk_write([InsertOne({})]) def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) msg = 'find_raw_batches does not support auto encryption' with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.find_raw_batches({}) msg = 'aggregate_raw_batches does not support auto encryption' with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.aggregate_raw_batches([]) if client_context.is_mongos: msg = 'Exhaust cursors are not supported by mongos' else: msg = 'exhaust cursors do not support auto encryption' with self.assertRaisesRegex(InvalidOperation, msg): next(client.test.test.find(cursor_type=CursorType.EXHAUST)) class TestExplicitSimple(EncryptionIntegrationTest): def test_encrypt_decrypt(self): client_encryption = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) self.addCleanup(client_encryption.close) # Use standard UUID representation. key_vault = client_context.client.keyvault.get_collection( 'datakeys', codec_options=OPTS) self.addCleanup(key_vault.drop) # Create the encrypted field's data key. key_id = client_encryption.create_data_key( 'local', key_alt_names=['name']) self.assertBinaryUUID(key_id) self.assertTrue(key_vault.find_one({'_id': key_id})) # Create an unused data key to make sure filtering works. unused_key_id = client_encryption.create_data_key( 'local', key_alt_names=['unused']) self.assertBinaryUUID(unused_key_id) self.assertTrue(key_vault.find_one({'_id': unused_key_id})) doc = {'_id': 0, 'ssn': '000'} encrypted_ssn = client_encryption.encrypt( doc['ssn'], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id) # Ensure encryption via key_alt_name for the same key produces the # same output. encrypted_ssn2 = client_encryption.encrypt( doc['ssn'], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name='name') self.assertEqual(encrypted_ssn, encrypted_ssn2) # Test decryption. decrypted_ssn = client_encryption.decrypt(encrypted_ssn) self.assertEqual(decrypted_ssn, doc['ssn']) def test_validation(self): client_encryption = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) self.addCleanup(client_encryption.close) msg = 'value to decrypt must be a bson.binary.Binary with subtype 6' with self.assertRaisesRegex(TypeError, msg): client_encryption.decrypt('str') with self.assertRaisesRegex(TypeError, msg): client_encryption.decrypt(Binary(b'123')) msg = 'key_id must be a bson.binary.Binary with subtype 4' algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic with self.assertRaisesRegex(TypeError, msg): client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) with self.assertRaisesRegex(TypeError, msg): client_encryption.encrypt('str', algo, key_id=Binary(b'123')) def test_bson_errors(self): client_encryption = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) self.addCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() with self.assertRaises(BSONError): client_encryption.encrypt( unencodable_value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE)) def test_codec_options(self): with self.assertRaisesRegex(TypeError, 'codec_options must be'): ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) opts = CodecOptions(uuid_representation=JAVA_LEGACY) client_encryption_legacy = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, opts) self.addCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = client_encryption_legacy.create_data_key('local') # Encrypt a UUID with JAVA_LEGACY codec options. value = uuid.uuid4() encrypted_legacy = client_encryption_legacy.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id) decrypted_value_legacy = client_encryption_legacy.decrypt( encrypted_legacy) self.assertEqual(decrypted_value_legacy, value) # Encrypt the same UUID with STANDARD codec options. client_encryption = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) self.addCleanup(client_encryption.close) encrypted_standard = client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id) decrypted_standard = client_encryption.decrypt(encrypted_standard) self.assertEqual(decrypted_standard, value) # Test that codec_options is applied during encryption. self.assertNotEqual(encrypted_standard, encrypted_legacy) # Test that codec_options is applied during decryption. self.assertEqual( client_encryption_legacy.decrypt(encrypted_standard), value) self.assertNotEqual( client_encryption.decrypt(encrypted_legacy), value) def test_close(self): client_encryption = ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) client_encryption.close() # Close can be called multiple times. client_encryption.close() algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic msg = 'Cannot use closed ClientEncryption' with self.assertRaisesRegex(InvalidOperation, msg): client_encryption.create_data_key('local') with self.assertRaisesRegex(InvalidOperation, msg): client_encryption.encrypt('val', algo, key_alt_name='name') with self.assertRaisesRegex(InvalidOperation, msg): client_encryption.decrypt(Binary(b'', 6)) def test_with_statement(self): with ClientEncryption( KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) as client_encryption: pass with self.assertRaisesRegex( InvalidOperation, 'Cannot use closed ClientEncryption'): client_encryption.create_data_key('local') # Spec tests AWS_CREDS = { 'accessKeyId': os.environ.get('FLE_AWS_KEY', ''), 'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '') } class TestSpec(SpecRunner): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') @client_context.require_version_min(3, 6) # SpecRunner requires sessions. def setUpClass(cls): super(TestSpec, cls).setUpClass() def parse_auto_encrypt_opts(self, opts): """Parse clientOptions.autoEncryptOpts.""" opts = camel_to_snake_args(opts) kms_providers = opts['kms_providers'] if 'aws' in kms_providers: kms_providers['aws'] = AWS_CREDS if not any(AWS_CREDS.values()): self.skipTest('AWS environment credentials are not set') if 'key_vault_namespace' not in opts: opts['key_vault_namespace'] = 'keyvault.datakeys' opts = dict(opts) return AutoEncryptionOpts(**opts) def parse_client_options(self, opts): """Override clientOptions parsing to support autoEncryptOpts.""" encrypt_opts = opts.pop('autoEncryptOpts') if encrypt_opts: opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts( encrypt_opts) return super(TestSpec, self).parse_client_options(opts) def get_object_name(self, op): """Default object is collection.""" return op.get('object', 'collection') def maybe_skip_scenario(self, test): super(TestSpec, self).maybe_skip_scenario(test) desc = test['description'].lower() if 'type=symbol' in desc: self.skipTest('PyMongo does not support the symbol type') if desc == 'explain a find with deterministic encryption': # PyPy and Python 3.6+ have ordered dict. if sys.version_info[:2] < (3, 6) and 'PyPy' not in sys.version: self.skipTest( 'explain test does not work without ordered dict') def setup_scenario(self, scenario_def): """Override a test's setup.""" key_vault_data = scenario_def['key_vault_data'] if key_vault_data: coll = client_context.client.get_database( 'keyvault', write_concern=WriteConcern(w='majority'), codec_options=OPTS)['datakeys'] coll.drop() coll.insert_many(key_vault_data) db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) db = client_context.client.get_database( db_name, write_concern=WriteConcern(w='majority'), codec_options=OPTS) coll = db[coll_name] coll.drop() json_schema = scenario_def['json_schema'] if json_schema: db.create_collection( coll_name, validator={'$jsonSchema': json_schema}, codec_options=OPTS) else: db.create_collection(coll_name) if scenario_def['data']: # Load data. coll.insert_many(scenario_def['data']) def allowable_errors(self, op): """Override expected error classes.""" errors = super(TestSpec, self).allowable_errors(op) # An updateOne test expects encryption to error when no $ operator # appears but pymongo raises a client side ValueError in this case. if op['name'] == 'updateOne': errors += (ValueError,) return errors def create_test(scenario_def, test, name): @client_context.require_test_commands def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestSpec, SPEC_PATH) test_creator.create_tests() # Prose Tests LOCAL_MASTER_KEY = base64.b64decode( b'Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ' b'5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk') LOCAL_KEY_ID = Binary( base64.b64decode(b'LOCALAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) AWS_KEY_ID = Binary( base64.b64decode(b'AWSAAAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) def create_with_schema(coll, json_schema): """Create and return a Collection with a jsonSchema.""" coll.with_options(write_concern=WriteConcern(w='majority')).drop() return coll.database.create_collection( coll.name, validator={'$jsonSchema': json_schema}, codec_options=OPTS) def create_key_vault(vault, *data_keys): """Create the key vault collection with optional data keys.""" vault = vault.with_options( write_concern=WriteConcern(w='majority'), codec_options=OPTS) vault.drop() if data_keys: vault.insert_many(data_keys) return vault class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): @classmethod @unittest.skipUnless(all(AWS_CREDS.values()), 'AWS environment credentials are not set') def setUpClass(cls): super(TestDataKeyDoubleEncryption, cls).setUpClass() @staticmethod def kms_providers(): return {'aws': AWS_CREDS, 'local': {'key': LOCAL_MASTER_KEY}} def test_data_key(self): listener = OvertCommandListener() client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) client.db.coll.drop() vault = create_key_vault(client.keyvault.datakeys) self.addCleanup(vault.drop) # Configure the encrypted field via the local schema_map option. schemas = { "db.coll": { "bsonType": "object", "properties": { "encrypted_placeholder": { "encrypt": { "keyId": "/placeholder", "bsonType": "string", "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random" } } } } } opts = AutoEncryptionOpts( self.kms_providers(), 'keyvault.datakeys', schema_map=schemas) client_encrypted = rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation='standard') self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( self.kms_providers(), 'keyvault.datakeys', client, OPTS) self.addCleanup(client_encryption.close) # Local create data key. listener.reset() local_datakey_id = client_encryption.create_data_key( 'local', key_alt_names=['local_altname']) self.assertBinaryUUID(local_datakey_id) cmd = listener.results['started'][-1] self.assertEqual('insert', cmd.command_name) self.assertEqual({'w': 'majority'}, cmd.command.get('writeConcern')) docs = list(vault.find({'_id': local_datakey_id})) self.assertEqual(len(docs), 1) self.assertEqual(docs[0]['masterKey']['provider'], 'local') # Local encrypt by key_id. local_encrypted = client_encryption.encrypt( 'hello local', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=local_datakey_id) self.assertEncrypted(local_encrypted) client_encrypted.db.coll.insert_one( {'_id': 'local', 'value': local_encrypted}) doc_decrypted = client_encrypted.db.coll.find_one({'_id': 'local'}) self.assertEqual(doc_decrypted['value'], 'hello local') # Local encrypt by key_alt_name. local_encrypted_altname = client_encryption.encrypt( 'hello local', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name='local_altname') self.assertEqual(local_encrypted_altname, local_encrypted) # AWS create data key. listener.reset() master_key = { 'region': 'us-east-1', 'key': 'arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-' '9f25-e30687b580d0' } aws_datakey_id = client_encryption.create_data_key( 'aws', master_key=master_key, key_alt_names=['aws_altname']) self.assertBinaryUUID(aws_datakey_id) cmd = listener.results['started'][-1] self.assertEqual('insert', cmd.command_name) self.assertEqual({'w': 'majority'}, cmd.command.get('writeConcern')) docs = list(vault.find({'_id': aws_datakey_id})) self.assertEqual(len(docs), 1) self.assertEqual(docs[0]['masterKey']['provider'], 'aws') # AWS encrypt by key_id. aws_encrypted = client_encryption.encrypt( 'hello aws', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=aws_datakey_id) self.assertEncrypted(aws_encrypted) client_encrypted.db.coll.insert_one( {'_id': 'aws', 'value': aws_encrypted}) doc_decrypted = client_encrypted.db.coll.find_one({'_id': 'aws'}) self.assertEqual(doc_decrypted['value'], 'hello aws') # AWS encrypt by key_alt_name. aws_encrypted_altname = client_encryption.encrypt( 'hello aws', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name='aws_altname') self.assertEqual(aws_encrypted_altname, aws_encrypted) # Explicitly encrypting an auto encrypted field. msg = (r'Cannot encrypt element of type binData because schema ' r'requires that type is one of: \[ string \]') with self.assertRaisesRegex(EncryptionError, msg): client_encrypted.db.coll.insert_one( {'encrypted_placeholder': local_encrypted}) class TestExternalKeyVault(EncryptionIntegrationTest): @staticmethod def kms_providers(): return {'local': {'key': LOCAL_MASTER_KEY}} def _test_external_key_vault(self, with_external_key_vault): self.client.db.coll.drop() vault = create_key_vault( self.client.keyvault.datakeys, json_data('corpus', 'corpus-key-local.json'), json_data('corpus', 'corpus-key-aws.json')) self.addCleanup(vault.drop) # Configure the encrypted field via the local schema_map option. schemas = {'db.coll': json_data('external', 'external-schema.json')} if with_external_key_vault: key_vault_client = rs_or_single_client( username='fake-user', password='fake-pwd') self.addCleanup(key_vault_client.close) else: key_vault_client = client_context.client opts = AutoEncryptionOpts( self.kms_providers(), 'keyvault.datakeys', schema_map=schemas, key_vault_client=key_vault_client) client_encrypted = rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation='standard') self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( self.kms_providers(), 'keyvault.datakeys', key_vault_client, OPTS) self.addCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. with self.assertRaises(EncryptionError) as ctx: client_encrypted.db.coll.insert_one({"encrypted": "test"}) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) self.assertEqual(ctx.exception.cause.code, 18) else: client_encrypted.db.coll.insert_one({"encrypted": "test"}) if with_external_key_vault: # Authentication error. with self.assertRaises(EncryptionError) as ctx: client_encryption.encrypt( "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=LOCAL_KEY_ID) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) self.assertEqual(ctx.exception.cause.code, 18) else: client_encryption.encrypt( "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=LOCAL_KEY_ID) def test_external_key_vault_1(self): self._test_external_key_vault(True) def test_external_key_vault_2(self): self._test_external_key_vault(False) class TestViews(EncryptionIntegrationTest): @staticmethod def kms_providers(): return {'local': {'key': LOCAL_MASTER_KEY}} def test_views_are_prohibited(self): self.client.db.view.drop() self.client.db.create_collection('view', viewOn='coll') self.addCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), 'keyvault.datakeys') client_encrypted = rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation='standard') self.addCleanup(client_encrypted.close) with self.assertRaisesRegex( EncryptionError, 'cannot auto encrypt a view'): client_encrypted.db.view.insert_one({}) class TestCorpus(EncryptionIntegrationTest): @classmethod @unittest.skipUnless(all(AWS_CREDS.values()), 'AWS environment credentials are not set') def setUpClass(cls): super(TestCorpus, cls).setUpClass() @staticmethod def kms_providers(): return {'aws': AWS_CREDS, 'local': {'key': LOCAL_MASTER_KEY}} @staticmethod def fix_up_schema(json_schema): """Remove deprecated symbol/dbPointer types from json schema.""" for key in json_schema['properties'].keys(): if '_symbol_' in key or '_dbPointer_' in key: del json_schema['properties'][key] return json_schema @staticmethod def fix_up_curpus(corpus): """Disallow deprecated symbol/dbPointer types from corpus test.""" for key in corpus: if '_symbol_' in key or '_dbPointer_' in key: corpus[key]['allowed'] = False return corpus @staticmethod def fix_up_curpus_encrypted(corpus_encrypted, corpus): """Fix the expected values for deprecated symbol/dbPointer types.""" for key in corpus_encrypted: if '_symbol_' in key or '_dbPointer_' in key: corpus_encrypted[key] = copy.deepcopy(corpus[key]) return corpus_encrypted def _test_corpus(self, opts): # Drop and create the collection 'db.coll' with jsonSchema. coll = create_with_schema( self.client.db.coll, self.fix_up_schema(json_data('corpus', 'corpus-schema.json'))) self.addCleanup(coll.drop) vault = create_key_vault( self.client.keyvault.datakeys, json_data('corpus', 'corpus-key-local.json'), json_data('corpus', 'corpus-key-aws.json')) self.addCleanup(vault.drop) client_encrypted = rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation='standard') self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( self.kms_providers(), 'keyvault.datakeys', client_context.client, OPTS) self.addCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json')) corpus_copied = SON() for key, value in corpus.items(): corpus_copied[key] = copy.deepcopy(value) if key in ('_id', 'altname_aws', 'altname_local'): continue if value['method'] == 'auto': continue if value['method'] == 'explicit': identifier = value['identifier'] self.assertIn(identifier, ('id', 'altname')) kms = value['kms'] self.assertIn(kms, ('local', 'aws')) if identifier == 'id': if kms == 'local': kwargs = dict(key_id=LOCAL_KEY_ID) else: kwargs = dict(key_id=AWS_KEY_ID) else: kwargs = dict(key_alt_name=kms) self.assertIn(value['algo'], ('det', 'rand')) if value['algo'] == 'det': algo = (Algorithm. AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic) else: algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random try: encrypted_val = client_encryption.encrypt( value['value'], algo, **kwargs) if not value['allowed']: self.fail('encrypt should have failed: %r: %r' % ( key, value)) corpus_copied[key]['value'] = encrypted_val except Exception: if value['allowed']: tb = traceback.format_exc() self.fail('encrypt failed: %r: %r, traceback: %s' % ( key, value, tb)) client_encrypted.db.coll.insert_one(corpus_copied) corpus_decrypted = client_encrypted.db.coll.find_one() self.assertEqual(corpus_decrypted, corpus) corpus_encrypted_expected = self.fix_up_curpus_encrypted(json_data( 'corpus', 'corpus-encrypted.json'), corpus) corpus_encrypted_actual = coll.find_one() for key, value in corpus_encrypted_actual.items(): if key in ('_id', 'altname_aws', 'altname_local'): continue if value['algo'] == 'det': self.assertEqual( value['value'], corpus_encrypted_expected[key]['value'], key) elif value['algo'] == 'rand' and value['allowed']: self.assertNotEqual( value['value'], corpus_encrypted_expected[key]['value'], key) if value['allowed']: decrypt_actual = client_encryption.decrypt(value['value']) decrypt_expected = client_encryption.decrypt( corpus_encrypted_expected[key]['value']) self.assertEqual(decrypt_actual, decrypt_expected, key) else: self.assertEqual(value['value'], corpus[key]['value'], key) def test_corpus(self): opts = AutoEncryptionOpts(self.kms_providers(), 'keyvault.datakeys') self._test_corpus(opts) def test_corpus_local_schema(self): # Configure the encrypted field via the local schema_map option. schemas = {'db.coll': self.fix_up_schema( json_data('corpus', 'corpus-schema.json'))} opts = AutoEncryptionOpts( self.kms_providers(), 'keyvault.datakeys', schema_map=schemas) self._test_corpus(opts) _2_MiB = 2097152 _16_MiB = 16777216 class TestBsonSizeBatches(EncryptionIntegrationTest): """Prose tests for BSON size limits and batch splitting.""" @classmethod def setUpClass(cls): super(TestBsonSizeBatches, cls).setUpClass() db = client_context.client.db cls.coll = db.coll cls.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data('limits', 'limits-schema.json') db.create_collection( 'coll', validator={'$jsonSchema': json_schema}, codec_options=OPTS, write_concern=WriteConcern(w='majority')) # Create the key vault. coll = client_context.client.get_database( 'keyvault', write_concern=WriteConcern(w='majority'), codec_options=OPTS)['datakeys'] coll.drop() coll.insert_one(json_data('limits', 'limits-key.json')) opts = AutoEncryptionOpts( {'local': {'key': LOCAL_MASTER_KEY}}, 'keyvault.datakeys') cls.listener = OvertCommandListener() cls.client_encrypted = rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener]) cls.coll_encrypted = cls.client_encrypted.db.coll @classmethod def tearDownClass(cls): cls.coll_encrypted.drop() cls.client_encrypted.close() super(TestBsonSizeBatches, cls).tearDownClass() def test_01_insert_succeeds_under_2MiB(self): doc = {'_id': 'over_2mib_under_16mib', 'unencrypted': 'a' * _2_MiB} self.coll_encrypted.insert_one(doc) # Same with bulk_write. doc['_id'] = 'over_2mib_under_16mib_bulk' self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_02_insert_succeeds_over_2MiB_post_encryption(self): doc = {'_id': 'encryption_exceeds_2mib', 'unencrypted': 'a' * ((2**21) - 2000)} doc.update(json_data('limits', 'limits-doc.json')) self.coll_encrypted.insert_one(doc) # Same with bulk_write. doc['_id'] = 'encryption_exceeds_2mib_bulk' self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_03_bulk_batch_split(self): doc1 = {'_id': 'over_2mib_1', 'unencrypted': 'a' * _2_MiB} doc2 = {'_id': 'over_2mib_2', 'unencrypted': 'a' * _2_MiB} self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) self.assertEqual( self.listener.started_command_names(), ['insert', 'insert']) def test_04_bulk_batch_split(self): limits_doc = json_data('limits', 'limits-doc.json') doc1 = {'_id': 'encryption_exceeds_2mib_1', 'unencrypted': 'a' * (_2_MiB - 2000)} doc1.update(limits_doc) doc2 = {'_id': 'encryption_exceeds_2mib_2', 'unencrypted': 'a' * (_2_MiB - 2000)} doc2.update(limits_doc) self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) self.assertEqual( self.listener.started_command_names(), ['insert', 'insert']) def test_05_insert_succeeds_just_under_16MiB(self): doc = {'_id': 'under_16mib', 'unencrypted': 'a' * (_16_MiB - 2000)} self.coll_encrypted.insert_one(doc) # Same with bulk_write. doc['_id'] = 'under_16mib_bulk' self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_06_insert_fails_over_16MiB(self): limits_doc = json_data('limits', 'limits-doc.json') doc = {'_id': 'encryption_exceeds_16mib', 'unencrypted': 'a' * (_16_MiB - 2000)} doc.update(limits_doc) with self.assertRaisesRegex(WriteError, 'object to insert too large'): self.coll_encrypted.insert_one(doc) # Same with bulk_write. doc['_id'] = 'encryption_exceeds_16mib_bulk' with self.assertRaises(BulkWriteError) as ctx: self.coll_encrypted.bulk_write([InsertOne(doc)]) err = ctx.exception.details['writeErrors'][0] self.assertEqual(2, err['code']) self.assertIn('object to insert too large', err['errmsg']) class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" @classmethod @unittest.skipUnless(all(AWS_CREDS.values()), 'AWS environment credentials are not set') def setUpClass(cls): super(TestCustomEndpoint, cls).setUpClass() cls.client_encryption = ClientEncryption( {'aws': AWS_CREDS}, 'keyvault.datakeys', client_context.client, OPTS) def _test_create_data_key(self, master_key): data_key_id = self.client_encryption.create_data_key( 'aws', master_key=master_key) encrypted = self.client_encryption.encrypt( 'test', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id) self.assertEqual('test', self.client_encryption.decrypt(encrypted)) def test_02_aws_region_key(self): self._test_create_data_key({ "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0") }) def test_03_aws_region_key_endpoint(self): self._test_create_data_key({ "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-1.amazonaws.com" }) def test_04_aws_region_key_endpoint_port(self): self._test_create_data_key({ "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-1.amazonaws.com:443" }) def test_05_endpoint_invalid_port(self): master_key = { "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-1.amazonaws.com:12345" } with self.assertRaises(EncryptionError) as ctx: self.client_encryption.create_data_key( 'aws', master_key=master_key) self.assertIsInstance(ctx.exception.cause, socket.error) def test_05_endpoint_wrong_region(self): master_key = { "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "kms.us-east-2.amazonaws.com" } # The full error should be something like: # "Credential should be scoped to a valid region, not 'us-east-1'" # but we only check for "us-east-1" to avoid breaking on slight # changes to AWS' error message. with self.assertRaisesRegex(EncryptionError, 'us-east-1'): self.client_encryption.create_data_key( 'aws', master_key=master_key) def test_05_endpoint_invalid_host(self): master_key = { "region": "us-east-1", "key": ("arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), "endpoint": "example.com" } with self.assertRaisesRegex(EncryptionError, 'parse error'): self.client_encryption.create_data_key( 'aws', master_key=master_key) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_errors.py000066400000000000000000000051151374256237000172030ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import traceback sys.path[0:0] = [""] from pymongo.errors import (NotMasterError, OperationFailure) from test import (PyMongoTestCase, unittest) class TestErrors(PyMongoTestCase): def test_not_master_error(self): exc = NotMasterError("not master test", {"errmsg": "error"}) self.assertIn("full error", str(exc)) try: raise exc except NotMasterError: self.assertIn("full error", traceback.format_exc()) def test_operation_failure(self): exc = OperationFailure("operation failure test", 10, {"errmsg": "error"}) self.assertIn("full error", str(exc)) try: raise exc except OperationFailure: self.assertIn("full error", traceback.format_exc()) def _test_unicode_strs(self, exc): if sys.version_info[0] == 2: self.assertEqual("unicode \xf0\x9f\x90\x8d, full error: {" "'errmsg': u'unicode \\U0001f40d'}", str(exc)) elif 'PyPy' in sys.version: # PyPy displays unicode in repr differently. self.assertEqual("unicode \U0001f40d, full error: {" "'errmsg': 'unicode \\U0001f40d'}", str(exc)) else: self.assertEqual("unicode \U0001f40d, full error: {" "'errmsg': 'unicode \U0001f40d'}", str(exc)) try: raise exc except Exception: self.assertIn("full error", traceback.format_exc()) def test_unicode_strs_operation_failure(self): exc = OperationFailure(u'unicode \U0001f40d', 10, {"errmsg": u'unicode \U0001f40d'}) self._test_unicode_strs(exc) def test_unicode_strs_not_master_error(self): exc = NotMasterError(u'unicode \U0001f40d', {"errmsg": u'unicode \U0001f40d'}) self._test_unicode_strs(exc) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_examples.py000066400000000000000000001124431374256237000175100ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MongoDB documentation examples in Python.""" import datetime import sys import threading sys.path[0:0] = [""] import pymongo from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern from test import client_context, unittest, IntegrationTest from test.utils import rs_client class TestSampleShellCommands(IntegrationTest): @classmethod def setUpClass(cls): super(TestSampleShellCommands, cls).setUpClass() # Run once before any tests run. cls.db.inventory.drop() @classmethod def tearDownClass(cls): cls.client.drop_database("pymongo_test") def tearDown(self): # Run after every test. self.db.inventory.drop() def test_first_three_examples(self): db = self.db # Start Example 1 db.inventory.insert_one( {"item": "canvas", "qty": 100, "tags": ["cotton"], "size": {"h": 28, "w": 35.5, "uom": "cm"}}) # End Example 1 self.assertEqual(db.inventory.count_documents({}), 1) # Start Example 2 cursor = db.inventory.find({"item": "canvas"}) # End Example 2 self.assertEqual(cursor.count(), 1) # Start Example 3 db.inventory.insert_many([ {"item": "journal", "qty": 25, "tags": ["blank", "red"], "size": {"h": 14, "w": 21, "uom": "cm"}}, {"item": "mat", "qty": 85, "tags": ["gray"], "size": {"h": 27.9, "w": 35.5, "uom": "cm"}}, {"item": "mousepad", "qty": 25, "tags": ["gel", "blue"], "size": {"h": 19, "w": 22.85, "uom": "cm"}}]) # End Example 3 self.assertEqual(db.inventory.count_documents({}), 4) def test_query_top_level_fields(self): db = self.db # Start Example 6 db.inventory.insert_many([ {"item": "journal", "qty": 25, "size": {"h": 14, "w": 21, "uom": "cm"}, "status": "A"}, {"item": "notebook", "qty": 50, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "A"}, {"item": "paper", "qty": 100, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "D"}, {"item": "planner", "qty": 75, "size": {"h": 22.85, "w": 30, "uom": "cm"}, "status": "D"}, {"item": "postcard", "qty": 45, "size": {"h": 10, "w": 15.25, "uom": "cm"}, "status": "A"}]) # End Example 6 self.assertEqual(db.inventory.count_documents({}), 5) # Start Example 7 cursor = db.inventory.find({}) # End Example 7 self.assertEqual(len(list(cursor)), 5) # Start Example 9 cursor = db.inventory.find({"status": "D"}) # End Example 9 self.assertEqual(len(list(cursor)), 2) # Start Example 10 cursor = db.inventory.find({"status": {"$in": ["A", "D"]}}) # End Example 10 self.assertEqual(len(list(cursor)), 5) # Start Example 11 cursor = db.inventory.find({"status": "A", "qty": {"$lt": 30}}) # End Example 11 self.assertEqual(len(list(cursor)), 1) # Start Example 12 cursor = db.inventory.find( {"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) # End Example 12 self.assertEqual(len(list(cursor)), 3) # Start Example 13 cursor = db.inventory.find({ "status": "A", "$or": [{"qty": {"$lt": 30}}, {"item": {"$regex": "^p"}}]}) # End Example 13 self.assertEqual(len(list(cursor)), 2) def test_query_embedded_documents(self): db = self.db # Start Example 14 # Subdocument key order matters in a few of these examples so we have # to use bson.son.SON instead of a Python dict. from bson.son import SON db.inventory.insert_many([ {"item": "journal", "qty": 25, "size": SON([("h", 14), ("w", 21), ("uom", "cm")]), "status": "A"}, {"item": "notebook", "qty": 50, "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), "status": "A"}, {"item": "paper", "qty": 100, "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), "status": "D"}, {"item": "planner", "qty": 75, "size": SON([("h", 22.85), ("w", 30), ("uom", "cm")]), "status": "D"}, {"item": "postcard", "qty": 45, "size": SON([("h", 10), ("w", 15.25), ("uom", "cm")]), "status": "A"}]) # End Example 14 # Start Example 15 cursor = db.inventory.find( {"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) # End Example 15 self.assertEqual(len(list(cursor)), 1) # Start Example 16 cursor = db.inventory.find( {"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) # End Example 16 self.assertEqual(len(list(cursor)), 0) # Start Example 17 cursor = db.inventory.find({"size.uom": "in"}) # End Example 17 self.assertEqual(len(list(cursor)), 2) # Start Example 18 cursor = db.inventory.find({"size.h": {"$lt": 15}}) # End Example 18 self.assertEqual(len(list(cursor)), 4) # Start Example 19 cursor = db.inventory.find( {"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) # End Example 19 self.assertEqual(len(list(cursor)), 1) def test_query_arrays(self): db = self.db # Start Example 20 db.inventory.insert_many([ {"item": "journal", "qty": 25, "tags": ["blank", "red"], "dim_cm": [14, 21]}, {"item": "notebook", "qty": 50, "tags": ["red", "blank"], "dim_cm": [14, 21]}, {"item": "paper", "qty": 100, "tags": ["red", "blank", "plain"], "dim_cm": [14, 21]}, {"item": "planner", "qty": 75, "tags": ["blank", "red"], "dim_cm": [22.85, 30]}, {"item": "postcard", "qty": 45, "tags": ["blue"], "dim_cm": [10, 15.25]}]) # End Example 20 # Start Example 21 cursor = db.inventory.find({"tags": ["red", "blank"]}) # End Example 21 self.assertEqual(len(list(cursor)), 1) # Start Example 22 cursor = db.inventory.find({"tags": {"$all": ["red", "blank"]}}) # End Example 22 self.assertEqual(len(list(cursor)), 4) # Start Example 23 cursor = db.inventory.find({"tags": "red"}) # End Example 23 self.assertEqual(len(list(cursor)), 4) # Start Example 24 cursor = db.inventory.find({"dim_cm": {"$gt": 25}}) # End Example 24 self.assertEqual(len(list(cursor)), 1) # Start Example 25 cursor = db.inventory.find({"dim_cm": {"$gt": 15, "$lt": 20}}) # End Example 25 self.assertEqual(len(list(cursor)), 4) # Start Example 26 cursor = db.inventory.find( {"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) # End Example 26 self.assertEqual(len(list(cursor)), 1) # Start Example 27 cursor = db.inventory.find({"dim_cm.1": {"$gt": 25}}) # End Example 27 self.assertEqual(len(list(cursor)), 1) # Start Example 28 cursor = db.inventory.find({"tags": {"$size": 3}}) # End Example 28 self.assertEqual(len(list(cursor)), 1) def test_query_array_of_documents(self): db = self.db # Start Example 29 # Subdocument key order matters in a few of these examples so we have # to use bson.son.SON instead of a Python dict. from bson.son import SON db.inventory.insert_many([ {"item": "journal", "instock": [ SON([("warehouse", "A"), ("qty", 5)]), SON([("warehouse", "C"), ("qty", 15)])]}, {"item": "notebook", "instock": [ SON([("warehouse", "C"), ("qty", 5)])]}, {"item": "paper", "instock": [ SON([("warehouse", "A"), ("qty", 60)]), SON([("warehouse", "B"), ("qty", 15)])]}, {"item": "planner", "instock": [ SON([("warehouse", "A"), ("qty", 40)]), SON([("warehouse", "B"), ("qty", 5)])]}, {"item": "postcard", "instock": [ SON([("warehouse", "B"), ("qty", 15)]), SON([("warehouse", "C"), ("qty", 35)])]}]) # End Example 29 # Start Example 30 cursor = db.inventory.find( {"instock": SON([("warehouse", "A"), ("qty", 5)])}) # End Example 30 self.assertEqual(len(list(cursor)), 1) # Start Example 31 cursor = db.inventory.find( {"instock": SON([("qty", 5), ("warehouse", "A")])}) # End Example 31 self.assertEqual(len(list(cursor)), 0) # Start Example 32 cursor = db.inventory.find({'instock.0.qty': {"$lte": 20}}) # End Example 32 self.assertEqual(len(list(cursor)), 3) # Start Example 33 cursor = db.inventory.find({'instock.qty': {"$lte": 20}}) # End Example 33 self.assertEqual(len(list(cursor)), 5) # Start Example 34 cursor = db.inventory.find( {"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) # End Example 34 self.assertEqual(len(list(cursor)), 1) # Start Example 35 cursor = db.inventory.find( {"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) # End Example 35 self.assertEqual(len(list(cursor)), 3) # Start Example 36 cursor = db.inventory.find({"instock.qty": {"$gt": 10, "$lte": 20}}) # End Example 36 self.assertEqual(len(list(cursor)), 4) # Start Example 37 cursor = db.inventory.find( {"instock.qty": 5, "instock.warehouse": "A"}) # End Example 37 self.assertEqual(len(list(cursor)), 2) def test_query_null(self): db = self.db # Start Example 38 db.inventory.insert_many([{"_id": 1, "item": None}, {"_id": 2}]) # End Example 38 # Start Example 39 cursor = db.inventory.find({"item": None}) # End Example 39 self.assertEqual(len(list(cursor)), 2) # Start Example 40 cursor = db.inventory.find({"item": {"$type": 10}}) # End Example 40 self.assertEqual(len(list(cursor)), 1) # Start Example 41 cursor = db.inventory.find({"item": {"$exists": False}}) # End Example 41 self.assertEqual(len(list(cursor)), 1) def test_projection(self): db = self.db # Start Example 42 db.inventory.insert_many([ {"item": "journal", "status": "A", "size": {"h": 14, "w": 21, "uom": "cm"}, "instock": [{"warehouse": "A", "qty": 5}]}, {"item": "notebook", "status": "A", "size": {"h": 8.5, "w": 11, "uom": "in"}, "instock": [{"warehouse": "C", "qty": 5}]}, {"item": "paper", "status": "D", "size": {"h": 8.5, "w": 11, "uom": "in"}, "instock": [{"warehouse": "A", "qty": 60}]}, {"item": "planner", "status": "D", "size": {"h": 22.85, "w": 30, "uom": "cm"}, "instock": [{"warehouse": "A", "qty": 40}]}, {"item": "postcard", "status": "A", "size": {"h": 10, "w": 15.25, "uom": "cm"}, "instock": [ {"warehouse": "B", "qty": 15}, {"warehouse": "C", "qty": 35}]}]) # End Example 42 # Start Example 43 cursor = db.inventory.find({"status": "A"}) # End Example 43 self.assertEqual(len(list(cursor)), 3) # Start Example 44 cursor = db.inventory.find( {"status": "A"}, {"item": 1, "status": 1}) # End Example 44 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertFalse("size" in doc) self.assertFalse("instock" in doc) # Start Example 45 cursor = db.inventory.find( {"status": "A"}, {"item": 1, "status": 1, "_id": 0}) # End Example 45 for doc in cursor: self.assertFalse("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertFalse("size" in doc) self.assertFalse("instock" in doc) # Start Example 46 cursor = db.inventory.find( {"status": "A"}, {"status": 0, "instock": 0}) # End Example 46 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertFalse("status" in doc) self.assertTrue("size" in doc) self.assertFalse("instock" in doc) # Start Example 47 cursor = db.inventory.find( {"status": "A"}, {"item": 1, "status": 1, "size.uom": 1}) # End Example 47 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertTrue("size" in doc) self.assertFalse("instock" in doc) size = doc['size'] self.assertTrue('uom' in size) self.assertFalse('h' in size) self.assertFalse('w' in size) # Start Example 48 cursor = db.inventory.find({"status": "A"}, {"size.uom": 0}) # End Example 48 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertTrue("size" in doc) self.assertTrue("instock" in doc) size = doc['size'] self.assertFalse('uom' in size) self.assertTrue('h' in size) self.assertTrue('w' in size) # Start Example 49 cursor = db.inventory.find( {"status": "A"}, {"item": 1, "status": 1, "instock.qty": 1}) # End Example 49 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertFalse("size" in doc) self.assertTrue("instock" in doc) for subdoc in doc['instock']: self.assertFalse('warehouse' in subdoc) self.assertTrue('qty' in subdoc) # Start Example 50 cursor = db.inventory.find( {"status": "A"}, {"item": 1, "status": 1, "instock": {"$slice": -1}}) # End Example 50 for doc in cursor: self.assertTrue("_id" in doc) self.assertTrue("item" in doc) self.assertTrue("status" in doc) self.assertFalse("size" in doc) self.assertTrue("instock" in doc) self.assertEqual(len(doc["instock"]), 1) def test_update_and_replace(self): db = self.db # Start Example 51 db.inventory.insert_many([ {"item": "canvas", "qty": 100, "size": {"h": 28, "w": 35.5, "uom": "cm"}, "status": "A"}, {"item": "journal", "qty": 25, "size": {"h": 14, "w": 21, "uom": "cm"}, "status": "A"}, {"item": "mat", "qty": 85, "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, "status": "A"}, {"item": "mousepad", "qty": 25, "size": {"h": 19, "w": 22.85, "uom": "cm"}, "status": "P"}, {"item": "notebook", "qty": 50, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "P"}, {"item": "paper", "qty": 100, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "D"}, {"item": "planner", "qty": 75, "size": {"h": 22.85, "w": 30, "uom": "cm"}, "status": "D"}, {"item": "postcard", "qty": 45, "size": {"h": 10, "w": 15.25, "uom": "cm"}, "status": "A"}, {"item": "sketchbook", "qty": 80, "size": {"h": 14, "w": 21, "uom": "cm"}, "status": "A"}, {"item": "sketch pad", "qty": 95, "size": {"h": 22.85, "w": 30.5, "uom": "cm"}, "status": "A"}]) # End Example 51 # Start Example 52 db.inventory.update_one( {"item": "paper"}, {"$set": {"size.uom": "cm", "status": "P"}, "$currentDate": {"lastModified": True}}) # End Example 52 for doc in db.inventory.find({"item": "paper"}): self.assertEqual(doc["size"]["uom"], "cm") self.assertEqual(doc["status"], "P") self.assertTrue("lastModified" in doc) # Start Example 53 db.inventory.update_many( {"qty": {"$lt": 50}}, {"$set": {"size.uom": "in", "status": "P"}, "$currentDate": {"lastModified": True}}) # End Example 53 for doc in db.inventory.find({"qty": {"$lt": 50}}): self.assertEqual(doc["size"]["uom"], "in") self.assertEqual(doc["status"], "P") self.assertTrue("lastModified" in doc) # Start Example 54 db.inventory.replace_one( {"item": "paper"}, {"item": "paper", "instock": [ {"warehouse": "A", "qty": 60}, {"warehouse": "B", "qty": 40}]}) # End Example 54 for doc in db.inventory.find({"item": "paper"}, {"_id": 0}): self.assertEqual(len(doc.keys()), 2) self.assertTrue("item" in doc) self.assertTrue("instock" in doc) self.assertEqual(len(doc["instock"]), 2) def test_delete(self): db = self.db # Start Example 55 db.inventory.insert_many([ {"item": "journal", "qty": 25, "size": {"h": 14, "w": 21, "uom": "cm"}, "status": "A"}, {"item": "notebook", "qty": 50, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "P"}, {"item": "paper", "qty": 100, "size": {"h": 8.5, "w": 11, "uom": "in"}, "status": "D"}, {"item": "planner", "qty": 75, "size": {"h": 22.85, "w": 30, "uom": "cm"}, "status": "D"}, {"item": "postcard", "qty": 45, "size": {"h": 10, "w": 15.25, "uom": "cm"}, "status": "A"}]) # End Example 55 self.assertEqual(db.inventory.count_documents({}), 5) # Start Example 57 db.inventory.delete_many({"status": "A"}) # End Example 57 self.assertEqual(db.inventory.count_documents({}), 3) # Start Example 58 db.inventory.delete_one({"status": "D"}) # End Example 58 self.assertEqual(db.inventory.count_documents({}), 2) # Start Example 56 db.inventory.delete_many({}) # End Example 56 self.assertEqual(db.inventory.count_documents({}), 0) @client_context.require_version_min(3, 5, 11) @client_context.require_replica_set @client_context.require_no_mmap def test_change_streams(self): db = self.db done = False def insert_docs(): while not done: db.inventory.insert_one({"username": "alice"}) db.inventory.delete_one({"username": "alice"}) t = threading.Thread(target=insert_docs) t.start() try: # 1. The database for reactive, real-time applications # Start Changestream Example 1 cursor = db.inventory.watch() document = next(cursor) # End Changestream Example 1 # Start Changestream Example 2 cursor = db.inventory.watch(full_document='updateLookup') document = next(cursor) # End Changestream Example 2 # Start Changestream Example 3 resume_token = cursor.resume_token cursor = db.inventory.watch(resume_after=resume_token) document = next(cursor) # End Changestream Example 3 # Start Changestream Example 4 pipeline = [ {'$match': {'fullDocument.username': 'alice'}}, {'$addFields': {'newField': 'this is an added field!'}} ] cursor = db.inventory.watch(pipeline=pipeline) document = next(cursor) # End Changestream Example 4 finally: done = True t.join() def test_aggregate_examples(self): db = self.db # Start Aggregation Example 1 db.sales.aggregate([ {"$match": {"items.fruit": "banana"}}, {"$sort": {"date": 1}} ]) # End Aggregation Example 1 # Start Aggregation Example 2 db.sales.aggregate([ {"$unwind": "$items"}, {"$match": {"items.fruit": "banana"}}, {"$group": { "_id": {"day": {"$dayOfWeek": "$date"}}, "count": {"$sum": "$items.quantity"}} }, {"$project": { "dayOfWeek": "$_id.day", "numberSold": "$count", "_id": 0} }, {"$sort": {"numberSold": 1}} ]) # End Aggregation Example 2 # Start Aggregation Example 3 db.sales.aggregate([ {"$unwind": "$items"}, {"$group": { "_id": {"day": {"$dayOfWeek": "$date"}}, "items_sold": {"$sum": "$items.quantity"}, "revenue": { "$sum": { "$multiply": [ "$items.quantity", "$items.price"] } } } }, {"$project": { "day": "$_id.day", "revenue": 1, "items_sold": 1, "discount": { "$cond": { "if": {"$lte": ["$revenue", 250]}, "then": 25, "else": 0 } } } } ]) # End Aggregation Example 3 # $lookup was new in 3.2. The let and pipeline options # were added in 3.6. if client_context.version.at_least(3, 6, 0): # Start Aggregation Example 4 db.air_alliances.aggregate([ {"$lookup": { "from": "air_airlines", "let": {"constituents": "$airlines"}, "pipeline": [ {"$match": {"$expr": {"$in": ["$name", "$$constituents"]}}} ], "as": "airlines" } }, {"$project": { "_id": 0, "name": 1, "airlines": { "$filter": { "input": "$airlines", "as": "airline", "cond": {"$eq": ["$$airline.country", "Canada"]} } } } } ]) # End Aggregation Example 4 def test_commands(self): db = self.db db.restaurants.insert_one({}) # Start runCommand Example 1 db.command("buildInfo") # End runCommand Example 1 # Start runCommand Example 2 db.command("collStats", "restaurants") # End runCommand Example 2 def test_index_management(self): db = self.db # Start Index Example 1 db.records.create_index("score") # End Index Example 1 # Start Index Example 1 db.restaurants.create_index( [("cuisine", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], partialFilterExpression={"rating": {"$gt": 5}} ) # End Index Example 1 @client_context.require_version_min(3, 6, 0) @client_context.require_replica_set def test_misc(self): # Marketing examples client = self.client self.addCleanup(client.drop_database, "test") self.addCleanup(client.drop_database, "my_database") # 2. Tunable consistency controls collection = client.my_database.my_collection with client.start_session() as session: collection.insert_one({'_id': 1}, session=session) collection.update_one( {'_id': 1}, {"$set": {"a": 1}}, session=session) for doc in collection.find({}, session=session): pass # 3. Exploiting the power of arrays collection = client.test.array_updates_test collection.update_one( {'_id': 1}, {"$set": {"a.$[i].b": 2}}, array_filters=[{"i.b": 0}]) class TestTransactionExamples(IntegrationTest): @client_context.require_transactions def test_transactions(self): # Transaction examples client = self.client self.addCleanup(client.drop_database, "hr") self.addCleanup(client.drop_database, "reporting") employees = client.hr.employees events = client.reporting.events employees.insert_one({"employee": 3, "status": "Active"}) events.insert_one( {"employee": 3, "status": {"new": "Active", "old": None}}) # Start Transactions Intro Example 1 def update_employee_info(session): employees_coll = session.client.hr.employees events_coll = session.client.reporting.events with session.start_transaction( read_concern=ReadConcern("snapshot"), write_concern=WriteConcern(w="majority")): employees_coll.update_one( {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session) events_coll.insert_one( {"employee": 3, "status": { "new": "Inactive", "old": "Active"}}, session=session) while True: try: # Commit uses write concern set at transaction start. session.commit_transaction() print("Transaction committed.") break except (ConnectionFailure, OperationFailure) as exc: # Can retry commit if exc.has_error_label( "UnknownTransactionCommitResult"): print("UnknownTransactionCommitResult, retrying " "commit operation ...") continue else: print("Error during commit ...") raise # End Transactions Intro Example 1 with client.start_session() as session: update_employee_info(session) employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') # Start Transactions Retry Example 1 def run_transaction_with_retry(txn_func, session): while True: try: txn_func(session) # performs transaction break except (ConnectionFailure, OperationFailure) as exc: print("Transaction aborted. Caught exception during " "transaction.") # If transient error, retry the whole transaction if exc.has_error_label("TransientTransactionError"): print("TransientTransactionError, retrying" "transaction ...") continue else: raise # End Transactions Retry Example 1 with client.start_session() as session: run_transaction_with_retry(update_employee_info, session) employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') # Start Transactions Retry Example 2 def commit_with_retry(session): while True: try: # Commit uses write concern set at transaction start. session.commit_transaction() print("Transaction committed.") break except (ConnectionFailure, OperationFailure) as exc: # Can retry commit if exc.has_error_label("UnknownTransactionCommitResult"): print("UnknownTransactionCommitResult, retrying " "commit operation ...") continue else: print("Error during commit ...") raise # End Transactions Retry Example 2 # Test commit_with_retry from the previous examples def _insert_employee_retry_commit(session): with session.start_transaction(): employees.insert_one( {"employee": 4, "status": "Active"}, session=session) events.insert_one( {"employee": 4, "status": {"new": "Active", "old": None}}, session=session) commit_with_retry(session) with client.start_session() as session: run_transaction_with_retry(_insert_employee_retry_commit, session) employee = employees.find_one({"employee": 4}) self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Active') # Start Transactions Retry Example 3 def run_transaction_with_retry(txn_func, session): while True: try: txn_func(session) # performs transaction break except (ConnectionFailure, OperationFailure) as exc: # If transient error, retry the whole transaction if exc.has_error_label("TransientTransactionError"): print("TransientTransactionError, retrying " "transaction ...") continue else: raise def commit_with_retry(session): while True: try: # Commit uses write concern set at transaction start. session.commit_transaction() print("Transaction committed.") break except (ConnectionFailure, OperationFailure) as exc: # Can retry commit if exc.has_error_label("UnknownTransactionCommitResult"): print("UnknownTransactionCommitResult, retrying " "commit operation ...") continue else: print("Error during commit ...") raise # Updates two collections in a transactions def update_employee_info(session): employees_coll = session.client.hr.employees events_coll = session.client.reporting.events with session.start_transaction( read_concern=ReadConcern("snapshot"), write_concern=WriteConcern(w="majority"), read_preference=ReadPreference.PRIMARY): employees_coll.update_one( {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session) events_coll.insert_one( {"employee": 3, "status": { "new": "Inactive", "old": "Active"}}, session=session) commit_with_retry(session) # Start a session. with client.start_session() as session: try: run_transaction_with_retry(update_employee_info, session) except Exception as exc: # Do something with error. raise # End Transactions Retry Example 3 employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) self.assertEqual(employee['status'], 'Inactive') MongoClient = lambda _: rs_client() uriString = None # Start Transactions withTxn API Example 1 # For a replica set, include the replica set name and a seedlist of the members in the URI string; e.g. # uriString = 'mongodb://mongodb0.example.com:27017,mongodb1.example.com:27017/?replicaSet=myRepl' # For a sharded cluster, connect to the mongos instances; e.g. # uriString = 'mongodb://mongos0.example.com:27017,mongos1.example.com:27017/' client = MongoClient(uriString) wc_majority = WriteConcern("majority", wtimeout=1000) # Prereq: Create collections. client.get_database( "mydb1", write_concern=wc_majority).foo.insert_one({'abc': 0}) client.get_database( "mydb2", write_concern=wc_majority).bar.insert_one({'xyz': 0}) # Step 1: Define the callback that specifies the sequence of operations to perform inside the transactions. def callback(session): collection_one = session.client.mydb1.foo collection_two = session.client.mydb2.bar # Important:: You must pass the session to the operations. collection_one.insert_one({'abc': 1}, session=session) collection_two.insert_one({'xyz': 999}, session=session) # Step 2: Start a client session. with client.start_session() as session: # Step 3: Use with_transaction to start a transaction, execute the callback, and commit (or abort on error). session.with_transaction( callback, read_concern=ReadConcern('local'), write_concern=wc_majority, read_preference=ReadPreference.PRIMARY) # End Transactions withTxn API Example 1 class TestCausalConsistencyExamples(IntegrationTest): @client_context.require_version_min(3, 6, 0) @client_context.require_secondaries_count(1) @client_context.require_no_mmap def test_causal_consistency(self): # Causal consistency examples client = self.client self.addCleanup(client.drop_database, 'test') client.test.drop_collection('items') client.test.items.insert_one({ 'sku': "111", 'name': 'Peanuts', 'start':datetime.datetime.today()}) # Start Causal Consistency Example 1 with client.start_session(causal_consistency=True) as s1: current_date = datetime.datetime.today() items = client.get_database( 'test', read_concern=ReadConcern('majority'), write_concern=WriteConcern('majority', wtimeout=1000)).items items.update_one( {'sku': "111", 'end': None}, {'$set': {'end': current_date}}, session=s1) items.insert_one( {'sku': "nuts-111", 'name': "Pecans", 'start': current_date}, session=s1) # End Causal Consistency Example 1 # Start Causal Consistency Example 2 with client.start_session(causal_consistency=True) as s2: s2.advance_cluster_time(s1.cluster_time) s2.advance_operation_time(s1.operation_time) items = client.get_database( 'test', read_preference=ReadPreference.SECONDARY, read_concern=ReadConcern('majority'), write_concern=WriteConcern('majority', wtimeout=1000)).items for item in items.find({'end': None}, session=s2): print(item) # End Causal Consistency Example 2 if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_grid_file.py000066400000000000000000000537041374256237000176220ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the grid_file module. """ import datetime import sys import zipfile sys.path[0:0] = [""] from bson.objectid import ObjectId from bson.py3compat import StringIO from gridfs import GridFS from gridfs.grid_file import (DEFAULT_CHUNK_SIZE, _SEEK_CUR, _SEEK_END, GridIn, GridOut, GridOutCursor) from gridfs.errors import NoFile from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError from pymongo.message import _CursorAddress from test import (IntegrationTest, unittest, qcheck) from test.utils import rs_or_single_client, EventListener class TestGridFileNoConnect(unittest.TestCase): """Test GridFile features on a client that does not connect. """ @classmethod def setUpClass(cls): cls.db = MongoClient(connect=False).pymongo_test def test_grid_in_custom_opts(self): self.assertRaises(TypeError, GridIn, "foo") a = GridIn(self.db.fs, _id=5, filename="my_file", contentType="text/html", chunkSize=1000, aliases=["foo"], metadata={"foo": 1, "bar": 2}, bar=3, baz="hello") self.assertEqual(5, a._id) self.assertEqual("my_file", a.filename) self.assertEqual("my_file", a.name) self.assertEqual("text/html", a.content_type) self.assertEqual(1000, a.chunk_size) self.assertEqual(["foo"], a.aliases) self.assertEqual({"foo": 1, "bar": 2}, a.metadata) self.assertEqual(3, a.bar) self.assertEqual("hello", a.baz) self.assertRaises(AttributeError, getattr, a, "mike") b = GridIn(self.db.fs, content_type="text/html", chunk_size=1000, baz=100) self.assertEqual("text/html", b.content_type) self.assertEqual(1000, b.chunk_size) self.assertEqual(100, b.baz) class TestGridFile(IntegrationTest): def setUp(self): self.db.drop_collection('fs.files') self.db.drop_collection('fs.chunks') def test_basic(self): f = GridIn(self.db.fs, filename="test") f.write(b"hello world") f.close() self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) g = GridOut(self.db.fs, f._id) self.assertEqual(b"hello world", g.read()) # make sure it's still there... g = GridOut(self.db.fs, f._id) self.assertEqual(b"hello world", g.read()) f = GridIn(self.db.fs, filename="test") f.close() self.assertEqual(2, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) g = GridOut(self.db.fs, f._id) self.assertEqual(b"", g.read()) # test that reading 0 returns proper type self.assertEqual(b"", g.read(0)) def test_md5(self): f = GridIn(self.db.fs) f.write(b"hello world\n") f.close() self.assertEqual("6f5902ac237024bdd0c176cb93063dc4", f.md5) def test_alternate_collection(self): self.db.alt.files.delete_many({}) self.db.alt.chunks.delete_many({}) f = GridIn(self.db.alt) f.write(b"hello world") f.close() self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) g = GridOut(self.db.alt, f._id) self.assertEqual(b"hello world", g.read()) # test that md5 still works... self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", g.md5) def test_grid_in_default_opts(self): self.assertRaises(TypeError, GridIn, "foo") a = GridIn(self.db.fs) self.assertTrue(isinstance(a._id, ObjectId)) self.assertRaises(AttributeError, setattr, a, "_id", 5) self.assertEqual(None, a.filename) self.assertEqual(None, a.name) a.filename = "my_file" self.assertEqual("my_file", a.filename) self.assertEqual("my_file", a.name) self.assertEqual(None, a.content_type) a.content_type = "text/html" self.assertEqual("text/html", a.content_type) self.assertRaises(AttributeError, getattr, a, "length") self.assertRaises(AttributeError, setattr, a, "length", 5) self.assertEqual(255 * 1024, a.chunk_size) self.assertRaises(AttributeError, setattr, a, "chunk_size", 5) self.assertRaises(AttributeError, getattr, a, "upload_date") self.assertRaises(AttributeError, setattr, a, "upload_date", 5) self.assertRaises(AttributeError, getattr, a, "aliases") a.aliases = ["foo"] self.assertEqual(["foo"], a.aliases) self.assertRaises(AttributeError, getattr, a, "metadata") a.metadata = {"foo": 1} self.assertEqual({"foo": 1}, a.metadata) self.assertRaises(AttributeError, setattr, a, "md5", 5) a.close() a.forty_two = 42 self.assertEqual(42, a.forty_two) self.assertTrue(isinstance(a._id, ObjectId)) self.assertRaises(AttributeError, setattr, a, "_id", 5) self.assertEqual("my_file", a.filename) self.assertEqual("my_file", a.name) self.assertEqual("text/html", a.content_type) self.assertEqual(0, a.length) self.assertRaises(AttributeError, setattr, a, "length", 5) self.assertEqual(255 * 1024, a.chunk_size) self.assertRaises(AttributeError, setattr, a, "chunk_size", 5) self.assertTrue(isinstance(a.upload_date, datetime.datetime)) self.assertRaises(AttributeError, setattr, a, "upload_date", 5) self.assertEqual(["foo"], a.aliases) self.assertEqual({"foo": 1}, a.metadata) self.assertEqual("d41d8cd98f00b204e9800998ecf8427e", a.md5) self.assertRaises(AttributeError, setattr, a, "md5", 5) # Make sure custom attributes that were set both before and after # a.close() are reflected in b. PYTHON-411. b = GridFS(self.db).get_last_version(filename=a.filename) self.assertEqual(a.metadata, b.metadata) self.assertEqual(a.aliases, b.aliases) self.assertEqual(a.forty_two, b.forty_two) def test_grid_out_default_opts(self): self.assertRaises(TypeError, GridOut, "foo") gout = GridOut(self.db.fs, 5) with self.assertRaises(NoFile): gout.name a = GridIn(self.db.fs) a.close() b = GridOut(self.db.fs, a._id) self.assertEqual(a._id, b._id) self.assertEqual(0, b.length) self.assertEqual(None, b.content_type) self.assertEqual(None, b.name) self.assertEqual(None, b.filename) self.assertEqual(255 * 1024, b.chunk_size) self.assertTrue(isinstance(b.upload_date, datetime.datetime)) self.assertEqual(None, b.aliases) self.assertEqual(None, b.metadata) self.assertEqual("d41d8cd98f00b204e9800998ecf8427e", b.md5) for attr in ["_id", "name", "content_type", "length", "chunk_size", "upload_date", "aliases", "metadata", "md5"]: self.assertRaises(AttributeError, setattr, b, attr, 5) def test_grid_out_cursor_options(self): self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {}, projection={"filename": 1}) cursor = GridOutCursor(self.db.fs, {}) cursor_clone = cursor.clone() cursor_dict = cursor.__dict__.copy() cursor_dict.pop('_Cursor__session') cursor_clone_dict = cursor_clone.__dict__.copy() cursor_clone_dict.pop('_Cursor__session') self.assertEqual(cursor_dict, cursor_clone_dict) self.assertRaises(NotImplementedError, cursor.add_option, 0) self.assertRaises(NotImplementedError, cursor.remove_option, 0) def test_grid_out_custom_opts(self): one = GridIn(self.db.fs, _id=5, filename="my_file", contentType="text/html", chunkSize=1000, aliases=["foo"], metadata={"foo": 1, "bar": 2}, bar=3, baz="hello") one.write(b"hello world") one.close() two = GridOut(self.db.fs, 5) self.assertEqual("my_file", two.name) self.assertEqual("my_file", two.filename) self.assertEqual(5, two._id) self.assertEqual(11, two.length) self.assertEqual("text/html", two.content_type) self.assertEqual(1000, two.chunk_size) self.assertTrue(isinstance(two.upload_date, datetime.datetime)) self.assertEqual(["foo"], two.aliases) self.assertEqual({"foo": 1, "bar": 2}, two.metadata) self.assertEqual(3, two.bar) self.assertEqual("5eb63bbbe01eeed093cb22bb8f5acdc3", two.md5) for attr in ["_id", "name", "content_type", "length", "chunk_size", "upload_date", "aliases", "metadata", "md5"]: self.assertRaises(AttributeError, setattr, two, attr, 5) def test_grid_out_file_document(self): one = GridIn(self.db.fs) one.write(b"foo bar") one.close() two = GridOut(self.db.fs, file_document=self.db.fs.files.find_one()) self.assertEqual(b"foo bar", two.read()) three = GridOut(self.db.fs, 5, file_document=self.db.fs.files.find_one()) self.assertEqual(b"foo bar", three.read()) four = GridOut(self.db.fs, file_document={}) with self.assertRaises(NoFile): four.name def test_write_file_like(self): one = GridIn(self.db.fs) one.write(b"hello world") one.close() two = GridOut(self.db.fs, one._id) three = GridIn(self.db.fs) three.write(two) three.close() four = GridOut(self.db.fs, three._id) self.assertEqual(b"hello world", four.read()) five = GridIn(self.db.fs, chunk_size=2) five.write(b"hello") buffer = StringIO(b" world") five.write(buffer) five.write(b" and mongodb") five.close() self.assertEqual(b"hello world and mongodb", GridOut(self.db.fs, five._id).read()) def test_write_lines(self): a = GridIn(self.db.fs) a.writelines([b"hello ", b"world"]) a.close() self.assertEqual(b"hello world", GridOut(self.db.fs, a._id).read()) def test_close(self): f = GridIn(self.db.fs) f.close() self.assertRaises(ValueError, f.write, "test") f.close() def test_multi_chunk_file(self): random_string = b'a' * (DEFAULT_CHUNK_SIZE + 1000) f = GridIn(self.db.fs) f.write(random_string) f.close() self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(2, self.db.fs.chunks.count_documents({})) g = GridOut(self.db.fs, f._id) self.assertEqual(random_string, g.read()) def test_small_chunks(self): self.files = 0 self.chunks = 0 def helper(data): f = GridIn(self.db.fs, chunkSize=1) f.write(data) f.close() self.files += 1 self.chunks += len(data) self.assertEqual(self.files, self.db.fs.files.count_documents({})) self.assertEqual(self.chunks, self.db.fs.chunks.count_documents({})) g = GridOut(self.db.fs, f._id) self.assertEqual(data, g.read()) g = GridOut(self.db.fs, f._id) self.assertEqual(data, g.read(10) + g.read(10)) return True qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20))) def test_seek(self): f = GridIn(self.db.fs, chunkSize=3) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(b"hello world", g.read()) g.seek(0) self.assertEqual(b"hello world", g.read()) g.seek(1) self.assertEqual(b"ello world", g.read()) self.assertRaises(IOError, g.seek, -1) g.seek(-3, _SEEK_END) self.assertEqual(b"rld", g.read()) g.seek(0, _SEEK_END) self.assertEqual(b"", g.read()) self.assertRaises(IOError, g.seek, -100, _SEEK_END) g.seek(3) g.seek(3, _SEEK_CUR) self.assertEqual(b"world", g.read()) self.assertRaises(IOError, g.seek, -100, _SEEK_CUR) def test_tell(self): f = GridIn(self.db.fs, chunkSize=3) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(0, g.tell()) g.read(0) self.assertEqual(0, g.tell()) g.read(1) self.assertEqual(1, g.tell()) g.read(2) self.assertEqual(3, g.tell()) g.read() self.assertEqual(g.length, g.tell()) def test_multiple_reads(self): f = GridIn(self.db.fs, chunkSize=3) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(b"he", g.read(2)) self.assertEqual(b"ll", g.read(2)) self.assertEqual(b"o ", g.read(2)) self.assertEqual(b"wo", g.read(2)) self.assertEqual(b"rl", g.read(2)) self.assertEqual(b"d", g.read(2)) self.assertEqual(b"", g.read(2)) def test_readline(self): f = GridIn(self.db.fs, chunkSize=5) f.write((b"""Hello world, How are you? Hope all is well. Bye""")) f.close() # Try read(), then readline(). g = GridOut(self.db.fs, f._id) self.assertEqual(b"H", g.read(1)) self.assertEqual(b"ello world,\n", g.readline()) self.assertEqual(b"How a", g.readline(5)) self.assertEqual(b"", g.readline(0)) self.assertEqual(b"re you?\n", g.readline()) self.assertEqual(b"Hope all is well.\n", g.readline(1000)) self.assertEqual(b"Bye", g.readline()) self.assertEqual(b"", g.readline()) # Try readline() first, then read(). g = GridOut(self.db.fs, f._id) self.assertEqual(b"He", g.readline(2)) self.assertEqual(b"l", g.read(1)) self.assertEqual(b"lo", g.readline(2)) self.assertEqual(b" world,\n", g.readline()) # Only readline(). g = GridOut(self.db.fs, f._id) self.assertEqual(b"H", g.readline(1)) self.assertEqual(b"e", g.readline(1)) self.assertEqual(b"llo world,\n", g.readline()) def test_iterator(self): f = GridIn(self.db.fs) f.close() g = GridOut(self.db.fs, f._id) self.assertEqual([], list(g)) f = GridIn(self.db.fs) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual([b"hello world"], list(g)) self.assertEqual(b"hello", g.read(5)) self.assertEqual([b"hello world"], list(g)) self.assertEqual(b" worl", g.read(5)) f = GridIn(self.db.fs, chunk_size=2) f.write(b"hello world") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual([b"he", b"ll", b"o ", b"wo", b"rl", b"d"], list(g)) def test_read_unaligned_buffer_size(self): in_data = (b"This is a text that doesn't " b"quite fit in a single 16-byte chunk.") f = GridIn(self.db.fs, chunkSize=16) f.write(in_data) f.close() g = GridOut(self.db.fs, f._id) out_data = b'' while 1: s = g.read(13) if not s: break out_data += s self.assertEqual(in_data, out_data) def test_readchunk(self): in_data = b'a' * 10 f = GridIn(self.db.fs, chunkSize=3) f.write(in_data) f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(3, len(g.readchunk())) self.assertEqual(2, len(g.read(2))) self.assertEqual(1, len(g.readchunk())) self.assertEqual(3, len(g.read(3))) self.assertEqual(1, len(g.readchunk())) self.assertEqual(0, len(g.readchunk())) def test_write_unicode(self): f = GridIn(self.db.fs) self.assertRaises(TypeError, f.write, u"foo") f = GridIn(self.db.fs, encoding="utf-8") f.write(u"foo") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(b"foo", g.read()) f = GridIn(self.db.fs, encoding="iso-8859-1") f.write(u"aé") f.close() g = GridOut(self.db.fs, f._id) self.assertEqual(u"aé".encode("iso-8859-1"), g.read()) def test_set_after_close(self): f = GridIn(self.db.fs, _id="foo", bar="baz") self.assertEqual("foo", f._id) self.assertEqual("baz", f.bar) self.assertRaises(AttributeError, getattr, f, "baz") self.assertRaises(AttributeError, getattr, f, "uploadDate") self.assertRaises(AttributeError, setattr, f, "_id", 5) f.bar = "foo" f.baz = 5 self.assertEqual("foo", f._id) self.assertEqual("foo", f.bar) self.assertEqual(5, f.baz) self.assertRaises(AttributeError, getattr, f, "uploadDate") f.close() self.assertEqual("foo", f._id) self.assertEqual("foo", f.bar) self.assertEqual(5, f.baz) self.assertTrue(f.uploadDate) self.assertRaises(AttributeError, setattr, f, "_id", 5) f.bar = "a" f.baz = "b" self.assertRaises(AttributeError, setattr, f, "upload_date", 5) g = GridOut(self.db.fs, f._id) self.assertEqual("a", g.bar) self.assertEqual("b", g.baz) # Versions 2.0.1 and older saved a _closed field for some reason. self.assertRaises(AttributeError, getattr, g, "_closed") def test_context_manager(self): contents = b"Imagine this is some important data..." with GridIn(self.db.fs, filename="important") as infile: infile.write(contents) with GridOut(self.db.fs, infile._id) as outfile: self.assertEqual(contents, outfile.read()) def test_prechunked_string(self): def write_me(s, chunk_size): buf = StringIO(s) infile = GridIn(self.db.fs) while True: to_write = buf.read(chunk_size) if to_write == b'': break infile.write(to_write) infile.close() buf.close() outfile = GridOut(self.db.fs, infile._id) data = outfile.read() self.assertEqual(s, data) s = b'x' * DEFAULT_CHUNK_SIZE * 4 # Test with default chunk size write_me(s, DEFAULT_CHUNK_SIZE) # Multiple write_me(s, DEFAULT_CHUNK_SIZE * 3) # Custom write_me(s, 262300) def test_grid_out_lazy_connect(self): fs = self.db.fs outfile = GridOut(fs, file_id=-1) self.assertRaises(NoFile, outfile.read) self.assertRaises(NoFile, getattr, outfile, 'filename') infile = GridIn(fs, filename=1) infile.close() outfile = GridOut(fs, infile._id) outfile.read() outfile.filename outfile = GridOut(fs, infile._id) outfile.readchunk() def test_grid_in_lazy_connect(self): client = MongoClient('badhost', connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) self.assertRaises(ServerSelectionTimeoutError, infile.write, b'data') self.assertRaises(ServerSelectionTimeoutError, infile.close) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): GridIn(rs_or_single_client(w=0).pymongo_test.fs) def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. # Use 102 batches to cause a single getMore. chunk_size = 1024 data = b'd' * (102 * chunk_size) listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test with GridIn(db.fs, chunk_size=chunk_size) as infile: infile.write(data) with GridOut(db.fs, infile._id) as outfile: self.assertEqual(len(outfile.readchunk()), chunk_size) # Kill the cursor to simulate the cursor timing out on the server # when an application spends a long time between two calls to # readchunk(). client._close_cursor_now( outfile._GridOut__chunk_iter._cursor.cursor_id, _CursorAddress(client.address, db.fs.chunks.full_name)) # Read the rest of the file without error. self.assertEqual(len(outfile.read()), len(data) - chunk_size) # Paranoid, ensure that a getMore was actually sent. self.assertIn("getMore", listener.started_command_names()) def test_zip(self): zf = StringIO() z = zipfile.ZipFile(zf, "w") z.writestr("test.txt", b"hello world") z.close() zf.seek(0) f = GridIn(self.db.fs, filename="test.zip") f.write(zf) f.close() self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) g = GridOut(self.db.fs, f._id) z = zipfile.ZipFile(g) self.assertSequenceEqual(z.namelist(), ["test.txt"]) self.assertEqual(z.read("test.txt"), b"hello world") if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_gridfs.py000066400000000000000000000472731374256237000171600ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the gridfs package. """ import sys sys.path[0:0] = [""] import datetime import threading import time import gridfs from bson.binary import Binary from bson.py3compat import StringIO, string_type from pymongo.mongo_client import MongoClient from pymongo.errors import (ConfigurationError, ConnectionFailure, ServerSelectionTimeoutError) from pymongo.read_preferences import ReadPreference from gridfs.errors import CorruptGridFile, FileExists, NoFile from test.test_replica_set_client import TestReplicaSetClientBase from test import (client_context, unittest, IntegrationTest) from test.utils import (ignore_deprecations, joinall, one, rs_client, rs_or_single_client, single_client) class JustWrite(threading.Thread): def __init__(self, fs, n): threading.Thread.__init__(self) self.fs = fs self.n = n self.setDaemon(True) def run(self): for _ in range(self.n): file = self.fs.new_file(filename="test") file.write(b"hello") file.close() class JustRead(threading.Thread): def __init__(self, fs, n, results): threading.Thread.__init__(self) self.fs = fs self.n = n self.results = results self.setDaemon(True) def run(self): for _ in range(self.n): file = self.fs.get("test") data = file.read() self.results.append(data) assert data == b"hello" class TestGridfsNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): cls.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") self.assertRaises(TypeError, gridfs.GridFS, self.db, 5) class TestGridfs(IntegrationTest): @classmethod def setUpClass(cls): super(TestGridfs, cls).setUpClass() cls.fs = gridfs.GridFS(cls.db) cls.alt = gridfs.GridFS(cls.db, "alt") def setUp(self): self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") self.db.drop_collection("alt.files") self.db.drop_collection("alt.chunks") def test_basic(self): oid = self.fs.put(b"hello world") self.assertEqual(b"hello world", self.fs.get(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) self.assertRaises(NoFile, self.fs.get, oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) self.assertRaises(NoFile, self.fs.get, "foo") oid = self.fs.put(b"hello world", _id="foo") self.assertEqual("foo", oid) self.assertEqual(b"hello world", self.fs.get("foo").read()) def test_multi_chunk_delete(self): self.db.fs.drop() self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) gfs = gridfs.GridFS(self.db) oid = gfs.put(b"hello", chunkSize=1) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(5, self.db.fs.chunks.count_documents({})) gfs.delete(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) def test_list(self): self.assertEqual([], self.fs.list()) self.fs.put(b"hello world") self.assertEqual([], self.fs.list()) # PYTHON-598: in server versions before 2.5.x, creating an index on # filename, uploadDate causes list() to include None. self.fs.get_last_version() self.assertEqual([], self.fs.list()) self.fs.put(b"", filename="mike") self.fs.put(b"foo", filename="test") self.fs.put(b"", filename="hello world") self.assertEqual(set(["mike", "test", "hello world"]), set(self.fs.list())) def test_empty_file(self): oid = self.fs.put(b"") self.assertEqual(b"", self.fs.get(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) raw = self.db.fs.files.find_one() self.assertEqual(0, raw["length"]) self.assertEqual(oid, raw["_id"]) self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) self.assertEqual(255 * 1024, raw["chunkSize"]) self.assertTrue(isinstance(raw["md5"], string_type)) def test_corrupt_chunk(self): files_id = self.fs.put(b'foobar') self.db.fs.chunks.update_one({'files_id': files_id}, {'$set': {'data': Binary(b'foo', 0)}}) try: out = self.fs.get(files_id) self.assertRaises(CorruptGridFile, out.read) out = self.fs.get(files_id) self.assertRaises(CorruptGridFile, out.readline) finally: self.fs.delete(files_id) def test_put_ensures_index(self): # setUp has dropped collections. names = self.db.list_collection_names() self.assertFalse([name for name in names if name.startswith('fs')]) chunks = self.db.fs.chunks files = self.db.fs.files self.fs.put(b"junk") self.assertTrue(any( info.get('key') == [('files_id', 1), ('n', 1)] for info in chunks.index_information().values())) self.assertTrue(any( info.get('key') == [('filename', 1), ('uploadDate', 1)] for info in files.index_information().values())) def test_alt_collection(self): oid = self.alt.put(b"hello world") self.assertEqual(b"hello world", self.alt.get(oid).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) self.assertRaises(NoFile, self.alt.get, oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) self.assertRaises(NoFile, self.alt.get, "foo") oid = self.alt.put(b"hello world", _id="foo") self.assertEqual("foo", oid) self.assertEqual(b"hello world", self.alt.get("foo").read()) self.alt.put(b"", filename="mike") self.alt.put(b"foo", filename="test") self.alt.put(b"", filename="hello world") self.assertEqual(set(["mike", "test", "hello world"]), set(self.alt.list())) def test_threaded_reads(self): self.fs.put(b"hello", _id="test") threads = [] results = [] for i in range(10): threads.append(JustRead(self.fs, 10, results)) threads[i].start() joinall(threads) self.assertEqual( 100 * [b'hello'], results ) def test_threaded_writes(self): threads = [] for i in range(10): threads.append(JustWrite(self.fs, 10)) threads[i].start() joinall(threads) f = self.fs.get_last_version("test") self.assertEqual(f.read(), b"hello") # Should have created 100 versions of 'test' file self.assertEqual( 100, self.db.fs.files.count_documents({'filename': 'test'}) ) def test_get_last_version(self): one = self.fs.put(b"foo", filename="test") time.sleep(0.01) two = self.fs.new_file(filename="test") two.write(b"bar") two.close() time.sleep(0.01) two = two._id three = self.fs.put(b"baz", filename="test") self.assertEqual(b"baz", self.fs.get_last_version("test").read()) self.fs.delete(three) self.assertEqual(b"bar", self.fs.get_last_version("test").read()) self.fs.delete(two) self.assertEqual(b"foo", self.fs.get_last_version("test").read()) self.fs.delete(one) self.assertRaises(NoFile, self.fs.get_last_version, "test") def test_get_last_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author") self.assertEqual(b"bar", self.fs.get_last_version(author="author").read()) self.fs.delete(two) self.assertEqual(b"foo", self.fs.get_last_version(author="author").read()) self.fs.delete(one) one = self.fs.put(b"foo", filename="test", author="author1") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author2") self.assertEqual(b"foo", self.fs.get_last_version(author="author1").read()) self.assertEqual(b"bar", self.fs.get_last_version(author="author2").read()) self.assertEqual(b"bar", self.fs.get_last_version(filename="test").read()) self.assertRaises(NoFile, self.fs.get_last_version, author="author3") self.assertRaises(NoFile, self.fs.get_last_version, filename="nottest", author="author1") self.fs.delete(one) self.fs.delete(two) def test_get_version(self): self.fs.put(b"foo", filename="test") time.sleep(0.01) self.fs.put(b"bar", filename="test") time.sleep(0.01) self.fs.put(b"baz", filename="test") time.sleep(0.01) self.assertEqual(b"foo", self.fs.get_version("test", 0).read()) self.assertEqual(b"bar", self.fs.get_version("test", 1).read()) self.assertEqual(b"baz", self.fs.get_version("test", 2).read()) self.assertEqual(b"baz", self.fs.get_version("test", -1).read()) self.assertEqual(b"bar", self.fs.get_version("test", -2).read()) self.assertEqual(b"foo", self.fs.get_version("test", -3).read()) self.assertRaises(NoFile, self.fs.get_version, "test", 3) self.assertRaises(NoFile, self.fs.get_version, "test", -4) def test_get_version_with_metadata(self): one = self.fs.put(b"foo", filename="test", author="author1") time.sleep(0.01) two = self.fs.put(b"bar", filename="test", author="author1") time.sleep(0.01) three = self.fs.put(b"baz", filename="test", author="author2") self.assertEqual(b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read()) self.assertEqual(b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read()) self.assertEqual(b"foo", self.fs.get_version(filename="test", author="author1", version=0).read()) self.assertEqual(b"bar", self.fs.get_version(filename="test", author="author1", version=1).read()) self.assertEqual(b"baz", self.fs.get_version(filename="test", author="author2", version=0).read()) self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read()) self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read()) self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3") self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2) self.fs.delete(one) self.fs.delete(two) self.fs.delete(three) def test_put_filelike(self): oid = self.fs.put(StringIO(b"hello world"), chunk_size=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) self.assertEqual(b"hello world", self.fs.get(oid).read()) def test_file_exists(self): oid = self.fs.put(b"hello") self.assertRaises(FileExists, self.fs.put, b"world", _id=oid) one = self.fs.new_file(_id=123) one.write(b"some content") one.close() two = self.fs.new_file(_id=123) self.assertRaises(FileExists, two.write, b'x' * 262146) def test_exists(self): oid = self.fs.put(b"hello") self.assertTrue(self.fs.exists(oid)) self.assertTrue(self.fs.exists({"_id": oid})) self.assertTrue(self.fs.exists(_id=oid)) self.assertFalse(self.fs.exists(filename="mike")) self.assertFalse(self.fs.exists("mike")) oid = self.fs.put(b"hello", filename="mike", foo=12) self.assertTrue(self.fs.exists(oid)) self.assertTrue(self.fs.exists({"_id": oid})) self.assertTrue(self.fs.exists(_id=oid)) self.assertTrue(self.fs.exists(filename="mike")) self.assertTrue(self.fs.exists({"filename": "mike"})) self.assertTrue(self.fs.exists(foo=12)) self.assertTrue(self.fs.exists({"foo": 12})) self.assertTrue(self.fs.exists(foo={"$gt": 11})) self.assertTrue(self.fs.exists({"foo": {"$gt": 11}})) self.assertFalse(self.fs.exists(foo=13)) self.assertFalse(self.fs.exists({"foo": 13})) self.assertFalse(self.fs.exists(foo={"$gt": 12})) self.assertFalse(self.fs.exists({"foo": {"$gt": 12}})) def test_put_unicode(self): self.assertRaises(TypeError, self.fs.put, u"hello") oid = self.fs.put(u"hello", encoding="utf-8") self.assertEqual(b"hello", self.fs.get(oid).read()) self.assertEqual("utf-8", self.fs.get(oid).encoding) oid = self.fs.put(u"aé", encoding="iso-8859-1") self.assertEqual(u"aé".encode("iso-8859-1"), self.fs.get(oid).read()) self.assertEqual("iso-8859-1", self.fs.get(oid).encoding) def test_missing_length_iter(self): # Test fix that guards against PHP-237 self.fs.put(b"", filename="empty") doc = self.db.fs.files.find_one({"filename": "empty"}) doc.pop("length") self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) f = self.fs.get_last_version(filename="empty") def iterate_file(grid_file): for chunk in grid_file: pass return True self.assertTrue(iterate_file(f)) def test_gridfs_lazy_connect(self): client = MongoClient('badhost', connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) self.assertRaises(ServerSelectionTimeoutError, gfs.list) fs = gridfs.GridFS(db) f = fs.new_file() self.assertRaises(ServerSelectionTimeoutError, f.close) @ignore_deprecations def test_gridfs_find(self): self.fs.put(b"test2", filename="two") time.sleep(0.01) self.fs.put(b"test2+", filename="two") time.sleep(0.01) self.fs.put(b"test1", filename="one") time.sleep(0.01) self.fs.put(b"test2++", filename="two") self.assertEqual(3, self.fs.find({"filename": "two"}).count()) self.assertEqual(4, self.fs.find().count()) cursor = self.fs.find( no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) gout = next(cursor) self.assertEqual(b"test1", gout.read()) cursor.rewind() gout = next(cursor) self.assertEqual(b"test1", gout.read()) gout = next(cursor) self.assertEqual(b"test2+", gout.read()) self.assertRaises(StopIteration, cursor.__next__) cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) def test_gridfs_find_one(self): self.assertEqual(None, self.fs.find_one()) id1 = self.fs.put(b'test1', filename='file1') self.assertEqual(b'test1', self.fs.find_one().read()) id2 = self.fs.put(b'test2', filename='file2', meta='data') self.assertEqual(b'test1', self.fs.find_one(id1).read()) self.assertEqual(b'test2', self.fs.find_one(id2).read()) self.assertEqual(b'test1', self.fs.find_one({'filename': 'file1'}).read()) self.assertEqual('data', self.fs.find_one(id2).meta) def test_grid_in_non_int_chunksize(self): # Lua, and perhaps other buggy GridFS clients, store size as a float. data = b'data' self.fs.put(data, filename='f') self.db.fs.files.update_one({'filename': 'f'}, {'$set': {'chunkSize': 100.0}}) self.assertEqual(data, self.fs.get_version('f').read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): gridfs.GridFS(rs_or_single_client(w=0).pymongo_test) def test_md5(self): gin = self.fs.new_file() gin.write(b"includes md5 sum") gin.close() self.assertIsNotNone(gin.md5) md5sum = gin.md5 gout = self.fs.get(gin._id) self.assertIsNotNone(gout.md5) self.assertEqual(md5sum, gout.md5) _id = self.fs.put(b"also includes md5 sum") gout = self.fs.get(_id) self.assertIsNotNone(gout.md5) fs = gridfs.GridFS(self.db, disable_md5=True) gin = fs.new_file() gin.write(b"no md5 sum") gin.close() self.assertIsNone(gin.md5) gout = self.fs.get(gin._id) self.assertIsNone(gout.md5) _id = fs.put(b"still no md5 sum") gout = self.fs.get(_id) self.assertIsNone(gout.md5) class TestGridfsReplicaSet(TestReplicaSetClientBase): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): super(TestGridfsReplicaSet, cls).setUpClass() @classmethod def tearDownClass(cls): client_context.client.drop_database('gfsreplica') def test_gridfs_replica_set(self): rsc = rs_client( w=self.w, read_preference=ReadPreference.SECONDARY) fs = gridfs.GridFS(rsc.gfsreplica, 'gfsreplicatest') gin = fs.new_file() self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) oid = fs.put(b'foo') content = fs.get(oid).read() self.assertEqual(b'foo', content) def test_gridfs_secondary(self): primary_host, primary_port = self.primary primary_connection = single_client(primary_host, primary_port) secondary_host, secondary_port = one(self.secondaries) secondary_connection = single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY) # Should detect it's connected to secondary and not attempt to # create index fs = gridfs.GridFS(secondary_connection.gfsreplica, 'gfssecondarytest') # This won't detect secondary, raises error self.assertRaises(ConnectionFailure, fs.put, b'foo') def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.secondaries) client = single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False) # Still no connection. fs = gridfs.GridFS(client.gfsreplica, 'gfssecondarylazytest') # Connects, doesn't create index. self.assertRaises(NoFile, fs.get_last_version) self.assertRaises(ConnectionFailure, fs.put, 'data') if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_gridfs_bucket.py000066400000000000000000000516101374256237000205030ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the gridfs package. """ import datetime import itertools import threading import time import gridfs from bson.binary import Binary from bson.int64 import Int64 from bson.objectid import ObjectId from bson.py3compat import StringIO, string_type from bson.son import SON from gridfs.errors import NoFile, CorruptGridFile from pymongo.errors import (ConfigurationError, ConnectionFailure, ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference from test import (client_context, unittest, IntegrationTest) from test.test_replica_set_client import TestReplicaSetClientBase from test.utils import (ignore_deprecations, joinall, one, rs_client, rs_or_single_client, single_client) class JustWrite(threading.Thread): def __init__(self, gfs, num): threading.Thread.__init__(self) self.gfs = gfs self.num = num self.setDaemon(True) def run(self): for _ in range(self.num): file = self.gfs.open_upload_stream("test") file.write(b"hello") file.close() class JustRead(threading.Thread): def __init__(self, gfs, num, results): threading.Thread.__init__(self) self.gfs = gfs self.num = num self.results = results self.setDaemon(True) def run(self): for _ in range(self.num): file = self.gfs.open_download_stream_by_name("test") data = file.read() self.results.append(data) assert data == b"hello" class TestGridfs(IntegrationTest): @classmethod def setUpClass(cls): super(TestGridfs, cls).setUpClass() cls.fs = gridfs.GridFSBucket(cls.db) cls.alt = gridfs.GridFSBucket( cls.db, bucket_name="alt") def setUp(self): self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") self.db.drop_collection("alt.files") self.db.drop_collection("alt.chunks") def test_basic(self): oid = self.fs.upload_from_stream("test_filename", b"hello world") self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) self.fs.delete(oid) self.assertRaises(NoFile, self.fs.open_download_stream, oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) def test_multi_chunk_delete(self): self.db.fs.drop() self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) gfs = gridfs.GridFSBucket(self.db) oid = gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(5, self.db.fs.chunks.count_documents({})) gfs.delete(oid) self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) def test_empty_file(self): oid = self.fs.upload_from_stream("test_filename", b"") self.assertEqual(b"", self.fs.open_download_stream(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) raw = self.db.fs.files.find_one() self.assertEqual(0, raw["length"]) self.assertEqual(oid, raw["_id"]) self.assertTrue(isinstance(raw["uploadDate"], datetime.datetime)) self.assertEqual(255 * 1024, raw["chunkSize"]) self.assertTrue(isinstance(raw["md5"], string_type)) def test_corrupt_chunk(self): files_id = self.fs.upload_from_stream("test_filename", b'foobar') self.db.fs.chunks.update_one({'files_id': files_id}, {'$set': {'data': Binary(b'foo', 0)}}) try: out = self.fs.open_download_stream(files_id) self.assertRaises(CorruptGridFile, out.read) out = self.fs.open_download_stream(files_id) self.assertRaises(CorruptGridFile, out.readline) finally: self.fs.delete(files_id) def test_upload_ensures_index(self): # setUp has dropped collections. names = self.db.list_collection_names() self.assertFalse([name for name in names if name.startswith('fs')]) chunks = self.db.fs.chunks files = self.db.fs.files self.fs.upload_from_stream("filename", b"junk") self.assertTrue(any( info.get('key') == [('files_id', 1), ('n', 1)] for info in chunks.index_information().values())) self.assertTrue(any( info.get('key') == [('filename', 1), ('uploadDate', 1)] for info in files.index_information().values())) def test_ensure_index_shell_compat(self): files = self.db.fs.files for i, j in itertools.combinations_with_replacement( [1, 1.0, Int64(1)], 2): # Create the index with different numeric types (as might be done # from the mongo shell). shell_index = [('filename', i), ('uploadDate', j)] self.db.command('createIndexes', files.name, indexes=[{'key': SON(shell_index), 'name': 'filename_1.0_uploadDate_1.0'}]) # No error. self.fs.upload_from_stream("filename", b"data") self.assertTrue(any( info.get('key') == [('filename', 1), ('uploadDate', 1)] for info in files.index_information().values())) files.drop() def test_alt_collection(self): oid = self.alt.upload_from_stream("test_filename", b"hello world") self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) self.alt.delete(oid) self.assertRaises(NoFile, self.alt.open_download_stream, oid) self.assertEqual(0, self.db.alt.files.count_documents({})) self.assertEqual(0, self.db.alt.chunks.count_documents({})) self.assertRaises(NoFile, self.alt.open_download_stream, "foo") self.alt.upload_from_stream("foo", b"hello world") self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read()) self.alt.upload_from_stream("mike", b"") self.alt.upload_from_stream("test", b"foo") self.alt.upload_from_stream("hello world", b"") self.assertEqual(set(["mike", "test", "hello world", "foo"]), set(k["filename"] for k in list( self.db.alt.files.find()))) def test_threaded_reads(self): self.fs.upload_from_stream("test", b"hello") threads = [] results = [] for i in range(10): threads.append(JustRead(self.fs, 10, results)) threads[i].start() joinall(threads) self.assertEqual( 100 * [b'hello'], results ) def test_threaded_writes(self): threads = [] for i in range(10): threads.append(JustWrite(self.fs, 10)) threads[i].start() joinall(threads) fstr = self.fs.open_download_stream_by_name("test") self.assertEqual(fstr.read(), b"hello") # Should have created 100 versions of 'test' file self.assertEqual( 100, self.db.fs.files.count_documents({'filename': 'test'}) ) def test_get_last_version(self): one = self.fs.upload_from_stream("test", b"foo") time.sleep(0.01) two = self.fs.open_upload_stream("test") two.write(b"bar") two.close() time.sleep(0.01) two = two._id three = self.fs.upload_from_stream("test", b"baz") self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(three) self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(two) self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(one) self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test") def test_get_version(self): self.fs.upload_from_stream("test", b"foo") time.sleep(0.01) self.fs.upload_from_stream("test", b"bar") time.sleep(0.01) self.fs.upload_from_stream("test", b"baz") time.sleep(0.01) self.assertEqual(b"foo", self.fs.open_download_stream_by_name( "test", revision=0).read()) self.assertEqual(b"bar", self.fs.open_download_stream_by_name( "test", revision=1).read()) self.assertEqual(b"baz", self.fs.open_download_stream_by_name( "test", revision=2).read()) self.assertEqual(b"baz", self.fs.open_download_stream_by_name( "test", revision=-1).read()) self.assertEqual(b"bar", self.fs.open_download_stream_by_name( "test", revision=-2).read()) self.assertEqual(b"foo", self.fs.open_download_stream_by_name( "test", revision=-3).read()) self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3) self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4) def test_upload_from_stream(self): oid = self.fs.upload_from_stream("test_file", StringIO(b"hello world"), chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) def test_upload_from_stream_with_id(self): oid = ObjectId() self.fs.upload_from_stream_with_id(oid, "test_file_custom_id", StringIO(b"custom id"), chunk_size_bytes=1) self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read()) def test_open_upload_stream(self): gin = self.fs.open_upload_stream("from_stream") gin.write(b"from stream") gin.close() self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read()) def test_open_upload_stream_with_id(self): oid = ObjectId() gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") gin.write(b"from stream with custom id") gin.close() self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read()) def test_missing_length_iter(self): # Test fix that guards against PHP-237 self.fs.upload_from_stream("empty", b"") doc = self.db.fs.files.find_one({"filename": "empty"}) doc.pop("length") self.db.fs.files.replace_one({"_id": doc["_id"]}, doc) fstr = self.fs.open_download_stream_by_name("empty") def iterate_file(grid_file): for _ in grid_file: pass return True self.assertTrue(iterate_file(fstr)) def test_gridfs_lazy_connect(self): client = MongoClient('badhost', connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) gfs = gridfs.GridFSBucket(cdb) self.assertRaises( ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b"") # Still no connection. @ignore_deprecations def test_gridfs_find(self): self.fs.upload_from_stream("two", b"test2") time.sleep(0.01) self.fs.upload_from_stream("two", b"test2+") time.sleep(0.01) self.fs.upload_from_stream("one", b"test1") time.sleep(0.01) self.fs.upload_from_stream("two", b"test2++") self.assertEqual(3, self.fs.find({"filename": "two"}).count()) self.assertEqual(4, self.fs.find({}).count()) cursor = self.fs.find( {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2) gout = next(cursor) self.assertEqual(b"test1", gout.read()) cursor.rewind() gout = next(cursor) self.assertEqual(b"test1", gout.read()) gout = next(cursor) self.assertEqual(b"test2+", gout.read()) self.assertRaises(StopIteration, cursor.__next__) cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True}) def test_grid_in_non_int_chunksize(self): # Lua, and perhaps other buggy GridFS clients, store size as a float. data = b'data' self.fs.upload_from_stream('f', data) self.db.fs.files.update_one({'filename': 'f'}, {'$set': {'chunkSize': 100.0}}) self.assertEqual(data, self.fs.open_download_stream_by_name('f').read()) def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b'testing') self.assertEqual(b'testing', self.fs.open_download_stream_by_name( "first_name").read()) self.fs.rename(_id, "second_name") self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name") self.assertEqual(b"testing", self.fs.open_download_stream_by_name( "second_name").read()) def test_abort(self): gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5) gin.write(b"test1") gin.write(b"test2") gin.write(b"test3") self.assertEqual(3, self.db.fs.chunks.count_documents( {"files_id": gin._id})) gin.abort() self.assertTrue(gin.closed) self.assertRaises(ValueError, gin.write, b"test4") self.assertEqual(0, self.db.fs.chunks.count_documents( {"files_id": gin._id})) def test_download_to_stream(self): file1 = StringIO(b"hello world") # Test with one chunk. oid = self.fs.upload_from_stream("one_chunk", file1) self.assertEqual(1, self.db.fs.chunks.count_documents({})) file2 = StringIO() self.fs.download_to_stream(oid, file2) file1.seek(0) file2.seek(0) self.assertEqual(file1.read(), file2.read()) # Test with many chunks. self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") file1.seek(0) oid = self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) file2 = StringIO() self.fs.download_to_stream(oid, file2) file1.seek(0) file2.seek(0) self.assertEqual(file1.read(), file2.read()) def test_download_to_stream_by_name(self): file1 = StringIO(b"hello world") # Test with one chunk. oid = self.fs.upload_from_stream("one_chunk", file1) self.assertEqual(1, self.db.fs.chunks.count_documents({})) file2 = StringIO() self.fs.download_to_stream_by_name("one_chunk", file2) file1.seek(0) file2.seek(0) self.assertEqual(file1.read(), file2.read()) # Test with many chunks. self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") file1.seek(0) self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) file2 = StringIO() self.fs.download_to_stream_by_name("many_chunks", file2) file1.seek(0) file2.seek(0) self.assertEqual(file1.read(), file2.read()) def test_md5(self): gin = self.fs.open_upload_stream("has md5") gin.write(b"includes md5 sum") gin.close() self.assertIsNotNone(gin.md5) md5sum = gin.md5 gout = self.fs.open_download_stream(gin._id) self.assertIsNotNone(gout.md5) self.assertEqual(md5sum, gout.md5) gin = self.fs.open_upload_stream_with_id(ObjectId(), "also has md5") gin.write(b"also includes md5 sum") gin.close() self.assertIsNotNone(gin.md5) md5sum = gin.md5 gout = self.fs.open_download_stream(gin._id) self.assertIsNotNone(gout.md5) self.assertEqual(md5sum, gout.md5) fs = gridfs.GridFSBucket(self.db, disable_md5=True) gin = fs.open_upload_stream("no md5") gin.write(b"no md5 sum") gin.close() self.assertIsNone(gin.md5) gout = fs.open_download_stream(gin._id) self.assertIsNone(gout.md5) gin = fs.open_upload_stream_with_id(ObjectId(), "also no md5") gin.write(b"also no md5 sum") gin.close() self.assertIsNone(gin.md5) gout = fs.open_download_stream(gin._id) self.assertIsNone(gout.md5) class TestGridfsBucketReplicaSet(TestReplicaSetClientBase): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): super(TestGridfsBucketReplicaSet, cls).setUpClass() @classmethod def tearDownClass(cls): client_context.client.drop_database('gfsbucketreplica') def test_gridfs_replica_set(self): rsc = rs_client( w=self.w, read_preference=ReadPreference.SECONDARY) gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, 'gfsbucketreplicatest') oid = gfs.upload_from_stream("test_filename", b'foo') content = gfs.open_download_stream(oid).read() self.assertEqual(b'foo', content) def test_gridfs_secondary(self): primary_host, primary_port = self.primary primary_connection = single_client(primary_host, primary_port) secondary_host, secondary_port = one(self.secondaries) secondary_connection = single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY) # Should detect it's connected to secondary and not attempt to # create index gfs = gridfs.GridFSBucket( secondary_connection.gfsbucketreplica, 'gfsbucketsecondarytest') # This won't detect secondary, raises error self.assertRaises(ConnectionFailure, gfs.upload_from_stream, "test_filename", b'foo') def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.secondaries) client = single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False) # Still no connection. gfs = gridfs.GridFSBucket( client.gfsbucketreplica, 'gfsbucketsecondarylazytest') # Connects, doesn't create index. self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename") self.assertRaises(ConnectionFailure, gfs.upload_from_stream, "test_filename", b'data') if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_gridfs_spec.py000066400000000000000000000204761374256237000201660ustar00rootroot00000000000000# Copyright 2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test GridFSBucket class.""" import copy import datetime import os import sys import re from json import loads import gridfs sys.path[0:0] = [""] from bson import Binary from bson.int64 import Int64 from bson.json_util import object_hook from bson.py3compat import bytes_from_hex from gridfs.errors import NoFile, CorruptGridFile from test import (unittest, IntegrationTest) # Commands. _COMMANDS = {"delete": lambda coll, doc: [coll.delete_many(d["q"]) for d in doc['deletes']], "insert": lambda coll, doc: coll.insert_many(doc['documents']), "update": lambda coll, doc: [coll.update_many(u["q"], u["u"]) for u in doc['updates']] } # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'gridfs') def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. Special case for _id. if camel == "id": return "file_id" snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() class TestAllScenarios(IntegrationTest): @classmethod def setUpClass(cls): super(TestAllScenarios, cls).setUpClass() cls.fs = gridfs.GridFSBucket(cls.db) cls.str_to_cmd = { "upload": cls.fs.upload_from_stream, "download": cls.fs.open_download_stream, "delete": cls.fs.delete, "download_by_name": cls.fs.open_download_stream_by_name} def init_db(self, data, test): self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") self.db.drop_collection("expected.files") self.db.drop_collection("expected.chunks") # Read in data. if data['files']: self.db.fs.files.insert_many(data['files']) self.db.expected.files.insert_many(data['files']) if data['chunks']: self.db.fs.chunks.insert_many(data['chunks']) self.db.expected.chunks.insert_many(data['chunks']) # Make initial modifications. if "arrange" in test: for cmd in test['arrange'].get('data', []): for key in cmd.keys(): if key in _COMMANDS: coll = self.db.get_collection(cmd[key]) _COMMANDS[key](coll, cmd) def init_expected_db(self, test, result): # Modify outcome DB. for cmd in test['assert'].get('data', []): for key in cmd.keys(): if key in _COMMANDS: # Replace wildcards in inserts. for doc in cmd.get('documents', []): keylist = doc.keys() for dockey in copy.deepcopy(list(keylist)): if "result" in str(doc[dockey]): doc[dockey] = result if "actual" in str(doc[dockey]): # Avoid duplicate doc.pop(dockey) # Move contentType to metadata. if dockey == "contentType": doc["metadata"] = {dockey: doc.pop(dockey)} coll = self.db.get_collection(cmd[key]) _COMMANDS[key](coll, cmd) if test['assert'].get('result') == "&result": test['assert']['result'] = result def sorted_list(self, coll, ignore_id): to_sort = [] for doc in coll.find(): docstr = "{" if ignore_id: # Cannot compare _id in chunks collection. doc.pop("_id") for k in sorted(doc.keys()): if k == "uploadDate": # Can't compare datetime. self.assertTrue(isinstance(doc[k], datetime.datetime)) else: docstr += "%s:%s " % (k, repr(doc[k])) to_sort.append(docstr + "}") return to_sort def create_test(scenario_def): def run_scenario(self): # Run tests. self.assertTrue(scenario_def['tests'], "tests cannot be empty") for test in scenario_def['tests']: self.init_db(scenario_def['data'], test) # Run GridFs Operation. operation = self.str_to_cmd[test['act']['operation']] args = test['act']['arguments'] extra_opts = args.pop("options", {}) if "contentType" in extra_opts: extra_opts["metadata"] = { "contentType": extra_opts.pop("contentType")} args.update(extra_opts) converted_args = dict((camel_to_snake(c), v) for c, v in args.items()) expect_error = test['assert'].get("error", False) result = None error = None try: result = operation(**converted_args) if 'download' in test['act']['operation']: result = Binary(result.read()) except Exception as exc: if not expect_error: raise error = exc self.init_expected_db(test, result) # Asserts. errors = {"FileNotFound": NoFile, "ChunkIsMissing": CorruptGridFile, "ExtraChunk": CorruptGridFile, "ChunkIsWrongSize": CorruptGridFile, "RevisionNotFound": NoFile} if expect_error: self.assertIsNotNone(error) self.assertIsInstance(error, errors[test['assert']['error']], test['description']) else: self.assertIsNone(error) if 'result' in test['assert']: if test['assert']['result'] == 'void': test['assert']['result'] = None self.assertEqual(result, test['assert'].get('result')) if 'data' in test['assert']: # Create alphabetized list self.assertEqual( set(self.sorted_list(self.db.fs.chunks, True)), set(self.sorted_list(self.db.expected.chunks, True))) self.assertEqual( set(self.sorted_list(self.db.fs.files, False)), set(self.sorted_list(self.db.expected.files, False))) return run_scenario def _object_hook(dct): if 'length' in dct: dct['length'] = Int64(dct['length']) return object_hook(dct) def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = loads( scenario_stream.read(), object_hook=_object_hook) # Because object_hook is already defined by bson.json_util, # and everything is named 'data' def str2hex(jsn): for key, val in jsn.items(): if key in ("data", "source", "result"): if "$hex" in val: jsn[key] = Binary(bytes_from_hex(val['$hex'])) if isinstance(jsn[key], dict): str2hex(jsn[key]) if isinstance(jsn[key], list): for k in jsn[key]: str2hex(k) str2hex(scenario_def) # Construct test from scenario. new_test = create_test(scenario_def) test_name = 'test_%s' % ( os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_heartbeat_monitoring.py000066400000000000000000000074731374256237000221040ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the monitoring of the server heartbeats.""" import sys import threading sys.path[0:0] = [""] from pymongo.errors import ConnectionFailure from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from test import unittest, client_knobs from test.utils import (HeartbeatEventListener, MockPool, single_client, wait_until) class TestHeartbeatMonitoring(unittest.TestCase): def create_mock_monitor(self, responses, uri, expected_results): listener = HeartbeatEventListener() with client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1): class MockMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): if isinstance(responses[1], Exception): raise responses[1] return IsMaster(responses[1]), 99 m = single_client( h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool) expected_len = len(expected_results) # Wait for *at least* expected_len number of results. The # monitor thread may run multiple times during the execution # of this test. wait_until( lambda: len(listener.results) >= expected_len, "publish all events") try: # zip gives us len(expected_results) pairs. for expected, actual in zip(expected_results, listener.results): self.assertEqual(expected, actual.__class__.__name__) self.assertEqual(actual.connection_id, responses[0]) if expected != 'ServerHeartbeatStartedEvent': if isinstance(actual.reply, IsMaster): self.assertEqual(actual.duration, 99) self.assertEqual(actual.reply._doc, responses[1]) else: self.assertEqual(actual.reply, responses[1]) finally: m.close() def test_standalone(self): responses = (('a', 27017), { "ismaster": True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1 }) uri = "mongodb://a:27017" expected_results = ['ServerHeartbeatStartedEvent', 'ServerHeartbeatSucceededEvent'] self.create_mock_monitor(responses, uri, expected_results) def test_standalone_error(self): responses = (('a', 27017), ConnectionFailure("SPECIAL MESSAGE")) uri = "mongodb://a:27017" # _check_with_socket failing results in a second attempt. expected_results = ['ServerHeartbeatStartedEvent', 'ServerHeartbeatFailedEvent', 'ServerHeartbeatStartedEvent', 'ServerHeartbeatFailedEvent'] self.create_mock_monitor(responses, uri, expected_results) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_json_util.py000066400000000000000000000433341374256237000177020ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test some utilities for working with JSON and PyMongo.""" import datetime import json import re import sys import uuid sys.path[0:0] = [""] from bson import json_util, EPOCH_AWARE, SON from bson.json_util import (DatetimeRepresentation, STRICT_JSON_OPTIONS) from bson.binary import (ALL_UUID_REPRESENTATIONS, Binary, MD5_SUBTYPE, USER_DEFINED_SUBTYPE, UuidRepresentation, STANDARD) from bson.code import Code from bson.dbref import DBRef from bson.int64 import Int64 from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.regex import Regex from bson.timestamp import Timestamp from bson.tz_util import FixedOffset, utc from test import unittest, IntegrationTest PY3 = sys.version_info[0] == 3 class TestJsonUtil(unittest.TestCase): def round_tripped(self, doc, **kwargs): return json_util.loads(json_util.dumps(doc, **kwargs), **kwargs) def round_trip(self, doc, **kwargs): self.assertEqual(doc, self.round_tripped(doc, **kwargs)) def test_basic(self): self.round_trip({"hello": "world"}) def test_objectid(self): self.round_trip({"id": ObjectId()}) def test_dbref(self): self.round_trip({"ref": DBRef("foo", 5)}) self.round_trip({"ref": DBRef("foo", 5, "db")}) self.round_trip({"ref": DBRef("foo", ObjectId())}) # Check order. self.assertEqual( '{"$ref": "collection", "$id": 1, "$db": "db"}', json_util.dumps(DBRef('collection', 1, 'db'))) def test_datetime(self): # only millis, not micros self.round_trip({"date": datetime.datetime(2009, 12, 9, 15, 49, 45, 191000, utc)}) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000+0000"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000000+0000"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000+00:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000Z"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000000Z"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00Z"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) # No explicit offset jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T00:00:00.000000"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) # Localtime behind UTC jsn = '{"dt": { "$date" : "1969-12-31T16:00:00.000-0800"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1969-12-31T16:00:00.000000-0800"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1969-12-31T16:00:00.000-08:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) # Localtime ahead of UTC jsn = '{"dt": { "$date" : "1970-01-01T01:00:00.000+0100"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T01:00:00.000000+0100"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T01:00:00.000+01:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01:00"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) jsn = '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01"}}' self.assertEqual(EPOCH_AWARE, json_util.loads(jsn)["dt"]) dtm = datetime.datetime(1, 1, 1, 1, 1, 1, 0, utc) jsn = '{"dt": {"$date": -62135593139000}}' self.assertEqual(dtm, json_util.loads(jsn)["dt"]) jsn = '{"dt": {"$date": {"$numberLong": "-62135593139000"}}}' self.assertEqual(dtm, json_util.loads(jsn)["dt"]) # Test dumps format pre_epoch = {"dt": datetime.datetime(1, 1, 1, 1, 1, 1, 10000, utc)} post_epoch = {"dt": datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc)} self.assertEqual( '{"dt": {"$date": -62135593138990}}', json_util.dumps(pre_epoch)) self.assertEqual( '{"dt": {"$date": 63075661010}}', json_util.dumps(post_epoch)) self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', json_util.dumps(pre_epoch, json_options=STRICT_JSON_OPTIONS)) self.assertEqual( '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', json_util.dumps(post_epoch, json_options=STRICT_JSON_OPTIONS)) number_long_options = json_util.JSONOptions( datetime_representation=DatetimeRepresentation.NUMBERLONG) self.assertEqual( '{"dt": {"$date": {"$numberLong": "63075661010"}}}', json_util.dumps(post_epoch, json_options=number_long_options)) self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', json_util.dumps(pre_epoch, json_options=number_long_options)) # ISO8601 mode assumes naive datetimes are UTC pre_epoch_naive = {"dt": datetime.datetime(1, 1, 1, 1, 1, 1, 10000)} post_epoch_naive = { "dt": datetime.datetime(1972, 1, 1, 1, 1, 1, 10000)} self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', json_util.dumps(pre_epoch_naive, json_options=STRICT_JSON_OPTIONS)) self.assertEqual( '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', json_util.dumps(post_epoch_naive, json_options=STRICT_JSON_OPTIONS)) # Test tz_aware and tzinfo options self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc), json_util.loads( '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}')["dt"]) self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc), json_util.loads( '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', json_options=json_util.JSONOptions(tz_aware=True, tzinfo=utc))["dt"]) self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000), json_util.loads( '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', json_options=json_util.JSONOptions(tz_aware=False))["dt"]) self.round_trip(pre_epoch_naive, json_options=json_util.JSONOptions( tz_aware=False)) # Test a non-utc timezone pacific = FixedOffset(-8 * 60, 'US/Pacific') aware_datetime = {"dt": datetime.datetime(2002, 10, 27, 6, 0, 0, 10000, pacific)} self.assertEqual( '{"dt": {"$date": "2002-10-27T06:00:00.010-0800"}}', json_util.dumps(aware_datetime, json_options=STRICT_JSON_OPTIONS)) self.round_trip(aware_datetime, json_options=json_util.JSONOptions( tz_aware=True, tzinfo=pacific)) self.round_trip(aware_datetime, json_options=json_util.JSONOptions( datetime_representation=DatetimeRepresentation.ISO8601, tz_aware=True, tzinfo=pacific)) def test_regex_object_hook(self): # Extended JSON format regular expression. pat = 'a*b' json_re = '{"$regex": "%s", "$options": "u"}' % pat loaded = json_util.object_hook(json.loads(json_re)) self.assertTrue(isinstance(loaded, Regex)) self.assertEqual(pat, loaded.pattern) self.assertEqual(re.U, loaded.flags) def test_regex(self): for regex_instance in ( re.compile("a*b", re.IGNORECASE), Regex("a*b", re.IGNORECASE)): res = self.round_tripped({"r": regex_instance})["r"] self.assertEqual("a*b", res.pattern) res = self.round_tripped({"r": Regex("a*b", re.IGNORECASE)})["r"] self.assertEqual("a*b", res.pattern) self.assertEqual(re.IGNORECASE, res.flags) unicode_options = re.I|re.M|re.S|re.U|re.X regex = re.compile("a*b", unicode_options) res = self.round_tripped({"r": regex})["r"] self.assertEqual(unicode_options, res.flags) # Some tools may not add $options if no flags are set. res = json_util.loads('{"r": {"$regex": "a*b"}}')['r'] self.assertEqual(0, res.flags) self.assertEqual( Regex('.*', 'ilm'), json_util.loads( '{"r": {"$regex": ".*", "$options": "ilm"}}')['r']) # Check order. self.assertEqual( '{"$regex": ".*", "$options": "mx"}', json_util.dumps(Regex('.*', re.M | re.X))) self.assertEqual( '{"$regex": ".*", "$options": "mx"}', json_util.dumps(re.compile(b'.*', re.M | re.X))) def test_minkey(self): self.round_trip({"m": MinKey()}) def test_maxkey(self): self.round_trip({"m": MaxKey()}) def test_timestamp(self): dct = {"ts": Timestamp(4, 13)} res = json_util.dumps(dct, default=json_util.default) rtdct = json_util.loads(res) self.assertEqual(dct, rtdct) self.assertEqual('{"ts": {"$timestamp": {"t": 4, "i": 13}}}', res) def test_uuid(self): doc = {'uuid': uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')} self.round_trip(doc) self.assertEqual( '{"uuid": {"$uuid": "f47ac10b58cc4372a5670e02b2c3d479"}}', json_util.dumps(doc)) self.assertEqual( '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}', json_util.dumps( doc, json_options=json_util.STRICT_JSON_OPTIONS)) self.assertEqual( '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', json_util.dumps( doc, json_options=json_util.JSONOptions( strict_uuid=True, uuid_representation=STANDARD))) self.assertEqual( doc, json_util.loads( '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}')) for uuid_representation in (set(ALL_UUID_REPRESENTATIONS) - {UuidRepresentation.UNSPECIFIED}): options = json_util.JSONOptions( strict_uuid=True, uuid_representation=uuid_representation) self.round_trip(doc, json_options=options) # Ignore UUID representation when decoding BSON binary subtype 4. self.assertEqual(doc, json_util.loads( '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', json_options=options)) def test_uuid_uuid_rep_unspecified(self): _uuid = uuid.uuid4() options = json_util.JSONOptions( strict_uuid=True, uuid_representation=UuidRepresentation.UNSPECIFIED) # Cannot directly encode native UUIDs with UNSPECIFIED. doc = {'uuid': _uuid} with self.assertRaises(ValueError): json_util.dumps(doc, json_options=options) # All UUID subtypes are decoded as Binary with UNSPECIFIED. # subtype 3 doc = {'uuid': Binary(_uuid.bytes, subtype=3)} ext_json_str = json_util.dumps(doc) self.assertEqual( doc, json_util.loads(ext_json_str, json_options=options)) # subtype 4 doc = {'uuid': Binary(_uuid.bytes, subtype=4)} ext_json_str = json_util.dumps(doc) self.assertEqual( doc, json_util.loads(ext_json_str, json_options=options)) # $uuid-encoded fields doc = {'uuid': Binary(_uuid.bytes, subtype=4)} ext_json_str = json_util.dumps({'uuid': _uuid}) self.assertEqual( doc, json_util.loads(ext_json_str, json_options=options)) def test_binary(self): if PY3: bin_type_dict = {"bin": b"\x00\x01\x02\x03\x04"} else: bin_type_dict = {"bin": Binary(b"\x00\x01\x02\x03\x04")} md5_type_dict = { "md5": Binary(b' n7\x18\xaf\t/\xd1\xd1/\x80\xca\xe7q\xcc\xac', MD5_SUBTYPE)} custom_type_dict = {"custom": Binary(b"hello", USER_DEFINED_SUBTYPE)} self.round_trip(bin_type_dict) self.round_trip(md5_type_dict) self.round_trip(custom_type_dict) # Binary with subtype 0 is decoded into bytes in Python 3. bin = json_util.loads( '{"bin": {"$binary": "AAECAwQ=", "$type": "00"}}')['bin'] if PY3: self.assertEqual(type(bin), bytes) else: self.assertEqual(type(bin), Binary) # PYTHON-443 ensure old type formats are supported json_bin_dump = json_util.dumps(bin_type_dict) self.assertTrue('"$type": "00"' in json_bin_dump) self.assertEqual(bin_type_dict, json_util.loads('{"bin": {"$type": 0, "$binary": "AAECAwQ="}}')) json_bin_dump = json_util.dumps(md5_type_dict) # Check order. self.assertEqual( '{"md5": {"$binary": "IG43GK8JL9HRL4DK53HMrA==",' + ' "$type": "05"}}', json_bin_dump) self.assertEqual(md5_type_dict, json_util.loads('{"md5": {"$type": 5, "$binary":' ' "IG43GK8JL9HRL4DK53HMrA=="}}')) json_bin_dump = json_util.dumps(custom_type_dict) self.assertTrue('"$type": "80"' in json_bin_dump) self.assertEqual(custom_type_dict, json_util.loads('{"custom": {"$type": 128, "$binary":' ' "aGVsbG8="}}')) # Handle mongoexport where subtype >= 128 self.assertEqual(128, json_util.loads('{"custom": {"$type": "ffffff80", "$binary":' ' "aGVsbG8="}}')['custom'].subtype) self.assertEqual(255, json_util.loads('{"custom": {"$type": "ffffffff", "$binary":' ' "aGVsbG8="}}')['custom'].subtype) def test_code(self): self.round_trip({"code": Code("function x() { return 1; }")}) code = Code("return z", z=2) res = json_util.dumps(code) self.assertEqual(code, json_util.loads(res)) # Check order. self.assertEqual('{"$code": "return z", "$scope": {"z": 2}}', res) no_scope = Code('function() {}') self.assertEqual( '{"$code": "function() {}"}', json_util.dumps(no_scope)) def test_undefined(self): jsn = '{"name": {"$undefined": true}}' self.assertIsNone(json_util.loads(jsn)['name']) def test_numberlong(self): jsn = '{"weight": {"$numberLong": "65535"}}' self.assertEqual(json_util.loads(jsn)['weight'], Int64(65535)) self.assertEqual(json_util.dumps({"weight": Int64(65535)}), '{"weight": 65535}') json_options = json_util.JSONOptions(strict_number_long=True) self.assertEqual(json_util.dumps({"weight": Int64(65535)}, json_options=json_options), jsn) def test_loads_document_class(self): # document_class dict should always work self.assertEqual({"foo": "bar"}, json_util.loads( '{"foo": "bar"}', json_options=json_util.JSONOptions(document_class=dict))) self.assertEqual(SON([("foo", "bar"), ("b", 1)]), json_util.loads( '{"foo": "bar", "b": 1}', json_options=json_util.JSONOptions(document_class=SON))) class TestJsonUtilRoundtrip(IntegrationTest): def test_cursor(self): db = self.db db.drop_collection("test") docs = [ {'foo': [1, 2]}, {'bar': {'hello': 'world'}}, {'code': Code("function x() { return 1; }")}, {'bin': Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, {'dbref': {'_ref': DBRef('simple', ObjectId('509b8db456c02c5ab7e63c34'))}} ] db.test.insert_many(docs) reloaded_docs = json_util.loads(json_util.dumps(db.test.find())) for doc in docs: self.assertTrue(doc in reloaded_docs) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_legacy_api.py000066400000000000000000002627441374256237000200010ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test various legacy / deprecated API features.""" import itertools import sys import threading import time import uuid import warnings sys.path[0:0] = [""] from bson.binary import PYTHON_LEGACY, STANDARD from bson.code import Code from bson.codec_options import CodecOptions from bson.objectid import ObjectId from bson.py3compat import string_type from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEOHAYSTACK from pymongo.database import Database from pymongo.common import partition_node from pymongo.errors import (BulkWriteError, ConfigurationError, CursorNotFound, DocumentTooLarge, DuplicateKeyError, InvalidDocument, InvalidOperation, OperationFailure, WriteConcernError, WTimeoutError) from pymongo.message import _CursorAddress from pymongo.operations import IndexModel from pymongo.son_manipulator import (AutoReference, NamespaceInjector, ObjectIdShuffler, SONManipulator) from pymongo.write_concern import WriteConcern from test import client_context, qcheck, unittest, SkipTest from test.test_client import IntegrationTest from test.test_bulk import BulkTestBase, BulkAuthorizationTestBase from test.utils import (DeprecationFilter, joinall, oid_generated_on_process, rs_or_single_client, rs_or_single_client_noauth, single_client, wait_until) class TestDeprecations(IntegrationTest): @classmethod def setUpClass(cls): super(TestDeprecations, cls).setUpClass() cls.deprecation_filter = DeprecationFilter("error") @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() def test_save_deprecation(self): self.assertRaises( DeprecationWarning, lambda: self.db.test.save({})) def test_insert_deprecation(self): self.assertRaises( DeprecationWarning, lambda: self.db.test.insert({})) def test_update_deprecation(self): self.assertRaises( DeprecationWarning, lambda: self.db.test.update({}, {})) def test_remove_deprecation(self): self.assertRaises( DeprecationWarning, lambda: self.db.test.remove({})) def test_find_and_modify_deprecation(self): self.assertRaises( DeprecationWarning, lambda: self.db.test.find_and_modify({'i': 5}, {})) def test_add_son_manipulator_deprecation(self): db = self.client.pymongo_test self.assertRaises(DeprecationWarning, lambda: db.add_son_manipulator(AutoReference(db))) def test_ensure_index_deprecation(self): try: self.assertRaises( DeprecationWarning, lambda: self.db.test.ensure_index('i')) finally: self.db.test.drop() def test_reindex_deprecation(self): self.assertRaises(DeprecationWarning, lambda: self.db.test.reindex()) def test_geoHaystack_deprecation(self): self.addCleanup(self.db.test.drop) keys = [("pos", GEOHAYSTACK), ("type", ASCENDING)] self.assertRaises( DeprecationWarning, self.db.test.create_index, keys, bucketSize=1) indexes = [IndexModel(keys, bucketSize=1)] self.assertRaises( DeprecationWarning, self.db.test.create_indexes, indexes) class TestLegacy(IntegrationTest): @classmethod def setUpClass(cls): super(TestLegacy, cls).setUpClass() cls.w = client_context.w cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() def test_insert_find_one(self): # Tests legacy insert. db = self.db db.test.drop() self.assertEqual(0, len(list(db.test.find()))) doc = {"hello": u"world"} _id = db.test.insert(doc) self.assertEqual(1, len(list(db.test.find()))) self.assertEqual(doc, db.test.find_one()) self.assertEqual(doc["_id"], _id) self.assertTrue(isinstance(_id, ObjectId)) doc_class = dict # Work around http://bugs.jython.org/issue1728 if (sys.platform.startswith('java') and sys.version_info[:3] >= (2, 5, 2)): doc_class = SON db = self.client.get_database( db.name, codec_options=CodecOptions(document_class=doc_class)) def remove_insert_find_one(doc): db.test.remove({}) db.test.insert(doc) # SON equality is order sensitive. return db.test.find_one() == doc.to_dict() qcheck.check_unittest(self, remove_insert_find_one, qcheck.gen_mongo_dict(3)) def test_generator_insert(self): # Only legacy insert currently supports insert from a generator. db = self.db db.test.remove({}) self.assertEqual(db.test.find().count(), 0) db.test.insert(({'a': i} for i in range(5)), manipulate=False) self.assertEqual(5, db.test.count()) db.test.remove({}) db.test.insert(({'a': i} for i in range(5)), manipulate=True) self.assertEqual(5, db.test.count()) db.test.remove({}) def test_insert_multiple(self): # Tests legacy insert. db = self.db db.drop_collection("test") doc1 = {"hello": u"world"} doc2 = {"hello": u"mike"} self.assertEqual(db.test.find().count(), 0) ids = db.test.insert([doc1, doc2]) self.assertEqual(db.test.find().count(), 2) self.assertEqual(doc1, db.test.find_one({"hello": u"world"})) self.assertEqual(doc2, db.test.find_one({"hello": u"mike"})) self.assertEqual(2, len(ids)) self.assertEqual(doc1["_id"], ids[0]) self.assertEqual(doc2["_id"], ids[1]) ids = db.test.insert([{"hello": 1}]) self.assertTrue(isinstance(ids, list)) self.assertEqual(1, len(ids)) self.assertRaises(InvalidOperation, db.test.insert, []) # Generator that raises StopIteration on first call to next(). self.assertRaises(InvalidOperation, db.test.insert, (i for i in [])) def test_insert_multiple_with_duplicate(self): # Tests legacy insert. db = self.db db.drop_collection("test_insert_multiple_with_duplicate") collection = db.test_insert_multiple_with_duplicate collection.create_index([('i', ASCENDING)], unique=True) # No error collection.insert([{'i': i} for i in range(5, 10)], w=0) wait_until(lambda: 5 == collection.count(), 'insert 5 documents') db.drop_collection("test_insert_multiple_with_duplicate") collection.create_index([('i', ASCENDING)], unique=True) # No error collection.insert([{'i': 1}] * 2, w=0) wait_until(lambda: 1 == collection.count(), 'insert 1 document') self.assertRaises( DuplicateKeyError, lambda: collection.insert([{'i': 2}] * 2), ) db.drop_collection("test_insert_multiple_with_duplicate") db = self.client.get_database( db.name, write_concern=WriteConcern(w=0)) collection = db.test_insert_multiple_with_duplicate collection.create_index([('i', ASCENDING)], unique=True) # No error. collection.insert([{'i': 1}] * 2) wait_until(lambda: 1 == collection.count(), 'insert 1 document') # Implied acknowledged. self.assertRaises( DuplicateKeyError, lambda: collection.insert([{'i': 2}] * 2, fsync=True), ) # Explicit acknowledged. self.assertRaises( DuplicateKeyError, lambda: collection.insert([{'i': 2}] * 2, w=1)) db.drop_collection("test_insert_multiple_with_duplicate") @client_context.require_replica_set def test_insert_prefers_write_errors(self): # Tests legacy insert. collection = self.db.test_insert_prefers_write_errors self.db.drop_collection(collection.name) collection.insert_one({'_id': 1}) large = 's' * 1024 * 1024 * 15 with self.assertRaises(DuplicateKeyError): collection.insert( [{'_id': 1, 's': large}, {'_id': 2, 's': large}]) self.assertEqual(1, collection.count()) with self.assertRaises(DuplicateKeyError): collection.insert( [{'_id': 1, 's': large}, {'_id': 2, 's': large}], continue_on_error=True) self.assertEqual(2, collection.count()) collection.delete_one({'_id': 2}) # A writeError followed by a writeConcernError should prefer to raise # the writeError. with self.assertRaises(DuplicateKeyError): collection.insert( [{'_id': 1, 's': large}, {'_id': 2, 's': large}], continue_on_error=True, w=len(client_context.nodes) + 10, wtimeout=1) self.assertEqual(2, collection.count()) collection.delete_many({}) with self.assertRaises(WriteConcernError): collection.insert( [{'_id': 1, 's': large}, {'_id': 2, 's': large}], continue_on_error=True, w=len(client_context.nodes) + 10, wtimeout=1) self.assertEqual(2, collection.count()) def test_insert_iterables(self): # Tests legacy insert. db = self.db self.assertRaises(TypeError, db.test.insert, 4) self.assertRaises(TypeError, db.test.insert, None) self.assertRaises(TypeError, db.test.insert, True) db.drop_collection("test") self.assertEqual(db.test.find().count(), 0) db.test.insert(({"hello": u"world"}, {"hello": u"world"})) self.assertEqual(db.test.find().count(), 2) db.drop_collection("test") self.assertEqual(db.test.find().count(), 0) db.test.insert(map(lambda x: {"hello": "world"}, itertools.repeat(None, 10))) self.assertEqual(db.test.find().count(), 10) def test_insert_manipulate_false(self): # Test two aspects of legacy insert with manipulate=False: # 1. The return value is None or [None] as appropriate. # 2. _id is not set on the passed-in document object. collection = self.db.test_insert_manipulate_false collection.drop() oid = ObjectId() doc = {'a': oid} try: # The return value is None. self.assertTrue(collection.insert(doc, manipulate=False) is None) # insert() shouldn't set _id on the passed-in document object. self.assertEqual({'a': oid}, doc) # Bulk insert. The return value is a list of None. self.assertEqual([None], collection.insert([{}], manipulate=False)) docs = [{}, {}] ids = collection.insert(docs, manipulate=False) self.assertEqual([None, None], ids) self.assertEqual([{}, {}], docs) finally: collection.drop() def test_continue_on_error(self): # Tests legacy insert. db = self.db db.drop_collection("test_continue_on_error") collection = db.test_continue_on_error oid = collection.insert({"one": 1}) self.assertEqual(1, collection.count()) docs = [] docs.append({"_id": oid, "two": 2}) # Duplicate _id. docs.append({"three": 3}) docs.append({"four": 4}) docs.append({"five": 5}) with self.assertRaises(DuplicateKeyError): collection.insert(docs, manipulate=False) self.assertEqual(1, collection.count()) with self.assertRaises(DuplicateKeyError): collection.insert(docs, manipulate=False, continue_on_error=True) self.assertEqual(4, collection.count()) collection.remove({}, w=client_context.w) oid = collection.insert({"_id": oid, "one": 1}, w=0) wait_until(lambda: 1 == collection.count(), 'insert 1 document') docs[0].pop("_id") docs[2]["_id"] = oid with self.assertRaises(DuplicateKeyError): collection.insert(docs, manipulate=False) self.assertEqual(3, collection.count()) collection.insert(docs, manipulate=False, continue_on_error=True, w=0) wait_until(lambda: 6 == collection.count(), 'insert 3 documents') def test_acknowledged_insert(self): # Tests legacy insert. db = self.db db.drop_collection("test_acknowledged_insert") collection = db.test_acknowledged_insert a = {"hello": "world"} collection.insert(a) collection.insert(a, w=0) self.assertRaises(OperationFailure, collection.insert, a) def test_insert_adds_id(self): # Tests legacy insert. doc = {"hello": "world"} self.db.test.insert(doc) self.assertTrue("_id" in doc) docs = [{"hello": "world"}, {"hello": "world"}] self.db.test.insert(docs) for doc in docs: self.assertTrue("_id" in doc) def test_insert_large_batch(self): # Tests legacy insert. db = self.client.test_insert_large_batch self.addCleanup(self.client.drop_database, 'test_insert_large_batch') max_bson_size = self.client.max_bson_size # Write commands are limited to 16MB + 16k per batch big_string = 'x' * int(max_bson_size / 2) # Batch insert that requires 2 batches. successful_insert = [{'x': big_string}, {'x': big_string}, {'x': big_string}, {'x': big_string}] db.collection_0.insert(successful_insert, w=1) self.assertEqual(4, db.collection_0.count()) db.collection_0.drop() # Test that inserts fail after first error. insert_second_fails = [{'_id': 'id0', 'x': big_string}, {'_id': 'id0', 'x': big_string}, {'_id': 'id1', 'x': big_string}, {'_id': 'id2', 'x': big_string}] with self.assertRaises(DuplicateKeyError): db.collection_1.insert(insert_second_fails) self.assertEqual(1, db.collection_1.count()) db.collection_1.drop() # 2 batches, 2nd insert fails, don't continue on error. self.assertTrue(db.collection_2.insert(insert_second_fails, w=0)) wait_until(lambda: 1 == db.collection_2.count(), 'insert 1 document', timeout=60) db.collection_2.drop() # 2 batches, ids of docs 0 and 1 are dupes, ids of docs 2 and 3 are # dupes. Acknowledged, continue on error. insert_two_failures = [{'_id': 'id0', 'x': big_string}, {'_id': 'id0', 'x': big_string}, {'_id': 'id1', 'x': big_string}, {'_id': 'id1', 'x': big_string}] with self.assertRaises(OperationFailure) as context: db.collection_3.insert(insert_two_failures, continue_on_error=True, w=1) self.assertIn('id1', str(context.exception)) # Only the first and third documents should be inserted. self.assertEqual(2, db.collection_3.count()) db.collection_3.drop() # 2 batches, 2 errors, unacknowledged, continue on error. db.collection_4.insert(insert_two_failures, continue_on_error=True, w=0) # Only the first and third documents are inserted. wait_until(lambda: 2 == db.collection_4.count(), 'insert 2 documents', timeout=60) db.collection_4.drop() def test_bad_dbref(self): # Requires the legacy API to test. c = self.db.test c.drop() # Incomplete DBRefs. self.assertRaises( InvalidDocument, c.insert_one, {'ref': {'$ref': 'collection'}}) self.assertRaises( InvalidDocument, c.insert_one, {'ref': {'$id': ObjectId()}}) ref_only = {'ref': {'$ref': 'collection'}} id_only = {'ref': {'$id': ObjectId()}} def test_update(self): # Tests legacy update. db = self.db db.drop_collection("test") id1 = db.test.save({"x": 5}) db.test.update({}, {"$inc": {"x": 1}}) self.assertEqual(db.test.find_one(id1)["x"], 6) id2 = db.test.save({"x": 1}) db.test.update({"x": 6}, {"$inc": {"x": 1}}) self.assertEqual(db.test.find_one(id1)["x"], 7) self.assertEqual(db.test.find_one(id2)["x"], 1) def test_update_manipulate(self): # Tests legacy update. db = self.db db.drop_collection("test") db.test.insert({'_id': 1}) db.test.update({'_id': 1}, {'a': 1}, manipulate=True) self.assertEqual( {'_id': 1, 'a': 1}, db.test.find_one()) class AddField(SONManipulator): def transform_incoming(self, son, dummy): son['field'] = 'value' return son db.add_son_manipulator(AddField()) db.test.update({'_id': 1}, {'a': 2}, manipulate=False) self.assertEqual( {'_id': 1, 'a': 2}, db.test.find_one()) db.test.update({'_id': 1}, {'a': 3}, manipulate=True) self.assertEqual( {'_id': 1, 'a': 3, 'field': 'value'}, db.test.find_one()) def test_update_nmodified(self): # Tests legacy update. db = self.db db.drop_collection("test") ismaster = self.client.admin.command('ismaster') used_write_commands = (ismaster.get("maxWireVersion", 0) > 1) db.test.insert({'_id': 1}) result = db.test.update({'_id': 1}, {'$set': {'x': 1}}) if used_write_commands: self.assertEqual(1, result['nModified']) else: self.assertFalse('nModified' in result) # x is already 1. result = db.test.update({'_id': 1}, {'$set': {'x': 1}}) if used_write_commands: self.assertEqual(0, result['nModified']) else: self.assertFalse('nModified' in result) def test_multi_update(self): # Tests legacy update. db = self.db db.drop_collection("test") db.test.save({"x": 4, "y": 3}) db.test.save({"x": 5, "y": 5}) db.test.save({"x": 4, "y": 4}) db.test.update({"x": 4}, {"$set": {"y": 5}}, multi=True) self.assertEqual(3, db.test.count()) for doc in db.test.find(): self.assertEqual(5, doc["y"]) self.assertEqual(2, db.test.update({"x": 4}, {"$set": {"y": 6}}, multi=True)["n"]) def test_upsert(self): # Tests legacy update. db = self.db db.drop_collection("test") db.test.update({"page": "/"}, {"$inc": {"count": 1}}, upsert=True) db.test.update({"page": "/"}, {"$inc": {"count": 1}}, upsert=True) self.assertEqual(1, db.test.count()) self.assertEqual(2, db.test.find_one()["count"]) def test_acknowledged_update(self): # Tests legacy update. db = self.db db.drop_collection("test_acknowledged_update") collection = db.test_acknowledged_update collection.create_index("x", unique=True) collection.insert({"x": 5}) _id = collection.insert({"x": 4}) self.assertEqual( None, collection.update({"_id": _id}, {"$inc": {"x": 1}}, w=0)) self.assertRaises(DuplicateKeyError, collection.update, {"_id": _id}, {"$inc": {"x": 1}}) self.assertEqual(1, collection.update({"_id": _id}, {"$inc": {"x": 2}})["n"]) self.assertEqual(0, collection.update({"_id": "foo"}, {"$inc": {"x": 2}})["n"]) db.drop_collection("test_acknowledged_update") def test_update_backward_compat(self): # MongoDB versions >= 2.6.0 don't return the updatedExisting field # and return upsert _id in an array subdocument. This test should # pass regardless of server version or type (mongod/s). # Tests legacy update. c = self.db.test c.drop() oid = ObjectId() res = c.update({'_id': oid}, {'$set': {'a': 'a'}}, upsert=True) self.assertFalse(res.get('updatedExisting')) self.assertEqual(oid, res.get('upserted')) res = c.update({'_id': oid}, {'$set': {'b': 'b'}}) self.assertTrue(res.get('updatedExisting')) def test_save(self): # Tests legacy save. self.db.drop_collection("test_save") collection = self.db.test_save # Save a doc with autogenerated id _id = collection.save({"hello": "world"}) self.assertEqual(collection.find_one()["_id"], _id) self.assertTrue(isinstance(_id, ObjectId)) # Save a doc with explicit id collection.save({"_id": "explicit_id", "hello": "bar"}) doc = collection.find_one({"_id": "explicit_id"}) self.assertEqual(doc['_id'], 'explicit_id') self.assertEqual(doc['hello'], 'bar') # Save docs with _id field already present (shouldn't create new docs) self.assertEqual(2, collection.count()) collection.save({'_id': _id, 'hello': 'world'}) self.assertEqual(2, collection.count()) collection.save({'_id': 'explicit_id', 'hello': 'baz'}) self.assertEqual(2, collection.count()) self.assertEqual( 'baz', collection.find_one({'_id': 'explicit_id'})['hello'] ) # Acknowledged mode. collection.create_index("hello", unique=True) # No exception, even though we duplicate the first doc's "hello" value collection.save({'_id': 'explicit_id', 'hello': 'world'}, w=0) self.assertRaises( DuplicateKeyError, collection.save, {'_id': 'explicit_id', 'hello': 'world'}) self.db.drop_collection("test") def test_save_with_invalid_key(self): if client_context.version.at_least(3, 5, 8): raise SkipTest("MongoDB >= 3.5.8 allows dotted fields in updates") # Tests legacy save. self.db.drop_collection("test") self.assertTrue(self.db.test.insert({"hello": "world"})) doc = self.db.test.find_one() doc['a.b'] = 'c' self.assertRaises(OperationFailure, self.db.test.save, doc) def test_acknowledged_save(self): # Tests legacy save. db = self.db db.drop_collection("test_acknowledged_save") collection = db.test_acknowledged_save collection.create_index("hello", unique=True) collection.save({"hello": "world"}) collection.save({"hello": "world"}, w=0) self.assertRaises(DuplicateKeyError, collection.save, {"hello": "world"}) db.drop_collection("test_acknowledged_save") def test_save_adds_id(self): # Tests legacy save. doc = {"hello": "jesse"} self.db.test.save(doc) self.assertTrue("_id" in doc) def test_save_returns_id(self): doc = {"hello": "jesse"} _id = self.db.test.save(doc) self.assertTrue(isinstance(_id, ObjectId)) self.assertEqual(_id, doc["_id"]) doc["hi"] = "bernie" _id = self.db.test.save(doc) self.assertTrue(isinstance(_id, ObjectId)) self.assertEqual(_id, doc["_id"]) def test_remove_one(self): # Tests legacy remove. self.db.test.remove() self.assertEqual(0, self.db.test.count()) self.db.test.insert({"x": 1}) self.db.test.insert({"y": 1}) self.db.test.insert({"z": 1}) self.assertEqual(3, self.db.test.count()) self.db.test.remove(multi=False) self.assertEqual(2, self.db.test.count()) self.db.test.remove() self.assertEqual(0, self.db.test.count()) def test_remove_all(self): # Tests legacy remove. self.db.test.remove() self.assertEqual(0, self.db.test.count()) self.db.test.insert({"x": 1}) self.db.test.insert({"y": 1}) self.assertEqual(2, self.db.test.count()) self.db.test.remove() self.assertEqual(0, self.db.test.count()) def test_remove_non_objectid(self): # Tests legacy remove. db = self.db db.drop_collection("test") db.test.insert_one({"_id": 5}) self.assertEqual(1, db.test.count()) db.test.remove(5) self.assertEqual(0, db.test.count()) def test_write_large_document(self): # Tests legacy insert, save, and update. max_size = self.db.client.max_bson_size half_size = int(max_size / 2) self.assertEqual(max_size, 16777216) self.assertRaises(OperationFailure, self.db.test.insert, {"foo": "x" * max_size}) self.assertRaises(OperationFailure, self.db.test.save, {"foo": "x" * max_size}) self.assertRaises(OperationFailure, self.db.test.insert, [{"x": 1}, {"foo": "x" * max_size}]) self.db.test.insert([{"foo": "x" * half_size}, {"foo": "x" * half_size}]) self.db.test.insert({"bar": "x"}) # Use w=0 here to test legacy doc size checking in all server versions self.assertRaises(DocumentTooLarge, self.db.test.update, {"bar": "x"}, {"bar": "x" * (max_size - 14)}, w=0) # This will pass with OP_UPDATE or the update command. self.db.test.update({"bar": "x"}, {"bar": "x" * (max_size - 32)}) def test_last_error_options(self): # Tests legacy write methods. self.db.test.save({"x": 1}, w=1, wtimeout=1) self.db.test.insert({"x": 1}, w=1, wtimeout=1) self.db.test.remove({"x": 1}, w=1, wtimeout=1) self.db.test.update({"x": 1}, {"y": 2}, w=1, wtimeout=1) if client_context.replica_set_name: # client_context.w is the number of hosts in the replica set w = client_context.w + 1 # MongoDB 2.8+ raises error code 100, CannotSatisfyWriteConcern, # if w > number of members. Older versions just time out after 1 ms # as if they had enough secondaries but some are lagging. They # return an error with 'wtimeout': True and no code. def wtimeout_err(f, *args, **kwargs): try: f(*args, **kwargs) except WTimeoutError as exc: self.assertIsNotNone(exc.details) except OperationFailure as exc: self.assertIsNotNone(exc.details) self.assertEqual(100, exc.code, "Unexpected error: %r" % exc) else: self.fail("%s should have failed" % f) coll = self.db.test wtimeout_err(coll.save, {"x": 1}, w=w, wtimeout=1) wtimeout_err(coll.insert, {"x": 1}, w=w, wtimeout=1) wtimeout_err(coll.update, {"x": 1}, {"y": 2}, w=w, wtimeout=1) wtimeout_err(coll.remove, {"x": 1}, w=w, wtimeout=1) # can't use fsync and j options together self.assertRaises(ConfigurationError, self.db.test.insert, {"_id": 1}, j=True, fsync=True) def test_find_and_modify(self): c = self.db.test c.drop() c.insert({'_id': 1, 'i': 1}) # Test that we raise DuplicateKeyError when appropriate. c.ensure_index('i', unique=True) self.assertRaises(DuplicateKeyError, c.find_and_modify, query={'i': 1, 'j': 1}, update={'$set': {'k': 1}}, upsert=True) c.drop_indexes() # Test correct findAndModify self.assertEqual({'_id': 1, 'i': 1}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}})) self.assertEqual({'_id': 1, 'i': 3}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True)) self.assertEqual({'_id': 1, 'i': 3}, c.find_and_modify({'_id': 1}, remove=True)) self.assertEqual(None, c.find_one({'_id': 1})) self.assertEqual(None, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}})) self.assertEqual(None, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, upsert=True)) self.assertEqual({'_id': 1, 'i': 2}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, upsert=True, new=True)) self.assertEqual({'_id': 1, 'i': 2}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, fields=['i'])) self.assertEqual({'_id': 1, 'i': 4}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, fields={'i': 1})) # Test with full_response=True. result = c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, upsert=True, full_response=True, fields={'i': 1}) self.assertEqual({'_id': 1, 'i': 5}, result["value"]) self.assertEqual(True, result["lastErrorObject"]["updatedExisting"]) result = c.find_and_modify({'_id': 2}, {'$inc': {'i': 1}}, new=True, upsert=True, full_response=True, fields={'i': 1}) self.assertEqual({'_id': 2, 'i': 1}, result["value"]) self.assertEqual(False, result["lastErrorObject"]["updatedExisting"]) class ExtendedDict(dict): pass result = c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, fields={'i': 1}) self.assertFalse(isinstance(result, ExtendedDict)) c = self.db.get_collection( "test", codec_options=CodecOptions(document_class=ExtendedDict)) result = c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, fields={'i': 1}) self.assertTrue(isinstance(result, ExtendedDict)) def test_find_and_modify_with_sort(self): c = self.db.test c.drop() for j in range(5): c.insert({'j': j, 'i': 0}) sort = {'j': DESCENDING} self.assertEqual(4, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = {'j': ASCENDING} self.assertEqual(0, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = [('j', DESCENDING)] self.assertEqual(4, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = [('j', ASCENDING)] self.assertEqual(0, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = SON([('j', DESCENDING)]) self.assertEqual(4, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = SON([('j', ASCENDING)]) self.assertEqual(0, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) try: from collections import OrderedDict sort = OrderedDict([('j', DESCENDING)]) self.assertEqual(4, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) sort = OrderedDict([('j', ASCENDING)]) self.assertEqual(0, c.find_and_modify({}, {'$inc': {'i': 1}}, sort=sort)['j']) except ImportError: pass # Test that a standard dict with two keys is rejected. sort = {'j': DESCENDING, 'foo': DESCENDING} self.assertRaises(TypeError, c.find_and_modify, {}, {'$inc': {'i': 1}}, sort=sort) def test_find_and_modify_with_manipulator(self): class AddCollectionNameManipulator(SONManipulator): def will_copy(self): return True def transform_incoming(self, son, dummy): copy = SON(son) if 'collection' in copy: del copy['collection'] return copy def transform_outgoing(self, son, collection): copy = SON(son) copy['collection'] = collection.name return copy db = self.client.pymongo_test db.add_son_manipulator(AddCollectionNameManipulator()) c = db.test c.drop() c.insert({'_id': 1, 'i': 1}) # Test correct findAndModify # With manipulators self.assertEqual({'_id': 1, 'i': 1, 'collection': 'test'}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, manipulate=True)) self.assertEqual({'_id': 1, 'i': 3, 'collection': 'test'}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True, manipulate=True)) # With out manipulators self.assertEqual({'_id': 1, 'i': 3}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}})) self.assertEqual({'_id': 1, 'i': 5}, c.find_and_modify({'_id': 1}, {'$inc': {'i': 1}}, new=True)) @client_context.require_version_max(4, 1, 0, -1) def test_group(self): db = self.db db.drop_collection("test") self.assertEqual([], db.test.group([], {}, {"count": 0}, "function (obj, prev) { prev.count++; }" )) db.test.insert_many([{"a": 2}, {"b": 5}, {"a": 1}]) self.assertEqual([{"count": 3}], db.test.group([], {}, {"count": 0}, "function (obj, prev) { prev.count++; }" )) self.assertEqual([{"count": 1}], db.test.group([], {"a": {"$gt": 1}}, {"count": 0}, "function (obj, prev) { prev.count++; }" )) db.test.insert_one({"a": 2, "b": 3}) self.assertEqual([{"a": 2, "count": 2}, {"a": None, "count": 1}, {"a": 1, "count": 1}], db.test.group(["a"], {}, {"count": 0}, "function (obj, prev) { prev.count++; }" )) # modifying finalize self.assertEqual([{"a": 2, "count": 3}, {"a": None, "count": 2}, {"a": 1, "count": 2}], db.test.group(["a"], {}, {"count": 0}, "function (obj, prev) " "{ prev.count++; }", "function (obj) { obj.count++; }")) # returning finalize self.assertEqual([2, 1, 1], db.test.group(["a"], {}, {"count": 0}, "function (obj, prev) " "{ prev.count++; }", "function (obj) { return obj.count; }")) # keyf self.assertEqual([2, 2], db.test.group("function (obj) { if (obj.a == 2) " "{ return {a: true} }; " "return {b: true}; }", {}, {"count": 0}, "function (obj, prev) " "{ prev.count++; }", "function (obj) { return obj.count; }")) # no key self.assertEqual([{"count": 4}], db.test.group(None, {}, {"count": 0}, "function (obj, prev) { prev.count++; }" )) self.assertRaises(OperationFailure, db.test.group, [], {}, {}, "5 ++ 5") @client_context.require_version_max(4, 1, 0, -1) def test_group_with_scope(self): db = self.db db.drop_collection("test") db.test.insert_many([{"a": 1}, {"b": 1}]) reduce_function = "function (obj, prev) { prev.count += inc_value; }" self.assertEqual(2, db.test.group([], {}, {"count": 0}, Code(reduce_function, {"inc_value": 1}))[0]['count']) self.assertEqual(4, db.test.group([], {}, {"count": 0}, Code(reduce_function, {"inc_value": 2}))[0]['count']) self.assertEqual(1, db.test.group([], {}, {"count": 0}, Code(reduce_function, {"inc_value": 0.5}))[0]['count']) self.assertEqual(2, db.test.group( [], {}, {"count": 0}, Code(reduce_function, {"inc_value": 1}))[0]['count']) self.assertEqual(4, db.test.group( [], {}, {"count": 0}, Code(reduce_function, {"inc_value": 2}))[0]['count']) self.assertEqual(1, db.test.group( [], {}, {"count": 0}, Code(reduce_function, {"inc_value": 0.5}))[0]['count']) @client_context.require_version_max(4, 1, 0, -1) def test_group_uuid_representation(self): db = self.db coll = db.uuid coll.drop() uu = uuid.uuid4() coll.insert_one({"_id": uu, "a": 2}) coll.insert_one({"_id": uuid.uuid4(), "a": 1}) reduce = "function (obj, prev) { prev.count++; }" coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual([], coll.group([], {"_id": uu}, {"count": 0}, reduce)) coll = self.db.get_collection( "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) self.assertEqual([{"count": 1}], coll.group([], {"_id": uu}, {"count": 0}, reduce)) def test_last_status(self): # Tests many legacy API elements. # We must call getlasterror on same socket as the last operation. db = rs_or_single_client(maxPoolSize=1).pymongo_test collection = db.test_last_status collection.remove({}) collection.save({"i": 1}) collection.update({"i": 1}, {"$set": {"i": 2}}, w=0) # updatedExisting is always false on mongos after an OP_MSG # unacknowledged write. if not (client_context.version >= (3, 6) and client_context.is_mongos): self.assertTrue(db.last_status()["updatedExisting"]) wait_until(lambda: collection.find_one({"i": 2}), "found updated w=0 doc") collection.update({"i": 1}, {"$set": {"i": 500}}, w=0) self.assertFalse(db.last_status()["updatedExisting"]) def test_auto_ref_and_deref(self): # Legacy API. db = self.client.pymongo_test db.add_son_manipulator(AutoReference(db)) db.add_son_manipulator(NamespaceInjector()) db.test.a.remove({}) db.test.b.remove({}) db.test.c.remove({}) a = {"hello": u"world"} db.test.a.save(a) b = {"test": a} db.test.b.save(b) c = {"another test": b} db.test.c.save(c) a["hello"] = "mike" db.test.a.save(a) self.assertEqual(db.test.a.find_one(), a) self.assertEqual(db.test.b.find_one()["test"], a) self.assertEqual(db.test.c.find_one()["another test"]["test"], a) self.assertEqual(db.test.b.find_one(), b) self.assertEqual(db.test.c.find_one()["another test"], b) self.assertEqual(db.test.c.find_one(), c) def test_auto_ref_and_deref_list(self): # Legacy API. db = self.client.pymongo_test db.add_son_manipulator(AutoReference(db)) db.add_son_manipulator(NamespaceInjector()) db.drop_collection("users") db.drop_collection("messages") message_1 = {"title": "foo"} db.messages.save(message_1) message_2 = {"title": "bar"} db.messages.save(message_2) user = {"messages": [message_1, message_2]} db.users.save(user) db.messages.update(message_1, {"title": "buzz"}) self.assertEqual("buzz", db.users.find_one()["messages"][0]["title"]) self.assertEqual("bar", db.users.find_one()["messages"][1]["title"]) def test_object_to_dict_transformer(self): # PYTHON-709: Some users rely on their custom SONManipulators to run # before any other checks, so they can insert non-dict objects and # have them dictified before the _id is inserted or any other # processing. # Tests legacy API elements. class Thing(object): def __init__(self, value): self.value = value class ThingTransformer(SONManipulator): def transform_incoming(self, thing, dummy): return {'value': thing.value} db = self.client.foo db.add_son_manipulator(ThingTransformer()) t = Thing('value') db.test.remove() db.test.insert([t]) out = db.test.find_one() self.assertEqual('value', out.get('value')) def test_son_manipulator_outgoing(self): class Thing(object): def __init__(self, value): self.value = value class ThingTransformer(SONManipulator): def transform_outgoing(self, doc, collection): # We don't want this applied to the command return # value in pymongo.cursor.Cursor. if 'value' in doc: return Thing(doc['value']) return doc db = self.client.foo db.add_son_manipulator(ThingTransformer()) db.test.delete_many({}) db.test.insert_one({'value': 'value'}) out = db.test.find_one() self.assertTrue(isinstance(out, Thing)) self.assertEqual('value', out.value) out = next(db.test.aggregate([], cursor={})) self.assertTrue(isinstance(out, Thing)) self.assertEqual('value', out.value) def test_son_manipulator_inheritance(self): # Tests legacy API elements. class Thing(object): def __init__(self, value): self.value = value class ThingTransformer(SONManipulator): def transform_incoming(self, thing, dummy): return {'value': thing.value} def transform_outgoing(self, son, dummy): return Thing(son['value']) class Child(ThingTransformer): pass db = self.client.foo db.add_son_manipulator(Child()) t = Thing('value') db.test.remove() db.test.insert([t]) out = db.test.find_one() self.assertTrue(isinstance(out, Thing)) self.assertEqual('value', out.value) def test_disabling_manipulators(self): class IncByTwo(SONManipulator): def transform_outgoing(self, son, collection): if 'foo' in son: son['foo'] += 2 return son db = self.client.pymongo_test db.add_son_manipulator(IncByTwo()) c = db.test c.drop() c.insert({'foo': 0}) self.assertEqual(2, c.find_one()['foo']) self.assertEqual(0, c.find_one(manipulate=False)['foo']) self.assertEqual(2, c.find_one(manipulate=True)['foo']) c.drop() def test_manipulator_properties(self): db = self.client.foo self.assertEqual([], db.incoming_manipulators) self.assertEqual([], db.incoming_copying_manipulators) self.assertEqual([], db.outgoing_manipulators) self.assertEqual([], db.outgoing_copying_manipulators) db.add_son_manipulator(AutoReference(db)) db.add_son_manipulator(NamespaceInjector()) db.add_son_manipulator(ObjectIdShuffler()) self.assertEqual(1, len(db.incoming_manipulators)) self.assertEqual(db.incoming_manipulators, ['NamespaceInjector']) self.assertEqual(2, len(db.incoming_copying_manipulators)) for name in db.incoming_copying_manipulators: self.assertTrue(name in ('ObjectIdShuffler', 'AutoReference')) self.assertEqual([], db.outgoing_manipulators) self.assertEqual(['AutoReference'], db.outgoing_copying_manipulators) def test_ensure_index(self): db = self.db self.assertRaises(TypeError, db.test.ensure_index, {"hello": 1}) self.assertRaises(TypeError, db.test.ensure_index, {"hello": 1}, cache_for='foo') db.test.drop_indexes() self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye")) db.test.drop_indexes() self.assertEqual("foo", db.test.ensure_index("goodbye", name="foo")) self.assertEqual(None, db.test.ensure_index("goodbye", name="foo")) db.test.drop_indexes() self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye")) db.test.drop_index("goodbye_1") self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye")) db.drop_collection("test") self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye")) db.test.drop_index("goodbye_1") self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) self.assertEqual(None, db.test.ensure_index("goodbye")) db.test.drop_index("goodbye_1") self.assertEqual("goodbye_1", db.test.ensure_index("goodbye", cache_for=1)) time.sleep(1.2) self.assertEqual("goodbye_1", db.test.ensure_index("goodbye")) # Make sure the expiration time is updated. self.assertEqual(None, db.test.ensure_index("goodbye")) # Clean up indexes for later tests db.test.drop_indexes() @client_context.require_version_max(4, 1) # PYTHON-1734 def test_ensure_index_threaded(self): coll = self.db.threaded_index_creation index_docs = [] class Indexer(threading.Thread): def run(self): coll.ensure_index('foo0') coll.ensure_index('foo1') coll.ensure_index('foo2') index_docs.append(coll.index_information()) try: threads = [] for _ in range(10): t = Indexer() t.setDaemon(True) threads.append(t) for thread in threads: thread.start() joinall(threads) first = index_docs[0] for index_doc in index_docs[1:]: self.assertEqual(index_doc, first) finally: coll.drop() def test_ensure_purge_index_threaded(self): coll = self.db.threaded_index_creation class Indexer(threading.Thread): def run(self): coll.ensure_index('foo') try: coll.drop_index('foo') except OperationFailure: # The index may have already been dropped. pass coll.ensure_index('foo') coll.drop_indexes() coll.create_index('foo') try: threads = [] for _ in range(10): t = Indexer() t.setDaemon(True) threads.append(t) for thread in threads: thread.start() joinall(threads) self.assertTrue('foo_1' in coll.index_information()) finally: coll.drop() @client_context.require_version_max(4, 1) # PYTHON-1734 def test_ensure_unique_index_threaded(self): coll = self.db.test_unique_threaded coll.drop() coll.insert_many([{'foo': i} for i in range(10000)]) class Indexer(threading.Thread): def run(self): try: coll.ensure_index('foo', unique=True) coll.insert_one({'foo': 'bar'}) coll.insert_one({'foo': 'bar'}) except OperationFailure: pass threads = [] for _ in range(10): t = Indexer() t.setDaemon(True) threads.append(t) for i in range(10): threads[i].start() joinall(threads) self.assertEqual(10001, coll.count()) coll.drop() def test_kill_cursors_with_cursoraddress(self): coll = self.client.pymongo_test.test coll.drop() coll.insert_many([{'_id': i} for i in range(200)]) cursor = coll.find().batch_size(1) next(cursor) self.client.kill_cursors( [cursor.cursor_id], _CursorAddress(self.client.address, coll.full_name)) # Prevent killcursors from reaching the server while a getmore is in # progress -- the server logs "Assertion: 16089:Cannot kill active # cursor." time.sleep(2) def raises_cursor_not_found(): try: next(cursor) return False except CursorNotFound: return True wait_until(raises_cursor_not_found, 'close cursor') def test_kill_cursors_with_tuple(self): # Some evergreen distros (Debian 7.1) still test against 3.6.5 where # OP_KILL_CURSORS does not work. if (client_context.is_mongos and client_context.auth_enabled and (3, 6, 0) <= client_context.version < (3, 6, 6)): raise SkipTest("SERVER-33553 This server version does not support " "OP_KILL_CURSORS") coll = self.client.pymongo_test.test coll.drop() coll.insert_many([{'_id': i} for i in range(200)]) cursor = coll.find().batch_size(1) next(cursor) self.client.kill_cursors( [cursor.cursor_id], self.client.address) # Prevent killcursors from reaching the server while a getmore is in # progress -- the server logs "Assertion: 16089:Cannot kill active # cursor." time.sleep(2) def raises_cursor_not_found(): try: next(cursor) return False except CursorNotFound: return True wait_until(raises_cursor_not_found, 'close cursor') class TestLegacyBulk(BulkTestBase): @classmethod def setUpClass(cls): super(TestLegacyBulk, cls).setUpClass() cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() def test_empty(self): bulk = self.coll.initialize_ordered_bulk_op() self.assertRaises(InvalidOperation, bulk.execute) def test_find(self): # find() requires a selector. bulk = self.coll.initialize_ordered_bulk_op() self.assertRaises(TypeError, bulk.find) self.assertRaises(TypeError, bulk.find, 'foo') # No error. bulk.find({}) @client_context.require_version_min(3, 1, 9, -1) def test_bypass_document_validation_bulk_op(self): # Test insert self.coll.insert_one({"z": 0}) self.db.command(SON([("collMod", "test"), ("validator", {"z": {"$gte": 0}})])) bulk = self.coll.initialize_ordered_bulk_op( bypass_document_validation=False) bulk.insert({"z": -1}) # error self.assertRaises(BulkWriteError, bulk.execute) self.assertEqual(0, self.coll.count({"z": -1})) bulk = self.coll.initialize_ordered_bulk_op( bypass_document_validation=True) bulk.insert({"z": -1}) bulk.execute() self.assertEqual(1, self.coll.count({"z": -1})) self.coll.insert_one({"z": 0}) self.db.command(SON([("collMod", "test"), ("validator", {"z": {"$gte": 0}})])) bulk = self.coll.initialize_unordered_bulk_op( bypass_document_validation=False) bulk.insert({"z": -1}) # error self.assertRaises(BulkWriteError, bulk.execute) self.assertEqual(1, self.coll.count({"z": -1})) bulk = self.coll.initialize_unordered_bulk_op( bypass_document_validation=True) bulk.insert({"z": -1}) bulk.execute() self.assertEqual(2, self.coll.count({"z": -1})) self.coll.drop() def test_insert(self): expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 1, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } bulk = self.coll.initialize_ordered_bulk_op() self.assertRaises(TypeError, bulk.insert, 1) # find() before insert() is prohibited. self.assertRaises(AttributeError, lambda: bulk.find({}).insert({})) # We don't allow multiple documents per call. self.assertRaises(TypeError, bulk.insert, [{}, {}]) self.assertRaises(TypeError, bulk.insert, ({} for _ in range(2))) bulk.insert({}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(1, self.coll.count()) doc = self.coll.find_one() self.assertTrue(oid_generated_on_process(doc['_id'])) bulk = self.coll.initialize_unordered_bulk_op() bulk.insert({}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(2, self.coll.count()) def test_insert_check_keys(self): bulk = self.coll.initialize_ordered_bulk_op() bulk.insert({'$dollar': 1}) self.assertRaises(InvalidDocument, bulk.execute) bulk = self.coll.initialize_ordered_bulk_op() bulk.insert({'a.b': 1}) self.assertRaises(InvalidDocument, bulk.execute) def test_update(self): expected = { 'nMatched': 2, 'nModified': 2, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_ordered_bulk_op() # update() requires find() first. self.assertRaises( AttributeError, lambda: bulk.update({'$set': {'x': 1}})) self.assertRaises(TypeError, bulk.find({}).update, 1) self.assertRaises(ValueError, bulk.find({}).update, {}) # All fields must be $-operators. self.assertRaises(ValueError, bulk.find({}).update, {'foo': 'bar'}) bulk.find({}).update({'$set': {'foo': 'bar'}}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.find({'foo': 'bar'}).count(), 2) # All fields must be $-operators -- validated server-side. bulk = self.coll.initialize_ordered_bulk_op() updates = SON([('$set', {'x': 1}), ('y', 1)]) bulk.find({}).update(updates) self.assertRaises(BulkWriteError, bulk.execute) self.coll.delete_many({}) self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({}).update({'$set': {'bim': 'baz'}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 2, 'nModified': 2, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(self.coll.find({'bim': 'baz'}).count(), 2) self.coll.insert_one({'x': 1}) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({'x': 1}).update({'$set': {'x': 42}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(1, self.coll.find({'x': 42}).count()) # Second time, x is already 42 so nModified is 0. bulk = self.coll.initialize_unordered_bulk_op() bulk.find({'x': 42}).update({'$set': {'x': 42}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) def test_update_one(self): expected = { 'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_ordered_bulk_op() # update_one() requires find() first. self.assertRaises( AttributeError, lambda: bulk.update_one({'$set': {'x': 1}})) self.assertRaises(TypeError, bulk.find({}).update_one, 1) self.assertRaises(ValueError, bulk.find({}).update_one, {}) self.assertRaises(ValueError, bulk.find({}).update_one, {'foo': 'bar'}) bulk.find({}).update_one({'$set': {'foo': 'bar'}}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.find({'foo': 'bar'}).count(), 1) self.coll.delete_many({}) self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({}).update_one({'$set': {'bim': 'baz'}}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.find({'bim': 'baz'}).count(), 1) # All fields must be $-operators -- validated server-side. bulk = self.coll.initialize_ordered_bulk_op() updates = SON([('$set', {'x': 1}), ('y', 1)]) bulk.find({}).update_one(updates) self.assertRaises(BulkWriteError, bulk.execute) def test_replace_one(self): expected = { 'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_ordered_bulk_op() self.assertRaises(TypeError, bulk.find({}).replace_one, 1) self.assertRaises(ValueError, bulk.find({}).replace_one, {'$set': {'foo': 'bar'}}) bulk.find({}).replace_one({'foo': 'bar'}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.find({'foo': 'bar'}).count(), 1) self.coll.delete_many({}) self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({}).replace_one({'bim': 'baz'}) result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.find({'bim': 'baz'}).count(), 1) def test_remove(self): # Test removing all documents, ordered. expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 2, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_ordered_bulk_op() # remove() must be preceded by find(). self.assertRaises(AttributeError, lambda: bulk.remove()) bulk.find({}).remove() result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.count(), 0) # Test removing some documents, ordered. self.coll.insert_many([{}, {'x': 1}, {}, {'x': 1}]) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({'x': 1}).remove() result = bulk.execute() self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 2, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(self.coll.count(), 2) self.coll.delete_many({}) # Test removing all documents, unordered. self.coll.insert_many([{}, {}]) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({}).remove() result = bulk.execute() self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 2, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) # Test removing some documents, unordered. self.assertEqual(self.coll.count(), 0) self.coll.insert_many([{}, {'x': 1}, {}, {'x': 1}]) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({'x': 1}).remove() result = bulk.execute() self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 2, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(self.coll.count(), 2) self.coll.delete_many({}) def test_remove_one(self): bulk = self.coll.initialize_ordered_bulk_op() # remove_one() must be preceded by find(). self.assertRaises(AttributeError, lambda: bulk.remove_one()) # Test removing one document, empty selector. # First ordered, then unordered. self.coll.insert_many([{}, {}]) expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 1, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': [] } bulk.find({}).remove_one() result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.count(), 1) self.coll.insert_one({}) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({}).remove_one() result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual(self.coll.count(), 1) # Test removing one document, with a selector. # First ordered, then unordered. self.coll.insert_one({'x': 1}) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({'x': 1}).remove_one() result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual([{}], list(self.coll.find({}, {'_id': False}))) self.coll.insert_one({'x': 1}) bulk = self.coll.initialize_unordered_bulk_op() bulk.find({'x': 1}).remove_one() result = bulk.execute() self.assertEqualResponse(expected, result) self.assertEqual([{}], list(self.coll.find({}, {'_id': False}))) def test_upsert(self): bulk = self.coll.initialize_ordered_bulk_op() # upsert() requires find() first. self.assertRaises( AttributeError, lambda: bulk.upsert()) expected = { 'nMatched': 0, 'nModified': 0, 'nUpserted': 1, 'nInserted': 0, 'nRemoved': 0, 'upserted': [{'index': 0, '_id': '...'}] } bulk.find({}).upsert().replace_one({'foo': 'bar'}) result = bulk.execute() self.assertEqualResponse(expected, result) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({}).upsert().update_one({'$set': {'bim': 'baz'}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(self.coll.find({'bim': 'baz'}).count(), 1) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({}).upsert().update({'$set': {'bim': 'bop'}}) # Non-upsert, no matches. bulk.find({'x': 1}).update({'$set': {'x': 2}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 1, 'nUpserted': 0, 'nInserted': 0, 'nRemoved': 0, 'upserted': [], 'writeErrors': [], 'writeConcernErrors': []}, result) self.assertEqual(self.coll.find({'bim': 'bop'}).count(), 1) self.assertEqual(self.coll.find({'x': 2}).count(), 0) def test_upsert_large(self): big = 'a' * (client_context.client.max_bson_size - 37) bulk = self.coll.initialize_ordered_bulk_op() bulk.find({'x': 1}).upsert().update({'$set': {'s': big}}) result = bulk.execute() self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 1, 'nInserted': 0, 'nRemoved': 0, 'upserted': [{'index': 0, '_id': '...'}]}, result) self.assertEqual(1, self.coll.find({'x': 1}).count()) def test_client_generated_upsert_id(self): batch = self.coll.initialize_ordered_bulk_op() batch.find({'_id': 0}).upsert().update_one({'$set': {'a': 0}}) batch.find({'a': 1}).upsert().replace_one({'_id': 1}) # This is just here to make the counts right in all cases. batch.find({'_id': 2}).upsert().replace_one({'_id': 2}) result = batch.execute() self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 3, 'nInserted': 0, 'nRemoved': 0, 'upserted': [{'index': 0, '_id': 0}, {'index': 1, '_id': 1}, {'index': 2, '_id': 2}]}, result) def test_single_ordered_batch(self): batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1}) batch.find({'a': 1}).update_one({'$set': {'b': 1}}) batch.find({'a': 2}).upsert().update_one({'$set': {'b': 2}}) batch.insert({'a': 3}) batch.find({'a': 3}).remove() result = batch.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 1, 'nUpserted': 1, 'nInserted': 2, 'nRemoved': 1, 'upserted': [{'index': 2, '_id': '...'}]}, result) def test_single_error_ordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) batch = self.coll.initialize_ordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.find({'b': 2}).upsert().update_one({'$set': {'a': 1}}) batch.insert({'b': 3, 'a': 2}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 1, 'nRemoved': 0, 'upserted': [], 'writeConcernErrors': [], 'writeErrors': [ {'index': 1, 'code': 11000, 'errmsg': '...', 'op': {'q': {'b': 2}, 'u': {'$set': {'a': 1}}, 'multi': False, 'upsert': True}}]}, result) def test_multiple_error_ordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) batch = self.coll.initialize_ordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.find({'b': 2}).upsert().update_one({'$set': {'a': 1}}) batch.find({'b': 3}).upsert().update_one({'$set': {'a': 2}}) batch.find({'b': 2}).upsert().update_one({'$set': {'a': 1}}) batch.insert({'b': 4, 'a': 3}) batch.insert({'b': 5, 'a': 1}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 1, 'nRemoved': 0, 'upserted': [], 'writeConcernErrors': [], 'writeErrors': [ {'index': 1, 'code': 11000, 'errmsg': '...', 'op': {'q': {'b': 2}, 'u': {'$set': {'a': 1}}, 'multi': False, 'upsert': True}}]}, result) def test_single_unordered_batch(self): batch = self.coll.initialize_unordered_bulk_op() batch.insert({'a': 1}) batch.find({'a': 1}).update_one({'$set': {'b': 1}}) batch.find({'a': 2}).upsert().update_one({'$set': {'b': 2}}) batch.insert({'a': 3}) batch.find({'a': 3}).remove() result = batch.execute() self.assertEqualResponse( {'nMatched': 1, 'nModified': 1, 'nUpserted': 1, 'nInserted': 2, 'nRemoved': 1, 'upserted': [{'index': 2, '_id': '...'}], 'writeErrors': [], 'writeConcernErrors': []}, result) def test_single_error_unordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) batch = self.coll.initialize_unordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.find({'b': 2}).upsert().update_one({'$set': {'a': 1}}) batch.insert({'b': 3, 'a': 2}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 2, 'nRemoved': 0, 'upserted': [], 'writeConcernErrors': [], 'writeErrors': [ {'index': 1, 'code': 11000, 'errmsg': '...', 'op': {'q': {'b': 2}, 'u': {'$set': {'a': 1}}, 'multi': False, 'upsert': True}}]}, result) def test_multiple_error_unordered_batch(self): self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) batch = self.coll.initialize_unordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.find({'b': 2}).upsert().update_one({'$set': {'a': 3}}) batch.find({'b': 3}).upsert().update_one({'$set': {'a': 4}}) batch.find({'b': 4}).upsert().update_one({'$set': {'a': 3}}) batch.insert({'b': 5, 'a': 2}) batch.insert({'b': 6, 'a': 1}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") # Assume the update at index 1 runs before the update at index 3, # although the spec does not require it. Same for inserts. self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 2, 'nInserted': 2, 'nRemoved': 0, 'upserted': [ {'index': 1, '_id': '...'}, {'index': 2, '_id': '...'}], 'writeConcernErrors': [], 'writeErrors': [ {'index': 3, 'code': 11000, 'errmsg': '...', 'op': {'q': {'b': 4}, 'u': {'$set': {'a': 3}}, 'multi': False, 'upsert': True}}, {'index': 5, 'code': 11000, 'errmsg': '...', 'op': {'_id': '...', 'b': 6, 'a': 1}}]}, result) def test_large_inserts_ordered(self): big = 'x' * self.coll.database.client.max_bson_size batch = self.coll.initialize_ordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.insert({'big': big}) batch.insert({'b': 2, 'a': 2}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqual(1, result['nInserted']) self.coll.delete_many({}) big = 'x' * (1024 * 1024 * 4) batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1, 'big': big}) batch.insert({'a': 2, 'big': big}) batch.insert({'a': 3, 'big': big}) batch.insert({'a': 4, 'big': big}) batch.insert({'a': 5, 'big': big}) batch.insert({'a': 6, 'big': big}) result = batch.execute() self.assertEqual(6, result['nInserted']) self.assertEqual(6, self.coll.count()) def test_large_inserts_unordered(self): big = 'x' * self.coll.database.client.max_bson_size batch = self.coll.initialize_unordered_bulk_op() batch.insert({'b': 1, 'a': 1}) batch.insert({'big': big}) batch.insert({'b': 2, 'a': 2}) try: batch.execute() except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqual(2, result['nInserted']) self.coll.delete_many({}) big = 'x' * (1024 * 1024 * 4) batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1, 'big': big}) batch.insert({'a': 2, 'big': big}) batch.insert({'a': 3, 'big': big}) batch.insert({'a': 4, 'big': big}) batch.insert({'a': 5, 'big': big}) batch.insert({'a': 6, 'big': big}) result = batch.execute() self.assertEqual(6, result['nInserted']) self.assertEqual(6, self.coll.count()) def test_numerous_inserts(self): # Ensure we don't exceed server's 1000-document batch size limit. n_docs = 2100 batch = self.coll.initialize_unordered_bulk_op() for _ in range(n_docs): batch.insert({}) result = batch.execute() self.assertEqual(n_docs, result['nInserted']) self.assertEqual(n_docs, self.coll.count()) # Same with ordered bulk. self.coll.delete_many({}) batch = self.coll.initialize_ordered_bulk_op() for _ in range(n_docs): batch.insert({}) result = batch.execute() self.assertEqual(n_docs, result['nInserted']) self.assertEqual(n_docs, self.coll.count()) def test_multiple_execution(self): batch = self.coll.initialize_ordered_bulk_op() batch.insert({}) batch.execute() self.assertRaises(InvalidOperation, batch.execute) def test_generator_insert(self): def gen(): yield {'a': 1, 'b': 1} yield {'a': 1, 'b': 2} yield {'a': 2, 'b': 3} yield {'a': 3, 'b': 5} yield {'a': 5, 'b': 8} result = self.coll.insert_many(gen()) self.assertEqual(5, len(result.inserted_ids)) class TestLegacyBulkNoResults(BulkTestBase): @classmethod def setUpClass(cls): super(TestLegacyBulkNoResults, cls).setUpClass() cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() def tearDown(self): self.coll.delete_many({}) def test_no_results_ordered_success(self): batch = self.coll.initialize_ordered_bulk_op() batch.insert({'_id': 1}) batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() self.assertTrue(batch.execute({'w': 0}) is None) wait_until(lambda: 2 == self.coll.count(), 'insert 2 documents') wait_until(lambda: self.coll.find_one({'_id': 1}) is None, 'removed {"_id": 1}') def test_no_results_ordered_failure(self): batch = self.coll.initialize_ordered_bulk_op() batch.insert({'_id': 1}) batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) # Fails with duplicate key error. batch.insert({'_id': 1}) # Should not be executed since the batch is ordered. batch.find({'_id': 1}).remove_one() self.assertTrue(batch.execute({'w': 0}) is None) wait_until(lambda: 3 == self.coll.count(), 'insert 3 documents') self.assertEqual({'_id': 1}, self.coll.find_one({'_id': 1})) def test_no_results_unordered_success(self): batch = self.coll.initialize_unordered_bulk_op() batch.insert({'_id': 1}) batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) batch.find({'_id': 1}).remove_one() self.assertTrue(batch.execute({'w': 0}) is None) wait_until(lambda: 2 == self.coll.count(), 'insert 2 documents') wait_until(lambda: self.coll.find_one({'_id': 1}) is None, 'removed {"_id": 1}') def test_no_results_unordered_failure(self): batch = self.coll.initialize_unordered_bulk_op() batch.insert({'_id': 1}) batch.find({'_id': 3}).upsert().update_one({'$set': {'b': 1}}) batch.insert({'_id': 2}) # Fails with duplicate key error. batch.insert({'_id': 1}) # Should be executed since the batch is unordered. batch.find({'_id': 1}).remove_one() self.assertTrue(batch.execute({'w': 0}) is None) wait_until(lambda: 2 == self.coll.count(), 'insert 2 documents') wait_until(lambda: self.coll.find_one({'_id': 1}) is None, 'removed {"_id": 1}') class TestLegacyBulkWriteConcern(BulkTestBase): @classmethod def setUpClass(cls): super(TestLegacyBulkWriteConcern, cls).setUpClass() cls.w = client_context.w cls.secondary = None if cls.w > 1: for member in client_context.ismaster['hosts']: if member != client_context.ismaster['primary']: cls.secondary = single_client(*partition_node(member)) break # We tested wtimeout errors by specifying a write concern greater than # the number of members, but in MongoDB 2.7.8+ this causes a different # sort of error, "Not enough data-bearing nodes". In recent servers we # use a failpoint to pause replication on a secondary. cls.need_replication_stopped = client_context.version.at_least(2, 7, 8) cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() if cls.secondary: cls.secondary.close() def cause_wtimeout(self, batch): if self.need_replication_stopped: if not client_context.test_commands_enabled: raise SkipTest("Test commands must be enabled.") self.secondary.admin.command('configureFailPoint', 'rsSyncApplyStop', mode='alwaysOn') try: return batch.execute({'w': self.w, 'wtimeout': 1}) finally: self.secondary.admin.command('configureFailPoint', 'rsSyncApplyStop', mode='off') else: return batch.execute({'w': self.w + 1, 'wtimeout': 1}) def test_fsync_and_j(self): batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1}) self.assertRaises( ConfigurationError, batch.execute, {'fsync': True, 'j': True}) @client_context.require_replica_set def test_write_concern_failure_ordered(self): # Ensure we don't raise on wnote. batch = self.coll.initialize_ordered_bulk_op() batch.find({"something": "that does no exist"}).remove() self.assertTrue(batch.execute({"w": self.w})) batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1}) batch.insert({'a': 2}) # Replication wtimeout is a 'soft' error. # It shouldn't stop batch processing. try: self.cause_wtimeout(batch) except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 0, 'nInserted': 2, 'nRemoved': 0, 'upserted': [], 'writeErrors': []}, result) # When talking to legacy servers there will be a # write concern error for each operation. self.assertTrue(len(result['writeConcernErrors']) > 0) failed = result['writeConcernErrors'][0] self.assertEqual(64, failed['code']) self.assertTrue(isinstance(failed['errmsg'], string_type)) self.coll.delete_many({}) self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) # Fail due to write concern support as well # as duplicate key error on ordered batch. batch = self.coll.initialize_ordered_bulk_op() batch.insert({'a': 1}) batch.find({'a': 3}).upsert().replace_one({'b': 1}) batch.insert({'a': 1}) batch.insert({'a': 2}) try: self.cause_wtimeout(batch) except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqualResponse( {'nMatched': 0, 'nModified': 0, 'nUpserted': 1, 'nInserted': 1, 'nRemoved': 0, 'upserted': [{'index': 1, '_id': '...'}], 'writeErrors': [ {'index': 2, 'code': 11000, 'errmsg': '...', 'op': {'_id': '...', 'a': 1}}]}, result) self.assertTrue(len(result['writeConcernErrors']) > 1) failed = result['writeErrors'][0] self.assertTrue("duplicate" in failed['errmsg']) @client_context.require_replica_set def test_write_concern_failure_unordered(self): # Ensure we don't raise on wnote. batch = self.coll.initialize_unordered_bulk_op() batch.find({"something": "that does no exist"}).remove() self.assertTrue(batch.execute({"w": self.w})) batch = self.coll.initialize_unordered_bulk_op() batch.insert({'a': 1}) batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3, 'b': 1}}) batch.insert({'a': 2}) # Replication wtimeout is a 'soft' error. # It shouldn't stop batch processing. try: self.cause_wtimeout(batch) except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqual(2, result['nInserted']) self.assertEqual(1, result['nUpserted']) self.assertEqual(0, len(result['writeErrors'])) # When talking to legacy servers there will be a # write concern error for each operation. self.assertTrue(len(result['writeConcernErrors']) > 1) self.coll.delete_many({}) self.coll.create_index('a', unique=True) self.addCleanup(self.coll.drop_index, [('a', 1)]) # Fail due to write concern support as well # as duplicate key error on unordered batch. batch = self.coll.initialize_unordered_bulk_op() batch.insert({'a': 1}) batch.find({'a': 3}).upsert().update_one({'$set': {'a': 3, 'b': 1}}) batch.insert({'a': 1}) batch.insert({'a': 2}) try: self.cause_wtimeout(batch) except BulkWriteError as exc: result = exc.details self.assertEqual(exc.code, 65) else: self.fail("Error not raised") self.assertEqual(2, result['nInserted']) self.assertEqual(1, result['nUpserted']) self.assertEqual(1, len(result['writeErrors'])) # When talking to legacy servers there will be a # write concern error for each operation. self.assertTrue(len(result['writeConcernErrors']) > 1) failed = result['writeErrors'][0] self.assertEqual(2, failed['index']) self.assertEqual(11000, failed['code']) self.assertTrue(isinstance(failed['errmsg'], string_type)) self.assertEqual(1, failed['op']['a']) failed = result['writeConcernErrors'][0] self.assertEqual(64, failed['code']) self.assertTrue(isinstance(failed['errmsg'], string_type)) upserts = result['upserted'] self.assertEqual(1, len(upserts)) self.assertEqual(1, upserts[0]['index']) self.assertTrue(upserts[0].get('_id')) class TestLegacyBulkAuthorization(BulkAuthorizationTestBase): @classmethod def setUpClass(cls): super(TestLegacyBulkAuthorization, cls).setUpClass() cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. cli = rs_or_single_client_noauth() db = cli.pymongo_test coll = db.test db.authenticate('readonly', 'pw') bulk = coll.initialize_ordered_bulk_op() bulk.insert({'x': 1}) self.assertRaises(OperationFailure, bulk.execute) def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. cli = rs_or_single_client_noauth() db = cli.pymongo_test coll = db.test db.authenticate('noremove', 'pw') bulk = coll.initialize_ordered_bulk_op() bulk.insert({'x': 1}) bulk.find({'x': 2}).upsert().replace_one({'x': 2}) bulk.find({}).remove() # Prohibited. bulk.insert({'x': 3}) # Never attempted. self.assertRaises(OperationFailure, bulk.execute) self.assertEqual(set([1, 2]), set(self.coll.distinct('x'))) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_max_staleness.py000066400000000000000000000136201374256237000205350ustar00rootroot00000000000000# Copyright 2016 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test maxStalenessSeconds support.""" import os import sys import time import warnings sys.path[0:0] = [""] from pymongo import MongoClient from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector from test import client_context, unittest from test.utils import rs_or_single_client from test.utils_selection_tests import create_selection_tests # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'max_staleness') class TestAllScenarios(create_selection_tests(_TEST_PATH)): pass class TestMaxStaleness(unittest.TestCase): def test_max_staleness(self): client = MongoClient() self.assertEqual(-1, client.read_preference.max_staleness) client = MongoClient("mongodb://a/?readPreference=secondary") self.assertEqual(-1, client.read_preference.max_staleness) # These tests are specified in max-staleness-tests.rst. with self.assertRaises(ConfigurationError): # Default read pref "primary" can't be used with max staleness. MongoClient("mongodb://a/?maxStalenessSeconds=120") with self.assertRaises(ConfigurationError): # Read pref "primary" can't be used with max staleness. MongoClient("mongodb://a/?readPreference=primary&" "maxStalenessSeconds=120") client = MongoClient("mongodb://host/?maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) client = MongoClient("mongodb://host/?readPreference=primary&" "maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) client = MongoClient("mongodb://host/?readPreference=secondary&" "maxStalenessSeconds=120") self.assertEqual(120, client.read_preference.max_staleness) client = MongoClient("mongodb://a/?readPreference=secondary&" "maxStalenessSeconds=1") self.assertEqual(1, client.read_preference.max_staleness) client = MongoClient("mongodb://a/?readPreference=secondary&" "maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest") self.assertEqual(-1, client.read_preference.max_staleness) with self.assertRaises(TypeError): # Prohibit None. MongoClient(maxStalenessSeconds=None, readPreference="nearest") def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") self.assertIn("must be an integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5" "&readPreference=nearest") # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) self.assertIn("must be an integer", str(ctx[0])) def test_max_staleness_zero(self): # Zero is too small. with self.assertRaises(ValueError) as ctx: rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") self.assertIn("must be a positive integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") client = MongoClient("mongodb://host/?maxStalenessSeconds=0" "&readPreference=nearest") # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) self.assertIn("must be a positive integer", str(ctx[0])) @client_context.require_version_min(3, 3, 6) # SERVER-8858 @client_context.require_replica_set def test_last_write_date(self): # From max-staleness-tests.rst, "Parse lastWriteDate". client = rs_or_single_client(heartbeatFrequencyMS=500) client.pymongo_test.test.insert_one({}) # Wait for the server description to be updated. time.sleep(1) server = client._topology.select_server(writable_server_selector) first = server.description.last_write_date self.assertTrue(first) # The first last_write_date may correspond to a internal server write, # sleep so that the next write does not occur within the same second. time.sleep(1) client.pymongo_test.test.insert_one({}) # Wait for the server description to be updated. time.sleep(1) server = client._topology.select_server(writable_server_selector) second = server.description.last_write_date self.assertGreater(second, first) self.assertLess(second, first + 10) @client_context.require_version_max(3, 3) def test_last_write_date_absent(self): # From max-staleness-tests.rst, "Absent lastWriteDate". client = rs_or_single_client() sd = client._topology.select_server(writable_server_selector) self.assertIsNone(sd.description.last_write_date) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_mongos_load_balancing.py000066400000000000000000000146001374256237000221650ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test MongoClient's mongos load balancing using a mock.""" import sys import threading sys.path[0:0] = [""] from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector from pymongo.topology_description import TOPOLOGY_TYPE from test import unittest, client_context, MockClientTest from test.pymongo_mocks import MockClient from test.utils import connected, wait_until @client_context.require_connection def setUpModule(): pass class SimpleOp(threading.Thread): def __init__(self, client): super(SimpleOp, self).__init__() self.client = client self.passed = False def run(self): self.client.db.command('ismaster') self.passed = True # No exception raised. def do_simple_op(client, nthreads): threads = [SimpleOp(client) for _ in range(nthreads)] for t in threads: t.start() for t in threads: t.join() for t in threads: assert t.passed def writable_addresses(topology): return set(server.description.address for server in topology.select_servers(writable_server_selector)) class TestMongosLoadBalancing(MockClientTest): def mock_client(self, **kwargs): mock_client = MockClient( standalones=[], members=[], mongoses=['a:1', 'b:2', 'c:3'], host='a:1,b:2,c:3', connect=False, **kwargs) self.addCleanup(mock_client.close) # Latencies in seconds. mock_client.mock_rtts['a:1'] = 0.020 mock_client.mock_rtts['b:2'] = 0.025 mock_client.mock_rtts['c:3'] = 0.045 return mock_client def test_lazy_connect(self): # While connected() ensures we can trigger connection from the main # thread and wait for the monitors, this test triggers connection from # several threads at once to check for data races. nthreads = 10 client = self.mock_client() self.assertEqual(0, len(client.nodes)) # Trigger initial connection. do_simple_op(client, nthreads) wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') def test_reconnect(self): nthreads = 10 client = connected(self.mock_client()) # connected() ensures we've contacted at least one mongos. Wait for # all of them. wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') # Trigger reconnect. client.close() do_simple_op(client, nthreads) wait_until(lambda: len(client.nodes) == 3, 'reconnect to all mongoses') def test_failover(self): nthreads = 10 client = connected(self.mock_client(localThresholdMS=0.001)) wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') # Our chosen mongos goes down. client.kill_host('a:1') # Trigger failover to higher-latency nodes. AutoReconnect should be # raised at most once in each thread. passed = [] def f(): try: client.db.command('ismaster') except AutoReconnect: # Second attempt succeeds. client.db.command('ismaster') passed.append(True) threads = [threading.Thread(target=f) for _ in range(nthreads)] for t in threads: t.start() for t in threads: t.join() self.assertEqual(nthreads, len(passed)) # Down host removed from list. self.assertEqual(2, len(client.nodes)) def test_local_threshold(self): client = connected(self.mock_client(localThresholdMS=30)) self.assertEqual(30, client.local_threshold_ms) wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') topology = client._topology # All are within a 30-ms latency window, see self.mock_client(). self.assertEqual(set([('a', 1), ('b', 2), ('c', 3)]), writable_addresses(topology)) # No error client.admin.command('ismaster') client = connected(self.mock_client(localThresholdMS=0)) self.assertEqual(0, client.local_threshold_ms) # No error client.db.command('ismaster') # Our chosen mongos goes down. client.kill_host('%s:%s' % next(iter(client.nodes))) try: client.db.command('ismaster') except: pass # We eventually connect to a new mongos. def connect_to_new_mongos(): try: return client.db.command('ismaster') except AutoReconnect: pass wait_until(connect_to_new_mongos, 'connect to a new mongos') def test_load_balancing(self): # Although the server selection JSON tests already prove that # select_servers works for sharded topologies, here we do an end-to-end # test of discovering servers' round trip times and configuring # localThresholdMS. client = connected(self.mock_client()) wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') # Prohibited for topology type Sharded. with self.assertRaises(InvalidOperation): client.address topology = client._topology self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) # a and b are within the 15-ms latency window, see self.mock_client(). self.assertEqual(set([('a', 1), ('b', 2)]), writable_addresses(topology)) client.mock_rtts['a:1'] = 0.045 # Discover only b is within latency window. wait_until(lambda: set([('b', 2)]) == writable_addresses(topology), 'discover server "a" is too far') if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_monitor.py000066400000000000000000000050661374256237000173630ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the monitor module.""" import gc import sys from functools import partial sys.path[0:0] = [""] from pymongo.periodic_executor import _EXECUTORS from test import unittest, IntegrationTest from test.utils import (connected, ServerAndTopologyEventListener, single_client, wait_until) def unregistered(ref): gc.collect() return ref not in _EXECUTORS def get_executors(client): executors = [] for server in client._topology._servers.values(): executors.append(server._monitor._executor) executors.append(server._monitor._rtt_monitor._executor) executors.append(client._kill_cursors_executor) executors.append(client._topology._Topology__events_executor) return [e for e in executors if e is not None] def create_client(): listener = ServerAndTopologyEventListener() client = single_client(event_listeners=[listener]) connected(client) return client class TestMonitor(IntegrationTest): def test_cleanup_executors_on_client_del(self): client = create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) # Each executor stores a weakref to itself in _EXECUTORS. executor_refs = [ (r, r()._name) for r in _EXECUTORS.copy() if r() in executors] del executors del client for ref, name in executor_refs: wait_until(partial(unregistered, ref), 'unregister executor: %s' % (name,), timeout=5) def test_cleanup_executors_on_client_close(self): client = create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) client.close() for executor in executors: wait_until(lambda: executor._stopped, 'closed executor: %s' % (executor._name,), timeout=5) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_monitoring.py000066400000000000000000002150201374256237000200520ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import datetime import sys import time import warnings sys.path[0:0] = [""] from bson.objectid import ObjectId from bson.py3compat import text_type from bson.son import SON from pymongo import CursorType, monitoring, InsertOne, UpdateOne, DeleteOne from pymongo.command_cursor import CommandCursor from pymongo.errors import (AutoReconnect, NotMasterError, OperationFailure) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern from test import (client_context, client_knobs, PyMongoTestCase, sanitize_cmd, unittest) from test.utils import (EventListener, get_pool, ignore_deprecations, rs_or_single_client, single_client, wait_until) class TestCommandMonitoring(PyMongoTestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = EventListener() cls.client = rs_or_single_client( event_listeners=[cls.listener], retryWrites=False) @classmethod def tearDownClass(cls): cls.client.close() def tearDown(self): self.listener.results.clear() def test_started_simple(self): self.client.pymongo_test.command('ismaster') results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(SON([('ismaster', 1)]), started.command) self.assertEqual('ismaster', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) def test_succeeded_simple(self): self.client.pymongo_test.command('ismaster') results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertEqual('ismaster', succeeded.command_name) self.assertEqual(self.client.address, succeeded.connection_id) self.assertEqual(1, succeeded.reply.get('ok')) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertTrue(isinstance(succeeded.duration_micros, int)) def test_failed_simple(self): try: self.client.pymongo_test.command('oops!') except OperationFailure: pass results = self.listener.results started = results['started'][0] failed = results['failed'][0] self.assertEqual(0, len(results['succeeded'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertTrue( isinstance(failed, monitoring.CommandFailedEvent)) self.assertEqual('oops!', failed.command_name) self.assertEqual(self.client.address, failed.connection_id) self.assertEqual(0, failed.failure.get('ok')) self.assertTrue(isinstance(failed.request_id, int)) self.assertTrue(isinstance(failed.duration_micros, int)) def test_find_one(self): self.client.pymongo_test.test.find_one() results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('find', 'test'), ('filter', {}), ('limit', 1), ('singleBatch', True)]), started.command) self.assertEqual('find', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) def test_find_and_get_more(self): self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_many([{} for _ in range(10)]) self.listener.results.clear() cursor = self.client.pymongo_test.test.find( projection={'_id': False}, batch_size=4) for _ in range(4): next(cursor) cursor_id = cursor.cursor_id results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('find', 'test'), ('filter', {}), ('projection', {'_id': False}), ('batchSize', 4)]), started.command) self.assertEqual('find', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('find', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) csr = succeeded.reply["cursor"] self.assertEqual(csr["id"], cursor_id) self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(csr["firstBatch"], [{} for _ in range(4)]) self.listener.results.clear() # Next batch. Exhausting the cursor could cause a getMore # that returns id of 0 and no results. next(cursor) try: results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('getMore', cursor_id), ('collection', 'test'), ('batchSize', 4)]), started.command) self.assertEqual('getMore', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('getMore', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) csr = succeeded.reply["cursor"] self.assertEqual(csr["id"], cursor_id) self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(csr["nextBatch"], [{} for _ in range(4)]) finally: # Exhaust the cursor to avoid kill cursors. tuple(cursor) def test_find_with_explain(self): cmd = SON([('explain', SON([('find', 'test'), ('filter', {})]))]) self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() coll = self.client.pymongo_test.test # Test that we publish the unwrapped command. if self.client.is_mongos: coll = coll.with_options( read_preference=ReadPreference.PRIMARY_PREFERRED) res = coll.find().explain() results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(cmd, started.command) self.assertEqual('explain', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('explain', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(self.client.address, succeeded.connection_id) self.assertEqual(res, succeeded.reply) def _test_find_options(self, query, expected_cmd): coll = self.client.pymongo_test.test coll.drop() coll.create_index('x') coll.insert_many([{'x': i} for i in range(5)]) # Test that we publish the unwrapped command. self.listener.results.clear() if self.client.is_mongos: coll = coll.with_options( read_preference=ReadPreference.PRIMARY_PREFERRED) cursor = coll.find(**query) next(cursor) try: results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(expected_cmd, started.command) self.assertEqual('find', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('find', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(self.client.address, succeeded.connection_id) finally: # Exhaust the cursor to avoid kill cursors. tuple(cursor) def test_find_options(self): query = dict(filter={}, hint=[('x', 1)], max_time_ms=10000, max={'x': 10}, min={'x': -10}, return_key=True, show_record_id=True, projection={'x': False}, skip=1, no_cursor_timeout=True, sort=[('_id', 1)], allow_partial_results=True, comment='this is a test', batch_size=2) cmd = dict(find='test', filter={}, hint=SON([('x', 1)]), comment='this is a test', maxTimeMS=10000, max={'x': 10}, min={'x': -10}, returnKey=True, showRecordId=True, sort=SON([('_id', 1)]), projection={'x': False}, skip=1, batchSize=2, noCursorTimeout=True, allowPartialResults=True) if client_context.version < (4, 1, 0, -1): query['max_scan'] = 10 cmd['maxScan'] = 10 self._test_find_options(query, cmd) @client_context.require_version_max(3, 7, 2) def test_find_snapshot(self): # Test "snapshot" parameter separately, can't combine with "sort". query = dict(filter={}, snapshot=True) cmd = dict(find='test', filter={}, snapshot=True) self._test_find_options(query, cmd) def test_command_and_get_more(self): self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_many( [{'x': 1} for _ in range(10)]) self.listener.results.clear() coll = self.client.pymongo_test.test # Test that we publish the unwrapped command. if self.client.is_mongos: coll = coll.with_options( read_preference=ReadPreference.PRIMARY_PREFERRED) cursor = coll.aggregate( [{'$project': {'_id': False, 'x': 1}}], batchSize=4) for _ in range(4): next(cursor) cursor_id = cursor.cursor_id results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('aggregate', 'test'), ('pipeline', [{'$project': {'_id': False, 'x': 1}}]), ('cursor', {'batchSize': 4})]), started.command) self.assertEqual('aggregate', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('aggregate', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_cursor = {'id': cursor_id, 'ns': 'pymongo_test.test', 'firstBatch': [{'x': 1} for _ in range(4)]} self.assertEqualCommand(expected_cursor, succeeded.reply.get('cursor')) self.listener.results.clear() next(cursor) try: results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('getMore', cursor_id), ('collection', 'test'), ('batchSize', 4)]), started.command) self.assertEqual('getMore', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('getMore', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_result = { 'cursor': {'id': cursor_id, 'ns': 'pymongo_test.test', 'nextBatch': [{'x': 1} for _ in range(4)]}, 'ok': 1.0} self.assertEqualReply(expected_result, succeeded.reply) finally: # Exhaust the cursor to avoid kill cursors. tuple(cursor) def test_get_more_failure(self): address = self.client.address coll = self.client.pymongo_test.test cursor_doc = {"id": 12345, "firstBatch": [], "ns": coll.full_name} cursor = CommandCursor(coll, cursor_doc, address) try: next(cursor) except Exception: pass results = self.listener.results started = results['started'][0] self.assertEqual(0, len(results['succeeded'])) failed = results['failed'][0] self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('getMore', 12345), ('collection', 'test')]), started.command) self.assertEqual('getMore', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(failed, monitoring.CommandFailedEvent)) self.assertTrue(isinstance(failed.duration_micros, int)) self.assertEqual('getMore', failed.command_name) self.assertTrue(isinstance(failed.request_id, int)) self.assertEqual(cursor.address, failed.connection_id) self.assertEqual(0, failed.failure.get("ok")) @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_not_master_error(self): address = next(iter(client_context.client.secondaries)) client = single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. client.admin.command('ismaster') self.listener.results.clear() error = None try: client.pymongo_test.test.find_one_and_delete({}) except NotMasterError as exc: error = exc.errors results = self.listener.results started = results['started'][0] failed = results['failed'][0] self.assertEqual(0, len(results['succeeded'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertTrue( isinstance(failed, monitoring.CommandFailedEvent)) self.assertEqual('findAndModify', failed.command_name) self.assertEqual(address, failed.connection_id) self.assertEqual(0, failed.failure.get('ok')) self.assertTrue(isinstance(failed.request_id, int)) self.assertTrue(isinstance(failed.duration_micros, int)) self.assertEqual(error, failed.failure) @client_context.require_no_mongos def test_exhaust(self): self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_many([{} for _ in range(10)]) self.listener.results.clear() cursor = self.client.pymongo_test.test.find( projection={'_id': False}, batch_size=5, cursor_type=CursorType.EXHAUST) next(cursor) cursor_id = cursor.cursor_id results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('find', 'test'), ('filter', {}), ('projection', {'_id': False}), ('batchSize', 5)]), started.command) self.assertEqual('find', started.command_name) self.assertEqual(cursor.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('find', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_result = { 'cursor': {'id': cursor_id, 'ns': 'pymongo_test.test', 'firstBatch': [{} for _ in range(5)]}, 'ok': 1} self.assertEqualReply(expected_result, succeeded.reply) self.listener.results.clear() tuple(cursor) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( SON([('getMore', cursor_id), ('collection', 'test'), ('batchSize', 5)]), started.command) self.assertEqual('getMore', started.command_name) self.assertEqual(cursor.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('getMore', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_result = { 'cursor': {'id': 0, 'ns': 'pymongo_test.test', 'nextBatch': [{} for _ in range(5)]}, 'ok': 1} self.assertEqualReply(expected_result, succeeded.reply) def test_kill_cursors(self): with client_knobs(kill_cursor_frequency=0.01): self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_many([{} for _ in range(10)]) cursor = self.client.pymongo_test.test.find().batch_size(5) next(cursor) cursor_id = cursor.cursor_id self.listener.results.clear() cursor.close() time.sleep(2) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) # There could be more than one cursor_id here depending on # when the thread last ran. self.assertIn(cursor_id, started.command['cursors']) self.assertEqual('killCursors', started.command_name) self.assertIs(type(started.connection_id), tuple) self.assertEqual(cursor.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) self.assertEqual('killCursors', succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertIs(type(succeeded.connection_id), tuple) self.assertEqual(cursor.address, succeeded.connection_id) # There could be more than one cursor_id here depending on # when the thread last ran. self.assertTrue(cursor_id in succeeded.reply['cursorsUnknown'] or cursor_id in succeeded.reply['cursorsKilled']) def test_non_bulk_writes(self): coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() # Implied write concern insert_one res = coll.insert_one({'x': 1}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': res.inserted_id, 'x': 1}])]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # Unacknowledged insert_one self.listener.results.clear() coll = coll.with_options(write_concern=WriteConcern(w=0)) res = coll.insert_one({'x': 1}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': res.inserted_id, 'x': 1}]), ('writeConcern', {'w': 0})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertEqualReply(succeeded.reply, {'ok': 1}) # Explicit write concern insert_one self.listener.results.clear() coll = coll.with_options(write_concern=WriteConcern(w=1)) res = coll.insert_one({'x': 1}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': res.inserted_id, 'x': 1}]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # delete_many self.listener.results.clear() res = coll.delete_many({'x': 1}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('delete', coll.name), ('ordered', True), ('deletes', [SON([('q', {'x': 1}), ('limit', 0)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('delete', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(res.deleted_count, reply.get('n')) # replace_one self.listener.results.clear() oid = ObjectId() res = coll.replace_one({'_id': oid}, {'_id': oid, 'x': 1}, upsert=True) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'_id': oid}), ('u', {'_id': oid, 'x': 1}), ('multi', False), ('upsert', True)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) self.assertEqual([{'index': 0, '_id': oid}], reply.get('upserted')) # update_one self.listener.results.clear() res = coll.update_one({'x': 1}, {'$inc': {'x': 1}}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'x': 1}), ('u', {'$inc': {'x': 1}}), ('multi', False), ('upsert', False)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # update_many self.listener.results.clear() res = coll.update_many({'x': 2}, {'$inc': {'x': 1}}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'x': 2}), ('u', {'$inc': {'x': 1}}), ('multi', True), ('upsert', False)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # delete_one self.listener.results.clear() res = coll.delete_one({'x': 3}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('delete', coll.name), ('ordered', True), ('deletes', [SON([('q', {'x': 3}), ('limit', 1)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('delete', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) self.assertEqual(0, coll.count_documents({})) # write errors coll.insert_one({'_id': 1}) try: self.listener.results.clear() coll.insert_one({'_id': 1}) except OperationFailure: pass results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': 1}]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(0, reply.get('n')) errors = reply.get('writeErrors') self.assertIsInstance(errors, list) error = errors[0] self.assertEqual(0, error.get('index')) self.assertIsInstance(error.get('code'), int) self.assertIsInstance(error.get('errmsg'), text_type) def test_legacy_writes(self): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() # Implied write concern insert _id = coll.insert({'x': 1}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': _id, 'x': 1}])]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # Unacknowledged insert self.listener.results.clear() _id = coll.insert({'x': 1}, w=0) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': _id, 'x': 1}]), ('writeConcern', {'w': 0})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertEqual(succeeded.reply, {'ok': 1}) # Explicit write concern insert self.listener.results.clear() _id = coll.insert({'x': 1}, w=1) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': _id, 'x': 1}]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('insert', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # remove all self.listener.results.clear() res = coll.remove({'x': 1}, w=1) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('delete', coll.name), ('ordered', True), ('deletes', [SON([('q', {'x': 1}), ('limit', 0)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('delete', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(res['n'], reply.get('n')) # upsert self.listener.results.clear() oid = ObjectId() coll.update({'_id': oid}, {'_id': oid, 'x': 1}, upsert=True, w=1) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'_id': oid}), ('u', {'_id': oid, 'x': 1}), ('multi', False), ('upsert', True)])]), ('writeConcern', {'w': 1})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) self.assertEqual([{'index': 0, '_id': oid}], reply.get('upserted')) # update one self.listener.results.clear() coll.update({'x': 1}, {'$inc': {'x': 1}}) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'x': 1}), ('u', {'$inc': {'x': 1}}), ('multi', False), ('upsert', False)])])]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # update many self.listener.results.clear() coll.update({'x': 2}, {'$inc': {'x': 1}}, multi=True) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'x': 2}), ('u', {'$inc': {'x': 1}}), ('multi', True), ('upsert', False)])])]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('update', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) # remove one self.listener.results.clear() coll.remove({'x': 3}, multi=False) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('delete', coll.name), ('ordered', True), ('deletes', [SON([('q', {'x': 3}), ('limit', 1)])])]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('delete', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply self.assertEqual(1, reply.get('ok')) self.assertEqual(1, reply.get('n')) self.assertEqual(0, coll.count_documents({})) def test_insert_many(self): # This always uses the bulk API. coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() big = 'x' * (1024 * 1024 * 4) docs = [{'_id': i, 'big': big} for i in range(6)] coll.insert_many(docs) results = self.listener.results started = results['started'] succeeded = results['succeeded'] self.assertEqual(0, len(results['failed'])) documents = [] count = 0 operation_id = started[0].operation_id self.assertIsInstance(operation_id, int) for start, succeed in zip(started, succeeded): self.assertIsInstance(start, monitoring.CommandStartedEvent) cmd = sanitize_cmd(start.command) self.assertEqual(['insert', 'ordered', 'documents'], list(cmd.keys())) self.assertEqual(coll.name, cmd['insert']) self.assertIs(True, cmd['ordered']) documents.extend(cmd['documents']) self.assertEqual('pymongo_test', start.database_name) self.assertEqual('insert', start.command_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) self.assertIsInstance(succeed.duration_micros, int) self.assertEqual(start.command_name, succeed.command_name) self.assertEqual(start.request_id, succeed.request_id) self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) reply = succeed.reply self.assertEqual(1, reply.get('ok')) count += reply.get('n', 0) self.assertEqual(documents, docs) self.assertEqual(6, count) def test_legacy_insert_many(self): # On legacy servers this uses bulk OP_INSERT. with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() # Force two batches on legacy servers. big = 'x' * (1024 * 1024 * 12) docs = [{'_id': i, 'big': big} for i in range(6)] coll.insert(docs) results = self.listener.results started = results['started'] succeeded = results['succeeded'] self.assertEqual(0, len(results['failed'])) documents = [] count = 0 operation_id = started[0].operation_id self.assertIsInstance(operation_id, int) for start, succeed in zip(started, succeeded): self.assertIsInstance(start, monitoring.CommandStartedEvent) cmd = sanitize_cmd(start.command) self.assertEqual(['insert', 'ordered', 'documents'], list(cmd.keys())) self.assertEqual(coll.name, cmd['insert']) self.assertIs(True, cmd['ordered']) documents.extend(cmd['documents']) self.assertEqual('pymongo_test', start.database_name) self.assertEqual('insert', start.command_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) self.assertIsInstance(succeed.duration_micros, int) self.assertEqual(start.command_name, succeed.command_name) self.assertEqual(start.request_id, succeed.request_id) self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) reply = succeed.reply self.assertEqual(1, reply.get('ok')) count += reply.get('n', 0) self.assertEqual(documents, docs) self.assertEqual(6, count) def test_bulk_write(self): coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() coll.bulk_write([InsertOne({'_id': 1}), UpdateOne({'_id': 1}, {'$set': {'x': 1}}), DeleteOne({'_id': 1})]) results = self.listener.results started = results['started'] succeeded = results['succeeded'] self.assertEqual(0, len(results['failed'])) operation_id = started[0].operation_id pairs = list(zip(started, succeeded)) self.assertEqual(3, len(pairs)) for start, succeed in pairs: self.assertIsInstance(start, monitoring.CommandStartedEvent) self.assertEqual('pymongo_test', start.database_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) self.assertIsInstance(succeed.duration_micros, int) self.assertEqual(start.command_name, succeed.command_name) self.assertEqual(start.request_id, succeed.request_id) self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) expected = SON([('insert', coll.name), ('ordered', True), ('documents', [{'_id': 1}])]) self.assertEqualCommand(expected, started[0].command) expected = SON([('update', coll.name), ('ordered', True), ('updates', [SON([('q', {'_id': 1}), ('u', {'$set': {'x': 1}}), ('multi', False), ('upsert', False)])])]) self.assertEqualCommand(expected, started[1].command) expected = SON([('delete', coll.name), ('ordered', True), ('deletes', [SON([('q', {'_id': 1}), ('limit', 1)])])]) self.assertEqualCommand(expected, started[2].command) @client_context.require_failCommand_fail_point def test_bulk_write_command_network_error(self): coll = self.client.pymongo_test.test self.listener.results.clear() insert_network_error = { 'configureFailPoint': 'failCommand', 'mode': {'times': 1}, 'data': { 'failCommands': ['insert'], 'closeConnection': True, }, } with self.fail_point(insert_network_error): with self.assertRaises(AutoReconnect): coll.bulk_write([InsertOne({'_id': 1})]) failed = self.listener.results['failed'] self.assertEqual(1, len(failed)) event = failed[0] self.assertEqual(event.command_name, 'insert') self.assertIsInstance(event.failure, dict) self.assertEqual(event.failure['errtype'], 'AutoReconnect') self.assertTrue(event.failure['errmsg']) @client_context.require_failCommand_fail_point def test_bulk_write_command_error(self): coll = self.client.pymongo_test.test self.listener.results.clear() insert_command_error = { 'configureFailPoint': 'failCommand', 'mode': {'times': 1}, 'data': { 'failCommands': ['insert'], 'closeConnection': False, 'errorCode': 10107, # NotMaster }, } with self.fail_point(insert_command_error): with self.assertRaises(NotMasterError): coll.bulk_write([InsertOne({'_id': 1})]) failed = self.listener.results['failed'] self.assertEqual(1, len(failed)) event = failed[0] self.assertEqual(event.command_name, 'insert') self.assertIsInstance(event.failure, dict) self.assertEqual(event.failure['code'], 10107) self.assertTrue(event.failure['errmsg']) @client_context.require_version_max(3, 4, 99) def test_bulk_write_legacy_network_error(self): self.listener.results.clear() # Make the delete operation run on a closed connection. self.client.admin.command('ping') pool = get_pool(self.client) sock_info = pool.sockets[0] sock_info.sock.close() # Test legacy unacknowledged write network error. coll = self.client.pymongo_test.get_collection( 'test', write_concern=WriteConcern(w=0)) with self.assertRaises(AutoReconnect): coll.bulk_write([InsertOne({'_id': 1})], ordered=False) failed = self.listener.results['failed'] self.assertEqual(1, len(failed)) event = failed[0] self.assertEqual(event.command_name, 'insert') self.assertIsInstance(event.failure, dict) self.assertEqual(event.failure['errtype'], 'AutoReconnect') self.assertTrue(event.failure['errmsg']) def test_write_errors(self): coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() try: coll.bulk_write([InsertOne({'_id': 1}), InsertOne({'_id': 1}), InsertOne({'_id': 1}), DeleteOne({'_id': 1})], ordered=False) except OperationFailure: pass results = self.listener.results started = results['started'] succeeded = results['succeeded'] self.assertEqual(0, len(results['failed'])) operation_id = started[0].operation_id pairs = list(zip(started, succeeded)) errors = [] for start, succeed in pairs: self.assertIsInstance(start, monitoring.CommandStartedEvent) self.assertEqual('pymongo_test', start.database_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) self.assertIsInstance(succeed.duration_micros, int) self.assertEqual(start.command_name, succeed.command_name) self.assertEqual(start.request_id, succeed.request_id) self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) if 'writeErrors' in succeed.reply: errors.extend(succeed.reply['writeErrors']) self.assertEqual(2, len(errors)) fields = set(['index', 'code', 'errmsg']) for error in errors: self.assertTrue(fields.issubset(set(error))) def test_first_batch_helper(self): # Regardless of server version and use of helpers._first_batch # this test should still pass. self.listener.results.clear() tuple(self.client.pymongo_test.test.list_indexes()) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('listIndexes', 'test'), ('cursor', {})]) self.assertEqualCommand(expected, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('listIndexes', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertTrue('cursor' in succeeded.reply) self.assertTrue('ok' in succeeded.reply) self.listener.results.clear() self.client.pymongo_test.current_op(True) started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = SON([('currentOp', 1), ('$all', True)]) self.assertEqualCommand(expected, started.command) self.assertEqual('admin', started.database_name) self.assertEqual('currentOp', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertTrue('inprog' in succeeded.reply) self.assertTrue('ok' in succeeded.reply) if not client_context.is_mongos: with ignore_deprecations(): self.client.fsync(lock=True) self.listener.results.clear() self.client.unlock() # Wait for async unlock... wait_until( lambda: not self.client.is_locked, "unlock the database") started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) expected = {'fsyncUnlock': 1} self.assertEqualCommand(expected, started.command) self.assertEqual('admin', started.database_name) self.assertEqual('fsyncUnlock', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertIsInstance(succeeded.duration_micros, int) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertTrue('info' in succeeded.reply) self.assertTrue('ok' in succeeded.reply) def test_sensitive_commands(self): listeners = self.client._event_listeners self.listener.results.clear() cmd = SON([("getnonce", 1)]) listeners.publish_command_start( cmd, "pymongo_test", 12345, self.client.address) delta = datetime.timedelta(milliseconds=100) listeners.publish_command_success( delta, {'nonce': 'e474f4561c5eb40b', 'ok': 1.0}, "getnonce", 12345, self.client.address) results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertIsInstance(started, monitoring.CommandStartedEvent) self.assertEqual({}, started.command) self.assertEqual('pymongo_test', started.database_name) self.assertEqual('getnonce', started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) self.assertEqual(succeeded.duration_micros, 100000) self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) self.assertEqual({}, succeeded.reply) class TestGlobalListener(PyMongoTestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) cls.client = single_client() # Get one (authenticated) socket in the pool. cls.client.pymongo_test.command('ismaster') @classmethod def tearDownClass(cls): monitoring._LISTENERS = cls.saved_listeners cls.client.close() def setUp(self): self.listener.results.clear() def test_simple(self): self.client.pymongo_test.command('ismaster') results = self.listener.results started = results['started'][0] succeeded = results['succeeded'][0] self.assertEqual(0, len(results['failed'])) self.assertTrue( isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue( isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(SON([('ismaster', 1)]), started.command) self.assertEqual('ismaster', started.command_name) self.assertEqual(self.client.address, started.connection_id) self.assertEqual('pymongo_test', started.database_name) self.assertTrue(isinstance(started.request_id, int)) class TestEventClasses(PyMongoTestCase): def test_command_event_repr(self): request_id, connection_id, operation_id = 1, ('localhost', 27017), 2 event = monitoring.CommandStartedEvent( {'isMaster': 1}, 'admin', request_id, connection_id, operation_id) self.assertEqual( repr(event), "") delta = datetime.timedelta(milliseconds=100) event = monitoring.CommandSucceededEvent( delta, {'ok': 1}, 'isMaster', request_id, connection_id, operation_id) self.assertEqual( repr(event), "") event = monitoring.CommandFailedEvent( delta, {'ok': 0}, 'isMaster', request_id, connection_id, operation_id) self.assertEqual( repr(event), "") def test_server_heartbeat_event_repr(self): connection_id = ('localhost', 27017) event = monitoring.ServerHeartbeatStartedEvent(connection_id) self.assertEqual( repr(event), "") delta = 0.1 event = monitoring.ServerHeartbeatSucceededEvent( delta, {'ok': 1}, connection_id) self.assertEqual( repr(event), "") event = monitoring.ServerHeartbeatFailedEvent( delta, 'ERROR', connection_id) self.assertEqual( repr(event), "") def test_server_event_repr(self): server_address = ('localhost', 27017) topology_id = ObjectId('000000000000000000000001') event = monitoring.ServerOpeningEvent(server_address, topology_id) self.assertEqual( repr(event), "") event = monitoring.ServerDescriptionChangedEvent( 'PREV', 'NEW', server_address, topology_id) self.assertEqual( repr(event), "") event = monitoring.ServerClosedEvent(server_address, topology_id) self.assertEqual( repr(event), "") def test_topology_event_repr(self): topology_id = ObjectId('000000000000000000000001') event = monitoring.TopologyOpenedEvent(topology_id) self.assertEqual( repr(event), "") event = monitoring.TopologyDescriptionChangedEvent( 'PREV', 'NEW', topology_id) self.assertEqual( repr(event), "") event = monitoring.TopologyClosedEvent(topology_id) self.assertEqual( repr(event), "") if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_monotonic.py000066400000000000000000000023311374256237000176710ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the monotonic module.""" import sys sys.path[0:0] = [""] from pymongo.monotonic import time as pymongo_time from test import unittest class TestMonotonic(unittest.TestCase): def test_monotonic_time(self): try: from monotonic import monotonic self.assertIs(monotonic, pymongo_time) except ImportError: if sys.version_info[:2] >= (3, 3): from time import monotonic self.assertIs(monotonic, pymongo_time) else: from time import time self.assertIs(time, pymongo_time) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_objectid.py000066400000000000000000000203201374256237000174450ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the objectid module.""" import datetime import pickle import struct import sys sys.path[0:0] = [""] from bson.errors import InvalidId from bson.objectid import ObjectId, _MAX_COUNTER_VALUE from bson.py3compat import PY3, _unicode from bson.tz_util import (FixedOffset, utc) from test import SkipTest, unittest from test.utils import oid_generated_on_process def oid(x): return ObjectId() class TestObjectId(unittest.TestCase): def test_creation(self): self.assertRaises(TypeError, ObjectId, 4) self.assertRaises(TypeError, ObjectId, 175.0) self.assertRaises(TypeError, ObjectId, {"test": 4}) self.assertRaises(TypeError, ObjectId, ["something"]) self.assertRaises(InvalidId, ObjectId, "") self.assertRaises(InvalidId, ObjectId, "12345678901") self.assertRaises(InvalidId, ObjectId, "1234567890123") self.assertTrue(ObjectId()) self.assertTrue(ObjectId(b"123456789012")) a = ObjectId() self.assertTrue(ObjectId(a)) def test_unicode(self): a = ObjectId() self.assertEqual(a, ObjectId(_unicode(a))) self.assertEqual(ObjectId("123456789012123456789012"), ObjectId(u"123456789012123456789012")) self.assertRaises(InvalidId, ObjectId, u"hello") def test_from_hex(self): ObjectId("123456789012123456789012") self.assertRaises(InvalidId, ObjectId, "123456789012123456789G12") self.assertRaises(InvalidId, ObjectId, u"123456789012123456789G12") def test_repr_str(self): self.assertEqual(repr(ObjectId("1234567890abcdef12345678")), "ObjectId('1234567890abcdef12345678')") self.assertEqual(str(ObjectId("1234567890abcdef12345678")), "1234567890abcdef12345678") self.assertEqual(str(ObjectId(b"123456789012")), "313233343536373839303132") self.assertEqual(ObjectId("1234567890abcdef12345678").binary, b'\x124Vx\x90\xab\xcd\xef\x124Vx') self.assertEqual(str(ObjectId(b'\x124Vx\x90\xab\xcd\xef\x124Vx')), "1234567890abcdef12345678") def test_equality(self): a = ObjectId() self.assertEqual(a, ObjectId(a)) self.assertEqual(ObjectId(b"123456789012"), ObjectId(b"123456789012")) self.assertNotEqual(ObjectId(), ObjectId()) self.assertNotEqual(ObjectId(b"123456789012"), b"123456789012") # Explicitly test inequality self.assertFalse(a != ObjectId(a)) self.assertFalse(ObjectId(b"123456789012") != ObjectId(b"123456789012")) def test_binary_str_equivalence(self): a = ObjectId() self.assertEqual(a, ObjectId(a.binary)) self.assertEqual(a, ObjectId(str(a))) def test_generation_time(self): d1 = datetime.datetime.utcnow() d2 = ObjectId().generation_time self.assertEqual(utc, d2.tzinfo) d2 = d2.replace(tzinfo=None) self.assertTrue(d2 - d1 < datetime.timedelta(seconds=2)) def test_from_datetime(self): if 'PyPy 1.8.0' in sys.version: # See https://bugs.pypy.org/issue1092 raise SkipTest("datetime.timedelta is broken in pypy 1.8.0") d = datetime.datetime.utcnow() d = d - datetime.timedelta(microseconds=d.microsecond) oid = ObjectId.from_datetime(d) self.assertEqual(d, oid.generation_time.replace(tzinfo=None)) self.assertEqual("0" * 16, str(oid)[8:]) aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) oid = ObjectId.from_datetime(aware) self.assertEqual(as_utc, oid.generation_time) def test_pickling(self): orig = ObjectId() for protocol in [0, 1, 2, -1]: pkl = pickle.dumps(orig, protocol=protocol) self.assertEqual(orig, pickle.loads(pkl)) def test_pickle_backwards_compatability(self): # This string was generated by pickling an ObjectId in pymongo # version 1.9 pickled_with_1_9 = ( b"ccopy_reg\n_reconstructor\np0\n" b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" b"object\np2\nNtp3\nRp4\n" b"(dp5\nS'_ObjectId__id'\np6\n" b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb.") # We also test against a hardcoded "New" pickle format so that we # make sure we're backward compatible with the current version in # the future as well. pickled_with_1_10 = ( b"ccopy_reg\n_reconstructor\np0\n" b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" b"object\np2\nNtp3\nRp4\n" b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb.") if PY3: # Have to load using 'latin-1' since these were pickled in python2.x. oid_1_9 = pickle.loads(pickled_with_1_9, encoding='latin-1') oid_1_10 = pickle.loads(pickled_with_1_10, encoding='latin-1') else: oid_1_9 = pickle.loads(pickled_with_1_9) oid_1_10 = pickle.loads(pickled_with_1_10) self.assertEqual(oid_1_9, ObjectId("4d9a66561376c00b88000000")) self.assertEqual(oid_1_9, oid_1_10) def test_random_bytes(self): self.assertTrue(oid_generated_on_process(ObjectId())) def test_is_valid(self): self.assertFalse(ObjectId.is_valid(None)) self.assertFalse(ObjectId.is_valid(4)) self.assertFalse(ObjectId.is_valid(175.0)) self.assertFalse(ObjectId.is_valid({"test": 4})) self.assertFalse(ObjectId.is_valid(["something"])) self.assertFalse(ObjectId.is_valid("")) self.assertFalse(ObjectId.is_valid("12345678901")) self.assertFalse(ObjectId.is_valid("1234567890123")) self.assertTrue(ObjectId.is_valid(b"123456789012")) self.assertTrue(ObjectId.is_valid("123456789012123456789012")) def test_counter_overflow(self): # Spec-test to check counter overflows from max value to 0. ObjectId._inc = _MAX_COUNTER_VALUE ObjectId() self.assertEqual(ObjectId._inc, 0) def test_timestamp_values(self): # Spec-test to check timestamp field is interpreted correctly. TEST_DATA = { 0x00000000: (1970, 1, 1, 0, 0, 0), 0x7FFFFFFF: (2038, 1, 19, 3, 14, 7), 0x80000000: (2038, 1, 19, 3, 14, 8), 0xFFFFFFFF: (2106, 2, 7, 6, 28, 15), } def generate_objectid_with_timestamp(timestamp): oid = ObjectId() _, trailing_bytes = struct.unpack(">IQ", oid.binary) new_oid = struct.pack(">IQ", timestamp, trailing_bytes) return ObjectId(new_oid) for tstamp, exp_datetime_args in TEST_DATA.items(): oid = generate_objectid_with_timestamp(tstamp) # 32-bit platforms may overflow in datetime.fromtimestamp. if tstamp > 0x7FFFFFFF and sys.maxsize < 2**32: try: oid.generation_time except (OverflowError, ValueError): continue self.assertEqual( oid.generation_time, datetime.datetime(*exp_datetime_args, tzinfo=utc)) def test_random_regenerated_on_pid_change(self): # Test that change of pid triggers new random number generation. random_original = ObjectId._random() ObjectId._pid += 1 random_new = ObjectId._random() self.assertNotEqual(random_original, random_new) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_ocsp_cache.py000066400000000000000000000115741374256237000177640ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the pymongo ocsp_support module.""" from collections import namedtuple from datetime import datetime, timedelta from os import urandom import random import sys from time import sleep sys.path[0:0] = [""] from pymongo.ocsp_cache import _OCSPCache from test import unittest class TestOcspCache(unittest.TestCase): @classmethod def setUpClass(cls): cls.MockHashAlgorithm = namedtuple( "MockHashAlgorithm", ['name']) cls.MockOcspRequest = namedtuple( "MockOcspRequest", ['hash_algorithm', 'issuer_name_hash', 'issuer_key_hash', 'serial_number']) cls.MockOcspResponse = namedtuple( "MockOcspResponse", ["this_update", "next_update"]) def setUp(self): self.cache = _OCSPCache() def _create_mock_request(self): hash_algorithm = self.MockHashAlgorithm( random.choice(['sha1', 'md5', 'sha256'])) issuer_name_hash = urandom(8) issuer_key_hash = urandom(8) serial_number = random.randint(0, 10**10) return self.MockOcspRequest( hash_algorithm=hash_algorithm, issuer_name_hash=issuer_name_hash, issuer_key_hash=issuer_key_hash, serial_number=serial_number) def _create_mock_response(self, this_update_delta_seconds, next_update_delta_seconds): now = datetime.utcnow() this_update = now + timedelta(seconds=this_update_delta_seconds) if next_update_delta_seconds is not None: next_update = now + timedelta(seconds=next_update_delta_seconds) else: next_update = None return self.MockOcspResponse( this_update=this_update, next_update=next_update) def _add_mock_cache_entry(self, mock_request, mock_response): key = self.cache._get_cache_key(mock_request) self.cache._data[key] = mock_response def test_simple(self): # Start with 1 valid entry in the cache. request = self._create_mock_request() response = self._create_mock_response(-10, +3600) self._add_mock_cache_entry(request, response) # Ensure entry can be retrieved. self.assertEqual(self.cache[request], response) # Valid entries with an earlier next_update have no effect. response_1 = self._create_mock_response(-20, +1800) self.cache[request] = response_1 self.assertEqual(self.cache[request], response) # Invalid entries with a later this_update have no effect. response_2 = self._create_mock_response(+20, +1800) self.cache[request] = response_2 self.assertEqual(self.cache[request], response) # Invalid entries with passed next_update have no effect. response_3 = self._create_mock_response(-10, -5) self.cache[request] = response_3 self.assertEqual(self.cache[request], response) # Valid entries with a later next_update update the cache. response_new = self._create_mock_response(-5, +7200) self.cache[request] = response_new self.assertEqual(self.cache[request], response_new) # Entries with an unset next_update purge the cache. response_notset = self._create_mock_response(-5, None) self.cache[request] = response_notset with self.assertRaises(KeyError): _ = self.cache[request] def test_invalidate(self): # Start with 1 valid entry in the cache. request = self._create_mock_request() response = self._create_mock_response(-10, +0.25) self._add_mock_cache_entry(request, response) # Ensure entry can be retrieved. self.assertEqual(self.cache[request], response) # Wait for entry to become invalid and ensure KeyError is raised. sleep(0.5) with self.assertRaises(KeyError): _ = self.cache[request] def test_non_existent(self): # Start with 1 valid entry in the cache. request = self._create_mock_request() response = self._create_mock_response(-10, +10) self._add_mock_cache_entry(request, response) # Attempt to retrieve non-existent entry must raise KeyError. with self.assertRaises(KeyError): _ = self.cache[self._create_mock_request()] if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_pooling.py000066400000000000000000000356361374256237000173510ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test built in connection-pooling with threads.""" import gc import random import socket import sys import threading import time from pymongo import MongoClient from pymongo.errors import (AutoReconnect, ConnectionFailure, DuplicateKeyError, ExceededMaxWaiters) sys.path[0:0] = [""] from pymongo.pool import Pool, PoolOptions from pymongo.socket_checker import SocketChecker from test import client_context, unittest from test.utils import (get_pool, joinall, delay, rs_or_single_client) @client_context.require_connection def setUpModule(): pass N = 10 DB = "pymongo-pooling-tests" def gc_collect_until_done(threads, timeout=60): start = time.time() running = list(threads) while running: assert (time.time() - start) < timeout, "Threads timed out" for t in running: t.join(0.1) if not t.is_alive(): running.remove(t) gc.collect() class MongoThread(threading.Thread): """A thread that uses a MongoClient.""" def __init__(self, client): super(MongoThread, self).__init__() self.daemon = True # Don't hang whole test if thread hangs. self.client = client self.db = self.client[DB] self.passed = False def run(self): self.run_mongo_thread() self.passed = True def run_mongo_thread(self): raise NotImplementedError class InsertOneAndFind(MongoThread): def run_mongo_thread(self): for _ in range(N): rand = random.randint(0, N) _id = self.db.sf.insert_one({"x": rand}).inserted_id assert rand == self.db.sf.find_one(_id)["x"] class Unique(MongoThread): def run_mongo_thread(self): for _ in range(N): self.db.unique.insert_one({}) # no error class NonUnique(MongoThread): def run_mongo_thread(self): for _ in range(N): try: self.db.unique.insert_one({"_id": "jesse"}) except DuplicateKeyError: pass else: raise AssertionError("Should have raised DuplicateKeyError") class Disconnect(MongoThread): def run_mongo_thread(self): for _ in range(N): self.client.close() class SocketGetter(MongoThread): """Utility for TestPooling. Checks out a socket and holds it forever. Used in test_no_wait_queue_timeout, test_wait_queue_multiple, and test_no_wait_queue_multiple. """ def __init__(self, client, pool): super(SocketGetter, self).__init__(client) self.state = 'init' self.pool = pool self.sock = None def run_mongo_thread(self): self.state = 'get_socket' # Pass 'checkout' so we can hold the socket. with self.pool.get_socket({}, checkout=True) as sock: self.sock = sock self.state = 'sock' def __del__(self): if self.sock: self.sock.close_socket(None) def run_cases(client, cases): threads = [] n_runs = 5 for case in cases: for i in range(n_runs): t = case(client) t.start() threads.append(t) for t in threads: t.join() for t in threads: assert t.passed, "%s.run() threw an exception" % repr(t) class _TestPoolingBase(unittest.TestCase): """Base class for all connection-pool tests.""" def setUp(self): self.c = rs_or_single_client() db = self.c[DB] db.unique.drop() db.test.drop() db.unique.insert_one({"_id": "jesse"}) db.test.insert_many([{} for _ in range(10)]) def tearDown(self): self.c.close() def create_pool( self, pair=(client_context.host, client_context.port), *args, **kwargs): # Start the pool with the correct ssl options. pool_options = client_context.client._topology_settings.pool_options kwargs['ssl_context'] = pool_options.ssl_context kwargs['ssl_match_hostname'] = pool_options.ssl_match_hostname return Pool(pair, PoolOptions(*args, **kwargs)) class TestPooling(_TestPoolingBase): def test_max_pool_size_validation(self): host, port = client_context.host, client_context.port self.assertRaises( ValueError, MongoClient, host=host, port=port, maxPoolSize=-1) self.assertRaises( ValueError, MongoClient, host=host, port=port, maxPoolSize='foo') c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False) self.assertEqual(c.max_pool_size, 100) def test_no_disconnect(self): run_cases(self.c, [NonUnique, Unique, InsertOneAndFind]) def test_disconnect(self): run_cases(self.c, [InsertOneAndFind, Disconnect, Unique]) def test_pool_reuses_open_socket(self): # Test Pool's _check_closed() method doesn't close a healthy socket. cx_pool = self.create_pool(max_pool_size=10) cx_pool._check_interval_seconds = 0 # Always check. with cx_pool.get_socket({}) as sock_info: pass with cx_pool.get_socket({}) as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) def test_get_socket_and_exception(self): # get_socket() returns socket after a non-network error. cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) with self.assertRaises(ZeroDivisionError): with cx_pool.get_socket({}) as sock_info: 1 / 0 # Socket was returned, not closed. with cx_pool.get_socket({}) as new_sock_info: self.assertEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) def test_pool_removes_closed_socket(self): # Test that Pool removes explicitly closed socket. cx_pool = self.create_pool() with cx_pool.get_socket({}) as sock_info: # Use SocketInfo's API to close the socket. sock_info.close_socket(None) self.assertEqual(0, len(cx_pool.sockets)) def test_pool_removes_dead_socket(self): # Test that Pool removes dead socket and the socket doesn't return # itself PYTHON-344 cx_pool = self.create_pool(max_pool_size=1, wait_queue_timeout=1) cx_pool._check_interval_seconds = 0 # Always check. with cx_pool.get_socket({}) as sock_info: # Simulate a closed socket without telling the SocketInfo it's # closed. sock_info.sock.close() self.assertTrue(sock_info.socket_closed()) with cx_pool.get_socket({}) as new_sock_info: self.assertEqual(0, len(cx_pool.sockets)) self.assertNotEqual(sock_info, new_sock_info) self.assertEqual(1, len(cx_pool.sockets)) # Semaphore was released. with cx_pool.get_socket({}): pass def test_socket_closed(self): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.connect((client_context.host, client_context.port)) socket_checker = SocketChecker() self.assertFalse(socket_checker.socket_closed(s)) s.close() self.assertTrue(socket_checker.socket_closed(s)) def test_return_socket_after_reset(self): pool = self.create_pool() with pool.get_socket({}) as sock: pool.reset() self.assertTrue(sock.closed) self.assertEqual(0, len(pool.sockets)) def test_pool_check(self): # Test that Pool recovers from two connection failures in a row. # This exercises code at the end of Pool._check(). cx_pool = self.create_pool(max_pool_size=1, connect_timeout=1, wait_queue_timeout=1) cx_pool._check_interval_seconds = 0 # Always check. self.addCleanup(cx_pool.close) with cx_pool.get_socket({}) as sock_info: # Simulate a closed socket without telling the SocketInfo it's # closed. sock_info.sock.close() # Swap pool's address with a bad one. address, cx_pool.address = cx_pool.address, ('foo.com', 1234) with self.assertRaises(AutoReconnect): with cx_pool.get_socket({}): pass # Back to normal, semaphore was correctly released. cx_pool.address = address with cx_pool.get_socket({}, checkout=True) as sock_info: pass sock_info.close_socket(None) def test_wait_queue_timeout(self): wait_queue_timeout = 2 # Seconds pool = self.create_pool( max_pool_size=1, wait_queue_timeout=wait_queue_timeout) self.addCleanup(pool.close) with pool.get_socket({}) as sock_info: start = time.time() with self.assertRaises(ConnectionFailure): with pool.get_socket({}): pass duration = time.time() - start self.assertTrue( abs(wait_queue_timeout - duration) < 1, "Waited %.2f seconds for a socket, expected %f" % ( duration, wait_queue_timeout)) def test_no_wait_queue_timeout(self): # Verify get_socket() with no wait_queue_timeout blocks forever. pool = self.create_pool(max_pool_size=1) self.addCleanup(pool.close) # Reach max_size. with pool.get_socket({}) as s1: t = SocketGetter(self.c, pool) t.start() while t.state != 'get_socket': time.sleep(0.1) time.sleep(1) self.assertEqual(t.state, 'get_socket') while t.state != 'sock': time.sleep(0.1) self.assertEqual(t.state, 'sock') self.assertEqual(t.sock, s1) def test_wait_queue_multiple(self): wait_queue_multiple = 3 pool = self.create_pool( max_pool_size=2, wait_queue_multiple=wait_queue_multiple) # Reach max_size sockets. with pool.get_socket({}): with pool.get_socket({}): # Reach max_size * wait_queue_multiple waiters. threads = [] for _ in range(6): t = SocketGetter(self.c, pool) t.start() threads.append(t) time.sleep(1) for t in threads: self.assertEqual(t.state, 'get_socket') with self.assertRaises(ExceededMaxWaiters): with pool.get_socket({}): pass def test_no_wait_queue_multiple(self): pool = self.create_pool(max_pool_size=2) socks = [] for _ in range(2): # Pass 'checkout' so we can hold the socket. with pool.get_socket({}, checkout=True) as sock: socks.append(sock) threads = [] for _ in range(30): t = SocketGetter(self.c, pool) t.start() threads.append(t) time.sleep(1) for t in threads: self.assertEqual(t.state, 'get_socket') for socket_info in socks: socket_info.close_socket(None) class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 c = rs_or_single_client(maxPoolSize=max_pool_size) collection = c[DB].test # Need one document. collection.drop() collection.insert_one({}) # nthreads had better be much larger than max_pool_size to ensure that # max_pool_size sockets are actually required at some point in this # test's execution. cx_pool = get_pool(c) nthreads = 10 threads = [] lock = threading.Lock() self.n_passed = 0 def f(): for _ in range(5): collection.find_one({'$where': delay(0.1)}) assert len(cx_pool.sockets) <= max_pool_size with lock: self.n_passed += 1 for i in range(nthreads): t = threading.Thread(target=f) threads.append(t) t.start() joinall(threads) self.assertEqual(nthreads, self.n_passed) self.assertTrue(len(cx_pool.sockets) > 1) self.assertEqual(max_pool_size, cx_pool._socket_semaphore.counter) def test_max_pool_size_none(self): c = rs_or_single_client(maxPoolSize=None) collection = c[DB].test # Need one document. collection.drop() collection.insert_one({}) cx_pool = get_pool(c) nthreads = 10 threads = [] lock = threading.Lock() self.n_passed = 0 def f(): for _ in range(5): collection.find_one({'$where': delay(0.1)}) with lock: self.n_passed += 1 for i in range(nthreads): t = threading.Thread(target=f) threads.append(t) t.start() joinall(threads) self.assertEqual(nthreads, self.n_passed) self.assertTrue(len(cx_pool.sockets) > 1) def test_max_pool_size_zero(self): with self.assertRaises(ValueError): rs_or_single_client(maxPoolSize=0) def test_max_pool_size_with_connection_failure(self): # The pool acquires its semaphore before attempting to connect; ensure # it releases the semaphore on connection failure. test_pool = Pool( ('somedomainthatdoesntexist.org', 27017), PoolOptions( max_pool_size=1, connect_timeout=1, socket_timeout=1, wait_queue_timeout=1)) # First call to get_socket fails; if pool doesn't release its semaphore # then the second call raises "ConnectionFailure: Timed out waiting for # socket from pool" instead of AutoReconnect. for i in range(2): with self.assertRaises(AutoReconnect) as context: with test_pool.get_socket({}, checkout=True): pass # Testing for AutoReconnect instead of ConnectionFailure, above, # is sufficient right *now* to catch a semaphore leak. But that # seems error-prone, so check the message too. self.assertNotIn('waiting for socket from pool', str(context.exception)) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_pymongo.py000066400000000000000000000017441374256237000173630ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the pymongo module itself.""" import sys sys.path[0:0] = [""] import pymongo from test import unittest class TestPyMongo(unittest.TestCase): def test_mongo_client_alias(self): # Testing that pymongo module imports mongo_client.MongoClient self.assertEqual(pymongo.MongoClient, pymongo.mongo_client.MongoClient) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_raw_bson.py000066400000000000000000000165201374256237000175030ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import uuid from bson import decode, encode from bson.binary import Binary, JAVA_LEGACY from bson.codec_options import CodecOptions from bson.errors import InvalidBSON from bson.raw_bson import RawBSONDocument, DEFAULT_RAW_BSON_OPTIONS from bson.son import SON from test import client_context, unittest from test.test_client import IntegrationTest class TestRawBSONDocument(IntegrationTest): # {u'_id': ObjectId('556df68b6e32ab21a95e0785'), # u'name': u'Sherlock', # u'addresses': [{u'street': u'Baker Street'}]} bson_string = ( b'Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t' b'\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e' b'\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00' ) document = RawBSONDocument(bson_string) @classmethod def setUpClass(cls): super(TestRawBSONDocument, cls).setUpClass() cls.client = client_context.client def tearDown(self): if client_context.connected: self.client.pymongo_test.test_raw.drop() def test_decode(self): self.assertEqual('Sherlock', self.document['name']) first_address = self.document['addresses'][0] self.assertIsInstance(first_address, RawBSONDocument) self.assertEqual('Baker Street', first_address['street']) def test_raw(self): self.assertEqual(self.bson_string, self.document.raw) def test_empty_doc(self): doc = RawBSONDocument(encode({})) with self.assertRaises(KeyError): doc['does-not-exist'] def test_invalid_bson_sequence(self): bson_byte_sequence = encode({'a': 1})+encode({}) with self.assertRaisesRegex(InvalidBSON, 'invalid object length'): RawBSONDocument(bson_byte_sequence) def test_invalid_bson_eoo(self): invalid_bson_eoo = encode({'a': 1})[:-1] + b'\x01' with self.assertRaisesRegex(InvalidBSON, 'bad eoo'): RawBSONDocument(invalid_bson_eoo) @client_context.require_connection def test_round_trip(self): db = self.client.get_database( 'pymongo_test', codec_options=CodecOptions(document_class=RawBSONDocument)) db.test_raw.insert_one(self.document) result = db.test_raw.find_one(self.document['_id']) self.assertIsInstance(result, RawBSONDocument) self.assertEqual(dict(self.document.items()), dict(result.items())) @client_context.require_connection def test_round_trip_raw_uuid(self): coll = self.client.get_database('pymongo_test').test_raw uid = uuid.uuid4() doc = {'_id': 1, 'bin4': Binary(uid.bytes, 4), 'bin3': Binary(uid.bytes, 3)} raw = RawBSONDocument(encode(doc)) coll.insert_one(raw) self.assertEqual(coll.find_one(), {'_id': 1, 'bin4': uid, 'bin3': uid}) # Test that the raw bytes haven't changed. raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS) self.assertEqual(raw_coll.find_one(), raw) def test_with_codec_options(self): # {u'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), # u'_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} # encoded with JAVA_LEGACY uuid representation. bson_string = ( b'-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02' b'\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM' b'\x01\x00\x00\x00' ) document = RawBSONDocument( bson_string, codec_options=CodecOptions(uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument)) self.assertEqual(uuid.UUID('026fab8f-975f-4965-9fbf-85ad874c60ff'), document['_id']) @client_context.require_connection def test_round_trip_codec_options(self): doc = { 'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), '_id': uuid.UUID('026fab8f-975f-4965-9fbf-85ad874c60ff') } db = self.client.pymongo_test coll = db.get_collection( 'test_raw', codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)) coll.insert_one(doc) raw_java_legacy = CodecOptions(uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument) coll = db.get_collection('test_raw', codec_options=raw_java_legacy) self.assertEqual( RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), coll.find_one()) @client_context.require_connection def test_raw_bson_document_embedded(self): doc = {'embedded': self.document} db = self.client.pymongo_test db.test_raw.insert_one(doc) result = db.test_raw.find_one() self.assertEqual(decode(self.document.raw), result['embedded']) # Make sure that CodecOptions are preserved. # {'embedded': [ # {u'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), # u'_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} # ]} # encoded with JAVA_LEGACY uuid representation. bson_string = ( b'D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00' b'\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00' b'\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00' b'\x00' ) rbd = RawBSONDocument( bson_string, codec_options=CodecOptions(uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument)) db.test_raw.drop() db.test_raw.insert_one(rbd) result = db.get_collection('test_raw', codec_options=CodecOptions( uuid_representation=JAVA_LEGACY)).find_one() self.assertEqual(rbd['embedded'][0]['_id'], result['embedded'][0]['_id']) @client_context.require_connection def test_write_response_raw_bson(self): coll = self.client.get_database( 'pymongo_test', codec_options=CodecOptions(document_class=RawBSONDocument)).test_raw # No Exceptions raised while handling write response. coll.insert_one(self.document) coll.delete_one(self.document) coll.insert_many([self.document]) coll.delete_many(self.document) coll.update_one(self.document, {'$set': {'a': 'b'}}, upsert=True) coll.update_many(self.document, {'$set': {'b': 'c'}}) def test_preserve_key_ordering(self): keyvaluepairs = [('a', 1), ('b', 2), ('c', 3),] rawdoc = RawBSONDocument(encode(SON(keyvaluepairs))) for rkey, elt in zip(rawdoc, keyvaluepairs): self.assertEqual(rkey, elt[0]) pymongo-3.11.0/test/test_read_concern.py000066400000000000000000000134051374256237000203120ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the read_concern module.""" from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure from pymongo.read_concern import ReadConcern from test import client_context, PyMongoTestCase from test.utils import single_client, rs_or_single_client, OvertCommandListener class TestReadConcern(PyMongoTestCase): @classmethod @client_context.require_connection def setUpClass(cls): cls.listener = OvertCommandListener() cls.client = single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test client_context.client.pymongo_test.create_collection('coll') @classmethod def tearDownClass(cls): cls.client.close() client_context.client.pymongo_test.drop_collection('coll') def tearDown(self): self.listener.results.clear() def test_read_concern(self): rc = ReadConcern() self.assertIsNone(rc.level) self.assertTrue(rc.ok_for_legacy) rc = ReadConcern('majority') self.assertEqual('majority', rc.level) self.assertFalse(rc.ok_for_legacy) rc = ReadConcern('local') self.assertEqual('local', rc.level) self.assertTrue(rc.ok_for_legacy) self.assertRaises(TypeError, ReadConcern, 42) def test_read_concern_uri(self): uri = 'mongodb://%s/?readConcernLevel=majority' % ( client_context.pair,) client = rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern('majority'), client.read_concern) @client_context.require_version_max(3, 1) def test_invalid_read_concern(self): coll = self.db.get_collection( 'coll', read_concern=ReadConcern('majority')) self.assertRaisesRegexp( ConfigurationError, 'read concern level of majority is not valid ' 'with a max wire version of [0-3]', coll.count) @client_context.require_version_min(3, 1, 9, -1) def test_find_command(self): # readConcern not sent in command if not specified. coll = self.db.coll tuple(coll.find({'field': 'value'})) self.assertNotIn('readConcern', self.listener.results['started'][0].command) self.listener.results.clear() # Explicitly set readConcern to 'local'. coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) tuple(coll.find({'field': 'value'})) self.assertEqualCommand( SON([('find', 'coll'), ('filter', {'field': 'value'}), ('readConcern', {'level': 'local'})]), self.listener.results['started'][0].command) @client_context.require_version_min(3, 1, 9, -1) def test_command_cursor(self): # readConcern not sent in command if not specified. coll = self.db.coll tuple(coll.aggregate([{'$match': {'field': 'value'}}])) self.assertNotIn('readConcern', self.listener.results['started'][0].command) self.listener.results.clear() # Explicitly set readConcern to 'local'. coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) tuple(coll.aggregate([{'$match': {'field': 'value'}}])) self.assertEqual( {'level': 'local'}, self.listener.results['started'][0].command['readConcern']) def test_aggregate_out(self): coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) tuple(coll.aggregate([{'$match': {'field': 'value'}}, {'$out': 'output_collection'}])) # Aggregate with $out supports readConcern MongoDB 4.2 onwards. if client_context.version >= (4, 1): self.assertIn('readConcern', self.listener.results['started'][0].command) else: self.assertNotIn('readConcern', self.listener.results['started'][0].command) def test_map_reduce_out(self): coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) coll.map_reduce('function() { emit(this._id, this.value); }', 'function(key, values) { return 42; }', out='output_collection') self.assertNotIn('readConcern', self.listener.results['started'][0].command) if client_context.version.at_least(3, 1, 9, -1): self.listener.results.clear() coll.map_reduce( 'function() { emit(this._id, this.value); }', 'function(key, values) { return 42; }', out={'inline': 1}) self.assertEqual( {'level': 'local'}, self.listener.results['started'][0].command['readConcern']) @client_context.require_version_min(3, 1, 9, -1) def test_inline_map_reduce(self): coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) tuple(coll.inline_map_reduce( 'function() { emit(this._id, this.value); }', 'function(key, values) { return 42; }')) self.assertEqual( {'level': 'local'}, self.listener.results['started'][0].command['readConcern']) pymongo-3.11.0/test/test_read_preferences.py000066400000000000000000000675061374256237000211770ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the replica_set_connection module.""" import contextlib import copy import pickle import random import sys import warnings sys.path[0:0] = [""] from bson.py3compat import MAXSIZE from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure from pymongo.message import _maybe_add_read_preference from pymongo.mongo_client import MongoClient from pymongo.read_preferences import (ReadPreference, MovingAverage, Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, Selection from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern from test import (SkipTest, client_context, unittest, db_user, db_pwd) from test.test_replica_set_client import TestReplicaSetClientBase from test.utils import (connected, ignore_deprecations, one, OvertCommandListener, rs_client, single_client, wait_until) from test.version import Version class TestSelections(unittest.TestCase): @client_context.require_connection def test_bool(self): client = single_client() wait_until(lambda: client.address, "discover primary") selection = Selection.from_topology_description( client._topology.description) self.assertTrue(selection) self.assertFalse(selection.with_server_descriptions([])) class TestReadPreferenceObjects(unittest.TestCase): prefs = [Primary(), PrimaryPreferred(), Secondary(), Nearest(tag_sets=[{'a': 1}, {'b': 2}]), SecondaryPreferred(max_staleness=30)] def test_pickle(self): for pref in self.prefs: self.assertEqual(pref, pickle.loads(pickle.dumps(pref))) def test_copy(self): for pref in self.prefs: self.assertEqual(pref, copy.copy(pref)) def test_deepcopy(self): for pref in self.prefs: self.assertEqual(pref, copy.deepcopy(pref)) class TestReadPreferencesBase(TestReplicaSetClientBase): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): super(TestReadPreferencesBase, cls).setUpClass() def setUp(self): super(TestReadPreferencesBase, self).setUp() # Insert some data so we can use cursors in read_from_which_host self.client.pymongo_test.test.drop() self.client.get_database( "pymongo_test", write_concern=WriteConcern(w=self.w)).test.insert_many( [{'_id': i} for i in range(10)]) self.addCleanup(self.client.pymongo_test.test.drop) def read_from_which_host(self, client): """Do a find() on the client and return which host was used """ cursor = client.pymongo_test.test.find() next(cursor) return cursor.address def read_from_which_kind(self, client): """Do a find() on the client and return 'primary' or 'secondary' depending on which the client used. """ address = self.read_from_which_host(client) if address == client.primary: return 'primary' elif address in client.secondaries: return 'secondary' else: self.fail( 'Cursor used address %s, expected either primary ' '%s or secondaries %s' % ( address, client.primary, client.secondaries)) def assertReadsFrom(self, expected, **kwargs): c = rs_client(**kwargs) wait_until( lambda: len(c.nodes - c.arbiters) == self.w, "discovered all nodes") used = self.read_from_which_kind(c) self.assertEqual(expected, used, 'Cursor used %s, expected %s' % ( used, expected)) class TestSingleSlaveOk(TestReadPreferencesBase): def test_reads_from_secondary(self): host, port = next(iter(self.client.secondaries)) # Direct connection to a secondary. client = single_client(host, port) self.assertFalse(client.is_primary) # Regardless of read preference, we should be able to do # "reads" with a direct connection to a secondary. # See server-selection.rst#topology-type-single. self.assertEqual(client.read_preference, ReadPreference.PRIMARY) db = client.pymongo_test coll = db.test # Test find and find_one. self.assertIsNotNone(coll.find_one()) self.assertEqual(10, len(list(coll.find()))) # Test some database helpers. self.assertIsNotNone(db.collection_names()) self.assertIsNotNone(db.list_collection_names()) self.assertIsNotNone(db.validate_collection("test")) self.assertIsNotNone(db.command("ping")) # Test some collection helpers. self.assertEqual(10, coll.count_documents({})) self.assertEqual(10, len(coll.distinct("_id"))) self.assertIsNotNone(coll.aggregate([])) self.assertIsNotNone(coll.index_information()) # Test some "magic" namespace helpers. self.assertIsNotNone(db.current_op()) class TestReadPreferences(TestReadPreferencesBase): def test_mode_validation(self): for mode in (ReadPreference.PRIMARY, ReadPreference.PRIMARY_PREFERRED, ReadPreference.SECONDARY, ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST): self.assertEqual( mode, rs_client(read_preference=mode).read_preference) self.assertRaises( TypeError, rs_client, read_preference='foo') def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) self.assertEqual( [{}], rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{'k': 'v'}]) self.assertEqual( [{'k': 'v'}], rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{'k': 'v'}, {}]) self.assertEqual( [{'k': 'v'}, {}], rs_client(read_preference=S).read_preference.tag_sets) self.assertRaises(ValueError, Secondary, tag_sets=[]) # One dict not ok, must be a list of dicts self.assertRaises(TypeError, Secondary, tag_sets={'k': 'v'}) self.assertRaises(TypeError, Secondary, tag_sets='foo') self.assertRaises(TypeError, Secondary, tag_sets=['foo']) def test_threshold_validation(self): self.assertEqual(17, rs_client( localThresholdMS=17 ).local_threshold_ms) self.assertEqual(42, rs_client( localThresholdMS=42 ).local_threshold_ms) self.assertEqual(666, rs_client( localthresholdms=666 ).local_threshold_ms) self.assertEqual(0, rs_client( localthresholdms=0 ).local_threshold_ms) self.assertRaises(ValueError, rs_client, localthresholdms=-1) def test_zero_latency(self): ping_times = set() # Generate unique ping times. while len(ping_times) < len(self.client.nodes): ping_times.add(random.random()) for ping_time, host in zip(ping_times, self.client.nodes): ServerDescription._host_to_round_trip_time[host] = ping_time try: client = connected( rs_client(readPreference='nearest', localThresholdMS=0)) wait_until( lambda: client.nodes == self.client.nodes, "discovered all nodes") host = self.read_from_which_host(client) for _ in range(5): self.assertEqual(host, self.read_from_which_host(client)) finally: ServerDescription._host_to_round_trip_time.clear() def test_primary(self): self.assertReadsFrom( 'primary', read_preference=ReadPreference.PRIMARY) def test_primary_with_tags(self): # Tags not allowed with PRIMARY self.assertRaises( ConfigurationError, rs_client, tag_sets=[{'dc': 'ny'}]) def test_primary_preferred(self): self.assertReadsFrom( 'primary', read_preference=ReadPreference.PRIMARY_PREFERRED) def test_secondary(self): self.assertReadsFrom( 'secondary', read_preference=ReadPreference.SECONDARY) def test_secondary_preferred(self): self.assertReadsFrom( 'secondary', read_preference=ReadPreference.SECONDARY_PREFERRED) def test_nearest(self): # With high localThresholdMS, expect to read from any # member c = rs_client( read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds data_members = set(self.hosts).difference(set(self.arbiters)) # This is a probabilistic test; track which members we've read from so # far, and keep reading until we've used all the members or give up. # Chance of using only 2 of 3 members 10k times if there's no bug = # 3 * (2/3)**10000, very low. used = set() i = 0 while data_members.difference(used) and i < 10000: address = self.read_from_which_host(c) used.add(address) i += 1 not_used = data_members.difference(used) latencies = ', '.join( '%s: %dms' % (server.description.address, server.description.round_trip_time) for server in c._get_topology().select_servers( readable_server_selector)) self.assertFalse( not_used, "Expected to use primary and all secondaries for mode NEAREST," " but didn't use %s\nlatencies: %s" % (not_used, latencies)) class ReadPrefTester(MongoClient): def __init__(self, *args, **kwargs): self.has_read_from = set() client_options = client_context.default_client_options.copy() client_options.update(kwargs) super(ReadPrefTester, self).__init__(*args, **client_options) @contextlib.contextmanager def _socket_for_reads(self, read_preference, session): context = super(ReadPrefTester, self)._socket_for_reads( read_preference, session) with context as (sock_info, slave_ok): self.record_a_read(sock_info.address) yield sock_info, slave_ok @contextlib.contextmanager def _slaveok_for_server(self, read_preference, server, session, exhaust=False): context = super(ReadPrefTester, self)._slaveok_for_server( read_preference, server, session, exhaust=exhaust) with context as (sock_info, slave_ok): self.record_a_read(sock_info.address) yield sock_info, slave_ok def record_a_read(self, address): server = self._get_topology().select_server_by_address(address, 0) self.has_read_from.add(server) _PREF_MAP = [ (Primary, SERVER_TYPE.RSPrimary), (PrimaryPreferred, SERVER_TYPE.RSPrimary), (Secondary, SERVER_TYPE.RSSecondary), (SecondaryPreferred, SERVER_TYPE.RSSecondary), (Nearest, 'any') ] class TestCommandAndReadPreference(TestReplicaSetClientBase): @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): super(TestCommandAndReadPreference, cls).setUpClass() cls.c = ReadPrefTester( client_context.pair, replicaSet=cls.name, # Ignore round trip times, to test ReadPreference modes only. localThresholdMS=1000*1000) if client_context.auth_enabled: cls.c.admin.authenticate(db_user, db_pwd) cls.client_version = Version.from_client(cls.c) # mapReduce and group fail with no collection coll = cls.c.pymongo_test.get_collection( 'test', write_concern=WriteConcern(w=cls.w)) coll.insert_one({}) @classmethod def tearDownClass(cls): cls.c.drop_database('pymongo_test') cls.c.close() def executed_on_which_server(self, client, fn, *args, **kwargs): """Execute fn(*args, **kwargs) and return the Server instance used.""" client.has_read_from.clear() fn(*args, **kwargs) self.assertEqual(1, len(client.has_read_from)) return one(client.has_read_from) def assertExecutedOn(self, server_type, client, fn, *args, **kwargs): server = self.executed_on_which_server(client, fn, *args, **kwargs) self.assertEqual(SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type]) def _test_fn(self, server_type, fn): for _ in range(10): if server_type == 'any': used = set() for _ in range(1000): server = self.executed_on_which_server(self.c, fn) used.add(server.description.address) if len(used) == len(self.c.secondaries) + 1: # Success break unused = self.c.secondaries.union( set([self.c.primary]) ).difference(used) if unused: self.fail( "Some members not used for NEAREST: %s" % ( unused)) else: self.assertExecutedOn(server_type, self.c, fn) def _test_primary_helper(self, func): # Helpers that ignore read preference. self._test_fn(SERVER_TYPE.RSPrimary, func) def _test_coll_helper(self, secondary_ok, coll, meth, *args, **kwargs): for mode, server_type in _PREF_MAP: new_coll = coll.with_options(read_preference=mode()) func = lambda: getattr(new_coll, meth)(*args, **kwargs) if secondary_ok: self._test_fn(server_type, func) else: self._test_fn(SERVER_TYPE.RSPrimary, func) def test_command(self): # Test that the generic command helper obeys the read preference # passed to it. for mode, server_type in _PREF_MAP: func = lambda: self.c.pymongo_test.command('dbStats', read_preference=mode()) self._test_fn(server_type, func) def test_create_collection(self): # create_collection runs listCollections on the primary to check if # the collection already exists. self._test_primary_helper( lambda: self.c.pymongo_test.create_collection( 'some_collection%s' % random.randint(0, MAXSIZE))) @client_context.require_version_max(4, 1, 0, -1) def test_group(self): with warnings.catch_warnings(): warnings.simplefilter("ignore") self._test_coll_helper(True, self.c.pymongo_test.test, 'group', {'a': 1}, {}, {}, 'function() { }') def test_map_reduce(self): self._test_coll_helper(False, self.c.pymongo_test.test, 'map_reduce', 'function() { }', 'function() { }', {'inline': 1}) def test_inline_map_reduce(self): self._test_coll_helper(True, self.c.pymongo_test.test, 'inline_map_reduce', 'function() { }', 'function() { }') @ignore_deprecations def test_count(self): self._test_coll_helper(True, self.c.pymongo_test.test, 'count') def test_count_documents(self): self._test_coll_helper( True, self.c.pymongo_test.test, 'count_documents', {}) def test_estimated_document_count(self): self._test_coll_helper( True, self.c.pymongo_test.test, 'estimated_document_count') def test_distinct(self): self._test_coll_helper(True, self.c.pymongo_test.test, 'distinct', 'a') def test_aggregate(self): self._test_coll_helper(True, self.c.pymongo_test.test, 'aggregate', [{'$project': {'_id': 1}}]) def test_aggregate_write(self): self._test_coll_helper(False, self.c.pymongo_test.test, 'aggregate', [{'$project': {'_id': 1}}, {'$out': "agg_write_test"}]) class TestMovingAverage(unittest.TestCase): def test_moving_average(self): avg = MovingAverage() self.assertIsNone(avg.get()) avg.add_sample(10) self.assertAlmostEqual(10, avg.get()) avg.add_sample(20) self.assertAlmostEqual(12, avg.get()) avg.add_sample(30) self.assertAlmostEqual(15.6, avg.get()) class TestMongosAndReadPreference(unittest.TestCase): def test_read_preference_document(self): pref = Primary() self.assertEqual( pref.document, {'mode': 'primary'}) pref = PrimaryPreferred() self.assertEqual( pref.document, {'mode': 'primaryPreferred'}) pref = PrimaryPreferred(tag_sets=[{'dc': 'sf'}]) self.assertEqual( pref.document, {'mode': 'primaryPreferred', 'tags': [{'dc': 'sf'}]}) pref = PrimaryPreferred( tag_sets=[{'dc': 'sf'}], max_staleness=30) self.assertEqual( pref.document, {'mode': 'primaryPreferred', 'tags': [{'dc': 'sf'}], 'maxStalenessSeconds': 30}) pref = Secondary() self.assertEqual( pref.document, {'mode': 'secondary'}) pref = Secondary(tag_sets=[{'dc': 'sf'}]) self.assertEqual( pref.document, {'mode': 'secondary', 'tags': [{'dc': 'sf'}]}) pref = Secondary( tag_sets=[{'dc': 'sf'}], max_staleness=30) self.assertEqual( pref.document, {'mode': 'secondary', 'tags': [{'dc': 'sf'}], 'maxStalenessSeconds': 30}) pref = SecondaryPreferred() self.assertEqual( pref.document, {'mode': 'secondaryPreferred'}) pref = SecondaryPreferred(tag_sets=[{'dc': 'sf'}]) self.assertEqual( pref.document, {'mode': 'secondaryPreferred', 'tags': [{'dc': 'sf'}]}) pref = SecondaryPreferred( tag_sets=[{'dc': 'sf'}], max_staleness=30) self.assertEqual( pref.document, {'mode': 'secondaryPreferred', 'tags': [{'dc': 'sf'}], 'maxStalenessSeconds': 30}) pref = Nearest() self.assertEqual( pref.document, {'mode': 'nearest'}) pref = Nearest(tag_sets=[{'dc': 'sf'}]) self.assertEqual( pref.document, {'mode': 'nearest', 'tags': [{'dc': 'sf'}]}) pref = Nearest( tag_sets=[{'dc': 'sf'}], max_staleness=30) self.assertEqual( pref.document, {'mode': 'nearest', 'tags': [{'dc': 'sf'}], 'maxStalenessSeconds': 30}) with self.assertRaises(TypeError): Nearest(max_staleness=1.5) # Float is prohibited. with self.assertRaises(ValueError): Nearest(max_staleness=0) with self.assertRaises(ValueError): Nearest(max_staleness=-2) def test_read_preference_document_hedge(self): cases = { 'primaryPreferred': PrimaryPreferred, 'secondary': Secondary, 'secondaryPreferred': SecondaryPreferred, 'nearest': Nearest, } for mode, cls in cases.items(): with self.assertRaises(TypeError): cls(hedge=[]) pref = cls(hedge={}) self.assertEqual(pref.document, {'mode': mode}) out = _maybe_add_read_preference({}, pref) if cls == SecondaryPreferred: # SecondaryPreferred without hedge doesn't add $readPreference. self.assertEqual(out, {}) else: self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) hedge = {'enabled': True} pref = cls(hedge=hedge) self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) hedge = {'enabled': False} pref = cls(hedge=hedge) self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) hedge = {'enabled': False, 'extra': 'option'} pref = cls(hedge=hedge) self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) # Require OP_MSG so that $readPreference is visible in the command event. @client_context.require_version_min(3, 6) def test_send_hedge(self): cases = { 'primaryPreferred': PrimaryPreferred, 'secondary': Secondary, 'secondaryPreferred': SecondaryPreferred, 'nearest': Nearest, } listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) client.admin.command('ping') for mode, cls in cases.items(): pref = cls(hedge={'enabled': True}) coll = client.test.get_collection('test', read_preference=pref) listener.reset() coll.find_one() started = listener.results['started'] self.assertEqual(len(started), 1, started) cmd = started[0].command self.assertIn('$readPreference', cmd) self.assertEqual(cmd['$readPreference'], pref.document) def test_maybe_add_read_preference(self): # Primary doesn't add $readPreference out = _maybe_add_read_preference({}, Primary()) self.assertEqual(out, {}) pref = PrimaryPreferred() out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = PrimaryPreferred(tag_sets=[{'dc': 'nyc'}]) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Secondary() out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Secondary(tag_sets=[{'dc': 'nyc'}]) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) # SecondaryPreferred without tag_sets or max_staleness doesn't add # $readPreference pref = SecondaryPreferred() out = _maybe_add_read_preference({}, pref) self.assertEqual(out, {}) pref = SecondaryPreferred(tag_sets=[{'dc': 'nyc'}]) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = SecondaryPreferred(max_staleness=120) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Nearest() out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Nearest(tag_sets=[{'dc': 'nyc'}]) out = _maybe_add_read_preference({}, pref) self.assertEqual( out, SON([("$query", {}), ("$readPreference", pref.document)])) criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))]) pref = Nearest() out = _maybe_add_read_preference(criteria, pref) self.assertEqual( out, SON([("$query", {}), ("$orderby", SON([("_id", 1)])), ("$readPreference", pref.document)])) pref = Nearest(tag_sets=[{'dc': 'nyc'}]) out = _maybe_add_read_preference(criteria, pref) self.assertEqual( out, SON([("$query", {}), ("$orderby", SON([("_id", 1)])), ("$readPreference", pref.document)])) @client_context.require_mongos def test_mongos(self): shard = client_context.client.config.shards.find_one()['host'] num_members = shard.count(',') + 1 if num_members == 1: raise SkipTest("Need a replica set shard to test.") coll = client_context.client.pymongo_test.get_collection( "test", write_concern=WriteConcern(w=num_members)) coll.drop() res = coll.insert_many([{} for _ in range(5)]) first_id = res.inserted_ids[0] last_id = res.inserted_ids[-1] # Note - this isn't a perfect test since there's no way to # tell what shard member a query ran on. for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): qcoll = coll.with_options(read_preference=pref) results = list(qcoll.find().sort([("_id", 1)])) self.assertEqual(first_id, results[0]["_id"]) self.assertEqual(last_id, results[-1]["_id"]) results = list(qcoll.find().sort([("_id", -1)])) self.assertEqual(first_id, results[-1]["_id"]) self.assertEqual(last_id, results[0]["_id"]) @client_context.require_mongos @client_context.require_version_min(3, 3, 12) def test_mongos_max_staleness(self): # Sanity check that we're sending maxStalenessSeconds coll = client_context.client.pymongo_test.get_collection( "test", read_preference=SecondaryPreferred(max_staleness=120)) # No error coll.find_one() coll = client_context.client.pymongo_test.get_collection( "test", read_preference=SecondaryPreferred(max_staleness=10)) try: coll.find_one() except OperationFailure as exc: self.assertEqual(160, exc.code) else: self.fail("mongos accepted invalid staleness") coll = single_client( readPreference='secondaryPreferred', maxStalenessSeconds=120).pymongo_test.test # No error coll.find_one() coll = single_client( readPreference='secondaryPreferred', maxStalenessSeconds=10).pymongo_test.test try: coll.find_one() except OperationFailure as exc: self.assertEqual(160, exc.code) else: self.fail("mongos accepted invalid staleness") if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_read_write_concern_spec.py000066400000000000000000000320161374256237000225350ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the read and write concern tests.""" import json import os import sys import warnings sys.path[0:0] = [""] from pymongo import DESCENDING from pymongo.errors import (BulkWriteError, ConfigurationError, WTimeoutError, WriteConcernError) from pymongo.mongo_client import MongoClient from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern from test import (client_context, IntegrationTest, unittest) from test.utils import (EventListener, disable_replication, enable_replication, rs_or_single_client, TestCreator) from test.utils_spec_runner import SpecRunner _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'read_write_concern') class TestReadWriteConcernSpec(IntegrationTest): def test_omit_default_read_write_concern(self): listener = EventListener() # Client with default readConcern and writeConcern client = rs_or_single_client(event_listeners=[listener]) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) self.addCleanup(collection.drop) self.addCleanup(client.pymongo_test.collection2.drop) # Commands MUST NOT send the default read/write concern to the server. def rename_and_drop(): # Ensure collection exists. collection.insert_one({}) collection.rename('collection2') client.pymongo_test.collection2.drop() def insert_command_default_write_concern(): collection.database.command( 'insert', 'collection', documents=[{}], write_concern=WriteConcern()) ops = [ ('aggregate', lambda: list(collection.aggregate([]))), ('find', lambda: list(collection.find())), ('insert_one', lambda: collection.insert_one({})), ('update_one', lambda: collection.update_one({}, {'$set': {'x': 1}})), ('update_many', lambda: collection.update_many({}, {'$set': {'x': 1}})), ('delete_one', lambda: collection.delete_one({})), ('delete_many', lambda: collection.delete_many({})), ('bulk_write', lambda: collection.bulk_write([InsertOne({})])), ('rename_and_drop', rename_and_drop), ('command', insert_command_default_write_concern) ] for name, f in ops: listener.results.clear() f() self.assertGreaterEqual(len(listener.results['started']), 1) for i, event in enumerate(listener.results['started']): self.assertNotIn( 'readConcern', event.command, "%s sent default readConcern with %s" % ( name, event.command_name)) self.assertNotIn( 'writeConcern', event.command, "%s sent default writeConcern with %s" % ( name, event.command_name)) def assertWriteOpsRaise(self, write_concern, expected_exception): wc = write_concern.document # Set socket timeout to avoid indefinite stalls client = rs_or_single_client( w=wc['w'], wTimeoutMS=wc['wtimeout'], socketTimeoutMS=30000) db = client.get_database('pymongo_test') coll = db.test def insert_command(): coll.database.command( 'insert', 'new_collection', documents=[{}], writeConcern=write_concern.document, parse_write_concern_error=True) ops = [ ('insert_one', lambda: coll.insert_one({})), ('insert_many', lambda: coll.insert_many([{}, {}])), ('update_one', lambda: coll.update_one({}, {'$set': {'x': 1}})), ('update_many', lambda: coll.update_many({}, {'$set': {'x': 1}})), ('delete_one', lambda: coll.delete_one({})), ('delete_many', lambda: coll.delete_many({})), ('bulk_write', lambda: coll.bulk_write([InsertOne({})])), ('command', insert_command), ] ops_require_34 = [ ('aggregate', lambda: coll.aggregate([{'$out': 'out'}])), # SERVER-46668 Delete all the documents in the collection to # workaround a hang in createIndexes. ('delete_many', lambda: coll.delete_many({})), ('create_index', lambda: coll.create_index([('a', DESCENDING)])), ('create_indexes', lambda: coll.create_indexes([IndexModel('b')])), ('drop_index', lambda: coll.drop_index([('a', DESCENDING)])), ('create', lambda: db.create_collection('new')), ('rename', lambda: coll.rename('new')), ('drop', lambda: db.new.drop()), ] if client_context.version > (3, 4): ops.extend(ops_require_34) # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. if client_context.version[:2] != (3, 6): ops.append(('drop_database', lambda: client.drop_database(db))) for name, f in ops: # Ensure insert_many and bulk_write still raise BulkWriteError. if name in ('insert_many', 'bulk_write'): expected = BulkWriteError else: expected = expected_exception with self.assertRaises(expected, msg=name) as cm: f() if expected == BulkWriteError: bulk_result = cm.exception.details wc_errors = bulk_result['writeConcernErrors'] self.assertTrue(wc_errors) @client_context.require_replica_set def test_raise_write_concern_error(self): self.addCleanup(client_context.client.drop_database, 'pymongo_test') self.assertWriteOpsRaise( WriteConcern(w=client_context.w+1, wtimeout=1), WriteConcernError) # MongoDB 3.2 introduced the stopReplProducer failpoint. @client_context.require_version_min(3, 2) @client_context.require_secondaries_count(1) @client_context.require_test_commands def test_raise_wtimeout(self): self.addCleanup(client_context.client.drop_database, 'pymongo_test') self.addCleanup(enable_replication, client_context.client) # Disable replication to guarantee a wtimeout error. disable_replication(client_context.client) self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) @client_context.require_failCommand_fail_point def test_error_includes_errInfo(self): expected_wce = { "code": 100, "codeName": "UnsatisfiableWriteConcern", "errmsg": "Not enough data-bearing nodes", "errInfo": { "writeConcern": { "w": 2, "wtimeout": 0, "provenance": "clientSupplied" } } } cause_wce = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { "failCommands": ["insert"], "writeConcernError": expected_wce }, } with self.fail_point(cause_wce): # Write concern error on insert includes errInfo. with self.assertRaises(WriteConcernError) as ctx: self.db.test.insert_one({}) self.assertEqual(ctx.exception.details, expected_wce) # Test bulk_write as well. with self.assertRaises(BulkWriteError) as ctx: self.db.test.bulk_write([InsertOne({})]) expected_details = { 'writeErrors': [], 'writeConcernErrors': [expected_wce], 'nInserted': 1, 'nUpserted': 0, 'nMatched': 0, 'nModified': 0, 'nRemoved': 0, 'upserted': []} self.assertEqual(ctx.exception.details, expected_details) def normalize_write_concern(concern): result = {} for key in concern: if key.lower() == 'wtimeoutms': result['wtimeout'] = concern[key] elif key == 'journal': result['j'] = concern[key] else: result[key] = concern[key] return result def create_connection_string_test(test_case): def run_test(self): uri = test_case['uri'] valid = test_case['valid'] warning = test_case['warning'] if not valid: if warning is False: self.assertRaises( (ConfigurationError, ValueError), MongoClient, uri, connect=False) else: with warnings.catch_warnings(): warnings.simplefilter('error', UserWarning) self.assertRaises( UserWarning, MongoClient, uri, connect=False) else: client = MongoClient(uri, connect=False) if 'writeConcern' in test_case: document = client.write_concern.document self.assertEqual( document, normalize_write_concern(test_case['writeConcern'])) if 'readConcern' in test_case: document = client.read_concern.document self.assertEqual(document, test_case['readConcern']) return run_test def create_document_test(test_case): def run_test(self): valid = test_case['valid'] if 'writeConcern' in test_case: normalized = normalize_write_concern(test_case['writeConcern']) if not valid: self.assertRaises( (ConfigurationError, ValueError), WriteConcern, **normalized) else: concern = WriteConcern(**normalized) self.assertEqual( concern.document, test_case['writeConcernDocument']) self.assertEqual( concern.acknowledged, test_case['isAcknowledged']) self.assertEqual( concern.is_server_default, test_case['isServerDefault']) if 'readConcern' in test_case: # Any string for 'level' is equaly valid concern = ReadConcern(**test_case['readConcern']) self.assertEqual(concern.document, test_case['readConcernDocument']) self.assertEqual( not bool(concern.level), test_case['isServerDefault']) return run_test def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] if dirname == 'operation': # This directory is tested by TestOperations. continue elif dirname == 'connection-string': create_test = create_connection_string_test else: create_test = create_document_test for filename in filenames: with open(os.path.join(dirpath, filename)) as test_stream: test_cases = json.load(test_stream)['tests'] fname = os.path.splitext(filename)[0] for test_case in test_cases: new_test = create_test(test_case) test_name = 'test_%s_%s_%s' % ( dirname.replace('-', '_'), fname.replace('-', '_'), str(test_case['description'].lower().replace(' ', '_'))) new_test.__name__ = test_name setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) create_tests() class TestOperation(SpecRunner): # Location of JSON test specifications. TEST_PATH = os.path.join(_TEST_PATH, 'operation') def get_outcome_coll_name(self, outcome, collection): """Spec says outcome has an optional 'collection.name'.""" return outcome['collection'].get('name', collection.name) def create_operation_test(scenario_def, test, name): @client_context.require_test_commands def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator( create_operation_test, TestOperation, TestOperation.TEST_PATH) test_creator.create_tests() if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_replica_set_client.py000066400000000000000000000317331374256237000215240ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the mongo_replica_set_client module.""" import sys import warnings import time sys.path[0:0] = [""] from bson.codec_options import CodecOptions from bson.son import SON from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, partition_node from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure, NetworkTimeout, NotMasterError, OperationFailure) from pymongo.mongo_client import MongoClient from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.read_preferences import ReadPreference, Secondary, Nearest from pymongo.write_concern import WriteConcern from test import (client_context, client_knobs, IntegrationTest, unittest, SkipTest, db_pwd, db_user, MockClientTest, HAVE_IPADDRESS) from test.pymongo_mocks import MockClient from test.utils import (connected, delay, ignore_deprecations, one, rs_client, single_client, wait_until) class TestReplicaSetClientBase(IntegrationTest): @classmethod @client_context.require_replica_set def setUpClass(cls): super(TestReplicaSetClientBase, cls).setUpClass() cls.name = client_context.replica_set_name cls.w = client_context.w ismaster = client_context.ismaster cls.hosts = set(partition_node(h.lower()) for h in ismaster['hosts']) cls.arbiters = set(partition_node(h) for h in ismaster.get("arbiters", [])) repl_set_status = client_context.client.admin.command( 'replSetGetStatus') primary_info = [ m for m in repl_set_status['members'] if m['stateStr'] == 'PRIMARY' ][0] cls.primary = partition_node(primary_info['name'].lower()) cls.secondaries = set( partition_node(m['name'].lower()) for m in repl_set_status['members'] if m['stateStr'] == 'SECONDARY') class TestReplicaSetClient(TestReplicaSetClientBase): def test_deprecated(self): with warnings.catch_warnings(): warnings.simplefilter("error", DeprecationWarning) with self.assertRaises(DeprecationWarning): MongoReplicaSetClient() def test_connect(self): client = MongoClient( client_context.pair, replicaSet='fdlksjfdslkjfd', serverSelectionTimeoutMS=100) with self.assertRaises(ConnectionFailure): client.test.test.find_one() def test_repr(self): with ignore_deprecations(): client = MongoReplicaSetClient( client_context.host, client_context.port, replicaSet=self.name) self.assertIn("MongoReplicaSetClient(host=[", repr(client)) self.assertIn(client_context.pair, repr(client)) def test_properties(self): c = client_context.client c.admin.command('ping') wait_until(lambda: c.primary == self.primary, "discover primary") wait_until(lambda: c.arbiters == self.arbiters, "discover arbiters") wait_until(lambda: c.secondaries == self.secondaries, "discover secondaries") self.assertEqual(c.primary, self.primary) self.assertEqual(c.secondaries, self.secondaries) self.assertEqual(c.arbiters, self.arbiters) self.assertEqual(c.max_pool_size, 100) # Make sure MongoClient's properties are copied to Database and # Collection. for obj in c, c.pymongo_test, c.pymongo_test.test: self.assertEqual(obj.codec_options, CodecOptions()) self.assertEqual(obj.read_preference, ReadPreference.PRIMARY) self.assertEqual(obj.write_concern, WriteConcern()) cursor = c.pymongo_test.test.find() self.assertEqual( ReadPreference.PRIMARY, cursor._read_preference()) tag_sets = [{'dc': 'la', 'rack': '2'}, {'foo': 'bar'}] secondary = Secondary(tag_sets=tag_sets) c = rs_client( maxPoolSize=25, document_class=SON, tz_aware=True, read_preference=secondary, localThresholdMS=77, j=True) self.assertEqual(c.max_pool_size, 25) for obj in c, c.pymongo_test, c.pymongo_test.test: self.assertEqual(obj.codec_options, CodecOptions(SON, True)) self.assertEqual(obj.read_preference, secondary) self.assertEqual(obj.write_concern, WriteConcern(j=True)) cursor = c.pymongo_test.test.find() self.assertEqual( secondary, cursor._read_preference()) nearest = Nearest(tag_sets=[{'dc': 'ny'}, {}]) cursor = c.pymongo_test.get_collection( "test", read_preference=nearest).find() self.assertEqual(nearest, cursor._read_preference()) self.assertEqual(c.max_bson_size, 16777216) c.close() @client_context.require_secondaries_count(1) def test_timeout_does_not_mark_member_down(self): # If a query times out, the client shouldn't mark the member "down". # Disable background refresh. with client_knobs(heartbeat_frequency=999999): c = rs_client(socketTimeoutMS=1000, w=self.w) collection = c.pymongo_test.test collection.insert_one({}) # Query the primary. self.assertRaises( NetworkTimeout, collection.find_one, {'$where': delay(1.5)}) self.assertTrue(c.primary) collection.find_one() # No error. coll = collection.with_options( read_preference=ReadPreference.SECONDARY) # Query the secondary. self.assertRaises( NetworkTimeout, coll.find_one, {'$where': delay(1.5)}) self.assertTrue(c.secondaries) # No error. coll.find_one() @client_context.require_ipv6 def test_ipv6(self): if client_context.tls: if not HAVE_IPADDRESS: raise SkipTest("Need the ipaddress module to test with SSL") port = client_context.port c = rs_client("mongodb://[::1]:%d" % (port,)) # Client switches to IPv4 once it has first ismaster response. msg = 'discovered primary with IPv4 address "%r"' % (self.primary,) wait_until(lambda: c.primary == self.primary, msg) # Same outcome with both IPv4 and IPv6 seeds. c = rs_client("mongodb://[::1]:%d,localhost:%d" % (port, port)) wait_until(lambda: c.primary == self.primary, msg) if client_context.auth_enabled: auth_str = "%s:%s@" % (db_user, db_pwd) else: auth_str = "" uri = "mongodb://%slocalhost:%d,[::1]:%d" % (auth_str, port, port) client = rs_client(uri) client.pymongo_test.test.insert_one({"dummy": u"object"}) client.pymongo_test_bernie.test.insert_one({"dummy": u"object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) self.assertTrue("pymongo_test_bernie" in dbs) client.close() def _test_kill_cursor_explicit(self, read_pref): with client_knobs(kill_cursor_frequency=0.01): c = rs_client(read_preference=read_pref, w=self.w) db = c.pymongo_test db.drop_collection("test") test = db.test test.insert_many([{"i": i} for i in range(20)]) # Partially evaluate cursor so it's left alive, then kill it cursor = test.find().batch_size(10) next(cursor) self.assertNotEqual(0, cursor.cursor_id) if read_pref == ReadPreference.PRIMARY: msg = "Expected cursor's address to be %s, got %s" % ( c.primary, cursor.address) self.assertEqual(cursor.address, c.primary, msg) else: self.assertNotEqual( cursor.address, c.primary, "Expected cursor's address not to be primary") cursor_id = cursor.cursor_id # Cursor dead on server - trigger a getMore on the same cursor_id # and check that the server returns an error. cursor2 = cursor.clone() cursor2._Cursor__id = cursor_id if sys.platform.startswith('java') or 'PyPy' in sys.version: # Explicitly kill cursor. cursor.close() else: # Implicitly kill it in CPython. del cursor time.sleep(5) self.assertRaises(OperationFailure, lambda: list(cursor2)) def test_kill_cursor_explicit_primary(self): self._test_kill_cursor_explicit(ReadPreference.PRIMARY) @client_context.require_secondaries_count(1) def test_kill_cursor_explicit_secondary(self): self._test_kill_cursor_explicit(ReadPreference.SECONDARY) @client_context.require_secondaries_count(1) def test_not_master_error(self): secondary_address = one(self.secondaries) direct_client = single_client(*secondary_address) with self.assertRaises(NotMasterError): direct_client.pymongo_test.collection.insert_one({}) db = direct_client.get_database( "pymongo_test", write_concern=WriteConcern(w=0)) with self.assertRaises(NotMasterError): db.collection.insert_one({}) class TestReplicaSetWireVersion(MockClientTest): @client_context.require_connection @client_context.require_no_auth def test_wire_version(self): c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1', replicaSet='rs', connect=False) self.addCleanup(c.close) c.set_wire_version_range('a:1', 3, 7) c.set_wire_version_range('b:2', 2, 3) c.set_wire_version_range('c:3', 3, 4) c.db.command('ismaster') # Connect. # A secondary doesn't overlap with us. c.set_wire_version_range('b:2', MAX_SUPPORTED_WIRE_VERSION + 1, MAX_SUPPORTED_WIRE_VERSION + 2) def raises_configuration_error(): try: c.db.collection.find_one() return False except ConfigurationError: return True wait_until(raises_configuration_error, 'notice we are incompatible with server') self.assertRaises(ConfigurationError, c.db.collection.insert_one, {}) class TestReplicaSetClientInternalIPs(MockClientTest): @client_context.require_connection def test_connect_with_internal_ips(self): # Client is passed an IP it can reach, 'a:1', but the RS config # only contains unreachable IPs like 'internal-ip'. PYTHON-608. client = MockClient( standalones=[], members=['a:1'], mongoses=[], ismaster_hosts=['internal-ip:27017'], host='a:1', replicaSet='rs', serverSelectionTimeoutMS=100) self.addCleanup(client.close) with self.assertRaises(AutoReconnect) as context: connected(client) self.assertIn("Could not reach any servers in [('internal-ip', 27017)]." " Replica set is configured with internal hostnames or IPs?", str(context.exception)) class TestReplicaSetClientMaxWriteBatchSize(MockClientTest): @client_context.require_connection def test_max_write_batch_size(self): c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs', connect=False) self.addCleanup(c.close) c.set_max_write_batch_size('a:1', 1) c.set_max_write_batch_size('b:2', 2) # Uses primary's max batch size. self.assertEqual(c.max_write_batch_size, 1) # b becomes primary. c.mock_primary = 'b:2' wait_until(lambda: c.max_write_batch_size == 2, 'update max_write_batch_size') if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_replica_set_reconfig.py000066400000000000000000000135621374256237000220420ustar00rootroot00000000000000# Copyright 2013-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test clients and replica set configuration changes, using mocks.""" import sys sys.path[0:0] = [""] from pymongo.errors import ConnectionFailure, AutoReconnect from pymongo import ReadPreference from test import unittest, client_context, client_knobs, MockClientTest from test.pymongo_mocks import MockClient from test.utils import wait_until @client_context.require_connection def setUpModule(): pass class TestSecondaryBecomesStandalone(MockClientTest): # An administrator removes a secondary from a 3-node set and # brings it back up as standalone, without updating the other # members' config. Verify we don't continue using it. def test_client(self): c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1,b:2,c:3', replicaSet='rs', serverSelectionTimeoutMS=100) self.addCleanup(c.close) # MongoClient connects to primary by default. wait_until(lambda: c.address is not None, 'connect to primary') self.assertEqual(c.address, ('a', 1)) # C is brought up as a standalone. c.mock_members.remove('c:3') c.mock_standalones.append('c:3') # Fail over. c.kill_host('a:1') c.kill_host('b:2') # Force reconnect. c.close() with self.assertRaises(AutoReconnect): c.db.command('ismaster') self.assertEqual(c.address, None) def test_replica_set_client(self): c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1,b:2,c:3', replicaSet='rs') self.addCleanup(c.close) wait_until(lambda: ('b', 2) in c.secondaries, 'discover host "b"') wait_until(lambda: ('c', 3) in c.secondaries, 'discover host "c"') # C is brought up as a standalone. c.mock_members.remove('c:3') c.mock_standalones.append('c:3') wait_until(lambda: set([('b', 2)]) == c.secondaries, 'update the list of secondaries') self.assertEqual(('a', 1), c.primary) class TestSecondaryRemoved(MockClientTest): # An administrator removes a secondary from a 3-node set *without* # restarting it as standalone. def test_replica_set_client(self): c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1,b:2,c:3', replicaSet='rs') self.addCleanup(c.close) wait_until(lambda: ('b', 2) in c.secondaries, 'discover host "b"') wait_until(lambda: ('c', 3) in c.secondaries, 'discover host "c"') # C is removed. c.mock_ismaster_hosts.remove('c:3') wait_until(lambda: set([('b', 2)]) == c.secondaries, 'update list of secondaries') self.assertEqual(('a', 1), c.primary) class TestSocketError(MockClientTest): def test_socket_error_marks_member_down(self): # Disable background refresh. with client_knobs(heartbeat_frequency=999999): c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs', serverSelectionTimeoutMS=100) self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') # b now raises socket.error. c.mock_down_hosts.append('b:2') self.assertRaises( ConnectionFailure, c.db.collection.with_options( read_preference=ReadPreference.SECONDARY).find_one) self.assertEqual(1, len(c.nodes)) class TestSecondaryAdded(MockClientTest): def test_client(self): c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs') self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') # MongoClient connects to primary by default. self.assertEqual(c.address, ('a', 1)) self.assertEqual(set([('a', 1), ('b', 2)]), c.nodes) # C is added. c.mock_members.append('c:3') c.mock_ismaster_hosts.append('c:3') c.close() c.db.command('ismaster') self.assertEqual(c.address, ('a', 1)) wait_until(lambda: set([('a', 1), ('b', 2), ('c', 3)]) == c.nodes, 'reconnect to both secondaries') def test_replica_set_client(self): c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs') self.addCleanup(c.close) wait_until(lambda: ('a', 1) == c.primary, 'discover the primary') wait_until(lambda: set([('b', 2)]) == c.secondaries, 'discover the secondary') # C is added. c.mock_members.append('c:3') c.mock_ismaster_hosts.append('c:3') wait_until(lambda: set([('b', 2), ('c', 3)]) == c.secondaries, 'discover the new secondary') self.assertEqual(('a', 1), c.primary) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_retryable_reads.py000066400000000000000000000101021374256237000210260ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test retryable reads spec.""" import os import sys sys.path[0:0] = [""] from pymongo.mongo_client import MongoClient from pymongo.write_concern import WriteConcern from test import unittest, client_context, PyMongoTestCase from test.utils import TestCreator from test.utils_spec_runner import SpecRunner # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'retryable_reads') class TestClientOptions(PyMongoTestCase): def test_default(self): client = MongoClient(connect=False) self.assertEqual(client.retry_reads, True) def test_kwargs(self): client = MongoClient(retryReads=True, connect=False) self.assertEqual(client.retry_reads, True) client = MongoClient(retryReads=False, connect=False) self.assertEqual(client.retry_reads, False) def test_uri(self): client = MongoClient('mongodb://h/?retryReads=true', connect=False) self.assertEqual(client.retry_reads, True) client = MongoClient('mongodb://h/?retryReads=false', connect=False) self.assertEqual(client.retry_reads, False) class TestSpec(SpecRunner): @classmethod @client_context.require_version_min(4, 0) # TODO: remove this once PYTHON-1948 is done. @client_context.require_no_mmap def setUpClass(cls): super(TestSpec, cls).setUpClass() if client_context.is_mongos and client_context.version[:2] <= (4, 0): raise unittest.SkipTest("4.0 mongos does not support failCommand") def maybe_skip_scenario(self, test): super(TestSpec, self).maybe_skip_scenario(test) skip_names = [ 'listCollectionObjects', 'listIndexNames', 'listDatabaseObjects'] for name in skip_names: if name.lower() in test['description'].lower(): self.skipTest('PyMongo does not support %s' % (name,)) # Skip changeStream related tests on MMAPv1. test_name = self.id().rsplit('.')[-1] if ('changestream' in test_name.lower() and client_context.storage_engine == 'mmapv1'): self.skipTest("MMAPv1 does not support change streams.") def get_scenario_coll_name(self, scenario_def): """Override a test's collection name to support GridFS tests.""" if 'bucket_name' in scenario_def: return scenario_def['bucket_name'] return super(TestSpec, self).get_scenario_coll_name(scenario_def) def setup_scenario(self, scenario_def): """Override a test's setup to support GridFS tests.""" if 'bucket_name' in scenario_def: db_name = self.get_scenario_db_name(scenario_def) db = client_context.client.get_database( db_name, write_concern=WriteConcern(w='majority')) # Create a bucket for the retryable reads GridFS tests. client_context.client.drop_database(db_name) if scenario_def['data']: data = scenario_def['data'] # Load data. db['fs.chunks'].insert_many(data['fs.chunks']) db['fs.files'].insert_many(data['fs.files']) else: super(TestSpec, self).setup_scenario(scenario_def) def create_test(scenario_def, test, name): @client_context.require_test_commands def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestSpec, _TEST_PATH) test_creator.create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_retryable_writes.py000066400000000000000000000514021374256237000212550ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test retryable writes.""" import copy import os import sys sys.path[0:0] = [""] from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON from pymongo.errors import (ConnectionFailure, OperationFailure, ServerSelectionTimeoutError) from pymongo.mongo_client import MongoClient from pymongo.operations import (InsertOne, DeleteMany, DeleteOne, ReplaceOne, UpdateMany, UpdateOne) from pymongo.write_concern import WriteConcern from test import unittest, client_context, IntegrationTest, SkipTest, client_knobs from test.test_crud_v1 import check_result as crud_v1_check_result from test.utils import (rs_or_single_client, DeprecationFilter, OvertCommandListener, TestCreator) from test.utils_spec_runner import SpecRunner # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'retryable_writes') class TestAllScenarios(SpecRunner): def get_object_name(self, op): return op.get('object', 'collection') def get_scenario_db_name(self, scenario_def): return scenario_def.get('database_name', 'pymongo_test') def get_scenario_coll_name(self, scenario_def): return scenario_def.get('collection_name', 'test') def run_test_ops(self, sessions, collection, test): outcome = test['outcome'] should_fail = outcome.get('error') result = None error = None try: result = self.run_operation( sessions, collection, test['operation']) except (ConnectionFailure, OperationFailure) as exc: error = exc if should_fail: self.assertIsNotNone(error, 'should have raised an error') else: self.assertIsNone(error) crud_v1_check_result(self, outcome['result'], result) def create_test(scenario_def, test, name): @client_context.require_test_commands @client_context.require_no_mmap def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestAllScenarios, _TEST_PATH) test_creator.create_tests() def _retryable_single_statement_ops(coll): return [ (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}), (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {'ordered': False}), (coll.bulk_write, [[ReplaceOne({}, {})]], {}), (coll.bulk_write, [[ReplaceOne({}, {}), ReplaceOne({}, {})]], {}), (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), UpdateOne({}, {'$set': {'a': 1}})]], {}), (coll.bulk_write, [[DeleteOne({})]], {}), (coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}), (coll.insert_one, [{}], {}), (coll.insert_many, [[{}, {}]], {}), (coll.replace_one, [{}, {}], {}), (coll.update_one, [{}, {'$set': {'a': 1}}], {}), (coll.delete_one, [{}], {}), (coll.find_one_and_replace, [{}, {'a': 3}], {}), (coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}), (coll.find_one_and_delete, [{}, {}], {}), ] def retryable_single_statement_ops(coll): return _retryable_single_statement_ops(coll) + [ # Deprecated methods. # Insert with single or multiple documents. (coll.insert, [{}], {}), (coll.insert, [[{}]], {}), (coll.insert, [[{}, {}]], {}), # Save with and without an _id. (coll.save, [{}], {}), (coll.save, [{'_id': ObjectId()}], {}), # Non-multi update. (coll.update, [{}, {'$set': {'a': 1}}], {}), # Non-multi remove. (coll.remove, [{}], {'multi': False}), # Replace. (coll.find_and_modify, [{}, {'a': 3}], {}), # Update. (coll.find_and_modify, [{}, {'$set': {'a': 1}}], {}), # Delete. (coll.find_and_modify, [{}, {}], {'remove': True}), ] def non_retryable_single_statement_ops(coll): return [ (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), UpdateMany({}, {'$set': {'a': 1}})]], {}), (coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}), (coll.update_many, [{}, {'$set': {'a': 1}}], {}), (coll.delete_many, [{}], {}), # Deprecated methods. # Multi remove. (coll.remove, [{}], {}), # Multi update. (coll.update, [{}, {'$set': {'a': 1}}], {'multi': True}), # Unacknowledged deprecated methods. (coll.insert, [{}], {'w': 0}), # Unacknowledged Non-multi update. (coll.update, [{}, {'$set': {'a': 1}}], {'w': 0}), # Unacknowledged Non-multi remove. (coll.remove, [{}], {'multi': False, 'w': 0}), # Unacknowledged Replace. (coll.find_and_modify, [{}, {'a': 3}], {'writeConcern': {'w': 0}}), # Unacknowledged Update. (coll.find_and_modify, [{}, {'$set': {'a': 1}}], {'writeConcern': {'w': 0}}), # Unacknowledged Delete. (coll.find_and_modify, [{}, {}], {'remove': True, 'writeConcern': {'w': 0}}), ] class IgnoreDeprecationsTest(IntegrationTest): @classmethod def setUpClass(cls): super(IgnoreDeprecationsTest, cls).setUpClass() cls.deprecation_filter = DeprecationFilter() @classmethod def tearDownClass(cls): cls.deprecation_filter.stop() super(IgnoreDeprecationsTest, cls).tearDownClass() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): @classmethod def setUpClass(cls): super(TestRetryableWritesMMAPv1, cls).setUpClass() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.client = rs_or_single_client(retryWrites=True) cls.db = cls.client.pymongo_test @classmethod def tearDownClass(cls): cls.knobs.disable() cls.client.close() @client_context.require_version_min(3, 5) @client_context.require_no_standalone def test_actionable_error_message(self): if client_context.storage_engine != 'mmapv1': raise SkipTest('This cluster is not running MMAPv1') expected_msg = ("This MongoDB deployment does not support retryable " "writes. Please add retryWrites=false to your " "connection string.") for method, args, kwargs in retryable_single_statement_ops( self.db.retryable_write_test): with self.assertRaisesRegex(OperationFailure, expected_msg): method(*args, **kwargs) class TestRetryableWrites(IgnoreDeprecationsTest): @classmethod @client_context.require_no_mmap def setUpClass(cls): super(TestRetryableWrites, cls).setUpClass() # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.listener = OvertCommandListener() cls.client = rs_or_single_client( retryWrites=True, event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test @classmethod def tearDownClass(cls): cls.knobs.disable() cls.client.close() super(TestRetryableWrites, cls).tearDownClass() def setUp(self): if (client_context.version.at_least(3, 5) and client_context.is_rs and client_context.test_commands_enabled): self.client.admin.command(SON([ ('configureFailPoint', 'onPrimaryTransactionalWrite'), ('mode', 'alwaysOn')])) def tearDown(self): if (client_context.version.at_least(3, 5) and client_context.is_rs and client_context.test_commands_enabled): self.client.admin.command(SON([ ('configureFailPoint', 'onPrimaryTransactionalWrite'), ('mode', 'off')])) def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() client = rs_or_single_client( retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() method(*args, **kwargs) for event in listener.results['started']: self.assertNotIn( 'txnNumber', event.command, '%s sent txnNumber with %s' % (msg, event.command_name)) @client_context.require_version_min(3, 5) @client_context.require_no_standalone def test_supported_single_statement_supported_cluster(self): for method, args, kwargs in retryable_single_statement_ops( self.db.retryable_write_test): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) commands_started = self.listener.results['started'] self.assertEqual(len(self.listener.results['succeeded']), 1, msg) first_attempt = commands_started[0] self.assertIn( 'lsid', first_attempt.command, '%s sent no lsid with %s' % (msg, first_attempt.command_name)) initial_session_id = first_attempt.command['lsid'] self.assertIn( 'txnNumber', first_attempt.command, '%s sent no txnNumber with %s' % ( msg, first_attempt.command_name)) # There should be no retry when the failpoint is not active. if (client_context.is_mongos or not client_context.test_commands_enabled): self.assertEqual(len(commands_started), 1) continue initial_transaction_id = first_attempt.command['txnNumber'] retry_attempt = commands_started[1] self.assertIn( 'lsid', retry_attempt.command, '%s sent no lsid with %s' % (msg, first_attempt.command_name)) self.assertEqual( retry_attempt.command['lsid'], initial_session_id, msg) self.assertIn( 'txnNumber', retry_attempt.command, '%s sent no txnNumber with %s' % ( msg, first_attempt.command_name)) self.assertEqual(retry_attempt.command['txnNumber'], initial_transaction_id, msg) def test_supported_single_statement_unsupported_cluster(self): if client_context.version.at_least(3, 5) and ( client_context.is_rs or client_context.is_mongos): raise SkipTest('This cluster supports retryable writes') for method, args, kwargs in retryable_single_statement_ops( self.db.retryable_write_test): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) for event in self.listener.results['started']: self.assertNotIn( 'txnNumber', event.command, '%s sent txnNumber with %s' % (msg, event.command_name)) def test_unsupported_single_statement(self): coll = self.db.retryable_write_test coll.insert_many([{}, {}]) coll_w0 = coll.with_options(write_concern=WriteConcern(w=0)) for method, args, kwargs in (non_retryable_single_statement_ops(coll) + retryable_single_statement_ops(coll_w0)): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) started_events = self.listener.results['started'] self.assertEqual(len(self.listener.results['succeeded']), len(started_events), msg) self.assertEqual(len(self.listener.results['failed']), 0, msg) for event in started_events: self.assertNotIn( 'txnNumber', event.command, '%s sent txnNumber with %s' % (msg, event.command_name)) def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = OvertCommandListener() client = MongoClient( 'somedomainthatdoesntexist.org', serverSelectionTimeoutMS=1, retryWrites=True, event_listeners=[listener]) for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() with self.assertRaises(ServerSelectionTimeoutError, msg=msg): method(*args, **kwargs) self.assertEqual(len(listener.results['started']), 0, msg) @client_context.require_version_min(3, 5) @client_context.require_replica_set @client_context.require_test_commands def test_retry_timeout_raises_original_error(self): """A ServerSelectionTimeoutError on the retry attempt raises the original error. """ listener = OvertCommandListener() client = rs_or_single_client( retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server def mock_select_server(*args, **kwargs): server = select_server(*args, **kwargs) def raise_error(*args, **kwargs): raise ServerSelectionTimeoutError( 'No primary available for writes') # Raise ServerSelectionTimeout on the retry attempt. topology.select_server = raise_error return server for method, args, kwargs in retryable_single_statement_ops( client.db.retryable_write_test): msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) listener.results.clear() topology.select_server = mock_select_server with self.assertRaises(ConnectionFailure, msg=msg): method(*args, **kwargs) self.assertEqual(len(listener.results['started']), 1, msg) @client_context.require_version_min(3, 5) @client_context.require_replica_set @client_context.require_test_commands def test_batch_splitting(self): """Test retry succeeds after failures during batch splitting.""" large = 's' * 1024 * 1024 * 15 coll = self.db.retryable_write_test coll.delete_many({}) self.listener.results.clear() bulk_result = coll.bulk_write([ InsertOne({'_id': 1, 'l': large}), InsertOne({'_id': 2, 'l': large}), InsertOne({'_id': 3, 'l': large}), UpdateOne({'_id': 1, 'l': large}, {'$unset': {'l': 1}, '$inc': {'count': 1}}), UpdateOne({'_id': 2, 'l': large}, {'$set': {'foo': 'bar'}}), DeleteOne({'l': large}), DeleteOne({'l': large})]) # Each command should fail and be retried. # With OP_MSG 3 inserts are one batch. 2 updates another. # 2 deletes a third. self.assertEqual(len(self.listener.results['started']), 6) self.assertEqual(coll.find_one(), {'_id': 1, 'count': 1}) # Assert the final result expected_result = { "writeErrors": [], "writeConcernErrors": [], "nInserted": 3, "nUpserted": 0, "nMatched": 2, "nModified": 2, "nRemoved": 2, "upserted": [], } self.assertEqual(bulk_result.bulk_api_result, expected_result) @client_context.require_version_min(3, 5) @client_context.require_replica_set @client_context.require_test_commands def test_batch_splitting_retry_fails(self): """Test retry fails during batch splitting.""" large = 's' * 1024 * 1024 * 15 coll = self.db.retryable_write_test coll.delete_many({}) self.client.admin.command(SON([ ('configureFailPoint', 'onPrimaryTransactionalWrite'), ('mode', {'skip': 3}), # The number of _documents_ to skip. ('data', {'failBeforeCommitExceptionCode': 1})])) self.listener.results.clear() with self.client.start_session() as session: initial_txn = session._server_session._transaction_id try: coll.bulk_write([InsertOne({'_id': 1, 'l': large}), InsertOne({'_id': 2, 'l': large}), InsertOne({'_id': 3, 'l': large}), InsertOne({'_id': 4, 'l': large})], session=session) except ConnectionFailure: pass else: self.fail("bulk_write should have failed") started = self.listener.results['started'] self.assertEqual(len(started), 3) self.assertEqual(len(self.listener.results['succeeded']), 1) expected_txn = Int64(initial_txn + 1) self.assertEqual(started[0].command['txnNumber'], expected_txn) self.assertEqual(started[0].command['lsid'], session.session_id) expected_txn = Int64(initial_txn + 2) self.assertEqual(started[1].command['txnNumber'], expected_txn) self.assertEqual(started[1].command['lsid'], session.session_id) started[1].command.pop('$clusterTime') started[2].command.pop('$clusterTime') self.assertEqual(started[1].command, started[2].command) final_txn = session._server_session._transaction_id self.assertEqual(final_txn, expected_txn) self.assertEqual(coll.find_one(projection={'_id': True}), {'_id': 1}) # TODO: Make this a real integration test where we stepdown the primary. class TestRetryableWritesTxnNumber(IgnoreDeprecationsTest): @client_context.require_version_min(3, 6) @client_context.require_replica_set @client_context.require_no_mmap def test_increment_transaction_id_without_sending_command(self): """Test that the txnNumber field is properly incremented, even when the first attempt fails before sending the command. """ listener = OvertCommandListener() client = rs_or_single_client( retryWrites=True, event_listeners=[listener]) topology = client._topology select_server = topology.select_server def raise_connection_err_select_server(*args, **kwargs): # Raise ConnectionFailure on the first attempt and perform # normal selection on the retry attempt. topology.select_server = select_server raise ConnectionFailure('Connection refused') for method, args, kwargs in _retryable_single_statement_ops( client.db.retryable_write_test): listener.results.clear() topology.select_server = raise_connection_err_select_server with client.start_session() as session: kwargs = copy.deepcopy(kwargs) kwargs['session'] = session msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) initial_txn_id = session._server_session.transaction_id # Each operation should fail on the first attempt and succeed # on the second. method(*args, **kwargs) self.assertEqual(len(listener.results['started']), 1, msg) retry_cmd = listener.results['started'][0].command sent_txn_id = retry_cmd['txnNumber'] final_txn_id = session._server_session.transaction_id self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg) self.assertEqual(sent_txn_id, final_txn_id, msg) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_saslprep.py000066400000000000000000000030551374256237000175210ustar00rootroot00000000000000# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import warnings sys.path[0:0] = [""] from pymongo.saslprep import saslprep from test import unittest class TestSASLprep(unittest.TestCase): def test_saslprep(self): try: import stringprep except ImportError: self.assertRaises(TypeError, saslprep, u"anything...") # Bytes strings are ignored. self.assertEqual(saslprep(b"user"), b"user") else: # Examples from RFC4013, Section 3. self.assertEqual(saslprep(u"I\u00ADX"), u"IX") self.assertEqual(saslprep(u"user"), u"user") self.assertEqual(saslprep(u"USER"), u"USER") self.assertEqual(saslprep(u"\u00AA"), u"a") self.assertEqual(saslprep(u"\u2168"), u"IX") self.assertRaises(ValueError, saslprep, u"\u0007") self.assertRaises(ValueError, saslprep, u"\u0627\u0031") # Bytes strings are ignored. self.assertEqual(saslprep(b"user"), b"user") pymongo-3.11.0/test/test_sdam_monitoring_spec.py000066400000000000000000000327031374256237000220750ustar00rootroot00000000000000# Copyright 2016 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the sdam monitoring spec tests.""" import json import os import sys import time sys.path[0:0] = [""] from bson.json_util import object_hook from pymongo import monitoring from pymongo.common import clean_node from pymongo.errors import (ConnectionFailure, NotMasterError) from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.server_description import ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE from test import unittest, client_context, client_knobs, IntegrationTest from test.utils import (ServerAndTopologyEventListener, single_client, server_name_to_type, rs_or_single_client, wait_until) # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'sdam_monitoring') def compare_server_descriptions(expected, actual): if ((not expected['address'] == "%s:%s" % actual.address) or (not server_name_to_type(expected['type']) == actual.server_type)): return False expected_hosts = set( expected['arbiters'] + expected['passives'] + expected['hosts']) return expected_hosts == set("%s:%s" % s for s in actual.all_hosts) def compare_topology_descriptions(expected, actual): if not (TOPOLOGY_TYPE.__getattribute__( expected['topologyType']) == actual.topology_type): return False expected = expected['servers'] actual = actual.server_descriptions() if len(expected) != len(actual): return False for exp_server in expected: for address, actual_server in actual.items(): if compare_server_descriptions(exp_server, actual_server): break else: return False return True def compare_events(expected_dict, actual): if not expected_dict: return False, "Error: Bad expected value in YAML test" if not actual: return False, "Error: Event published was None" expected_type, expected = list(expected_dict.items())[0] if expected_type == "server_opening_event": if not isinstance(actual, monitoring.ServerOpeningEvent): return False, "Expected ServerOpeningEvent, got %s" % ( actual.__class__) if not expected['address'] == "%s:%s" % actual.server_address: return (False, "ServerOpeningEvent published with wrong address (expected" " %s, got %s" % (expected['address'], actual.server_address)) elif expected_type == "server_description_changed_event": if not isinstance(actual, monitoring.ServerDescriptionChangedEvent): return (False, "Expected ServerDescriptionChangedEvent, got %s" % ( actual.__class__)) if not expected['address'] == "%s:%s" % actual.server_address: return (False, "ServerDescriptionChangedEvent has wrong address" " (expected %s, got %s" % (expected['address'], actual.server_address)) if not compare_server_descriptions( expected['newDescription'], actual.new_description): return (False, "New ServerDescription incorrect in" " ServerDescriptionChangedEvent") if not compare_server_descriptions(expected['previousDescription'], actual.previous_description): return (False, "Previous ServerDescription incorrect in" " ServerDescriptionChangedEvent") elif expected_type == "server_closed_event": if not isinstance(actual, monitoring.ServerClosedEvent): return False, "Expected ServerClosedEvent, got %s" % ( actual.__class__) if not expected['address'] == "%s:%s" % actual.server_address: return (False, "ServerClosedEvent published with wrong address" " (expected %s, got %s" % (expected['address'], actual.server_address)) elif expected_type == "topology_opening_event": if not isinstance(actual, monitoring.TopologyOpenedEvent): return False, "Expected TopologyOpeningEvent, got %s" % ( actual.__class__) elif expected_type == "topology_description_changed_event": if not isinstance(actual, monitoring.TopologyDescriptionChangedEvent): return (False, "Expected TopologyDescriptionChangedEvent," " got %s" % (actual.__class__)) if not compare_topology_descriptions(expected['newDescription'], actual.new_description): return (False, "New TopologyDescription incorrect in " "TopologyDescriptionChangedEvent") if not compare_topology_descriptions( expected['previousDescription'], actual.previous_description): return (False, "Previous TopologyDescription incorrect in" " TopologyDescriptionChangedEvent") elif expected_type == "topology_closed_event": if not isinstance(actual, monitoring.TopologyClosedEvent): return False, "Expected TopologyClosedEvent, got %s" % ( actual.__class__) else: return False, "Incorrect event: expected %s, actual %s" % ( expected_type, actual) return True, "" def compare_multiple_events(i, expected_results, actual_results): events_in_a_row = [] j = i while(j < len(expected_results) and isinstance( actual_results[j], actual_results[i].__class__)): events_in_a_row.append(actual_results[j]) j += 1 message = '' for event in events_in_a_row: for k in range(i, j): passed, message = compare_events(expected_results[k], event) if passed: expected_results[k] = None break else: return i, False, message return j, True, '' class TestAllScenarios(unittest.TestCase): @classmethod @client_context.require_connection def setUp(cls): cls.all_listener = ServerAndTopologyEventListener() def create_test(scenario_def): def run_scenario(self): with client_knobs(events_queue_frequency=0.1): _run_scenario(self) def _run_scenario(self): class NoopMonitor(Monitor): """Override the _run method to do nothing.""" def _run(self): time.sleep(0.05) m = single_client(h=scenario_def['uri'], p=27017, event_listeners=[self.all_listener], _monitor_class=NoopMonitor) topology = m._get_topology() try: for phase in scenario_def['phases']: for (source, response) in phase['responses']: source_address = clean_node(source) topology.on_change(ServerDescription( address=source_address, ismaster=IsMaster(response), round_trip_time=0)) expected_results = phase['outcome']['events'] expected_len = len(expected_results) wait_until( lambda: len(self.all_listener.results) >= expected_len, "publish all events", timeout=15) # Wait some time to catch possible lagging extra events. time.sleep(0.5) i = 0 while i < expected_len: result = self.all_listener.results[i] if len( self.all_listener.results) > i else None # The order of ServerOpening/ClosedEvents doesn't matter if isinstance(result, (monitoring.ServerOpeningEvent, monitoring.ServerClosedEvent)): i, passed, message = compare_multiple_events( i, expected_results, self.all_listener.results) self.assertTrue(passed, message) else: self.assertTrue( *compare_events(expected_results[i], result)) i += 1 # Assert no extra events. extra_events = self.all_listener.results[expected_len:] if extra_events: self.fail('Extra events %r' % (extra_events,)) self.all_listener.reset() finally: m.close() return run_scenario def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json.load( scenario_stream, object_hook=object_hook) # Construct test from scenario. new_test = create_test(scenario_def) test_name = 'test_%s' % (os.path.splitext(filename)[0],) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): super(TestSdamMonitoring, cls).setUpClass() # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs(events_queue_frequency=0.1) cls.knobs.enable() cls.listener = ServerAndTopologyEventListener() retry_writes = client_context.supports_transactions() cls.test_client = rs_or_single_client( event_listeners=[cls.listener], retryWrites=retry_writes) cls.coll = cls.test_client[cls.client.db.name].test cls.coll.insert_one({}) @classmethod def tearDownClass(cls): cls.test_client.close() cls.knobs.disable() super(TestSdamMonitoring, cls).tearDownClass() def setUp(self): self.listener.reset() def _test_app_error(self, fail_command_opts, expected_error): address = self.test_client.address # Test that an application error causes a ServerDescriptionChangedEvent # to be published. data = {'failCommands': ['insert']} data.update(fail_command_opts) fail_insert = { 'configureFailPoint': 'failCommand', 'mode': {'times': 1}, 'data': data, } with self.fail_point(fail_insert): if self.test_client.retry_writes: self.coll.insert_one({}) else: with self.assertRaises(expected_error): self.coll.insert_one({}) self.coll.insert_one({}) def marked_unknown(event): return ( isinstance(event, monitoring.ServerDescriptionChangedEvent) and event.server_address == address and not event.new_description.is_server_type_known) def discovered_node(event): return ( isinstance(event, monitoring.ServerDescriptionChangedEvent) and event.server_address == address and not event.previous_description.is_server_type_known and event.new_description.is_server_type_known) def marked_unknown_and_rediscovered(): return (len(self.listener.matching(marked_unknown)) >= 1 and len(self.listener.matching(discovered_node)) >= 1) # Topology events are published asynchronously wait_until(marked_unknown_and_rediscovered, 'rediscover node') # Expect a single ServerDescriptionChangedEvent for the network error. marked_unknown_events = self.listener.matching(marked_unknown) self.assertEqual(len(marked_unknown_events), 1, marked_unknown_events) self.assertIsInstance( marked_unknown_events[0].new_description.error, expected_error) def test_network_error_publishes_events(self): self._test_app_error({'closeConnection': True}, ConnectionFailure) # In 4.4+, NotMaster errors from failCommand don't cause SDAM state # changes because topologyVersion is not incremented. @client_context.require_version_max(4, 3) def test_not_master_error_publishes_events(self): self._test_app_error({'errorCode': 10107, 'closeConnection': False, 'errorLabels': ['RetryableWriteError']}, NotMasterError) def test_shutdown_error_publishes_events(self): self._test_app_error({'errorCode': 91, 'closeConnection': False, 'errorLabels': ['RetryableWriteError']}, NotMasterError) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_server.py000066400000000000000000000021501374256237000171710ustar00rootroot00000000000000# Copyright 2014-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the server module.""" import sys sys.path[0:0] = [""] from pymongo.ismaster import IsMaster from pymongo.server import Server from pymongo.server_description import ServerDescription from test import unittest class TestServer(unittest.TestCase): def test_repr(self): ismaster = IsMaster({'ok': 1}) sd = ServerDescription(('localhost', 27017), ismaster) server = Server(sd, pool=object(), monitor=object()) self.assertTrue('Standalone' in str(server)) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_server_description.py000066400000000000000000000152641374256237000216060ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the server_description module.""" import sys sys.path[0:0] = [""] from bson.objectid import ObjectId from bson.int64 import Int64 from pymongo.server_type import SERVER_TYPE from pymongo.ismaster import IsMaster from pymongo.server_description import ServerDescription from test import unittest address = ('localhost', 27017) def parse_ismaster_response(doc): ismaster_response = IsMaster(doc) return ServerDescription(address, ismaster_response) class TestServerDescription(unittest.TestCase): def test_unknown(self): # Default, no ismaster_response. s = ServerDescription(address) self.assertEqual(SERVER_TYPE.Unknown, s.server_type) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_mongos(self): s = parse_ismaster_response({'ok': 1, 'msg': 'isdbgrid'}) self.assertEqual(SERVER_TYPE.Mongos, s.server_type) self.assertEqual('Mongos', s.server_type_name) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_primary(self): s = parse_ismaster_response( {'ok': 1, 'ismaster': True, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) self.assertEqual('RSPrimary', s.server_type_name) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_secondary(self): s = parse_ismaster_response( {'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSSecondary, s.server_type) self.assertEqual('RSSecondary', s.server_type_name) self.assertFalse(s.is_writable) self.assertTrue(s.is_readable) def test_arbiter(self): s = parse_ismaster_response( {'ok': 1, 'ismaster': False, 'arbiterOnly': True, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSArbiter, s.server_type) self.assertEqual('RSArbiter', s.server_type_name) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_other(self): s = parse_ismaster_response( {'ok': 1, 'ismaster': False, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSOther, s.server_type) self.assertEqual('RSOther', s.server_type_name) s = parse_ismaster_response({ 'ok': 1, 'ismaster': False, 'secondary': True, 'hidden': True, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSOther, s.server_type) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_ghost(self): s = parse_ismaster_response({'ok': 1, 'isreplicaset': True}) self.assertEqual(SERVER_TYPE.RSGhost, s.server_type) self.assertEqual('RSGhost', s.server_type_name) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_fields(self): s = parse_ismaster_response({ 'ok': 1, 'ismaster': False, 'secondary': True, 'primary': 'a:27017', 'tags': {'a': 'foo', 'b': 'baz'}, 'maxMessageSizeBytes': 1, 'maxBsonObjectSize': 2, 'maxWriteBatchSize': 3, 'minWireVersion': 4, 'maxWireVersion': 5, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSSecondary, s.server_type) self.assertEqual(('a', 27017), s.primary) self.assertEqual({'a': 'foo', 'b': 'baz'}, s.tags) self.assertEqual(1, s.max_message_size) self.assertEqual(2, s.max_bson_size) self.assertEqual(3, s.max_write_batch_size) self.assertEqual(4, s.min_wire_version) self.assertEqual(5, s.max_wire_version) def test_default_max_message_size(self): s = parse_ismaster_response({ 'ok': 1, 'ismaster': True, 'maxBsonObjectSize': 2}) # Twice max_bson_size. self.assertEqual(4, s.max_message_size) def test_standalone(self): s = parse_ismaster_response({'ok': 1, 'ismaster': True}) self.assertEqual(SERVER_TYPE.Standalone, s.server_type) # Mongod started with --slave. s = parse_ismaster_response({'ok': 1, 'ismaster': False}) self.assertEqual(SERVER_TYPE.Standalone, s.server_type) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_ok_false(self): s = parse_ismaster_response({'ok': 0, 'ismaster': True}) self.assertEqual(SERVER_TYPE.Unknown, s.server_type) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_all_hosts(self): s = parse_ismaster_response({ 'ok': 1, 'ismaster': True, 'hosts': ['a'], 'passives': ['b:27018'], 'arbiters': ['c'] }) self.assertEqual( [('a', 27017), ('b', 27018), ('c', 27017)], sorted(s.all_hosts)) def test_repr(self): s = parse_ismaster_response({'ok': 1, 'msg': 'isdbgrid'}) self.assertEqual(repr(s), "") def test_topology_version(self): topology_version = {'processId': ObjectId(), 'counter': Int64('0')} s = parse_ismaster_response( {'ok': 1, 'ismaster': True, 'setName': 'rs', 'topologyVersion': topology_version}) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) self.assertEqual(topology_version, s.topology_version) # Resetting a server to unknown preserves topology_version. s_unknown = s.to_unknown() self.assertEqual(SERVER_TYPE.Unknown, s_unknown.server_type) self.assertEqual(topology_version, s_unknown.topology_version) def test_topology_version_not_present(self): # No topologyVersion field. s = parse_ismaster_response( {'ok': 1, 'ismaster': True, 'setName': 'rs'}) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) self.assertEqual(None, s.topology_version) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_server_selection.py000066400000000000000000000204311374256237000212400ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the topology module's Server Selection Spec implementation.""" import os import sys from pymongo import MongoClient from pymongo import ReadPreference from pymongo.errors import ServerSelectionTimeoutError from pymongo.server_selectors import writable_server_selector from pymongo.settings import TopologySettings from pymongo.topology import Topology sys.path[0:0] = [""] from test import client_context, unittest, IntegrationTest from test.utils import (rs_or_single_client, wait_until, EventListener, FunctionCallRecorder) from test.utils_selection_tests import ( create_selection_tests, get_addresses, get_topology_settings_dict, make_server_description) # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join('server_selection', 'server_selection')) class SelectionStoreSelector(object): """No-op selector that keeps track of what was passed to it.""" def __init__(self): self.selection = None def __call__(self, selection): self.selection = selection return selection class TestAllScenarios(create_selection_tests(_TEST_PATH)): pass class TestCustomServerSelectorFunction(IntegrationTest): @client_context.require_replica_set def test_functional_select_max_port_number_host(self): # Selector that returns server with highest port number. def custom_selector(servers): ports = [s.address[1] for s in servers] idx = ports.index(max(ports)) return [servers[idx]] # Initialize client with appropriate listeners. listener = EventListener() client = rs_or_single_client( server_selector=custom_selector, event_listeners=[listener]) self.addCleanup(client.close) coll = client.get_database( 'testdb', read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, 'testdb') # Wait the node list to be fully populated. def all_hosts_started(): return (len(client.admin.command('isMaster')['hosts']) == len(client._topology._description.readable_servers)) wait_until(all_hosts_started, 'receive heartbeat from all hosts') expected_port = max([ n.address[1] for n in client._topology._description.readable_servers]) # Insert 1 record and access it 10 times. coll.insert_one({'name': 'John Doe'}) for _ in range(10): coll.find_one({'name': 'John Doe'}) # Confirm all find commands are run against appropriate host. for command in listener.results['started']: if command.command_name == 'find': self.assertEqual( command.connection_id[1], expected_port) def test_invalid_server_selector(self): # Client initialization must fail if server_selector is not callable. for selector_candidate in [list(), 10, 'string', {}]: with self.assertRaisesRegex(ValueError, "must be a callable"): MongoClient(connect=False, server_selector=selector_candidate) # None value for server_selector is OK. MongoClient(connect=False, server_selector=None) @client_context.require_replica_set def test_selector_called(self): selector = FunctionCallRecorder(lambda x: x) # Client setup. mongo_client = rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.drop_database, 'testdb') self.addCleanup(mongo_client.close) # Do N operations and test selector is called at least N times. test_collection.insert_one({'age': 20, 'name': 'John'}) test_collection.insert_one({'age': 31, 'name': 'Jane'}) test_collection.update_one({'name': 'Jane'}, {'$set': {'age': 21}}) test_collection.find_one({'name': 'Roe'}) self.assertGreaterEqual(selector.call_count, 4) @client_context.require_replica_set def test_latency_threshold_application(self): selector = SelectionStoreSelector() scenario_def = { 'topology_description': { 'type': 'ReplicaSetWithPrimary', 'servers': [ {'address': 'b:27017', 'avg_rtt_ms': 10000, 'type': 'RSSecondary', 'tag': {}}, {'address': 'c:27017', 'avg_rtt_ms': 20000, 'type': 'RSSecondary', 'tag': {}}, {'address': 'a:27017', 'avg_rtt_ms': 30000, 'type': 'RSPrimary', 'tag': {}}, ]}} # Create & populate Topology such that all but one server is too slow. rtt_times = [srv['avg_rtt_ms'] for srv in scenario_def['topology_description']['servers']] min_rtt_idx = rtt_times.index(min(rtt_times)) seeds, hosts = get_addresses( scenario_def["topology_description"]["servers"]) settings = get_topology_settings_dict( heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector) topology = Topology(TopologySettings(**settings)) topology.open() for server in scenario_def['topology_description']['servers']: server_description = make_server_description(server, hosts) topology.on_change(server_description) # Invoke server selection and assert no filtering based on latency # prior to custom server selection logic kicking in. server = topology.select_server(ReadPreference.NEAREST) self.assertEqual( len(selector.selection), len(topology.description.server_descriptions())) # Ensure proper filtering based on latency after custom selection. self.assertEqual( server.description.address, seeds[min_rtt_idx]) @client_context.require_replica_set def test_server_selector_bypassed(self): selector = FunctionCallRecorder(lambda x: x) scenario_def = { 'topology_description': { 'type': 'ReplicaSetNoPrimary', 'servers': [ {'address': 'b:27017', 'avg_rtt_ms': 10000, 'type': 'RSSecondary', 'tag': {}}, {'address': 'c:27017', 'avg_rtt_ms': 20000, 'type': 'RSSecondary', 'tag': {}}, {'address': 'a:27017', 'avg_rtt_ms': 30000, 'type': 'RSSecondary', 'tag': {}}, ]}} # Create & populate Topology such that no server is writeable. seeds, hosts = get_addresses( scenario_def["topology_description"]["servers"]) settings = get_topology_settings_dict( heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector) topology = Topology(TopologySettings(**settings)) topology.open() for server in scenario_def['topology_description']['servers']: server_description = make_server_description(server, hosts) topology.on_change(server_description) # Invoke server selection and assert no calls to our custom selector. with self.assertRaisesRegex( ServerSelectionTimeoutError, 'No primary available for writes'): topology.select_server( writable_server_selector, server_selection_timeout=0.1) self.assertEqual(selector.call_count, 0) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_server_selection_rtt.py000066400000000000000000000040201374256237000221250ustar00rootroot00000000000000# Copyright 2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the topology module.""" import json import os import sys sys.path[0:0] = [""] from test import unittest from pymongo.read_preferences import MovingAverage # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'server_selection/rtt') class TestAllScenarios(unittest.TestCase): pass def create_test(scenario_def): def run_scenario(self): moving_average = MovingAverage() if scenario_def['avg_rtt_ms'] != "NULL": moving_average.add_sample(scenario_def['avg_rtt_ms']) if scenario_def['new_rtt_ms'] != "NULL": moving_average.add_sample(scenario_def['new_rtt_ms']) self.assertAlmostEqual(moving_average.get(), scenario_def['new_avg_rtt']) return run_scenario def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json.load(scenario_stream) # Construct test from scenario. new_test = create_test(scenario_def) test_name = 'test_%s_%s' % ( dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_session.py000066400000000000000000001521641374256237000173610ustar00rootroot00000000000000# Copyright 2017 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the client_session module.""" import copy import os import sys from bson import DBRef from bson.py3compat import StringIO from gridfs import GridFS, GridFSBucket from pymongo import ASCENDING, InsertOne, IndexModel, OFF, monitoring from pymongo.common import _MAX_END_SESSIONS from pymongo.errors import (ConfigurationError, InvalidOperation, OperationFailure) from pymongo.monotonic import time as _time from pymongo.read_concern import ReadConcern from test import IntegrationTest, client_context, db_user, db_pwd, unittest, SkipTest from test.utils import (ignore_deprecations, rs_or_single_client, EventListener, TestCreator, wait_until) from test.utils_spec_runner import SpecRunner # Ignore auth commands like saslStart, so we can assert lsid is in all commands. class SessionTestListener(EventListener): def started(self, event): if not event.command_name.startswith('sasl'): super(SessionTestListener, self).started(event) def succeeded(self, event): if not event.command_name.startswith('sasl'): super(SessionTestListener, self).succeeded(event) def failed(self, event): if not event.command_name.startswith('sasl'): super(SessionTestListener, self).failed(event) def first_command_started(self): assert len(self.results['started']) >= 1, ( "No command-started events") return self.results['started'][0] def session_ids(client): return [s.session_id for s in copy.copy(client._topology._session_pool)] class TestSession(IntegrationTest): @classmethod @client_context.require_sessions def setUpClass(cls): super(TestSession, cls).setUpClass() # Create a second client so we can make sure clients cannot share # sessions. cls.client2 = rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() @classmethod def tearDownClass(cls): monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) cls.client2.close() super(TestSession, cls).tearDownClass() def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener]) self.addCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = set(s['id'] for s in session_ids(self.client)) def tearDown(self): """All sessions used in the test must be returned to the pool.""" self.client.drop_database('pymongo_test') used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.results['started']: if 'lsid' in event.command: used_lsids.add(event.command['lsid']['id']) current_lsids = set(s['id'] for s in session_ids(self.client)) self.assertLessEqual(used_lsids, current_lsids) def _test_ops(self, client, *ops): listener = client.event_listeners()[0][0] for f, args, kw in ops: with client.start_session() as s: last_use = s._server_session.last_use start = _time() self.assertLessEqual(last_use, start) listener.results.clear() # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) kw['session'] = s f(*args, **kw) self.assertGreaterEqual(s._server_session.last_use, start) self.assertGreaterEqual(len(listener.results['started']), 1) for event in listener.results['started']: self.assertTrue( 'lsid' in event.command, "%s sent no lsid with %s" % ( f.__name__, event.command_name)) self.assertEqual( s.session_id, event.command['lsid'], "%s sent wrong lsid with %s" % ( f.__name__, event.command_name)) self.assertFalse(s.has_ended) self.assertTrue(s.has_ended) with self.assertRaisesRegex(InvalidOperation, "ended session"): f(*args, **kw) # Test a session cannot be used on another client. with self.client2.start_session() as s: # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) kw['session'] = s with self.assertRaisesRegex( InvalidOperation, 'Can only use session with the MongoClient' ' that started it'): f(*args, **kw) # No explicit session. for f, args, kw in ops: listener.results.clear() f(*args, **kw) self.assertGreaterEqual(len(listener.results['started']), 1) lsids = [] for event in listener.results['started']: self.assertTrue( 'lsid' in event.command, "%s sent no lsid with %s" % ( f.__name__, event.command_name)) lsids.append(event.command['lsid']) if not (sys.platform.startswith('java') or 'PyPy' in sys.version): # Server session was returned to pool. Ignore interpreters with # non-deterministic GC. for lsid in lsids: self.assertIn( lsid, session_ids(client), "%s did not return implicit session to pool" % ( f.__name__,)) def test_pool_lifo(self): # "Pool is LIFO" test from Driver Sessions Spec. a = self.client.start_session() b = self.client.start_session() a_id = a.session_id b_id = b.session_id a.end_session() b.end_session() s = self.client.start_session() self.assertEqual(b_id, s.session_id) self.assertNotEqual(a_id, s.session_id) s2 = self.client.start_session() self.assertEqual(a_id, s2.session_id) self.assertNotEqual(b_id, s2.session_id) s.end_session() s2.end_session() def test_end_session(self): # We test elsewhere that using an ended session throws InvalidOperation. client = self.client s = client.start_session() self.assertFalse(s.has_ended) self.assertIsNotNone(s.session_id) s.end_session() self.assertTrue(s.has_ended) with self.assertRaisesRegex(InvalidOperation, "ended session"): s.session_id def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() client = rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: s.end_session() # Closing the client should end all sessions and clear the pool. self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1) client.close() self.assertEqual(len(client._topology._session_pool), 0) end_sessions = [e for e in listener.results['started'] if e.command_name == 'endSessions'] self.assertEqual(len(end_sessions), 2) # Closing again should not send any commands. listener.results.clear() client.close() self.assertEqual(len(listener.results['started']), 0) @ignore_deprecations # fsync and unlock def test_client(self): client = self.client # Make sure if the test fails we unlock the server. def unlock(): try: client.unlock() except OperationFailure: pass self.addCleanup(unlock) ops = [ (client.server_info, [], {}), (client.database_names, [], {}), (client.drop_database, ['pymongo_test'], {}), ] if not client_context.is_mongos: ops.extend([ (client.fsync, [], {'lock': True}), (client.unlock, [], {}), ]) self._test_ops(client, *ops) def test_database(self): client = self.client db = client.pymongo_test ops = [ (db.command, ['ping'], {}), (db.create_collection, ['collection'], {}), (db.collection_names, [], {}), (db.list_collection_names, [], {}), (db.validate_collection, ['collection'], {}), (db.drop_collection, ['collection'], {}), (db.current_op, [], {}), (db.profiling_info, [], {}), (db.dereference, [DBRef('collection', 1)], {}), ] if not client_context.is_mongos: ops.append((db.set_profiling_level, [OFF], {})) ops.append((db.profiling_level, [], {})) self._test_ops(client, *ops) @client_context.require_auth @ignore_deprecations def test_user_admin(self): client = self.client db = client.pymongo_test self._test_ops( client, (db.add_user, ['session-test', 'pass'], {'roles': ['read']}), # Do it again to test updateUser command. (db.add_user, ['session-test', 'pass'], {'roles': ['read']}), (db.remove_user, ['session-test'], {})) @staticmethod def collection_write_ops(coll): """Generate database write ops for tests.""" return [ (coll.drop, [], {}), (coll.bulk_write, [[InsertOne({})]], {}), (coll.insert_one, [{}], {}), (coll.insert_many, [[{}, {}]], {}), (coll.replace_one, [{}, {}], {}), (coll.update_one, [{}, {'$set': {'a': 1}}], {}), (coll.update_many, [{}, {'$set': {'a': 1}}], {}), (coll.delete_one, [{}], {}), (coll.delete_many, [{}], {}), (coll.map_reduce, ['function() {}', 'function() {}', 'output'], {}), (coll.find_one_and_replace, [{}, {}], {}), (coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}), (coll.find_one_and_delete, [{}, {}], {}), (coll.rename, ['collection2'], {}), # Drop collection2 between tests of "rename", above. (coll.database.drop_collection, ['collection2'], {}), (coll.create_indexes, [[IndexModel('a')]], {}), (coll.create_index, ['a'], {}), (coll.drop_index, ['a_1'], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] def test_collection(self): client = self.client coll = client.pymongo_test.collection # Test some collection methods - the rest are in test_cursor. ops = self.collection_write_ops(coll) ops.extend([ (coll.distinct, ['a'], {}), (coll.find_one, [], {}), (coll.count, [], {}), (coll.count_documents, [{}], {}), (coll.inline_map_reduce, ['function() {}', 'function() {}'], {}), (coll.list_indexes, [], {}), (coll.index_information, [], {}), (coll.options, [], {}), (coll.aggregate, [[]], {}), ]) if client_context.supports_reindex: ops.append((coll.reindex, [], {})) self._test_ops(client, *ops) @client_context.require_no_mongos @client_context.require_version_max(4, 1, 0) @ignore_deprecations def test_parallel_collection_scan(self): listener = self.listener client = self.client coll = client.pymongo_test.collection coll.insert_many([{'_id': i} for i in range(1000)]) listener.results.clear() def scan(session=None): cursors = coll.parallel_scan(4, session=session) for c in cursors: c.batch_size(2) list(c) listener.results.clear() with client.start_session() as session: scan(session) cursor_lsids = {} for event in listener.results['started']: self.assertIn( 'lsid', event.command, "parallel_scan sent no lsid with %s" % (event.command_name, )) if event.command_name == 'getMore': cursor_id = event.command['getMore'] if cursor_id in cursor_lsids: self.assertEqual(cursor_lsids[cursor_id], event.command['lsid']) else: cursor_lsids[cursor_id] = event.command['lsid'] def test_cursor_clone(self): coll = self.client.pymongo_test.collection # Ensure some batches. coll.insert_many({} for _ in range(10)) self.addCleanup(coll.drop) with self.client.start_session() as s: cursor = coll.find(session=s) self.assertTrue(cursor.session is s) clone = cursor.clone() self.assertTrue(clone.session is s) # No explicit session. cursor = coll.find(batch_size=2) next(cursor) # Session is "owned" by cursor. self.assertIsNone(cursor.session) self.assertIsNotNone(cursor._Cursor__session) clone = cursor.clone() next(clone) self.assertIsNone(clone.session) self.assertIsNotNone(clone._Cursor__session) self.assertFalse(cursor._Cursor__session is clone._Cursor__session) cursor.close() clone.close() def test_cursor(self): listener = self.listener client = self.client coll = client.pymongo_test.collection coll.insert_many([{} for _ in range(1000)]) # Test all cursor methods. ops = [ ('find', lambda session: list(coll.find(session=session))), ('getitem', lambda session: coll.find(session=session)[0]), ('count', lambda session: coll.find(session=session).count()), ('distinct', lambda session: coll.find(session=session).distinct('a')), ('explain', lambda session: coll.find(session=session).explain()), ] for name, f in ops: with client.start_session() as s: listener.results.clear() f(session=s) self.assertGreaterEqual(len(listener.results['started']), 1) for event in listener.results['started']: self.assertTrue( 'lsid' in event.command, "%s sent no lsid with %s" % ( name, event.command_name)) self.assertEqual( s.session_id, event.command['lsid'], "%s sent wrong lsid with %s" % ( name, event.command_name)) with self.assertRaisesRegex(InvalidOperation, "ended session"): f(session=s) # No explicit session. for name, f in ops: listener.results.clear() f(session=None) event0 = listener.first_command_started() self.assertTrue( 'lsid' in event0.command, "%s sent no lsid with %s" % ( name, event0.command_name)) lsid = event0.command['lsid'] for event in listener.results['started'][1:]: self.assertTrue( 'lsid' in event.command, "%s sent no lsid with %s" % ( name, event.command_name)) self.assertEqual( lsid, event.command['lsid'], "%s sent wrong lsid with %s" % ( name, event.command_name)) def test_gridfs(self): client = self.client fs = GridFS(client.pymongo_test) def new_file(session=None): grid_file = fs.new_file(_id=1, filename='f', session=session) # 1 MB, 5 chunks, to test that each chunk is fetched with same lsid. grid_file.write(b'a' * 1048576) grid_file.close() def find(session=None): files = list(fs.find({'_id': 1}, session=session)) for f in files: f.read() self._test_ops( client, (new_file, [], {}), (fs.put, [b'data'], {}), (lambda session=None: fs.get(1, session=session).read(), [], {}), (lambda session=None: fs.get_version('f', session=session).read(), [], {}), (lambda session=None: fs.get_last_version('f', session=session).read(), [], {}), (fs.list, [], {}), (fs.find_one, [1], {}), (lambda session=None: list(fs.find(session=session)), [], {}), (fs.exists, [1], {}), (find, [], {}), (fs.delete, [1], {})) def test_gridfs_bucket(self): client = self.client bucket = GridFSBucket(client.pymongo_test) def upload(session=None): stream = bucket.open_upload_stream('f', session=session) stream.write(b'a' * 1048576) stream.close() def upload_with_id(session=None): stream = bucket.open_upload_stream_with_id(1, 'f1', session=session) stream.write(b'a' * 1048576) stream.close() def open_download_stream(session=None): stream = bucket.open_download_stream(1, session=session) stream.read() def open_download_stream_by_name(session=None): stream = bucket.open_download_stream_by_name('f', session=session) stream.read() def find(session=None): files = list(bucket.find({'_id': 1}, session=session)) for f in files: f.read() sio = StringIO() self._test_ops( client, (upload, [], {}), (upload_with_id, [], {}), (bucket.upload_from_stream, ['f', b'data'], {}), (bucket.upload_from_stream_with_id, [2, 'f', b'data'], {}), (open_download_stream, [], {}), (open_download_stream_by_name, [], {}), (bucket.download_to_stream, [1, sio], {}), (bucket.download_to_stream_by_name, ['f', sio], {}), (find, [], {}), (bucket.rename, [1, 'f2'], {}), # Delete both files so _test_ops can run these operations twice. (bucket.delete, [1], {}), (bucket.delete, [2], {})) def test_gridfsbucket_cursor(self): client = self.client bucket = GridFSBucket(client.pymongo_test) for file_id in 1, 2: stream = bucket.open_upload_stream_with_id(file_id, str(file_id)) stream.write(b'a' * 1048576) stream.close() with client.start_session() as s: cursor = bucket.find(session=s) for f in cursor: f.read() self.assertFalse(s.has_ended) self.assertTrue(s.has_ended) # No explicit session. cursor = bucket.find(batch_size=1) files = [cursor.next()] s = cursor._Cursor__session self.assertFalse(s.has_ended) cursor.__del__() self.assertTrue(s.has_ended) self.assertIsNone(cursor._Cursor__session) # Files are still valid, they use their own sessions. for f in files: f.read() # Explicit session. with client.start_session() as s: cursor = bucket.find(session=s) s = cursor.session files = list(cursor) cursor.__del__() self.assertFalse(s.has_ended) for f in files: f.read() for f in files: # Attempt to read the file again. f.seek(0) with self.assertRaisesRegex(InvalidOperation, "ended session"): f.read() def test_aggregate(self): client = self.client coll = client.pymongo_test.collection def agg(session=None): list(coll.aggregate( [], batchSize=2, session=session)) # With empty collection. self._test_ops(client, (agg, [], {})) # Now with documents. coll.insert_many([{} for _ in range(10)]) self.addCleanup(coll.drop) self._test_ops(client, (agg, [], {})) def test_killcursors(self): client = self.client coll = client.pymongo_test.collection coll.insert_many([{} for _ in range(10)]) def explicit_close(session=None): cursor = coll.find(batch_size=2, session=session) next(cursor) cursor.close() self._test_ops(client, (explicit_close, [], {})) def test_aggregate_error(self): listener = self.listener client = self.client coll = client.pymongo_test.collection # 3.6.0 mongos only validates the aggregate pipeline when the # database exists. coll.insert_one({}) listener.results.clear() with self.assertRaises(OperationFailure): coll.aggregate([{'$badOperation': {'bar': 1}}]) event = listener.first_command_started() self.assertEqual(event.command_name, 'aggregate') lsid = event.command['lsid'] # Session was returned to pool despite error. self.assertIn(lsid, session_ids(client)) def _test_cursor_helper(self, create_cursor, close_cursor): coll = self.client.pymongo_test.collection coll.insert_many([{} for _ in range(1000)]) cursor = create_cursor(coll, None) next(cursor) # Session is "owned" by cursor. session = getattr(cursor, '_%s__session' % cursor.__class__.__name__) self.assertIsNotNone(session) lsid = session.session_id next(cursor) # Cursor owns its session unto death. self.assertNotIn(lsid, session_ids(self.client)) close_cursor(cursor) self.assertIn(lsid, session_ids(self.client)) # An explicit session is not ended by cursor.close() or list(cursor). with self.client.start_session() as s: cursor = create_cursor(coll, s) next(cursor) close_cursor(cursor) self.assertFalse(s.has_ended) lsid = s.session_id self.assertTrue(s.has_ended) self.assertIn(lsid, session_ids(self.client)) def test_cursor_close(self): self._test_cursor_helper( lambda coll, session: coll.find(session=session), lambda cursor: cursor.close()) def test_command_cursor_close(self): self._test_cursor_helper( lambda coll, session: coll.aggregate([], session=session), lambda cursor: cursor.close()) def test_cursor_del(self): self._test_cursor_helper( lambda coll, session: coll.find(session=session), lambda cursor: cursor.__del__()) def test_command_cursor_del(self): self._test_cursor_helper( lambda coll, session: coll.aggregate([], session=session), lambda cursor: cursor.__del__()) def test_cursor_exhaust(self): self._test_cursor_helper( lambda coll, session: coll.find(session=session), lambda cursor: list(cursor)) def test_command_cursor_exhaust(self): self._test_cursor_helper( lambda coll, session: coll.aggregate([], session=session), lambda cursor: list(cursor)) def test_cursor_limit_reached(self): self._test_cursor_helper( lambda coll, session: coll.find(limit=4, batch_size=2, session=session), lambda cursor: list(cursor)) def test_command_cursor_limit_reached(self): self._test_cursor_helper( lambda coll, session: coll.aggregate([], batchSize=900, session=session), lambda cursor: list(cursor)) def _test_unacknowledged_ops(self, client, *ops): listener = client.event_listeners()[0][0] for f, args, kw in ops: with client.start_session() as s: listener.results.clear() # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) kw['session'] = s with self.assertRaises( ConfigurationError, msg="%s did not raise ConfigurationError" % ( f.__name__,)): f(*args, **kw) if f.__name__ == 'create_collection': # create_collection runs listCollections first. event = listener.results['started'].pop(0) self.assertEqual('listCollections', event.command_name) self.assertIn('lsid', event.command, "%s sent no lsid with %s" % ( f.__name__, event.command_name)) # Should not run any command before raising an error. self.assertFalse(listener.results['started'], "%s sent command" % (f.__name__,)) self.assertTrue(s.has_ended) # Unacknowledged write without a session does not send an lsid. for f, args, kw in ops: listener.results.clear() f(*args, **kw) self.assertGreaterEqual(len(listener.results['started']), 1) if f.__name__ == 'create_collection': # create_collection runs listCollections first. event = listener.results['started'].pop(0) self.assertEqual('listCollections', event.command_name) self.assertIn('lsid', event.command, "%s sent no lsid with %s" % ( f.__name__, event.command_name)) for event in listener.results['started']: self.assertNotIn('lsid', event.command, "%s sent lsid with %s" % ( f.__name__, event.command_name)) def test_unacknowledged_writes(self): # Ensure the collection exists. self.client.pymongo_test.test_unacked_writes.insert_one({}) client = rs_or_single_client(w=0, event_listeners=[self.listener]) self.addCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes ops = [ (client.drop_database, [db.name], {}), (db.create_collection, ['collection'], {}), (db.drop_collection, ['collection'], {}), ] ops.extend(self.collection_write_ops(coll)) self._test_unacknowledged_ops(client, *ops) def drop_db(): try: self.client.drop_database(db.name) return True except OperationFailure as exc: # Try again on BackgroundOperationInProgressForDatabase and # BackgroundOperationInProgressForNamespace. if exc.code in (12586, 12587): return False raise wait_until(drop_db, 'dropped database after w=0 writes') class TestCausalConsistency(unittest.TestCase): @classmethod def setUpClass(cls): cls.listener = SessionTestListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): cls.client.close() @client_context.require_sessions def setUp(self): super(TestCausalConsistency, self).setUp() @client_context.require_no_standalone def test_core(self): with self.client.start_session() as sess: self.assertIsNone(sess.cluster_time) self.assertIsNone(sess.operation_time) self.listener.results.clear() self.client.pymongo_test.test.find_one(session=sess) started = self.listener.results['started'][0] cmd = started.command self.assertIsNone(cmd.get('readConcern')) op_time = sess.operation_time self.assertIsNotNone(op_time) succeeded = self.listener.results['succeeded'][0] reply = succeeded.reply self.assertEqual(op_time, reply.get('operationTime')) # No explicit session self.client.pymongo_test.test.insert_one({}) self.assertEqual(sess.operation_time, op_time) self.listener.results.clear() try: self.client.pymongo_test.command('doesntexist', session=sess) except: pass failed = self.listener.results['failed'][0] failed_op_time = failed.failure.get('operationTime') # Some older builds of MongoDB 3.5 / 3.6 return None for # operationTime when a command fails. Make sure we don't # change operation_time to None. if failed_op_time is None: self.assertIsNotNone(sess.operation_time) else: self.assertEqual( sess.operation_time, failed_op_time) with self.client.start_session() as sess2: self.assertIsNone(sess2.cluster_time) self.assertIsNone(sess2.operation_time) self.assertRaises(TypeError, sess2.advance_cluster_time, 1) self.assertRaises(ValueError, sess2.advance_cluster_time, {}) self.assertRaises(TypeError, sess2.advance_operation_time, 1) # No error sess2.advance_cluster_time(sess.cluster_time) sess2.advance_operation_time(sess.operation_time) self.assertEqual(sess.cluster_time, sess2.cluster_time) self.assertEqual(sess.operation_time, sess2.operation_time) def _test_reads(self, op, exception=None): coll = self.client.pymongo_test.test with self.client.start_session() as sess: coll.find_one({}, session=sess) operation_time = sess.operation_time self.assertIsNotNone(operation_time) self.listener.results.clear() if exception: with self.assertRaises(exception): op(coll, sess) else: op(coll, sess) act = self.listener.results['started'][0].command.get( 'readConcern', {}).get('afterClusterTime') self.assertEqual(operation_time, act) @client_context.require_no_standalone def test_reads(self): # Make sure the collection exists. self.client.pymongo_test.test.insert_one({}) self._test_reads( lambda coll, session: list(coll.aggregate([], session=session))) self._test_reads( lambda coll, session: list(coll.find({}, session=session))) self._test_reads( lambda coll, session: coll.find_one({}, session=session)) self._test_reads( lambda coll, session: coll.count(session=session)) self._test_reads( lambda coll, session: coll.count_documents({}, session=session)) self._test_reads( lambda coll, session: coll.distinct('foo', session=session)) # SERVER-40938 removed support for casually consistent mapReduce. map_reduce_exc = None if client_context.version.at_least(4, 1, 12): map_reduce_exc = OperationFailure # SERVER-44635 The mapReduce in aggregation project added back # support for casually consistent mapReduce. if client_context.version < (4, 3): self._test_reads( lambda coll, session: coll.map_reduce( 'function() {}', 'function() {}', 'inline', session=session), exception=map_reduce_exc) self._test_reads( lambda coll, session: coll.inline_map_reduce( 'function() {}', 'function() {}', session=session), exception=map_reduce_exc) if (not client_context.is_mongos and not client_context.version.at_least(4, 1, 0)): def scan(coll, session): cursors = coll.parallel_scan(1, session=session) for cur in cursors: list(cur) self._test_reads( lambda coll, session: scan(coll, session=session)) self.assertRaises( ConfigurationError, self._test_reads, lambda coll, session: list( coll.aggregate_raw_batches([], session=session))) self.assertRaises( ConfigurationError, self._test_reads, lambda coll, session: list( coll.find_raw_batches({}, session=session))) self.assertRaises( ConfigurationError, self._test_reads, lambda coll, session: coll.estimated_document_count( session=session)) def _test_writes(self, op): coll = self.client.pymongo_test.test with self.client.start_session() as sess: op(coll, sess) operation_time = sess.operation_time self.assertIsNotNone(operation_time) self.listener.results.clear() coll.find_one({}, session=sess) act = self.listener.results['started'][0].command.get( 'readConcern', {}).get('afterClusterTime') self.assertEqual(operation_time, act) @client_context.require_no_standalone def test_writes(self): self._test_writes( lambda coll, session: coll.bulk_write( [InsertOne({})], session=session)) self._test_writes( lambda coll, session: coll.insert_one({}, session=session)) self._test_writes( lambda coll, session: coll.insert_many([{}], session=session)) self._test_writes( lambda coll, session: coll.replace_one( {'_id': 1}, {'x': 1}, session=session)) self._test_writes( lambda coll, session: coll.update_one( {}, {'$set': {'X': 1}}, session=session)) self._test_writes( lambda coll, session: coll.update_many( {}, {'$set': {'x': 1}}, session=session)) self._test_writes( lambda coll, session: coll.delete_one({}, session=session)) self._test_writes( lambda coll, session: coll.delete_many({}, session=session)) self._test_writes( lambda coll, session: coll.find_one_and_replace( {'x': 1}, {'y': 1}, session=session)) self._test_writes( lambda coll, session: coll.find_one_and_update( {'y': 1}, {'$set': {'x': 1}}, session=session)) self._test_writes( lambda coll, session: coll.find_one_and_delete( {'x': 1}, session=session)) self._test_writes( lambda coll, session: coll.create_index("foo", session=session)) self._test_writes( lambda coll, session: coll.create_indexes( [IndexModel([("bar", ASCENDING)])], session=session)) self._test_writes( lambda coll, session: coll.drop_index("foo_1", session=session)) self._test_writes( lambda coll, session: coll.drop_indexes(session=session)) if client_context.supports_reindex: self._test_writes( lambda coll, session: coll.reindex(session=session)) def _test_no_read_concern(self, op): coll = self.client.pymongo_test.test with self.client.start_session() as sess: coll.find_one({}, session=sess) operation_time = sess.operation_time self.assertIsNotNone(operation_time) self.listener.results.clear() op(coll, sess) rc = self.listener.results['started'][0].command.get( 'readConcern') self.assertIsNone(rc) @client_context.require_no_standalone def test_writes_do_not_include_read_concern(self): self._test_no_read_concern( lambda coll, session: coll.bulk_write( [InsertOne({})], session=session)) self._test_no_read_concern( lambda coll, session: coll.insert_one({}, session=session)) self._test_no_read_concern( lambda coll, session: coll.insert_many([{}], session=session)) self._test_no_read_concern( lambda coll, session: coll.replace_one( {'_id': 1}, {'x': 1}, session=session)) self._test_no_read_concern( lambda coll, session: coll.update_one( {}, {'$set': {'X': 1}}, session=session)) self._test_no_read_concern( lambda coll, session: coll.update_many( {}, {'$set': {'x': 1}}, session=session)) self._test_no_read_concern( lambda coll, session: coll.delete_one({}, session=session)) self._test_no_read_concern( lambda coll, session: coll.delete_many({}, session=session)) self._test_no_read_concern( lambda coll, session: coll.find_one_and_replace( {'x': 1}, {'y': 1}, session=session)) self._test_no_read_concern( lambda coll, session: coll.find_one_and_update( {'y': 1}, {'$set': {'x': 1}}, session=session)) self._test_no_read_concern( lambda coll, session: coll.find_one_and_delete( {'x': 1}, session=session)) self._test_no_read_concern( lambda coll, session: coll.create_index("foo", session=session)) self._test_no_read_concern( lambda coll, session: coll.create_indexes( [IndexModel([("bar", ASCENDING)])], session=session)) self._test_no_read_concern( lambda coll, session: coll.drop_index("foo_1", session=session)) self._test_no_read_concern( lambda coll, session: coll.drop_indexes(session=session)) self._test_no_read_concern( lambda coll, session: coll.map_reduce( 'function() {}', 'function() {}', 'mrout', session=session)) # They are not writes, but currentOp and explain also don't support # readConcern. self._test_no_read_concern( lambda coll, session: coll.database.current_op(session=session)) self._test_no_read_concern( lambda coll, session: coll.find({}, session=session).explain()) if client_context.supports_reindex: self._test_no_read_concern( lambda coll, session: coll.reindex(session=session)) @client_context.require_no_standalone @client_context.require_version_max(4, 1, 0) def test_aggregate_out_does_not_include_read_concern(self): self._test_no_read_concern( lambda coll, session: list( coll.aggregate([{"$out": "aggout"}], session=session))) @client_context.require_no_standalone def test_get_more_does_not_include_read_concern(self): coll = self.client.pymongo_test.test with self.client.start_session() as sess: coll.find_one({}, session=sess) operation_time = sess.operation_time self.assertIsNotNone(operation_time) coll.insert_many([{}, {}]) cursor = coll.find({}).batch_size(1) next(cursor) self.listener.results.clear() list(cursor) started = self.listener.results['started'][0] self.assertEqual(started.command_name, 'getMore') self.assertIsNone(started.command.get('readConcern')) def test_session_not_causal(self): with self.client.start_session(causal_consistency=False) as s: self.client.pymongo_test.test.insert_one({}, session=s) self.listener.results.clear() self.client.pymongo_test.test.find_one({}, session=s) act = self.listener.results['started'][0].command.get( 'readConcern', {}).get('afterClusterTime') self.assertIsNone(act) @client_context.require_standalone def test_server_not_causal(self): with self.client.start_session(causal_consistency=True) as s: self.client.pymongo_test.test.insert_one({}, session=s) self.listener.results.clear() self.client.pymongo_test.test.find_one({}, session=s) act = self.listener.results['started'][0].command.get( 'readConcern', {}).get('afterClusterTime') self.assertIsNone(act) @client_context.require_no_standalone @client_context.require_no_mmap def test_read_concern(self): with self.client.start_session(causal_consistency=True) as s: coll = self.client.pymongo_test.test coll.insert_one({}, session=s) self.listener.results.clear() coll.find_one({}, session=s) read_concern = self.listener.results['started'][0].command.get( 'readConcern') self.assertIsNotNone(read_concern) self.assertIsNone(read_concern.get('level')) self.assertIsNotNone(read_concern.get('afterClusterTime')) coll = coll.with_options(read_concern=ReadConcern("majority")) self.listener.results.clear() coll.find_one({}, session=s) read_concern = self.listener.results['started'][0].command.get( 'readConcern') self.assertIsNotNone(read_concern) self.assertEqual(read_concern.get('level'), 'majority') self.assertIsNotNone(read_concern.get('afterClusterTime')) @client_context.require_no_standalone def test_cluster_time_with_server_support(self): self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() self.client.pymongo_test.test.find_one({}) after_cluster_time = self.listener.results['started'][0].command.get( '$clusterTime') self.assertIsNotNone(after_cluster_time) @client_context.require_standalone def test_cluster_time_no_server_support(self): self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() self.client.pymongo_test.test.find_one({}) after_cluster_time = self.listener.results['started'][0].command.get( '$clusterTime') self.assertIsNone(after_cluster_time) class TestSessionsMultiAuth(IntegrationTest): @client_context.require_auth @client_context.require_sessions def setUp(self): super(TestSessionsMultiAuth, self).setUp() client_context.create_user( 'pymongo_test', 'second-user', 'pass', roles=['readWrite']) self.addCleanup(client_context.drop_user, 'pymongo_test','second-user') @ignore_deprecations def test_session_authenticate_multiple(self): listener = SessionTestListener() # Logged in as root. client = rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test db.authenticate('second-user', 'pass') with self.assertRaises(InvalidOperation): client.start_session() # No implicit sessions. listener.results.clear() db.collection.find_one() event = listener.first_command_started() self.assertNotIn( 'lsid', event.command, "find_one with multi-auth shouldn't have sent lsid with %s" % ( event.command_name)) @ignore_deprecations def test_explicit_session_logout(self): listener = SessionTestListener() # Changing auth invalidates the session. Start as root. client = rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test db.collection.insert_many([{} for _ in range(10)]) self.addCleanup(db.collection.drop) with client.start_session() as s: listener.results.clear() cursor = db.collection.find(session=s).batch_size(2) next(cursor) event = listener.first_command_started() self.assertEqual(event.command_name, 'find') self.assertEqual( s.session_id, event.command.get('lsid'), "find() sent wrong lsid with %s cmd" % (event.command_name,)) client.admin.logout() db.authenticate('second-user', 'pass') err = ('Cannot use session after authenticating with different' ' credentials') with self.assertRaisesRegex(InvalidOperation, err): # Auth has changed between find and getMore. list(cursor) with self.assertRaisesRegex(InvalidOperation, err): db.collection.bulk_write([InsertOne({})], session=s) with self.assertRaisesRegex(InvalidOperation, err): db.collection_names(session=s) with self.assertRaisesRegex(InvalidOperation, err): db.collection.find_one(session=s) with self.assertRaisesRegex(InvalidOperation, err): list(db.collection.aggregate([], session=s)) @ignore_deprecations def test_implicit_session_logout(self): listener = SessionTestListener() # Changing auth doesn't invalidate the session. Start as root. client = rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test for name, f in [ ('bulk_write', lambda: db.collection.bulk_write([InsertOne({})])), ('collection_names', db.collection_names), ('find_one', db.collection.find_one), ('aggregate', lambda: list(db.collection.aggregate([]))) ]: def sub_test(): listener.results.clear() f() for event in listener.results['started']: self.assertIn( 'lsid', event.command, "%s sent no lsid with %s" % ( name, event.command_name)) # We switch auth without clearing the pool of session ids. The # server considers these to be new sessions since it's a new user. # The old sessions time out on the server after 30 minutes. client.admin.logout() db.authenticate('second-user', 'pass') sub_test() db.logout() client.admin.authenticate(db_user, db_pwd) sub_test() class TestSessionsNotSupported(IntegrationTest): @client_context.require_version_max(3, 5, 10) def test_sessions_not_supported(self): with self.assertRaisesRegex( ConfigurationError, "Sessions are not supported"): self.client.start_session() class TestClusterTime(IntegrationTest): def setUp(self): super(TestClusterTime, self).setUp() if '$clusterTime' not in client_context.ismaster: raise SkipTest('$clusterTime not supported') @ignore_deprecations def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) self.addCleanup(collection.drop) self.addCleanup(client.pymongo_test.collection2.drop) def bulk_insert(ordered): if ordered: bulk = collection.initialize_ordered_bulk_op() else: bulk = collection.initialize_unordered_bulk_op() bulk.insert({}) bulk.execute() def rename_and_drop(): # Ensure collection exists. collection.insert_one({}) collection.rename('collection2') client.pymongo_test.collection2.drop() def insert_and_find(): cursor = collection.find().batch_size(1) for _ in range(10): # Advance the cluster time. collection.insert_one({}) next(cursor) cursor.close() def insert_and_aggregate(): cursor = collection.aggregate([], batchSize=1).batch_size(1) for _ in range(5): # Advance the cluster time. collection.insert_one({}) next(cursor) cursor.close() ops = [ # Tests from Driver Sessions Spec. ('ping', lambda: client.admin.command('ping')), ('aggregate', lambda: list(collection.aggregate([]))), ('find', lambda: list(collection.find())), ('insert_one', lambda: collection.insert_one({})), # Additional PyMongo tests. ('insert_and_find', insert_and_find), ('insert_and_aggregate', insert_and_aggregate), ('update_one', lambda: collection.update_one({}, {'$set': {'x': 1}})), ('update_many', lambda: collection.update_many({}, {'$set': {'x': 1}})), ('delete_one', lambda: collection.delete_one({})), ('delete_many', lambda: collection.delete_many({})), ('bulk_write', lambda: collection.bulk_write([InsertOne({})])), ('ordered bulk', lambda: bulk_insert(True)), ('unordered bulk', lambda: bulk_insert(False)), ('rename_and_drop', rename_and_drop), ] for name, f in ops: listener.results.clear() # Call f() twice, insert to advance clusterTime, call f() again. f() f() collection.insert_one({}) f() self.assertGreaterEqual(len(listener.results['started']), 1) for i, event in enumerate(listener.results['started']): self.assertTrue( '$clusterTime' in event.command, "%s sent no $clusterTime with %s" % ( f.__name__, event.command_name)) if i > 0: succeeded = listener.results['succeeded'][i - 1] self.assertTrue( '$clusterTime' in succeeded.reply, "%s received no $clusterTime with %s" % ( f.__name__, succeeded.command_name)) self.assertTrue( event.command['$clusterTime']['clusterTime'] >= succeeded.reply['$clusterTime']['clusterTime'], "%s sent wrong $clusterTime with %s" % ( f.__name__, event.command_name)) class TestSpec(SpecRunner): # Location of JSON test specifications. TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'sessions') def last_two_command_events(self): """Return the last two command started events.""" started_events = self.listener.results['started'][-2:] self.assertEqual(2, len(started_events)) return started_events def assert_same_lsid_on_last_two_commands(self): """Run the assertSameLsidOnLastTwoCommands test operation.""" event1, event2 = self.last_two_command_events() self.assertEqual(event1.command['lsid'], event2.command['lsid']) def assert_different_lsid_on_last_two_commands(self): """Run the assertDifferentLsidOnLastTwoCommands test operation.""" event1, event2 = self.last_two_command_events() self.assertNotEqual(event1.command['lsid'], event2.command['lsid']) def assert_session_dirty(self, session): """Run the assertSessionDirty test operation. Assert that the given session is dirty. """ self.assertIsNotNone(session._server_session) self.assertTrue(session._server_session.dirty) def assert_session_not_dirty(self, session): """Run the assertSessionNotDirty test operation. Assert that the given session is not dirty. """ self.assertIsNotNone(session._server_session) self.assertFalse(session._server_session.dirty) def create_test(scenario_def, test, name): @client_context.require_test_commands def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestSpec, TestSpec.TEST_PATH) test_creator.create_tests() pymongo-3.11.0/test/test_son.py000066400000000000000000000157121374256237000164720ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the son module.""" import copy import pickle import re import sys sys.path[0:0] = [""] from bson.py3compat import b from bson.son import SON from test import SkipTest, unittest class TestSON(unittest.TestCase): def test_ordered_dict(self): a1 = SON() a1["hello"] = "world" a1["mike"] = "awesome" a1["hello_"] = "mike" self.assertEqual(list(a1.items()), [("hello", "world"), ("mike", "awesome"), ("hello_", "mike")]) b2 = SON({"hello": "world"}) self.assertEqual(b2["hello"], "world") self.assertRaises(KeyError, lambda: b2["goodbye"]) def test_equality(self): a1 = SON({"hello": "world"}) b2 = SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike'))) self.assertEqual(a1, SON({"hello": "world"})) self.assertEqual(b2, SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike')))) self.assertEqual(b2, dict((('hello_', 'mike'), ('mike', 'awesome'), ('hello', 'world')))) self.assertNotEqual(a1, b2) self.assertNotEqual(b2, SON((('hello_', 'mike'), ('mike', 'awesome'), ('hello', 'world')))) # Explicitly test inequality self.assertFalse(a1 != SON({"hello": "world"})) self.assertFalse(b2 != SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike')))) self.assertFalse(b2 != dict((('hello_', 'mike'), ('mike', 'awesome'), ('hello', 'world')))) # Embedded SON. d4 = SON([('blah', {'foo': SON()})]) self.assertEqual(d4, {'blah': {'foo': {}}}) self.assertEqual(d4, {'blah': {'foo': SON()}}) self.assertNotEqual(d4, {'blah': {'foo': []}}) # Original data unaffected. self.assertEqual(SON, d4['blah']['foo'].__class__) def test_to_dict(self): a1 = SON() b2 = SON([("blah", SON())]) c3 = SON([("blah", [SON()])]) d4 = SON([("blah", {"foo": SON()})]) self.assertEqual({}, a1.to_dict()) self.assertEqual({"blah": {}}, b2.to_dict()) self.assertEqual({"blah": [{}]}, c3.to_dict()) self.assertEqual({"blah": {"foo": {}}}, d4.to_dict()) self.assertEqual(dict, a1.to_dict().__class__) self.assertEqual(dict, b2.to_dict()["blah"].__class__) self.assertEqual(dict, c3.to_dict()["blah"][0].__class__) self.assertEqual(dict, d4.to_dict()["blah"]["foo"].__class__) # Original data unaffected. self.assertEqual(SON, d4['blah']['foo'].__class__) def test_pickle(self): simple_son = SON([]) complex_son = SON([('son', simple_son), ('list', [simple_son, simple_son])]) for protocol in range(pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.loads(pickle.dumps(complex_son, protocol=protocol)) self.assertEqual(pickled['son'], pickled['list'][0]) self.assertEqual(pickled['son'], pickled['list'][1]) def test_pickle_backwards_compatability(self): # This string was generated by pickling a SON object in pymongo # version 2.1.1 pickled_with_2_1_1 = b( "ccopy_reg\n_reconstructor\np0\n(cbson.son\nSON\np1\n" "c__builtin__\ndict\np2\n(dp3\ntp4\nRp5\n(dp6\n" "S'_SON__keys'\np7\n(lp8\nsb." ) son_2_1_1 = pickle.loads(pickled_with_2_1_1) self.assertEqual(son_2_1_1, SON([])) def test_copying(self): simple_son = SON([]) complex_son = SON([('son', simple_son), ('list', [simple_son, simple_son])]) regex_son = SON([("x", re.compile("^hello.*"))]) reflexive_son = SON([('son', simple_son)]) reflexive_son["reflexive"] = reflexive_son simple_son1 = copy.copy(simple_son) self.assertEqual(simple_son, simple_son1) complex_son1 = copy.copy(complex_son) self.assertEqual(complex_son, complex_son1) regex_son1 = copy.copy(regex_son) self.assertEqual(regex_son, regex_son1) reflexive_son1 = copy.copy(reflexive_son) self.assertEqual(reflexive_son, reflexive_son1) # Test deepcopying simple_son1 = copy.deepcopy(simple_son) self.assertEqual(simple_son, simple_son1) regex_son1 = copy.deepcopy(regex_son) self.assertEqual(regex_son, regex_son1) complex_son1 = copy.deepcopy(complex_son) self.assertEqual(complex_son, complex_son1) reflexive_son1 = copy.deepcopy(reflexive_son) self.assertEqual(list(reflexive_son), list(reflexive_son1)) self.assertEqual(id(reflexive_son1), id(reflexive_son1["reflexive"])) def test_iteration(self): """ Test __iter__ """ # test success case test_son = SON([(1, 100), (2, 200), (3, 300)]) for ele in test_son: self.assertEqual(ele * 100, test_son[ele]) def test_contains_has(self): """ has_key and __contains__ """ test_son = SON([(1, 100), (2, 200), (3, 300)]) self.assertIn(1, test_son) self.assertTrue(2 in test_son, "in failed") self.assertFalse(22 in test_son, "in succeeded when it shouldn't") self.assertTrue(test_son.has_key(2), "has_key failed") self.assertFalse(test_son.has_key(22), "has_key succeeded when it shouldn't") def test_clears(self): """ Test clear() """ test_son = SON([(1, 100), (2, 200), (3, 300)]) test_son.clear() self.assertNotIn(1, test_son) self.assertEqual(0, len(test_son)) self.assertEqual(0, len(test_son.keys())) self.assertEqual({}, test_son.to_dict()) def test_len(self): """ Test len """ test_son = SON() self.assertEqual(0, len(test_son)) test_son = SON([(1, 100), (2, 200), (3, 300)]) self.assertEqual(3, len(test_son)) test_son.popitem() self.assertEqual(2, len(test_son)) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_son_manipulator.py000066400000000000000000000104621374256237000211020ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for SONManipulators. """ import sys import warnings sys.path[0:0] = [""] from bson.son import SON from pymongo import MongoClient from pymongo.son_manipulator import (NamespaceInjector, ObjectIdInjector, ObjectIdShuffler, SONManipulator) from test import client_context, qcheck, unittest class TestSONManipulator(unittest.TestCase): @classmethod def setUpClass(cls): cls.warn_context = warnings.catch_warnings() cls.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) client = MongoClient( client_context.host, client_context.port, connect=False) cls.db = client.pymongo_test @classmethod def tearDownClass(cls): cls.warn_context.__exit__() cls.warn_context = None def test_basic(self): manip = SONManipulator() collection = self.db.test def incoming_is_identity(son): return son == manip.transform_incoming(son, collection) qcheck.check_unittest(self, incoming_is_identity, qcheck.gen_mongo_dict(3)) def outgoing_is_identity(son): return son == manip.transform_outgoing(son, collection) qcheck.check_unittest(self, outgoing_is_identity, qcheck.gen_mongo_dict(3)) def test_id_injection(self): manip = ObjectIdInjector() collection = self.db.test def incoming_adds_id(son): son = manip.transform_incoming(son, collection) assert "_id" in son return True qcheck.check_unittest(self, incoming_adds_id, qcheck.gen_mongo_dict(3)) def outgoing_is_identity(son): return son == manip.transform_outgoing(son, collection) qcheck.check_unittest(self, outgoing_is_identity, qcheck.gen_mongo_dict(3)) def test_id_shuffling(self): manip = ObjectIdShuffler() collection = self.db.test def incoming_moves_id(son_in): son = manip.transform_incoming(son_in, collection) if not "_id" in son: return True for (k, v) in son.items(): self.assertEqual(k, "_id") break # Key order matters in SON equality test, # matching collections.OrderedDict if isinstance(son_in, SON): return son_in.to_dict() == son.to_dict() return son_in == son self.assertTrue(incoming_moves_id({})) self.assertTrue(incoming_moves_id({"_id": 12})) self.assertTrue(incoming_moves_id({"hello": "world", "_id": 12})) self.assertTrue(incoming_moves_id(SON([("hello", "world"), ("_id", 12)]))) def outgoing_is_identity(son): return son == manip.transform_outgoing(son, collection) qcheck.check_unittest(self, outgoing_is_identity, qcheck.gen_mongo_dict(3)) def test_ns_injection(self): manip = NamespaceInjector() collection = self.db.test def incoming_adds_ns(son): son = manip.transform_incoming(son, collection) assert "_ns" in son return son["_ns"] == collection.name qcheck.check_unittest(self, incoming_adds_ns, qcheck.gen_mongo_dict(3)) def outgoing_is_identity(son): return son == manip.transform_outgoing(son, collection) qcheck.check_unittest(self, outgoing_is_identity, qcheck.gen_mongo_dict(3)) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_srv_polling.py000066400000000000000000000212641374256237000202300ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the SRV support tests.""" import sys from time import sleep sys.path[0:0] = [""] import pymongo from pymongo import common from pymongo.errors import ConfigurationError from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.mongo_client import MongoClient from test import client_knobs, unittest from test.utils import wait_until, FunctionCallRecorder WAIT_TIME = 0.1 class SrvPollingKnobs(object): def __init__(self, ttl_time=None, min_srv_rescan_interval=None, dns_resolver_nodelist_response=None, count_resolver_calls=False): self.ttl_time = ttl_time self.min_srv_rescan_interval = min_srv_rescan_interval self.dns_resolver_nodelist_response = dns_resolver_nodelist_response self.count_resolver_calls = count_resolver_calls self.old_min_srv_rescan_interval = None self.old_dns_resolver_response = None def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL self.old_dns_resolver_response = \ pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval def mock_get_hosts_and_min_ttl(resolver, *args): nodes, ttl = self.old_dns_resolver_response(resolver) if self.dns_resolver_nodelist_response is not None: nodes = self.dns_resolver_nodelist_response() if self.ttl_time is not None: ttl = self.ttl_time return nodes, ttl if self.count_resolver_calls: patch_func = FunctionCallRecorder(mock_get_hosts_and_min_ttl) else: patch_func = mock_get_hosts_and_min_ttl pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \ self.old_dns_resolver_response def __exit__(self, exc_type, exc_val, exc_tb): self.disable() class TestSrvPolling(unittest.TestCase): BASE_SRV_RESPONSE = [ ("localhost.test.build.10gen.cc", 27017), ("localhost.test.build.10gen.cc", 27018)] CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" def setUp(self): if not _HAVE_DNSPYTHON: raise unittest.SkipTest("SRV polling tests require the dnspython " "module") # Patch timeouts to ensure short rescan SRV interval. self.client_knobs = client_knobs( heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME, events_queue_frequency=WAIT_TIME) self.client_knobs.enable() def tearDown(self): self.client_knobs.disable() def get_nodelist(self, client): return client._topology.description.server_descriptions().keys() def assert_nodelist_change(self, expected_nodelist, client): """Check if the client._topology eventually sees all nodes in the expected_nodelist. """ def predicate(): nodelist = self.get_nodelist(client) if set(expected_nodelist) == set(nodelist): return True return False wait_until(predicate, "see expected nodelist", timeout=100*WAIT_TIME) def assert_nodelist_nochange(self, expected_nodelist, client): """Check if the client._topology ever deviates from seeing all nodes in the expected_nodelist. Consistency is checked after sleeping for (WAIT_TIME * 10) seconds. Also check that the resolver is called at least once. """ sleep(WAIT_TIME*10) nodelist = self.get_nodelist(client) if set(expected_nodelist) != set(nodelist): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, 1, "resolver was never called") return True def run_scenario(self, dns_response, expect_change): if callable(dns_response): dns_resolver_response = dns_response else: def dns_resolver_response(): return dns_response if expect_change: assertion_method = self.assert_nodelist_change count_resolver_calls = False expected_response = dns_response else: assertion_method = self.assert_nodelist_nochange count_resolver_calls = True expected_response = self.BASE_SRV_RESPONSE # Patch timeouts to ensure short test running times. with SrvPollingKnobs( ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING) self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( dns_resolver_nodelist_response=dns_resolver_response, count_resolver_calls=count_resolver_calls): assertion_method(expected_response, client) def test_addition(self): response = self.BASE_SRV_RESPONSE[:] response.append( ("localhost.test.build.10gen.cc", 27019)) self.run_scenario(response, True) def test_removal(self): response = self.BASE_SRV_RESPONSE[:] response.remove( ("localhost.test.build.10gen.cc", 27018)) self.run_scenario(response, True) def test_replace_one(self): response = self.BASE_SRV_RESPONSE[:] response.remove( ("localhost.test.build.10gen.cc", 27018)) response.append( ("localhost.test.build.10gen.cc", 27019)) self.run_scenario(response, True) def test_replace_both_with_one(self): response = [("localhost.test.build.10gen.cc", 27019)] self.run_scenario(response, True) def test_replace_both_with_two(self): response = [("localhost.test.build.10gen.cc", 27019), ("localhost.test.build.10gen.cc", 27020)] self.run_scenario(response, True) def test_dns_failures(self): from dns import exception for exc in (exception.FormError, exception.TooBig, exception.Timeout): def response_callback(*args): raise exc("DNS Failure!") self.run_scenario(response_callback, False) def test_dns_record_lookup_empty(self): response = [] self.run_scenario(response, False) def _test_recover_from_initial(self, initial_callback): # Construct a valid final response callback distinct from base. response_final = self.BASE_SRV_RESPONSE[:] response_final.pop() def final_callback(): return response_final with SrvPollingKnobs( ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, dns_resolver_nodelist_response=initial_callback, count_resolver_calls=True): # Client uses unpatched method to get initial nodelist client = MongoClient(self.CONNECTION_STRING) # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) with SrvPollingKnobs( ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, dns_resolver_nodelist_response=final_callback): # Nodelist should reflect new valid DNS resolver response. self.assert_nodelist_change(response_final, client) def test_recover_from_initially_empty_seedlist(self): def empty_seedlist(): return [] self._test_recover_from_initial(empty_seedlist) def test_recover_from_initially_erroring_seedlist(self): def erroring_seedlist(): raise ConfigurationError self._test_recover_from_initial(erroring_seedlist) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/test_ssl.py000066400000000000000000000702121374256237000164700ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for SSL support.""" import os import socket import sys sys.path[0:0] = [""] try: from urllib.parse import quote_plus except ImportError: # Python 2 from urllib import quote_plus from pymongo import MongoClient, ssl_support from pymongo.errors import (ConfigurationError, ConnectionFailure, OperationFailure) from pymongo.ssl_support import HAVE_SSL, get_ssl_context, validate_cert_reqs, _ssl from pymongo.write_concern import WriteConcern from test import (IntegrationTest, client_context, db_pwd, db_user, SkipTest, unittest, HAVE_IPADDRESS) from test.utils import (EventListener, cat_files, connected, remove_all_users) _HAVE_PYOPENSSL = False try: # All of these must be available to use PyOpenSSL import OpenSSL import requests import service_identity _HAVE_PYOPENSSL = True except ImportError: pass if _HAVE_PYOPENSSL: from pymongo.ocsp_support import _load_trusted_ca_certs else: _load_trusted_ca_certs = None if HAVE_SSL: import ssl CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'certificates') CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem') CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, 'password_protected.pem') CA_PEM = os.path.join(CERT_PATH, 'ca.pem') CA_BUNDLE_PEM = os.path.join(CERT_PATH, 'trusted-ca.pem') CRL_PEM = os.path.join(CERT_PATH, 'crl.pem') MONGODB_X509_USERNAME = ( "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client") _PY37PLUS = sys.version_info[:2] >= (3, 7) # To fully test this start a mongod instance (built with SSL support) like so: # mongod --dbpath /path/to/data/directory --sslOnNormalPorts \ # --sslPEMKeyFile /path/to/pymongo/test/certificates/server.pem \ # --sslCAFile /path/to/pymongo/test/certificates/ca.pem \ # --sslWeakCertificateValidation # Also, make sure you have 'server' as an alias for localhost in /etc/hosts # # Note: For all replica set tests to pass, the replica set configuration must # use 'localhost' for the hostname of all hosts. class TestClientSSL(unittest.TestCase): @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what " "happens without it.") def test_no_ssl_module(self): # Explicit self.assertRaises(ConfigurationError, MongoClient, ssl=True) # Implied self.assertRaises(ConfigurationError, MongoClient, ssl_certfile=CLIENT_PEM) @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") def test_config_ssl(self): # Tests various ssl configurations self.assertRaises(ValueError, MongoClient, ssl='foo') self.assertRaises(ConfigurationError, MongoClient, ssl=False, ssl_certfile=CLIENT_PEM) self.assertRaises(TypeError, MongoClient, ssl=0) self.assertRaises(TypeError, MongoClient, ssl=5.5) self.assertRaises(TypeError, MongoClient, ssl=[]) self.assertRaises(IOError, MongoClient, ssl_certfile="NoSuchFile") self.assertRaises(TypeError, MongoClient, ssl_certfile=True) self.assertRaises(TypeError, MongoClient, ssl_certfile=[]) self.assertRaises(IOError, MongoClient, ssl_keyfile="NoSuchFile") self.assertRaises(TypeError, MongoClient, ssl_keyfile=True) self.assertRaises(TypeError, MongoClient, ssl_keyfile=[]) # Test invalid combinations self.assertRaises(ConfigurationError, MongoClient, ssl=False, ssl_keyfile=CLIENT_PEM) self.assertRaises(ConfigurationError, MongoClient, ssl=False, ssl_certfile=CLIENT_PEM) self.assertRaises(ConfigurationError, MongoClient, ssl=False, ssl_keyfile=CLIENT_PEM, ssl_certfile=CLIENT_PEM) self.assertRaises( ValueError, validate_cert_reqs, 'ssl_cert_reqs', 3) self.assertRaises( ValueError, validate_cert_reqs, 'ssl_cert_reqs', -1) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', None), None) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', ssl.CERT_NONE), ssl.CERT_NONE) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', ssl.CERT_OPTIONAL), ssl.CERT_OPTIONAL) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', ssl.CERT_REQUIRED), ssl.CERT_REQUIRED) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 0), ssl.CERT_NONE) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 1), ssl.CERT_OPTIONAL) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 2), ssl.CERT_REQUIRED) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 'CERT_NONE'), ssl.CERT_NONE) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 'CERT_OPTIONAL'), ssl.CERT_OPTIONAL) self.assertEqual( validate_cert_reqs('ssl_cert_reqs', 'CERT_REQUIRED'), ssl.CERT_REQUIRED) @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") def test_use_openssl_when_available(self): self.assertTrue(_ssl.IS_PYOPENSSL) @unittest.skipUnless(_HAVE_PYOPENSSL, "Cannot test without PyOpenSSL") def test_load_trusted_ca_certs(self): trusted_ca_certs = _load_trusted_ca_certs(CA_BUNDLE_PEM) self.assertEqual(2, len(trusted_ca_certs)) class TestSSL(IntegrationTest): def assertClientWorks(self, client): coll = client.pymongo_test.ssl_test.with_options( write_concern=WriteConcern(w=client_context.w)) coll.drop() coll.insert_one({'ssl': True}) self.assertTrue(coll.find_one()['ssl']) coll.drop() @classmethod @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") def setUpClass(cls): super(TestSSL, cls).setUpClass() # MongoClient should connect to the primary by default. cls.saved_port = MongoClient.PORT MongoClient.PORT = client_context.port @classmethod def tearDownClass(cls): MongoClient.PORT = cls.saved_port super(TestSSL, cls).tearDownClass() @client_context.require_tls def test_simple_ssl(self): # Expects the server to be running with ssl and with # no --sslPEMKeyFile or with --sslWeakCertificateValidation self.assertClientWorks(self.client) @client_context.require_ssl_certfile def test_ssl_pem_passphrase(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem if not hasattr(ssl, 'SSLContext') and not _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, MongoClient, 'localhost', ssl=True, ssl_certfile=CLIENT_ENCRYPTED_PEM, ssl_pem_passphrase="qwerty", ssl_ca_certs=CA_PEM, serverSelectionTimeoutMS=100) else: connected(MongoClient('localhost', ssl=True, ssl_certfile=CLIENT_ENCRYPTED_PEM, ssl_pem_passphrase="qwerty", ssl_ca_certs=CA_PEM, serverSelectionTimeoutMS=5000, **self.credentials)) uri_fmt = ("mongodb://localhost/?ssl=true" "&ssl_certfile=%s&ssl_pem_passphrase=qwerty" "&ssl_ca_certs=%s&serverSelectionTimeoutMS=5000") connected(MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials)) @client_context.require_ssl_certfile @client_context.require_no_auth def test_cert_ssl_implicitly_set(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # # test that setting ssl_certfile causes ssl to be set to True client = MongoClient(client_context.host, client_context.port, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) response = client.admin.command('ismaster') if 'setName' in response: client = MongoClient(client_context.pair, replicaSet=response['setName'], w=len(response['hosts']), ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) self.assertClientWorks(client) @client_context.require_ssl_certfile @client_context.require_no_auth def test_cert_ssl_validation(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # client = MongoClient('localhost', ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM) response = client.admin.command('ismaster') if 'setName' in response: if response['primary'].split(":")[0] != 'localhost': raise SkipTest("No hosts in the replicaset for 'localhost'. " "Cannot validate hostname in the certificate") client = MongoClient('localhost', replicaSet=response['setName'], w=len(response['hosts']), ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM) self.assertClientWorks(client) if HAVE_IPADDRESS: client = MongoClient('127.0.0.1', ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM) self.assertClientWorks(client) @client_context.require_ssl_certfile @client_context.require_no_auth def test_cert_ssl_uri_support(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # uri_fmt = ("mongodb://localhost/?ssl=true&ssl_certfile=%s&ssl_cert_reqs" "=%s&ssl_ca_certs=%s&ssl_match_hostname=true") client = MongoClient(uri_fmt % (CLIENT_PEM, 'CERT_REQUIRED', CA_PEM)) self.assertClientWorks(client) @client_context.require_ssl_certfile @client_context.require_no_auth def test_cert_ssl_validation_optional(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # client = MongoClient('localhost', ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_OPTIONAL, ssl_ca_certs=CA_PEM) response = client.admin.command('ismaster') if 'setName' in response: if response['primary'].split(":")[0] != 'localhost': raise SkipTest("No hosts in the replicaset for 'localhost'. " "Cannot validate hostname in the certificate") client = MongoClient('localhost', replicaSet=response['setName'], w=len(response['hosts']), ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_OPTIONAL, ssl_ca_certs=CA_PEM) self.assertClientWorks(client) @client_context.require_ssl_certfile @client_context.require_server_resolvable def test_cert_ssl_validation_hostname_matching(self): # Expects the server to be running with server.pem and ca.pem # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # Python > 2.7.9. If SSLContext doesn't have load_default_certs # it also doesn't have check_hostname. ctx = get_ssl_context( None, None, None, None, ssl.CERT_NONE, None, False, True) if hasattr(ctx, 'load_default_certs'): self.assertFalse(ctx.check_hostname) ctx = get_ssl_context( None, None, None, None, ssl.CERT_NONE, None, True, True) self.assertFalse(ctx.check_hostname) ctx = get_ssl_context( None, None, None, None, ssl.CERT_REQUIRED, None, False, True) self.assertFalse(ctx.check_hostname) ctx = get_ssl_context( None, None, None, None, ssl.CERT_REQUIRED, None, True, True) if _PY37PLUS: self.assertTrue(ctx.check_hostname) else: self.assertFalse(ctx.check_hostname) response = self.client.admin.command('ismaster') with self.assertRaises(ConnectionFailure): connected(MongoClient('server', ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM, serverSelectionTimeoutMS=500, **self.credentials)) connected(MongoClient('server', ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM, ssl_match_hostname=False, serverSelectionTimeoutMS=500, **self.credentials)) if 'setName' in response: with self.assertRaises(ConnectionFailure): connected(MongoClient('server', replicaSet=response['setName'], ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM, serverSelectionTimeoutMS=500, **self.credentials)) connected(MongoClient('server', replicaSet=response['setName'], ssl=True, ssl_certfile=CLIENT_PEM, ssl_cert_reqs=ssl.CERT_REQUIRED, ssl_ca_certs=CA_PEM, ssl_match_hostname=False, serverSelectionTimeoutMS=500, **self.credentials)) @client_context.require_ssl_certfile def test_ssl_crlfile_support(self): if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF') or _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, MongoClient, 'localhost', ssl=True, ssl_ca_certs=CA_PEM, ssl_crlfile=CRL_PEM, serverSelectionTimeoutMS=100) else: connected(MongoClient('localhost', ssl=True, ssl_ca_certs=CA_PEM, serverSelectionTimeoutMS=100, **self.credentials)) with self.assertRaises(ConnectionFailure): connected(MongoClient('localhost', ssl=True, ssl_ca_certs=CA_PEM, ssl_crlfile=CRL_PEM, serverSelectionTimeoutMS=100, **self.credentials)) uri_fmt = ("mongodb://localhost/?ssl=true&" "ssl_ca_certs=%s&serverSelectionTimeoutMS=100") connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) uri_fmt = ("mongodb://localhost/?ssl=true&ssl_crlfile=%s" "&ssl_ca_certs=%s&serverSelectionTimeoutMS=100") with self.assertRaises(ConnectionFailure): connected(MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials)) @client_context.require_ssl_certfile @client_context.require_server_resolvable def test_validation_with_system_ca_certs(self): # Expects the server to be running with server.pem and ca.pem. # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # --sslWeakCertificateValidation # if sys.platform == "win32": raise SkipTest("Can't test system ca certs on Windows.") if sys.version_info < (2, 7, 9): raise SkipTest("Can't load system CA certificates.") if (ssl.OPENSSL_VERSION.lower().startswith('libressl') and sys.platform == 'darwin' and not _ssl.IS_PYOPENSSL): raise SkipTest( "LibreSSL on OSX doesn't support setting CA certificates " "using SSL_CERT_FILE environment variable.") # Tell OpenSSL where CA certificates live. os.environ['SSL_CERT_FILE'] = CA_PEM try: with self.assertRaises(ConnectionFailure): # Server cert is verified but hostname matching fails connected(MongoClient('server', ssl=True, serverSelectionTimeoutMS=100, **self.credentials)) # Server cert is verified. Disable hostname matching. connected(MongoClient('server', ssl=True, ssl_match_hostname=False, serverSelectionTimeoutMS=100, **self.credentials)) # Server cert and hostname are verified. connected(MongoClient('localhost', ssl=True, serverSelectionTimeoutMS=100, **self.credentials)) # Server cert and hostname are verified. connected( MongoClient( 'mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=100', **self.credentials)) finally: os.environ.pop('SSL_CERT_FILE') def test_system_certs_config_error(self): ctx = get_ssl_context( None, None, None, None, ssl.CERT_NONE, None, False, True) if ((sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr(ctx, "load_default_certs")): raise SkipTest( "Can't test when system CA certificates are loadable.") have_certifi = ssl_support.HAVE_CERTIFI have_wincertstore = ssl_support.HAVE_WINCERTSTORE # Force the test regardless of environment. ssl_support.HAVE_CERTIFI = False ssl_support.HAVE_WINCERTSTORE = False try: with self.assertRaises(ConfigurationError): MongoClient("mongodb://localhost/?ssl=true") finally: ssl_support.HAVE_CERTIFI = have_certifi ssl_support.HAVE_WINCERTSTORE = have_wincertstore def test_certifi_support(self): if hasattr(ssl, "SSLContext"): # SSLSocket doesn't provide ca_certs attribute on pythons # with SSLContext and SSLContext provides no information # about ca_certs. raise SkipTest("Can't test when SSLContext available.") if not ssl_support.HAVE_CERTIFI: raise SkipTest("Need certifi to test certifi support.") have_wincertstore = ssl_support.HAVE_WINCERTSTORE # Force the test on Windows, regardless of environment. ssl_support.HAVE_WINCERTSTORE = False try: ctx = get_ssl_context( None, None, None, CA_PEM, ssl.CERT_REQUIRED, None, True, True) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, CA_PEM) ctx = get_ssl_context(None, None, None, None, None, None, True, True) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, ssl_support.certifi.where()) finally: ssl_support.HAVE_WINCERTSTORE = have_wincertstore def test_wincertstore(self): if sys.platform != "win32": raise SkipTest("Only valid on Windows.") if hasattr(ssl, "SSLContext"): # SSLSocket doesn't provide ca_certs attribute on pythons # with SSLContext and SSLContext provides no information # about ca_certs. raise SkipTest("Can't test when SSLContext available.") if not ssl_support.HAVE_WINCERTSTORE: raise SkipTest("Need wincertstore to test wincertstore.") ctx = get_ssl_context( None, None, None, CA_PEM, ssl.CERT_REQUIRED, None, True, True) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, CA_PEM) ctx = get_ssl_context(None, None, None, None, None, None, True, True) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, ssl_support._WINCERTS.name) @client_context.require_auth @client_context.require_ssl_certfile def test_mongodb_x509_auth(self): host, port = client_context.host, client_context.port ssl_client = MongoClient( client_context.pair, ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) self.addCleanup(remove_all_users, ssl_client['$external']) ssl_client.admin.authenticate(db_user, db_pwd) # Give x509 user all necessary privileges. client_context.create_user('$external', MONGODB_X509_USERNAME, roles=[ {'role': 'readWriteAnyDatabase', 'db': 'admin'}, {'role': 'userAdminAnyDatabase', 'db': 'admin'}]) noauth = MongoClient( client_context.pair, ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) self.assertRaises(OperationFailure, noauth.pymongo_test.test.count) listener = EventListener() auth = MongoClient( client_context.pair, authMechanism='MONGODB-X509', ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM, event_listeners=[listener]) if client_context.version.at_least(3, 3, 12): # No error auth.pymongo_test.test.find_one() names = listener.started_command_names() if client_context.version.at_least(4, 4, -1): # Speculative auth skips the authenticate command. self.assertEqual(names, ['find']) else: self.assertEqual(names, ['authenticate', 'find']) else: # Should require a username with self.assertRaises(ConfigurationError): auth.pymongo_test.test.find_one() uri = ('mongodb://%s@%s:%d/?authMechanism=' 'MONGODB-X509' % ( quote_plus(MONGODB_X509_USERNAME), host, port)) client = MongoClient(uri, ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) # No error client.pymongo_test.test.find_one() uri = 'mongodb://%s:%d/?authMechanism=MONGODB-X509' % (host, port) client = MongoClient(uri, ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) if client_context.version.at_least(3, 3, 12): # No error client.pymongo_test.test.find_one() else: # Should require a username with self.assertRaises(ConfigurationError): client.pymongo_test.test.find_one() # Auth should fail if username and certificate do not match uri = ('mongodb://%s@%s:%d/?authMechanism=' 'MONGODB-X509' % ( quote_plus("not the username"), host, port)) bad_client = MongoClient( uri, ssl=True, ssl_cert_reqs="CERT_NONE", ssl_certfile=CLIENT_PEM) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() bad_client = MongoClient( client_context.pair, username="not the username", authMechanism='MONGODB-X509', ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CLIENT_PEM) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() # Invalid certificate (using CA certificate as client certificate) uri = ('mongodb://%s@%s:%d/?authMechanism=' 'MONGODB-X509' % ( quote_plus(MONGODB_X509_USERNAME), host, port)) try: connected(MongoClient(uri, ssl=True, ssl_cert_reqs=ssl.CERT_NONE, ssl_certfile=CA_PEM, serverSelectionTimeoutMS=100)) except (ConnectionFailure, ConfigurationError): pass else: self.fail("Invalid certificate accepted.") @client_context.require_ssl_certfile def test_connect_with_ca_bundle(self): def remove(path): try: os.remove(path) except OSError: pass temp_ca_bundle = os.path.join(CERT_PATH, 'trusted-ca-bundle.pem') self.addCleanup(remove, temp_ca_bundle) # Add the CA cert file to the bundle. cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) with MongoClient('localhost', tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle) as client: self.assertTrue(client.admin.command('ismaster')) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_streaming_protocol.py000066400000000000000000000211211374256237000215740ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the database module.""" import sys import time sys.path[0:0] = [""] from pymongo import monitoring from test import (client_context, IntegrationTest, unittest) from test.utils import (HeartbeatEventListener, rs_or_single_client, single_client, ServerEventListener, wait_until) class TestStreamingProtocol(IntegrationTest): @client_context.require_failCommand_appName def test_failCommand_streaming(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() client = rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName='failingIsMasterTest') self.addCleanup(client.close) # Force a connection. client.admin.command('ping') address = client.address listener.reset() fail_ismaster = { 'configureFailPoint': 'failCommand', 'mode': {'times': 4}, 'data': { 'failCommands': ['isMaster'], 'closeConnection': False, 'errorCode': 10107, 'appName': 'failingIsMasterTest', }, } with self.fail_point(fail_ismaster): def _marked_unknown(event): return (event.server_address == address and not event.new_description.is_server_type_known) def _discovered_node(event): return (event.server_address == address and not event.previous_description.is_server_type_known and event.new_description.is_server_type_known) def marked_unknown(): return len(listener.matching(_marked_unknown)) >= 1 def rediscovered(): return len(listener.matching(_discovered_node)) >= 1 # Topology events are published asynchronously wait_until(marked_unknown, 'mark node unknown') wait_until(rediscovered, 'rediscover node') # Server should be selectable. client.admin.command('ping') @client_context.require_failCommand_appName def test_streaming_rtt(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() # On Windows, RTT can actually be 0.0 because time.time() only has # 1-15 millisecond resolution. We need to delay the initial isMaster # to ensure that RTT is never zero. name = 'streamingRttTest' delay_ismaster = { 'configureFailPoint': 'failCommand', 'mode': {'times': 1000}, 'data': { 'failCommands': ['isMaster'], 'blockConnection': True, 'blockTimeMS': 20, # This can be uncommented after SERVER-49220 is fixed. # 'appName': name, }, } with self.fail_point(delay_ismaster): client = rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name) self.addCleanup(client.close) # Force a connection. client.admin.command('ping') address = client.address delay_ismaster['data']['blockTimeMS'] = 500 delay_ismaster['data']['appName'] = name with self.fail_point(delay_ismaster): def rtt_exceeds_250_ms(): # XXX: Add a public TopologyDescription getter to MongoClient? topology = client._topology sd = topology.description.server_descriptions()[address] return sd.round_trip_time > 0.250 wait_until(rtt_exceeds_250_ms, 'exceed 250ms RTT') # Server should be selectable. client.admin.command('ping') def changed_event(event): return (event.server_address == address and isinstance( event, monitoring.ServerDescriptionChangedEvent)) # There should only be one event published, for the initial discovery. events = listener.matching(changed_event) self.assertEqual(1, len(events)) self.assertGreater(events[0].new_description.round_trip_time, 0) @client_context.require_failCommand_appName def test_monitor_waits_after_server_check_error(self): hb_listener = HeartbeatEventListener() client = rs_or_single_client( event_listeners=[hb_listener], heartbeatFrequencyMS=500, appName='waitAfterErrorTest') self.addCleanup(client.close) # Force a connection. client.admin.command('ping') address = client.address fail_ismaster = { 'mode': {'times': 50}, 'data': { 'failCommands': ['isMaster'], 'closeConnection': False, 'errorCode': 91, # This can be uncommented after SERVER-49220 is fixed. # 'appName': 'waitAfterErrorTest', }, } with self.fail_point(fail_ismaster): time.sleep(2) # Server should be selectable. client.admin.command('ping') def hb_started(event): return (isinstance(event, monitoring.ServerHeartbeatStartedEvent) and event.connection_id == address) hb_started_events = hb_listener.matching(hb_started) # Explanation of the expected heartbeat events: # Time: event # 0ms: create MongoClient # 1ms: run monitor handshake, 1 # 2ms: run awaitable isMaster, 2 # 3ms: run configureFailPoint # 502ms: isMaster fails for the first time with command error # 1002ms: run monitor handshake, 3 # 1502ms: run monitor handshake, 4 # 2002ms: run monitor handshake, 5 # 2003ms: disable configureFailPoint # 2004ms: isMaster succeeds, 6 # 2004ms: awaitable isMaster, 7 self.assertGreater(len(hb_started_events), 7) # This can be reduced to ~15 after SERVER-49220 is fixed. self.assertLess(len(hb_started_events), 40) @client_context.require_failCommand_appName def test_heartbeat_awaited_flag(self): hb_listener = HeartbeatEventListener() client = single_client( event_listeners=[hb_listener], heartbeatFrequencyMS=500, appName='heartbeatEventAwaitedFlag') self.addCleanup(client.close) # Force a connection. client.admin.command('ping') def hb_succeeded(event): return isinstance(event, monitoring.ServerHeartbeatSucceededEvent) def hb_failed(event): return isinstance(event, monitoring.ServerHeartbeatFailedEvent) fail_heartbeat = { 'mode': {'times': 2}, 'data': { 'failCommands': ['isMaster'], 'closeConnection': True, 'appName': 'heartbeatEventAwaitedFlag', }, } with self.fail_point(fail_heartbeat): wait_until(lambda: hb_listener.matching(hb_failed), "published failed event") # Reconnect. client.admin.command('ping') hb_succeeded_events = hb_listener.matching(hb_succeeded) hb_failed_events = hb_listener.matching(hb_failed) self.assertFalse(hb_succeeded_events[0].awaited) self.assertTrue(hb_failed_events[0].awaited) # Depending on thread scheduling, the failed heartbeat could occur on # the second or third check. events = [type(e) for e in hb_listener.results[:4]] if events == [monitoring.ServerHeartbeatStartedEvent, monitoring.ServerHeartbeatSucceededEvent, monitoring.ServerHeartbeatStartedEvent, monitoring.ServerHeartbeatFailedEvent]: self.assertFalse(hb_succeeded_events[1].awaited) else: self.assertTrue(hb_succeeded_events[1].awaited) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_threads.py000066400000000000000000000145311374256237000173230ustar00rootroot00000000000000# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test that pymongo is thread safe.""" import threading from test import (client_context, db_user, db_pwd, IntegrationTest, unittest) from test.utils import rs_or_single_client_noauth, rs_or_single_client from test.utils import joinall from pymongo.errors import OperationFailure @client_context.require_connection def setUpModule(): pass class AutoAuthenticateThreads(threading.Thread): def __init__(self, collection, num): threading.Thread.__init__(self) self.coll = collection self.num = num self.success = False self.setDaemon(True) def run(self): for i in range(self.num): self.coll.insert_one({'num': i}) self.coll.find_one({'num': i}) self.success = True class SaveAndFind(threading.Thread): def __init__(self, collection): threading.Thread.__init__(self) self.collection = collection self.setDaemon(True) self.passed = False def run(self): sum = 0 for document in self.collection.find(): sum += document["x"] assert sum == 499500, "sum was %d not 499500" % sum self.passed = True class Insert(threading.Thread): def __init__(self, collection, n, expect_exception): threading.Thread.__init__(self) self.collection = collection self.n = n self.expect_exception = expect_exception self.setDaemon(True) def run(self): for _ in range(self.n): error = True try: self.collection.insert_one({"test": "insert"}) error = False except: if not self.expect_exception: raise if self.expect_exception: assert error class Update(threading.Thread): def __init__(self, collection, n, expect_exception): threading.Thread.__init__(self) self.collection = collection self.n = n self.expect_exception = expect_exception self.setDaemon(True) def run(self): for _ in range(self.n): error = True try: self.collection.update_one({"test": "unique"}, {"$set": {"test": "update"}}) error = False except: if not self.expect_exception: raise if self.expect_exception: assert error class Disconnect(threading.Thread): def __init__(self, client, n): threading.Thread.__init__(self) self.client = client self.n = n self.passed = False def run(self): for _ in range(self.n): self.client.close() self.passed = True class TestThreads(IntegrationTest): def setUp(self): self.db = self.client.pymongo_test def test_threading(self): self.db.drop_collection("test") self.db.test.insert_many([{"x": i} for i in range(1000)]) threads = [] for i in range(10): t = SaveAndFind(self.db.test) t.start() threads.append(t) joinall(threads) def test_safe_insert(self): self.db.drop_collection("test1") self.db.test1.insert_one({"test": "insert"}) self.db.drop_collection("test2") self.db.test2.insert_one({"test": "insert"}) self.db.test2.create_index("test", unique=True) self.db.test2.find_one() okay = Insert(self.db.test1, 2000, False) error = Insert(self.db.test2, 2000, True) error.start() okay.start() error.join() okay.join() def test_safe_update(self): self.db.drop_collection("test1") self.db.test1.insert_one({"test": "update"}) self.db.test1.insert_one({"test": "unique"}) self.db.drop_collection("test2") self.db.test2.insert_one({"test": "update"}) self.db.test2.insert_one({"test": "unique"}) self.db.test2.create_index("test", unique=True) self.db.test2.find_one() okay = Update(self.db.test1, 2000, False) error = Update(self.db.test2, 2000, True) error.start() okay.start() error.join() okay.join() def test_client_disconnect(self): db = rs_or_single_client(serverSelectionTimeoutMS=30000).pymongo_test db.drop_collection("test") db.test.insert_many([{"x": i} for i in range(1000)]) # Start 10 threads that execute a query, and 10 threads that call # client.close() 10 times in a row. threads = [SaveAndFind(db.test) for _ in range(10)] threads.extend(Disconnect(db.client, 10) for _ in range(10)) for t in threads: t.start() for t in threads: t.join(300) for t in threads: self.assertTrue(t.passed) class TestThreadsAuth(IntegrationTest): @classmethod @client_context.require_auth def setUpClass(cls): super(TestThreadsAuth, cls).setUpClass() def test_auto_auth_login(self): # Create the database upfront to workaround SERVER-39167. self.client.auth_test.test.insert_one({}) self.addCleanup(self.client.drop_database, "auth_test") client = rs_or_single_client_noauth() self.assertRaises(OperationFailure, client.auth_test.test.find_one) # Admin auth client.admin.authenticate(db_user, db_pwd) nthreads = 10 threads = [] for _ in range(nthreads): t = AutoAuthenticateThreads(client.auth_test.test, 10) t.start() threads.append(t) joinall(threads) for t in threads: self.assertTrue(t.success) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_timestamp.py000066400000000000000000000052711374256237000176750ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for the Timestamp class.""" import datetime import sys import copy import pickle sys.path[0:0] = [""] from bson.timestamp import Timestamp from bson.tz_util import utc from test import unittest class TestTimestamp(unittest.TestCase): def test_timestamp(self): t = Timestamp(123, 456) self.assertEqual(t.time, 123) self.assertEqual(t.inc, 456) self.assertTrue(isinstance(t, Timestamp)) def test_datetime(self): d = datetime.datetime(2010, 5, 5, tzinfo=utc) t = Timestamp(d, 0) self.assertEqual(1273017600, t.time) self.assertEqual(d, t.as_datetime()) def test_datetime_copy_pickle(self): d = datetime.datetime(2010, 5, 5, tzinfo=utc) t = Timestamp(d, 0) dc = copy.deepcopy(d) self.assertEqual(dc, t.as_datetime()) for protocol in [0, 1, 2, -1]: pkl = pickle.dumps(d, protocol=protocol) dp = pickle.loads(pkl) self.assertEqual(dp, t.as_datetime()) def test_exceptions(self): self.assertRaises(TypeError, Timestamp) self.assertRaises(TypeError, Timestamp, None, 123) self.assertRaises(TypeError, Timestamp, 1.2, 123) self.assertRaises(TypeError, Timestamp, 123, None) self.assertRaises(TypeError, Timestamp, 123, 1.2) self.assertRaises(ValueError, Timestamp, 0, -1) self.assertRaises(ValueError, Timestamp, -1, 0) self.assertTrue(Timestamp(0, 0)) def test_equality(self): t = Timestamp(1, 1) self.assertNotEqual(t, Timestamp(0, 1)) self.assertNotEqual(t, Timestamp(1, 0)) self.assertEqual(t, Timestamp(1, 1)) # Explicitly test inequality self.assertFalse(t != Timestamp(1, 1)) def test_hash(self): self.assertEqual(hash(Timestamp(1, 2)), hash(Timestamp(1, 2))) self.assertNotEqual(hash(Timestamp(1, 2)), hash(Timestamp(1, 3))) self.assertNotEqual(hash(Timestamp(1, 2)), hash(Timestamp(2, 2))) def test_repr(self): t = Timestamp(0, 0) self.assertEqual(repr(t), "Timestamp(0, 0)") if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_topology.py000066400000000000000000000727561374256237000175620ustar00rootroot00000000000000# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the topology module.""" import sys sys.path[0:0] = [""] from bson.py3compat import imap from pymongo import common from pymongo.read_preferences import ReadPreference, Secondary from pymongo.server_type import SERVER_TYPE from pymongo.topology import (_ErrorContext, Topology) from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure) from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.pool import PoolOptions from pymongo.server_description import ServerDescription from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.settings import TopologySettings from test import client_knobs, unittest from test.utils import MockPool, wait_until class MockMonitor(object): def __init__(self, server_description, topology, pool, topology_settings): self._server_description = server_description self.opened = False def cancel_check(self): pass def open(self): self.opened = True def request_check(self): pass def close(self): self.opened = False class SetNameDiscoverySettings(TopologySettings): def get_topology_type(self): return TOPOLOGY_TYPE.ReplicaSetNoPrimary address = ('a', 27017) def create_mock_topology( seeds=None, replica_set_name=None, monitor_class=MockMonitor): partitioned_seeds = list(imap(common.partition_node, seeds or ['a'])) topology_settings = TopologySettings( partitioned_seeds, replica_set_name=replica_set_name, pool_class=MockPool, monitor_class=monitor_class) t = Topology(topology_settings) t.open() return t def got_ismaster(topology, server_address, ismaster_response): server_description = ServerDescription( server_address, IsMaster(ismaster_response), 0) topology.on_change(server_description) def disconnected(topology, server_address): # Create new description of server type Unknown. topology.on_change(ServerDescription(server_address)) def get_server(topology, hostname): return topology.get_server_by_address((hostname, 27017)) def get_type(topology, hostname): return get_server(topology, hostname).description.server_type def get_monitor(topology, hostname): return get_server(topology, hostname)._monitor class TopologyTest(unittest.TestCase): """Disables periodic monitoring, to make tests deterministic.""" def setUp(self): super(TopologyTest, self).setUp() self.client_knobs = client_knobs(heartbeat_frequency=999999) self.client_knobs.enable() self.addCleanup(self.client_knobs.disable) class TestTopologyConfiguration(TopologyTest): def test_timeout_configuration(self): pool_options = PoolOptions(connect_timeout=1, socket_timeout=2) topology_settings = TopologySettings(pool_options=pool_options) t = Topology(topology_settings=topology_settings) t.open() # Get the default server. server = t.get_server_by_address(('localhost', 27017)) # The pool for application operations obeys our settings. self.assertEqual(1, server._pool.opts.connect_timeout) self.assertEqual(2, server._pool.opts.socket_timeout) # The pool for monitoring operations uses our connect_timeout as both # its connect_timeout and its socket_timeout. monitor = server._monitor self.assertEqual(1, monitor._pool.opts.connect_timeout) self.assertEqual(1, monitor._pool.opts.socket_timeout) # The monitor, not its pool, is responsible for calling ismaster. self.assertFalse(monitor._pool.handshake) class TestSingleServerTopology(TopologyTest): def test_direct_connection(self): for server_type, ismaster_response in [ (SERVER_TYPE.RSPrimary, { 'ok': 1, 'ismaster': True, 'hosts': ['a'], 'setName': 'rs', 'maxWireVersion': 6}), (SERVER_TYPE.RSSecondary, { 'ok': 1, 'ismaster': False, 'secondary': True, 'hosts': ['a'], 'setName': 'rs', 'maxWireVersion': 6}), (SERVER_TYPE.Mongos, { 'ok': 1, 'ismaster': True, 'msg': 'isdbgrid', 'maxWireVersion': 6}), (SERVER_TYPE.RSArbiter, { 'ok': 1, 'ismaster': False, 'arbiterOnly': True, 'hosts': ['a'], 'setName': 'rs', 'maxWireVersion': 6}), (SERVER_TYPE.Standalone, { 'ok': 1, 'ismaster': True, 'maxWireVersion': 6}), # Slave. (SERVER_TYPE.Standalone, { 'ok': 1, 'ismaster': False, 'maxWireVersion': 6}), ]: t = create_mock_topology() # Can't select a server while the only server is of type Unknown. with self.assertRaisesRegex(ConnectionFailure, 'No servers found yet'): t.select_servers(any_server_selector, server_selection_timeout=0) got_ismaster(t, address, ismaster_response) # Topology type never changes. self.assertEqual(TOPOLOGY_TYPE.Single, t.description.topology_type) # No matter whether the server is writable, # select_servers() returns it. s = t.select_server(writable_server_selector) self.assertEqual(server_type, s.description.server_type) # Topology type single is always readable and writable regardless # of server type or state. self.assertEqual(t.description.topology_type_name, 'Single') self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) self.assertTrue(t.description.has_readable_server(Secondary())) self.assertTrue(t.description.has_readable_server( Secondary(tag_sets=[{'tag': 'does-not-exist'}]))) def test_reopen(self): t = create_mock_topology() # Additional calls are permitted. t.open() t.open() def test_unavailable_seed(self): t = create_mock_topology() disconnected(t, address) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) def test_round_trip_time(self): round_trip_time = 125 available = True class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): if available: return (IsMaster({'ok': 1, 'maxWireVersion': 6}), round_trip_time) else: raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) s = t.select_server(writable_server_selector) self.assertEqual(125, s.description.round_trip_time) round_trip_time = 25 t.request_check_all() # Exponential weighted average: .8 * 125 + .2 * 25 = 105. self.assertAlmostEqual(105, s.description.round_trip_time) # The server is temporarily down. available = False t.request_check_all() def raises_err(): try: t.select_server(writable_server_selector, server_selection_timeout=0.1) except ConnectionFailure: return True else: return False wait_until(raises_err, 'discover server is down') self.assertIsNone(s.description.round_trip_time) # Bring it back, RTT is now 20 milliseconds. available = True round_trip_time = 20 def new_average(): # We reset the average to the most recent measurement. description = s.description return (description.round_trip_time is not None and round(abs(20 - description.round_trip_time), 7) == 0) tries = 0 while not new_average(): t.request_check_all() tries += 1 if tries > 10: self.fail("Didn't ever calculate correct new average") class TestMultiServerTopology(TopologyTest): def test_readable_writable(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b']}) self.assertTrue( t.description.topology_type_name, 'ReplicaSetWithPrimary') self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) self.assertTrue( t.description.has_readable_server(Secondary())) self.assertFalse( t.description.has_readable_server( Secondary(tag_sets=[{'tag': 'exists'}]))) t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': False, 'secondary': False, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b']}) self.assertTrue( t.description.topology_type_name, 'ReplicaSetNoPrimary') self.assertFalse(t.description.has_writable_server()) self.assertFalse(t.description.has_readable_server()) self.assertTrue( t.description.has_readable_server(Secondary())) self.assertFalse( t.description.has_readable_server( Secondary(tag_sets=[{'tag': 'exists'}]))) t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b'], 'tags': {'tag': 'exists'}}) self.assertTrue( t.description.topology_type_name, 'ReplicaSetWithPrimary') self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) self.assertTrue( t.description.has_readable_server(Secondary())) self.assertTrue( t.description.has_readable_server( Secondary(tag_sets=[{'tag': 'exists'}]))) def test_close(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b']}) self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) self.assertTrue(get_monitor(t, 'a').opened) self.assertTrue(get_monitor(t, 'b').opened) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) t.close() self.assertEqual(2, len(t.description.server_descriptions())) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) self.assertFalse(get_monitor(t, 'a').opened) self.assertFalse(get_monitor(t, 'b').opened) self.assertEqual('rs', t.description.replica_set_name) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) # A closed topology should not be updated when receiving an isMaster. got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b', 'c']}) self.assertEqual(2, len(t.description.server_descriptions())) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) self.assertFalse(get_monitor(t, 'a').opened) self.assertFalse(get_monitor(t, 'b').opened) # Server c should not have been added. self.assertEqual(None, get_server(t, 'c')) self.assertEqual('rs', t.description.replica_set_name) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) def test_handle_error(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b']}) errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True) t.handle_error(('a', 27017), errctx) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) self.assertEqual('rs', t.description.replica_set_name) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) t.handle_error(('b', 27017), errctx) self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) self.assertEqual('rs', t.description.replica_set_name) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) def test_handle_getlasterror(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b']}) t.handle_getlasterror(('a', 27017), 'not master') self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) self.assertEqual('rs', t.description.replica_set_name) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b']}) self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) def test_handle_error_removed_server(self): t = create_mock_topology(replica_set_name='rs') # No error resetting a server not in the TopologyDescription. errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True) t.handle_error(('b', 27017), errctx) # Server was *not* added as type Unknown. self.assertFalse(t.has_server(('b', 27017))) def test_handle_getlasterror_removed_server(self): t = create_mock_topology(replica_set_name='rs') # No error resetting a server not in the TopologyDescription. t.handle_getlasterror(('b', 27017), 'not master') # Server was *not* added as type Unknown. self.assertFalse(t.has_server(('b', 27017))) def test_discover_set_name_from_primary(self): # Discovering a replica set without the setName supplied by the user # is not yet supported by MongoClient, but Topology can do it. topology_settings = SetNameDiscoverySettings( seeds=[address], pool_class=MockPool, monitor_class=MockMonitor) t = Topology(topology_settings) self.assertEqual(t.description.replica_set_name, None) self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) t.open() got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a']}) self.assertEqual(t.description.replica_set_name, 'rs') self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetWithPrimary) # Another response from the primary. Tests the code that processes # primary response when topology type is already ReplicaSetWithPrimary. got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a']}) # No change. self.assertEqual(t.description.replica_set_name, 'rs') self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetWithPrimary) def test_discover_set_name_from_secondary(self): # Discovering a replica set without the setName supplied by the user # is not yet supported by MongoClient, but Topology can do it. topology_settings = SetNameDiscoverySettings( seeds=[address], pool_class=MockPool, monitor_class=MockMonitor) t = Topology(topology_settings) self.assertEqual(t.description.replica_set_name, None) self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) t.open() got_ismaster(t, address, { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a']}) self.assertEqual(t.description.replica_set_name, 'rs') self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) def test_wire_version(self): t = create_mock_topology(replica_set_name='rs') t.description.check_compatible() # No error. got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a']}) # Use defaults. server = t.get_server_by_address(address) self.assertEqual(server.description.min_wire_version, 0) self.assertEqual(server.description.max_wire_version, 0) got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a'], 'minWireVersion': 1, 'maxWireVersion': 5}) self.assertEqual(server.description.min_wire_version, 1) self.assertEqual(server.description.max_wire_version, 5) # Incompatible. got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a'], 'minWireVersion': 11, 'maxWireVersion': 12}) try: t.select_servers(any_server_selector) except ConfigurationError as e: # Error message should say which server failed and why. self.assertEqual( str(e), "Server at a:27017 requires wire version 11, but this version " "of PyMongo only supports up to %d." % (common.MAX_SUPPORTED_WIRE_VERSION,)) else: self.fail('No error with incompatible wire version') # Incompatible. got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a'], 'minWireVersion': 0, 'maxWireVersion': 0}) try: t.select_servers(any_server_selector) except ConfigurationError as e: # Error message should say which server failed and why. self.assertEqual( str(e), "Server at a:27017 reports wire version 0, but this version " "of PyMongo requires at least %d (MongoDB %s)." % (common.MIN_SUPPORTED_WIRE_VERSION, common.MIN_SUPPORTED_SERVER_VERSION)) else: self.fail('No error with incompatible wire version') def test_max_write_batch_size(self): t = create_mock_topology(seeds=['a', 'b'], replica_set_name='rs') def write_batch_size(): s = t.select_server(writable_server_selector) return s.description.max_write_batch_size got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b'], 'maxWireVersion': 6, 'maxWriteBatchSize': 1}) got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a', 'b'], 'maxWireVersion': 6, 'maxWriteBatchSize': 2}) # Uses primary's max batch size. self.assertEqual(1, write_batch_size()) # b becomes primary. got_ismaster(t, ('b', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'b'], 'maxWireVersion': 6, 'maxWriteBatchSize': 2}) self.assertEqual(2, write_batch_size()) def test_topology_repr(self): t = create_mock_topology(replica_set_name='rs') self.addCleanup(t.close) got_ismaster(t, ('a', 27017), { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a', 'c', 'b']}) self.assertEqual( repr(t.description), ", " ", " "]>" % (t._topology_id,)) def wait_for_master(topology): """Wait for a Topology to discover a writable server. If the monitor is currently calling ismaster, a blocking call to select_server from this thread can trigger a spurious wake of the monitor thread. In applications this is harmless but it would break some tests, so we pass server_selection_timeout=0 and poll instead. """ def get_master(): try: return topology.select_server(writable_server_selector, 0) except ConnectionFailure: return None return wait_until(get_master, 'find master') class TestTopologyErrors(TopologyTest): # Errors when calling ismaster. def test_pool_reset(self): # ismaster succeeds at first, then always raises socket error. ismaster_count = [0] class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): ismaster_count[0] += 1 if ismaster_count[0] == 1: return IsMaster({'ok': 1, 'maxWireVersion': 6}), 0 else: raise AutoReconnect('mock monitor error') t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) server = wait_for_master(t) self.assertEqual(1, ismaster_count[0]) generation = server.pool.generation # Pool is reset by ismaster failure. t.request_check_all() self.assertNotEqual(generation, server.pool.generation) def test_ismaster_retry(self): # ismaster succeeds at first, then raises socket error, then succeeds. ismaster_count = [0] class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): ismaster_count[0] += 1 if ismaster_count[0] in (1, 3): return IsMaster({'ok': 1, 'maxWireVersion': 6}), 0 else: raise AutoReconnect( 'mock monitor error #%s' % (ismaster_count[0],)) t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) server = wait_for_master(t) self.assertEqual(1, ismaster_count[0]) self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) # Second ismaster call, server is marked Unknown, then the monitor # immediately runs a retry (third ismaster). t.request_check_all() # The third ismaster call (the immediate retry) happens sometime soon # after the failed check triggered by request_check_all. Wait until # the server becomes known again. server = t.select_server(writable_server_selector, 0.250) self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) self.assertEqual(3, ismaster_count[0]) def test_internal_monitor_error(self): exception = AssertionError('internal error') class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): raise exception t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) with self.assertRaisesRegex(ConnectionFailure, 'internal error'): t.select_server(any_server_selector, server_selection_timeout=0.5) class TestServerSelectionErrors(TopologyTest): def assertMessage(self, message, topology, selector=any_server_selector): with self.assertRaises(ConnectionFailure) as context: topology.select_server(selector, server_selection_timeout=0) self.assertIn(message, str(context.exception)) def test_no_primary(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, address, { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'rs', 'hosts': ['a']}) self.assertMessage('No replica set members match selector "Primary()"', t, ReadPreference.PRIMARY) self.assertMessage('No primary available for writes', t, writable_server_selector) def test_no_secondary(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, address, { 'ok': 1, 'ismaster': True, 'setName': 'rs', 'hosts': ['a']}) self.assertMessage( 'No replica set members match selector' ' "Secondary(tag_sets=None, max_staleness=-1, hedge=None)"', t, ReadPreference.SECONDARY) self.assertMessage( "No replica set members match selector" " \"Secondary(tag_sets=[{'dc': 'ny'}], max_staleness=-1, " "hedge=None)\"", t, Secondary(tag_sets=[{'dc': 'ny'}])) def test_bad_replica_set_name(self): t = create_mock_topology(replica_set_name='rs') got_ismaster(t, address, { 'ok': 1, 'ismaster': False, 'secondary': True, 'setName': 'wrong', 'hosts': ['a']}) self.assertMessage( 'No replica set members available for replica set name "rs"', t) def test_multiple_standalones(self): # Standalones are removed from a topology with multiple seeds. t = create_mock_topology(seeds=['a', 'b']) got_ismaster(t, ('a', 27017), {'ok': 1}) got_ismaster(t, ('b', 27017), {'ok': 1}) self.assertMessage('No servers available', t) def test_no_mongoses(self): # Standalones are removed from a topology with multiple seeds. t = create_mock_topology(seeds=['a', 'b']) # Discover a mongos and change topology type to Sharded. got_ismaster(t, ('a', 27017), {'ok': 1, 'msg': 'isdbgrid'}) # Oops, both servers are standalone now. Remove them. got_ismaster(t, ('a', 27017), {'ok': 1}) got_ismaster(t, ('b', 27017), {'ok': 1}) self.assertMessage('No mongoses available', t) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_transactions.py000066400000000000000000000444461374256237000204110ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Execute Transactions Spec tests.""" import os import sys sys.path[0:0] = [""] from bson.py3compat import StringIO from pymongo import client_session, WriteConcern from pymongo.client_session import TransactionOptions from pymongo.errors import (CollectionInvalid, ConfigurationError, ConnectionFailure, InvalidOperation, OperationFailure) from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from gridfs import GridFS, GridFSBucket from test import unittest, client_context from test.utils import (rs_client, single_client, wait_until, OvertCommandListener, TestCreator) from test.utils_spec_runner import SpecRunner # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'transactions') _TXN_TESTS_DEBUG = os.environ.get('TRANSACTION_TESTS_DEBUG') # Max number of operations to perform after a transaction to prove unpinning # occurs. Chosen so that there's a low false positive rate. With 2 mongoses, # 50 attempts yields a one in a quadrillion chance of a false positive # (1/(0.5^50)). UNPIN_TEST_MAX_ATTEMPTS = 50 class TransactionsBase(SpecRunner): @classmethod def setUpClass(cls): super(TransactionsBase, cls).setUpClass() if client_context.supports_transactions(): for address in client_context.mongoses: cls.mongos_clients.append(single_client('%s:%s' % address)) @classmethod def tearDownClass(cls): for client in cls.mongos_clients: client.close() super(TransactionsBase, cls).tearDownClass() def maybe_skip_scenario(self, test): super(TransactionsBase, self).maybe_skip_scenario(test) if ('secondary' in self.id() and not client_context.is_mongos and not client_context.has_secondaries): raise unittest.SkipTest('No secondaries') class TestTransactions(TransactionsBase): @client_context.require_transactions def test_transaction_options_validation(self): default_options = TransactionOptions() self.assertIsNone(default_options.read_concern) self.assertIsNone(default_options.write_concern) self.assertIsNone(default_options.read_preference) self.assertIsNone(default_options.max_commit_time_ms) # No error when valid options are provided. TransactionOptions(read_concern=ReadConcern(), write_concern=WriteConcern(), read_preference=ReadPreference.PRIMARY, max_commit_time_ms=10000) with self.assertRaisesRegex(TypeError, "read_concern must be "): TransactionOptions(read_concern={}) with self.assertRaisesRegex(TypeError, "write_concern must be "): TransactionOptions(write_concern={}) with self.assertRaisesRegex( ConfigurationError, "transactions do not support unacknowledged write concern"): TransactionOptions(write_concern=WriteConcern(w=0)) with self.assertRaisesRegex( TypeError, "is not valid for read_preference"): TransactionOptions(read_preference={}) with self.assertRaisesRegex( TypeError, "max_commit_time_ms must be an integer or None"): TransactionOptions(max_commit_time_ms="10000") @client_context.require_transactions def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = rs_client(w=0) self.addCleanup(client.close) db = client.test coll = db.test coll.insert_one({}) with client.start_session() as s: with s.start_transaction(write_concern=WriteConcern(w=1)): self.assertTrue(coll.insert_one({}, session=s).acknowledged) self.assertTrue(coll.insert_many( [{}, {}], session=s).acknowledged) self.assertTrue(coll.bulk_write( [InsertOne({})], session=s).acknowledged) self.assertTrue(coll.replace_one( {}, {}, session=s).acknowledged) self.assertTrue(coll.update_one( {}, {"$set": {"a": 1}}, session=s).acknowledged) self.assertTrue(coll.update_many( {}, {"$set": {"a": 1}}, session=s).acknowledged) self.assertTrue(coll.delete_one({}, session=s).acknowledged) self.assertTrue(coll.delete_many({}, session=s).acknowledged) coll.find_one_and_delete({}, session=s) coll.find_one_and_replace({}, {}, session=s) coll.find_one_and_update({}, {"$set": {"a": 1}}, session=s) unsupported_txn_writes = [ (client.drop_database, [db.name], {}), (db.drop_collection, ['collection'], {}), (coll.drop, [], {}), (coll.map_reduce, ['function() {}', 'function() {}', 'output'], {}), (coll.rename, ['collection2'], {}), # Drop collection2 between tests of "rename", above. (coll.database.drop_collection, ['collection2'], {}), (coll.create_indexes, [[IndexModel('a')]], {}), (coll.create_index, ['a'], {}), (coll.drop_index, ['a_1'], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] # Creating a collection in a transaction requires MongoDB 4.4+. if client_context.version < (4, 3, 4): unsupported_txn_writes.extend([ (db.create_collection, ['collection'], {}), ]) for op in unsupported_txn_writes: op, args, kwargs = op with client.start_session() as s: kwargs['session'] = s s.start_transaction(write_concern=WriteConcern(w=1)) with self.assertRaises(OperationFailure): op(*args, **kwargs) s.abort_transaction() @client_context.require_transactions @client_context.require_multiple_mongoses def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): with s.start_transaction(): cursor = coll.find({}, session=s) self.assertTrue(next(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @client_context.require_transactions @client_context.require_multiple_mongoses def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): coll.insert_one({}, session=s) addresses = set() for _ in range(UNPIN_TEST_MAX_ATTEMPTS): cursor = coll.find({}, session=s) self.assertTrue(next(cursor)) addresses.add(cursor.address) # Break early if we can. if len(addresses) > 1: break self.assertGreater(len(addresses), 1) @client_context.require_transactions @client_context.require_version_min(4, 3, 4) def test_create_collection(self): client = client_context.client db = client.pymongo_test coll = db.test_create_collection self.addCleanup(coll.drop) # Use with_transaction to avoid StaleConfig errors on sharded clusters. def create_and_insert(session): coll2 = db.create_collection(coll.name, session=session) self.assertEqual(coll, coll2) coll.insert_one({}, session=session) with client.start_session() as s: s.with_transaction(create_and_insert) # Outside a transaction we raise CollectionInvalid on existing colls. with self.assertRaises(CollectionInvalid): db.create_collection(coll.name) # Inside a transaction we raise the OperationFailure from create. with client.start_session() as s: s.start_transaction() with self.assertRaises(OperationFailure) as ctx: db.create_collection(coll.name, session=s) self.assertEqual(ctx.exception.code, 48) # NamespaceExists @client_context.require_transactions def test_gridfs_does_not_support_transactions(self): client = client_context.client db = client.pymongo_test gfs = GridFS(db) bucket = GridFSBucket(db) def gridfs_find(*args, **kwargs): return gfs.find(*args, **kwargs).next() def gridfs_open_upload_stream(*args, **kwargs): bucket.open_upload_stream(*args, **kwargs).write(b'1') gridfs_ops = [ (gfs.put, (b'123',)), (gfs.get, (1,)), (gfs.get_version, ('name',)), (gfs.get_last_version, ('name',)), (gfs.delete, (1, )), (gfs.list, ()), (gfs.find_one, ()), (gridfs_find, ()), (gfs.exists, ()), (gridfs_open_upload_stream, ('name',)), (bucket.upload_from_stream, ('name', b'data',)), (bucket.download_to_stream, (1, StringIO(),)), (bucket.download_to_stream_by_name, ('name', StringIO(),)), (bucket.delete, (1,)), (bucket.find, ()), (bucket.open_download_stream, (1,)), (bucket.open_download_stream_by_name, ('name',)), (bucket.rename, (1, 'new-name',)), ] with client.start_session() as s, s.start_transaction(): for op, args in gridfs_ops: with self.assertRaisesRegex( InvalidOperation, 'GridFS does not support multi-document transactions', ): op(*args, session=s) class PatchSessionTimeout(object): """Patches the client_session's with_transaction timeout for testing.""" def __init__(self, mock_timeout): self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT self.mock_timeout = mock_timeout def __enter__(self): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.mock_timeout return self def __exit__(self, exc_type, exc_val, exc_tb): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout class TestTransactionsConvenientAPI(TransactionsBase): TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'transactions-convenient-api') @client_context.require_transactions def test_callback_raises_custom_error(self): class _MyException(Exception):pass def raise_error(_): raise _MyException() with self.client.start_session() as s: with self.assertRaises(_MyException): s.with_transaction(raise_error) @client_context.require_transactions def test_callback_returns_value(self): def callback(_): return 'Foo' with self.client.start_session() as s: self.assertEqual(s.with_transaction(callback), 'Foo') self.db.test.insert_one({}) def callback(session): self.db.test.insert_one({}, session=session) return 'Foo' with self.client.start_session() as s: self.assertEqual(s.with_transaction(callback), 'Foo') @client_context.require_transactions def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) err = { 'ok': 0, 'errmsg': 'Transaction 7819 has been aborted.', 'code': 251, 'codeName': 'NoSuchTransaction', 'errorLabels': ['TransientTransactionError'], } raise OperationFailure(err['errmsg'], err['code'], err) # Create the collection. coll.insert_one({}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ['insert', 'abortTransaction']) @client_context.require_test_commands @client_context.require_transactions def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) # Create the collection. coll.insert_one({}) self.set_fail_point({ 'configureFailPoint': 'failCommand', 'mode': {'times': 1}, 'data': { 'failCommands': ['commitTransaction'], 'errorCode': 251, # NoSuchTransaction }}) self.addCleanup(self.set_fail_point, { 'configureFailPoint': 'failCommand', 'mode': 'off'}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(OperationFailure): s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ['insert', 'commitTransaction']) @client_context.require_test_commands @client_context.require_transactions def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): coll.insert_one({}, session=session) # Create the collection. coll.insert_one({}) self.set_fail_point({ 'configureFailPoint': 'failCommand', 'mode': {'times': 2}, 'data': { 'failCommands': ['commitTransaction'], 'closeConnection': True}}) self.addCleanup(self.set_fail_point, { 'configureFailPoint': 'failCommand', 'mode': 'off'}) listener.results.clear() with client.start_session() as s: with PatchSessionTimeout(0): with self.assertRaises(ConnectionFailure): s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic # retry). self.assertEqual(listener.started_command_names(), ['insert', 'commitTransaction', 'commitTransaction']) # Tested here because this supports Motor's convenient transactions API. @client_context.require_transactions def test_in_transaction_property(self): client = client_context.client coll = client.test.testcollection coll.insert_one({}) self.addCleanup(coll.drop) with client.start_session() as s: self.assertFalse(s.in_transaction) s.start_transaction() self.assertTrue(s.in_transaction) coll.insert_one({}, session=s) self.assertTrue(s.in_transaction) s.commit_transaction() self.assertFalse(s.in_transaction) with client.start_session() as s: s.start_transaction() # commit empty transaction s.commit_transaction() self.assertFalse(s.in_transaction) with client.start_session() as s: s.start_transaction() s.abort_transaction() self.assertFalse(s.in_transaction) # Using a callback def callback(session): self.assertTrue(session.in_transaction) with client.start_session() as s: self.assertFalse(s.in_transaction) s.with_transaction(callback) self.assertFalse(s.in_transaction) def create_test(scenario_def, test, name): @client_context.require_test_commands @client_context.require_transactions def run_scenario(self): self.run_scenario(scenario_def, test) return run_scenario test_creator = TestCreator(create_test, TestTransactions, _TEST_PATH) test_creator.create_tests() TestCreator(create_test, TestTransactionsConvenientAPI, TestTransactionsConvenientAPI.TEST_PATH).create_tests() if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_uri_parser.py000066400000000000000000000603661374256237000200530ustar00rootroot00000000000000# Copyright 2011-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the pymongo uri_parser module.""" import copy import sys import warnings try: from ssl import CERT_NONE except ImportError: CERT_NONE = 0 sys.path[0:0] = [""] from bson.binary import JAVA_LEGACY from bson.py3compat import string_type, _unicode from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI from pymongo.uri_parser import (parse_userinfo, split_hosts, split_options, parse_uri) from test import unittest class TestURI(unittest.TestCase): def test_validate_userinfo(self): self.assertRaises(InvalidURI, parse_userinfo, 'foo@') self.assertRaises(InvalidURI, parse_userinfo, ':password') self.assertRaises(InvalidURI, parse_userinfo, 'fo::o:p@ssword') self.assertRaises(InvalidURI, parse_userinfo, ':') self.assertTrue(parse_userinfo('user:password')) self.assertEqual(('us:r', 'p@ssword'), parse_userinfo('us%3Ar:p%40ssword')) self.assertEqual(('us er', 'p ssword'), parse_userinfo('us+er:p+ssword')) self.assertEqual(('us er', 'p ssword'), parse_userinfo('us%20er:p%20ssword')) self.assertEqual(('us+er', 'p+ssword'), parse_userinfo('us%2Ber:p%2Bssword')) self.assertEqual(('dev1@FOO.COM', ''), parse_userinfo('dev1%40FOO.COM')) self.assertEqual(('dev1@FOO.COM', ''), parse_userinfo('dev1%40FOO.COM:')) def test_split_hosts(self): self.assertRaises(ConfigurationError, split_hosts, 'localhost:27017,') self.assertRaises(ConfigurationError, split_hosts, ',localhost:27017') self.assertRaises(ConfigurationError, split_hosts, 'localhost:27017,,localhost:27018') self.assertEqual([('localhost', 27017), ('example.com', 27017)], split_hosts('localhost,example.com')) self.assertEqual([('localhost', 27018), ('example.com', 27019)], split_hosts('localhost:27018,example.com:27019')) self.assertEqual([('/tmp/mongodb-27017.sock', None)], split_hosts('/tmp/mongodb-27017.sock')) self.assertEqual([('/tmp/mongodb-27017.sock', None), ('example.com', 27017)], split_hosts('/tmp/mongodb-27017.sock,' 'example.com:27017')) self.assertEqual([('example.com', 27017), ('/tmp/mongodb-27017.sock', None)], split_hosts('example.com:27017,' '/tmp/mongodb-27017.sock')) self.assertRaises(ValueError, split_hosts, '::1', 27017) self.assertRaises(ValueError, split_hosts, '[::1:27017') self.assertRaises(ValueError, split_hosts, '::1') self.assertRaises(ValueError, split_hosts, '::1]:27017') self.assertEqual([('::1', 27017)], split_hosts('[::1]:27017')) self.assertEqual([('::1', 27017)], split_hosts('[::1]')) def test_split_options(self): self.assertRaises(ConfigurationError, split_options, 'foo') self.assertRaises(ConfigurationError, split_options, 'foo=bar;foo') self.assertTrue(split_options('ssl=true')) self.assertTrue(split_options('connect=true')) self.assertTrue(split_options('ssl_match_hostname=true')) # Test Invalid URI options that should throw warnings. with warnings.catch_warnings(): warnings.filterwarnings('error') self.assertRaises(Warning, split_options, 'foo=bar', warn=True) self.assertRaises(Warning, split_options, 'socketTimeoutMS=foo', warn=True) self.assertRaises(Warning, split_options, 'socketTimeoutMS=0.0', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=foo', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=0.0', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=1e100000', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=-1e100000', warn=True) self.assertRaises(Warning, split_options, 'ssl=foo', warn=True) self.assertRaises(Warning, split_options, 'connect=foo', warn=True) self.assertRaises(Warning, split_options, 'ssl_match_hostname=foo', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=inf', warn=True) self.assertRaises(Warning, split_options, 'connectTimeoutMS=-inf', warn=True) self.assertRaises(Warning, split_options, 'wtimeoutms=foo', warn=True) self.assertRaises(Warning, split_options, 'wtimeoutms=5.5', warn=True) self.assertRaises(Warning, split_options, 'fsync=foo', warn=True) self.assertRaises(Warning, split_options, 'fsync=5.5', warn=True) self.assertRaises(Warning, split_options, 'authMechanism=foo', warn=True) # Test invalid options with warn=False. self.assertRaises(ConfigurationError, split_options, 'foo=bar') self.assertRaises(ValueError, split_options, 'socketTimeoutMS=foo') self.assertRaises(ValueError, split_options, 'socketTimeoutMS=0.0') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=foo') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=0.0') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=1e100000') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=-1e100000') self.assertRaises(ValueError, split_options, 'ssl=foo') self.assertRaises(ValueError, split_options, 'connect=foo') self.assertRaises(ValueError, split_options, 'ssl_match_hostname=foo') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=inf') self.assertRaises(ValueError, split_options, 'connectTimeoutMS=-inf') self.assertRaises(ValueError, split_options, 'wtimeoutms=foo') self.assertRaises(ValueError, split_options, 'wtimeoutms=5.5') self.assertRaises(ValueError, split_options, 'fsync=foo') self.assertRaises(ValueError, split_options, 'fsync=5.5') self.assertRaises(ValueError, split_options, 'authMechanism=foo') # Test splitting options works when valid. self.assertTrue(split_options('socketTimeoutMS=300')) self.assertTrue(split_options('connectTimeoutMS=300')) self.assertEqual({'sockettimeoutms': 0.3}, split_options('socketTimeoutMS=300')) self.assertEqual({'sockettimeoutms': 0.0001}, split_options('socketTimeoutMS=0.1')) self.assertEqual({'connecttimeoutms': 0.3}, split_options('connectTimeoutMS=300')) self.assertEqual({'connecttimeoutms': 0.0001}, split_options('connectTimeoutMS=0.1')) self.assertTrue(split_options('connectTimeoutMS=300')) self.assertTrue(isinstance(split_options('w=5')['w'], int)) self.assertTrue(isinstance(split_options('w=5.5')['w'], string_type)) self.assertTrue(split_options('w=foo')) self.assertTrue(split_options('w=majority')) self.assertTrue(split_options('wtimeoutms=500')) self.assertEqual({'fsync': True}, split_options('fsync=true')) self.assertEqual({'fsync': False}, split_options('fsync=false')) self.assertEqual({'authmechanism': 'GSSAPI'}, split_options('authMechanism=GSSAPI')) self.assertEqual({'authmechanism': 'MONGODB-CR'}, split_options('authMechanism=MONGODB-CR')) self.assertEqual({'authmechanism': 'SCRAM-SHA-1'}, split_options('authMechanism=SCRAM-SHA-1')) self.assertEqual({'authsource': 'foobar'}, split_options('authSource=foobar')) self.assertEqual({'maxpoolsize': 50}, split_options('maxpoolsize=50')) def test_parse_uri(self): self.assertRaises(InvalidURI, parse_uri, "http://foobar.com") self.assertRaises(InvalidURI, parse_uri, "http://foo@foobar.com") self.assertRaises(ValueError, parse_uri, "mongodb://::1", 27017) orig = { 'nodelist': [("localhost", 27017)], 'username': None, 'password': None, 'database': None, 'collection': None, 'options': {}, 'fqdn': None } res = copy.deepcopy(orig) self.assertEqual(res, parse_uri("mongodb://localhost")) res.update({'username': 'fred', 'password': 'foobar'}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost")) res.update({'database': 'baz'}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost/baz")) res = copy.deepcopy(orig) res['nodelist'] = [("example1.com", 27017), ("example2.com", 27017)] self.assertEqual(res, parse_uri("mongodb://example1.com:27017," "example2.com:27017")) res = copy.deepcopy(orig) res['nodelist'] = [("localhost", 27017), ("localhost", 27018), ("localhost", 27019)] self.assertEqual(res, parse_uri("mongodb://localhost," "localhost:27018,localhost:27019")) res = copy.deepcopy(orig) res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://localhost/foo")) res = copy.deepcopy(orig) self.assertEqual(res, parse_uri("mongodb://localhost/")) res.update({'database': 'test', 'collection': 'yield_historical.in'}) self.assertEqual(res, parse_uri("mongodb://" "localhost/test.yield_historical.in")) res.update({'username': 'fred', 'password': 'foobar'}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost/" "test.yield_historical.in")) res = copy.deepcopy(orig) res['nodelist'] = [("example1.com", 27017), ("example2.com", 27017)] res.update({'database': 'test', 'collection': 'yield_historical.in'}) self.assertEqual(res, parse_uri("mongodb://example1.com:27017,example2.com" ":27017/test.yield_historical.in")) # Test socket path without escaped characters. self.assertRaises(InvalidURI, parse_uri, "mongodb:///tmp/mongodb-27017.sock") # Test with escaped characters. res = copy.deepcopy(orig) res['nodelist'] = [("example2.com", 27017), ("/tmp/mongodb-27017.sock", None)] self.assertEqual(res, parse_uri("mongodb://example2.com," "%2Ftmp%2Fmongodb-27017.sock")) res = copy.deepcopy(orig) res['nodelist'] = [("shoe.sock.pants.co.uk", 27017), ("/tmp/mongodb-27017.sock", None)] res['database'] = "nethers_db" self.assertEqual(res, parse_uri("mongodb://shoe.sock.pants.co.uk," "%2Ftmp%2Fmongodb-27017.sock/nethers_db")) res = copy.deepcopy(orig) res['nodelist'] = [("/tmp/mongodb-27017.sock", None), ("example2.com", 27017)] res.update({'database': 'test', 'collection': 'yield_historical.in'}) self.assertEqual(res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock," "example2.com:27017" "/test.yield_historical.in")) res = copy.deepcopy(orig) res['nodelist'] = [("/tmp/mongodb-27017.sock", None), ("example2.com", 27017)] res.update({'database': 'test', 'collection': 'yield_historical.sock'}) self.assertEqual(res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock," "example2.com:27017/test.yield_historical" ".sock")) res = copy.deepcopy(orig) res['nodelist'] = [("example2.com", 27017)] res.update({'database': 'test', 'collection': 'yield_historical.sock'}) self.assertEqual(res, parse_uri("mongodb://example2.com:27017" "/test.yield_historical.sock")) res = copy.deepcopy(orig) res['nodelist'] = [("/tmp/mongodb-27017.sock", None)] res.update({'database': 'test', 'collection': 'mongodb-27017.sock'}) self.assertEqual(res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock" "/test.mongodb-27017.sock")) res = copy.deepcopy(orig) res['nodelist'] = [('/tmp/mongodb-27020.sock', None), ("::1", 27017), ("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 27018), ("192.168.0.212", 27019), ("localhost", 27018)] self.assertEqual(res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27020.sock" ",[::1]:27017,[2001:0db8:" "85a3:0000:0000:8a2e:0370:7334]," "192.168.0.212:27019,localhost", 27018)) res = copy.deepcopy(orig) res.update({'username': 'fred', 'password': 'foobar'}) res.update({'database': 'test', 'collection': 'yield_historical.in'}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost/" "test.yield_historical.in")) res = copy.deepcopy(orig) res['database'] = 'test' res['collection'] = 'name/with "delimiters' self.assertEqual( res, parse_uri("mongodb://localhost/test.name/with \"delimiters")) res = copy.deepcopy(orig) res['options'] = { 'readpreference': ReadPreference.SECONDARY.mongos_mode } self.assertEqual(res, parse_uri( "mongodb://localhost/?readPreference=secondary")) # Various authentication tests res = copy.deepcopy(orig) res['options'] = {'authmechanism': 'MONGODB-CR'} res['username'] = 'user' res['password'] = 'password' self.assertEqual(res, parse_uri("mongodb://user:password@localhost/" "?authMechanism=MONGODB-CR")) res = copy.deepcopy(orig) res['options'] = {'authmechanism': 'MONGODB-CR', 'authsource': 'bar'} res['username'] = 'user' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user:password@localhost/foo" "?authSource=bar;authMechanism=MONGODB-CR")) res = copy.deepcopy(orig) res['options'] = {'authmechanism': 'MONGODB-CR'} res['username'] = 'user' res['password'] = '' self.assertEqual(res, parse_uri("mongodb://user:@localhost/" "?authMechanism=MONGODB-CR")) res = copy.deepcopy(orig) res['username'] = 'user@domain.com' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo")) res = copy.deepcopy(orig) res['options'] = {'authmechanism': 'GSSAPI'} res['username'] = 'user@domain.com' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo?authMechanism=GSSAPI")) res = copy.deepcopy(orig) res['options'] = {'authmechanism': 'GSSAPI'} res['username'] = 'user@domain.com' res['password'] = '' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com" "@localhost/foo?authMechanism=GSSAPI")) res = copy.deepcopy(orig) res['options'] = { 'readpreference': ReadPreference.SECONDARY.mongos_mode, 'readpreferencetags': [ {'dc': 'west', 'use': 'website'}, {'dc': 'east', 'use': 'website'} ] } res['username'] = 'user@domain.com' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo?readpreference=secondary&" "readpreferencetags=dc:west,use:website&" "readpreferencetags=dc:east,use:website")) res = copy.deepcopy(orig) res['options'] = { 'readpreference': ReadPreference.SECONDARY.mongos_mode, 'readpreferencetags': [ {'dc': 'west', 'use': 'website'}, {'dc': 'east', 'use': 'website'}, {} ] } res['username'] = 'user@domain.com' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo?readpreference=secondary&" "readpreferencetags=dc:west,use:website&" "readpreferencetags=dc:east,use:website&" "readpreferencetags=")) res = copy.deepcopy(orig) res['options'] = {'uuidrepresentation': JAVA_LEGACY} res['username'] = 'user@domain.com' res['password'] = 'password' res['database'] = 'foo' self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo?uuidrepresentation=" "javaLegacy")) with warnings.catch_warnings(): warnings.filterwarnings('error') self.assertRaises(Warning, parse_uri, "mongodb://user%40domain.com:password" "@localhost/foo?uuidrepresentation=notAnOption", warn=True) self.assertRaises(ValueError, parse_uri, "mongodb://user%40domain.com:password" "@localhost/foo?uuidrepresentation=notAnOption") def test_parse_ssl_paths(self): # Turn off "validate" since these paths don't exist on filesystem. self.assertEqual( {'collection': None, 'database': None, 'nodelist': [('/MongoDB.sock', None)], 'options': {'ssl_certfile': '/a/b'}, 'password': 'foo/bar', 'username': 'jesse', 'fqdn': None}, parse_uri( 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=/a/b', validate=False)) self.assertEqual( {'collection': None, 'database': None, 'nodelist': [('/MongoDB.sock', None)], 'options': {'ssl_certfile': 'a/b'}, 'password': 'foo/bar', 'username': 'jesse', 'fqdn': None}, parse_uri( 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?ssl_certfile=a/b', validate=False)) def test_tlsinsecure_simple(self): # check that tlsInsecure is expanded correctly. uri = "mongodb://example.com/?tlsInsecure=true" res = { "ssl_match_hostname": False, "ssl_cert_reqs": CERT_NONE, "tlsinsecure": True, 'ssl_check_ocsp_endpoint': False} self.assertEqual(res, parse_uri(uri)["options"]) def test_tlsinsecure_legacy_conflict(self): # must not allow use of tlsinsecure alongside legacy TLS options. # same check for modern TLS options is performed in the spec-tests. uri = "mongodb://srv.com/?tlsInsecure=true&ssl_match_hostname=true" with self.assertRaises(InvalidURI): parse_uri(uri, validate=False, warn=False, normalize=False) def test_normalize_options(self): # check that options are converted to their internal names correctly. uri = ("mongodb://example.com/?tls=true&appname=myapp&maxPoolSize=10&" "fsync=true&wtimeout=10") res = { "ssl": True, "appname": "myapp", "maxpoolsize": 10, "fsync": True, "wtimeoutms": 10} self.assertEqual(res, parse_uri(uri)["options"]) def test_waitQueueMultiple_deprecated(self): uri = "mongodb://example.com/?waitQueueMultiple=5" with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter('always') parse_uri(uri) self.assertEqual(len(ctx), 1) self.assertTrue(issubclass(ctx[0].category, DeprecationWarning)) def test_unquote_after_parsing(self): quoted_val = "val%21%40%23%24%25%5E%26%2A%28%29_%2B%2C%3A+etc" unquoted_val = "val!@#$%^&*()_+,: etc" uri = ("mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" "&authMechanismProperties=AWS_SESSION_TOKEN:"+quoted_val) res = parse_uri(uri) options = { 'authmechanism': 'MONGODB-AWS', 'authmechanismproperties': { 'AWS_SESSION_TOKEN': unquoted_val}} self.assertEqual(options, res['options']) uri = (("mongodb://localhost/foo?readpreference=secondary&" "readpreferencetags=dc:west,"+quoted_val+":"+quoted_val+"&" "readpreferencetags=dc:east,use:"+quoted_val)) res = parse_uri(uri) options = { 'readpreference': ReadPreference.SECONDARY.mongos_mode, 'readpreferencetags': [ {'dc': 'west', unquoted_val: unquoted_val}, {'dc': 'east', 'use': unquoted_val} ] } self.assertEqual(options, res['options']) def test_redact_AWS_SESSION_TOKEN(self): unquoted_colon = "token:" uri = ("mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" "&authMechanismProperties=AWS_SESSION_TOKEN:"+unquoted_colon) with self.assertRaisesRegex( ValueError, 'auth mechanism properties must be key:value pairs like ' 'SERVICE_NAME:mongodb, not AWS_SESSION_TOKEN:' ', did you forget to percent-escape the token with ' 'quote_plus?'): parse_uri(uri) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_uri_spec.py000066400000000000000000000217751374256237000175120ustar00rootroot00000000000000# Copyright 2011-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test that the pymongo.uri_parser module is compliant with the connection string and uri options specifications.""" import json import os import sys import warnings sys.path[0:0] = [""] from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _HAVE_SNAPPY from pymongo.uri_parser import parse_uri from test import clear_warning_registry, unittest CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join('connection_string', 'test')) URI_OPTIONS_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'uri_options') TEST_DESC_SKIP_LIST = [ "Valid options specific to single-threaded drivers are parsed correctly", "Invalid serverSelectionTryOnce causes a warning", "tlsDisableCertificateRevocationCheck can be set to true", "tlsDisableCertificateRevocationCheck can be set to false", "tlsAllowInvalidCertificates and tlsDisableCertificateRevocationCheck both present (and true) raises an error", "tlsAllowInvalidCertificates=true and tlsDisableCertificateRevocationCheck=false raises an error", "tlsAllowInvalidCertificates=false and tlsDisableCertificateRevocationCheck=true raises an error", "tlsAllowInvalidCertificates and tlsDisableCertificateRevocationCheck both present (and false) raises an error", "tlsDisableCertificateRevocationCheck and tlsAllowInvalidCertificates both present (and true) raises an error", "tlsDisableCertificateRevocationCheck=true and tlsAllowInvalidCertificates=false raises an error", "tlsDisableCertificateRevocationCheck=false and tlsAllowInvalidCertificates=true raises an error", "tlsDisableCertificateRevocationCheck and tlsAllowInvalidCertificates both present (and false) raises an error", "tlsInsecure and tlsDisableCertificateRevocationCheck both present (and true) raises an error", "tlsInsecure=true and tlsDisableCertificateRevocationCheck=false raises an error", "tlsInsecure=false and tlsDisableCertificateRevocationCheck=true raises an error", "tlsInsecure and tlsDisableCertificateRevocationCheck both present (and false) raises an error", "tlsDisableCertificateRevocationCheck and tlsInsecure both present (and true) raises an error", "tlsDisableCertificateRevocationCheck=true and tlsInsecure=false raises an error", "tlsDisableCertificateRevocationCheck=false and tlsInsecure=true raises an error", "tlsDisableCertificateRevocationCheck and tlsInsecure both present (and false) raises an error", "tlsDisableCertificateRevocationCheck and tlsDisableOCSPEndpointCheck both present (and true) raises an error", "tlsDisableCertificateRevocationCheck=true and tlsDisableOCSPEndpointCheck=false raises an error", "tlsDisableCertificateRevocationCheck=false and tlsDisableOCSPEndpointCheck=true raises an error", "tlsDisableCertificateRevocationCheck and tlsDisableOCSPEndpointCheck both present (and false) raises an error", "tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and true) raises an error", "tlsDisableOCSPEndpointCheck=true and tlsDisableCertificateRevocationCheck=false raises an error", "tlsDisableOCSPEndpointCheck=false and tlsDisableCertificateRevocationCheck=true raises an error", "tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and false) raises an error"] class TestAllScenarios(unittest.TestCase): def setUp(self): clear_warning_registry() def get_error_message_template(expected, artefact): return "%s %s for test '%s'" % ( "Expected" if expected else "Unexpected", artefact, "%s") def run_scenario_in_dir(target_workdir): def workdir_context_decorator(func): def modified_test_scenario(*args, **kwargs): original_workdir = os.getcwd() os.chdir(target_workdir) func(*args, **kwargs) os.chdir(original_workdir) return modified_test_scenario return workdir_context_decorator def create_test(test, test_workdir): def run_scenario(self): compressors = (test.get('options') or {}).get('compressors', []) if 'snappy' in compressors and not _HAVE_SNAPPY: self.skipTest('This test needs the snappy module.') valid = True warning = False with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter('always') try: options = parse_uri(test['uri'], warn=True) except Exception: valid = False else: warning = len(ctx) > 0 expected_valid = test.get('valid', True) self.assertEqual( valid, expected_valid, get_error_message_template( not expected_valid, "error") % test['description']) if expected_valid: expected_warning = test.get('warning', False) self.assertEqual( warning, expected_warning, get_error_message_template( expected_warning, "warning") % test['description']) # Compare hosts and port. if test['hosts'] is not None: self.assertEqual( len(test['hosts']), len(options['nodelist']), "Incorrect number of hosts parsed from URI") for exp, actual in zip(test['hosts'], options['nodelist']): self.assertEqual(exp['host'], actual[0], "Expected host %s but got %s" % (exp['host'], actual[0])) if exp['port'] is not None: self.assertEqual(exp['port'], actual[1], "Expected port %s but got %s" % (exp['port'], actual)) # Compare auth options. auth = test['auth'] if auth is not None: auth['database'] = auth.pop('db') # db == database # Special case for PyMongo's collection parsing. if options.get('collection') is not None: options['database'] += "." + options['collection'] for elm in auth: if auth[elm] is not None: self.assertEqual(auth[elm], options[elm], "Expected %s but got %s" % (auth[elm], options[elm])) # Compare URI options. err_msg = "For option %s expected %s but got %s" if test['options']: opts = options['options'] for opt in test['options']: lopt = opt.lower() optname = INTERNAL_URI_OPTION_NAME_MAP.get(lopt, lopt) if opts.get(optname) is not None: if opts[optname] == test['options'][opt]: expected_value = test['options'][opt] else: expected_value = validate( lopt, test['options'][opt])[1] self.assertEqual( opts[optname], expected_value, err_msg % (opt, expected_value, opts[optname],)) else: self.fail( "Missing expected option %s" % (opt,)) return run_scenario_in_dir(test_workdir)(run_scenario) def create_tests(test_path): for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath) dirname = os.path.split(dirname[-2])[-1] + '_' + dirname[-1] for filename in filenames: if not filename.endswith('.json'): # skip everything that is not a test specification continue with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json.load(scenario_stream) for testcase in scenario_def['tests']: dsc = testcase['description'] if dsc in TEST_DESC_SKIP_LIST: print("Skipping test '%s'" % dsc) continue testmethod = create_test(testcase, dirpath) testname = 'test_%s_%s_%s' % ( dirname, os.path.splitext(filename)[0], str(dsc).replace(' ', '_')) testmethod.__name__ = testname setattr(TestAllScenarios, testmethod.__name__, testmethod) for test_path in [CONN_STRING_TEST_PATH, URI_OPTIONS_TEST_PATH]: create_tests(test_path) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/test_write_concern.py000066400000000000000000000045211374256237000205300ustar00rootroot00000000000000# Copyright 2018-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Run the unit tests for WriteConcern.""" import collections import unittest from pymongo.write_concern import WriteConcern class TestWriteConcern(unittest.TestCase): def test_equality(self): concern = WriteConcern(j=True, wtimeout=3000) self.assertEqual(concern, WriteConcern(j=True, wtimeout=3000)) self.assertNotEqual(concern, WriteConcern()) def test_equality_to_none(self): concern = WriteConcern() self.assertNotEqual(concern, None) # Explicitly use the != operator. self.assertTrue(concern != None) # noqa def test_equality_compatible_type(self): class _FakeWriteConcern(object): def __init__(self, **document): self.document = document def __eq__(self, other): try: return self.document == other.document except AttributeError: return NotImplemented def __ne__(self, other): try: return self.document != other.document except AttributeError: return NotImplemented self.assertEqual(WriteConcern(j=True), _FakeWriteConcern(j=True)) self.assertEqual(_FakeWriteConcern(j=True), WriteConcern(j=True)) self.assertEqual(WriteConcern(j=True), _FakeWriteConcern(j=True)) self.assertEqual(WriteConcern(wtimeout=42), _FakeWriteConcern(wtimeout=42)) self.assertNotEqual(WriteConcern(wtimeout=42), _FakeWriteConcern(wtimeout=2000)) def test_equality_incompatible_type(self): _fake_type = collections.namedtuple('NotAWriteConcern', ['document']) self.assertNotEqual(WriteConcern(j=True), _fake_type({'j': True})) if __name__ == '__main__': unittest.main() pymongo-3.11.0/test/unicode/000077500000000000000000000000001374256237000157025ustar00rootroot00000000000000pymongo-3.11.0/test/unicode/test_utf8.py000066400000000000000000000044331374256237000202050ustar00rootroot00000000000000import sys sys.path[0:0] = [""] from bson import encode from bson.errors import InvalidStringData from bson.py3compat import PY3 from test import unittest class TestUTF8(unittest.TestCase): # Verify that python and bson have the same understanding of # legal utf-8 if the first byte is 0xf4 (244) def _assert_same_utf8_validation(self, data): try: data.decode('utf-8') py_is_legal = True except UnicodeDecodeError: py_is_legal = False try: encode({'x': data}) bson_is_legal = True except InvalidStringData: bson_is_legal = False self.assertEqual(py_is_legal, bson_is_legal, data) @unittest.skipIf(PY3, "python3 has strong separation between bytes/unicode") def test_legal_utf8_full_coverage(self): # This test takes 400 seconds. Which is too long to run each time. # However it is the only one which covers all possible bit combinations # in the 244 space. b1 = chr(0xf4) for b2 in map(chr, range(255)): m2 = b1 + b2 self._assert_same_utf8_validation(m2) for b3 in map(chr, range(255)): m3 = m2 + b3 self._assert_same_utf8_validation(m3) for b4 in map(chr, range(255)): m4 = m3 + b4 self._assert_same_utf8_validation(m4) # In python3: # - 'bytes' are not checked with isLegalutf # - 'unicode' We cannot create unicode objects with invalid utf8, since it # would result in non valid code-points. @unittest.skipIf(PY3, "python3 has strong separation between bytes/unicode") def test_legal_utf8_few_samples(self): good_samples = [ '\xf4\x80\x80\x80', '\xf4\x8a\x80\x80', '\xf4\x8e\x80\x80', '\xf4\x81\x80\x80', ] for data in good_samples: self._assert_same_utf8_validation(data) bad_samples = [ '\xf4\x00\x80\x80', '\xf4\x3a\x80\x80', '\xf4\x7f\x80\x80', '\xf4\x90\x80\x80', '\xf4\xff\x80\x80', ] for data in bad_samples: self._assert_same_utf8_validation(data) if __name__ == "__main__": unittest.main() pymongo-3.11.0/test/uri_options/000077500000000000000000000000001374256237000166265ustar00rootroot00000000000000pymongo-3.11.0/test/uri_options/ca.pem000066400000000000000000000002251374256237000177130ustar00rootroot00000000000000# This file exists solely for the purpose of facilitating drivers which check for the existence of files specified in the URI options at parse time. pymongo-3.11.0/test/uri_options/cert.pem000066400000000000000000000002251374256237000202650ustar00rootroot00000000000000# This file exists solely for the purpose of facilitating drivers which check for the existence of files specified in the URI options at parse time. pymongo-3.11.0/test/uri_options/client.pem000066400000000000000000000002251374256237000206060ustar00rootroot00000000000000# This file exists solely for the purpose of facilitating drivers which check for the existence of files specified in the URI options at parse time. pymongo-3.11.0/test/utils.py000066400000000000000000000710001374256237000157640ustar00rootroot00000000000000# Copyright 2012-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for testing pymongo """ import collections import contextlib import functools import os import re import shutil import sys import threading import time import warnings from collections import defaultdict from functools import partial from bson import json_util, py3compat from bson.objectid import ObjectId from bson.son import SON from pymongo import (MongoClient, monitoring, read_preferences) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.monitoring import _SENSITIVE_COMMANDS, ConnectionPoolListener from pymongo.pool import (_CancellationContext, PoolOptions) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern from test import (client_context, db_user, db_pwd) if sys.version_info[0] < 3: # Python 2.7, use our backport. from test.barrier import Barrier else: from threading import Barrier IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) class CMAPListener(ConnectionPoolListener): def __init__(self): self.events = [] def reset(self): self.events = [] def add_event(self, event): self.events.append(event) def event_count(self, event_type): return len([event for event in self.events[:] if isinstance(event, event_type)]) def connection_created(self, event): self.add_event(event) def connection_ready(self, event): self.add_event(event) def connection_closed(self, event): self.add_event(event) def connection_check_out_started(self, event): self.add_event(event) def connection_check_out_failed(self, event): self.add_event(event) def connection_checked_out(self, event): self.add_event(event) def connection_checked_in(self, event): self.add_event(event) def pool_created(self, event): self.add_event(event) def pool_cleared(self, event): self.add_event(event) def pool_closed(self, event): self.add_event(event) class EventListener(monitoring.CommandListener): def __init__(self): self.results = defaultdict(list) def started(self, event): self.results['started'].append(event) def succeeded(self, event): self.results['succeeded'].append(event) def failed(self, event): self.results['failed'].append(event) def started_command_names(self): """Return list of command names started.""" return [event.command_name for event in self.results['started']] def reset(self): """Reset the state of this listener.""" self.results.clear() class WhiteListEventListener(EventListener): def __init__(self, *commands): self.commands = set(commands) super(WhiteListEventListener, self).__init__() def started(self, event): if event.command_name in self.commands: super(WhiteListEventListener, self).started(event) def succeeded(self, event): if event.command_name in self.commands: super(WhiteListEventListener, self).succeeded(event) def failed(self, event): if event.command_name in self.commands: super(WhiteListEventListener, self).failed(event) class OvertCommandListener(EventListener): """A CommandListener that ignores sensitive commands.""" def started(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: super(OvertCommandListener, self).started(event) def succeeded(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: super(OvertCommandListener, self).succeeded(event) def failed(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: super(OvertCommandListener, self).failed(event) class _ServerEventListener(object): """Listens to all events.""" def __init__(self): self.results = [] def opened(self, event): self.results.append(event) def description_changed(self, event): self.results.append(event) def closed(self, event): self.results.append(event) def matching(self, matcher): """Return the matching events.""" results = self.results[:] return [event for event in results if matcher(event)] def reset(self): self.results = [] class ServerEventListener(_ServerEventListener, monitoring.ServerListener): """Listens to Server events.""" class ServerAndTopologyEventListener(ServerEventListener, monitoring.TopologyListener): """Listens to Server and Topology events.""" class HeartbeatEventListener(monitoring.ServerHeartbeatListener): """Listens to only server heartbeat events.""" def __init__(self): self.results = [] def started(self, event): self.results.append(event) def succeeded(self, event): self.results.append(event) def failed(self, event): self.results.append(event) def matching(self, matcher): """Return the matching events.""" results = self.results[:] return [event for event in results if matcher(event)] class MockSocketInfo(object): def __init__(self): self.cancel_context = _CancellationContext() self.more_to_come = False def close_socket(self, reason): pass def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass class MockPool(object): def __init__(self, *args, **kwargs): self.generation = 0 self._lock = threading.Lock() self.opts = PoolOptions() def get_socket(self, all_credentials, checkout=False): return MockSocketInfo() def return_socket(self, *args, **kwargs): pass def _reset(self): with self._lock: self.generation += 1 def reset(self): self._reset() def close(self): self._reset() def update_is_writable(self, is_writable): pass def remove_stale_sockets(self, *args, **kwargs): pass class ScenarioDict(dict): """Dict that returns {} for any unknown key, recursively.""" def __init__(self, data): def convert(v): if isinstance(v, collections.Mapping): return ScenarioDict(v) if isinstance(v, (py3compat.string_type, bytes)): return v if isinstance(v, collections.Sequence): return [convert(item) for item in v] return v dict.__init__(self, [(k, convert(v)) for k, v in data.items()]) def __getitem__(self, item): try: return dict.__getitem__(self, item) except KeyError: # Unlike a defaultdict, don't set the key, just return a dict. return ScenarioDict({}) class CompareType(object): """Class that compares equal to any object of the given type.""" def __init__(self, type): self.type = type def __eq__(self, other): return isinstance(other, self.type) def __ne__(self, other): """Needed for Python 2.""" return not self.__eq__(other) class FunctionCallRecorder(object): """Utility class to wrap a callable and record its invocations.""" def __init__(self, function): self._function = function self._call_list = [] def __call__(self, *args, **kwargs): self._call_list.append((args, kwargs)) return self._function(*args, **kwargs) def reset(self): """Wipes the call list.""" self._call_list = [] def call_list(self): """Returns a copy of the call list.""" return self._call_list[:] @property def call_count(self): """Returns the number of times the function has been called.""" return len(self._call_list) class TestCreator(object): """Class to create test cases from specifications.""" def __init__(self, create_test, test_class, test_path): """Create a TestCreator object. :Parameters: - `create_test`: callback that returns a test case. The callback must accept the following arguments - a dictionary containing the entire test specification (the `scenario_def`), a dictionary containing the specification for which the test case will be generated (the `test_def`). - `test_class`: the unittest.TestCase class in which to create the test case. - `test_path`: path to the directory containing the JSON files with the test specifications. """ self._create_test = create_test self._test_class = test_class self.test_path = test_path def _ensure_min_max_server_version(self, scenario_def, method): """Test modifier that enforces a version range for the server on a test case.""" if 'minServerVersion' in scenario_def: min_ver = tuple( int(elt) for elt in scenario_def['minServerVersion'].split('.')) if min_ver is not None: method = client_context.require_version_min(*min_ver)(method) if 'maxServerVersion' in scenario_def: max_ver = tuple( int(elt) for elt in scenario_def['maxServerVersion'].split('.')) if max_ver is not None: method = client_context.require_version_max(*max_ver)(method) return method @staticmethod def valid_topology(run_on_req): return client_context.is_topology_type( run_on_req.get('topology', ['single', 'replicaset', 'sharded'])) @staticmethod def min_server_version(run_on_req): version = run_on_req.get('minServerVersion') if version: min_ver = tuple(int(elt) for elt in version.split('.')) return client_context.version >= min_ver return True @staticmethod def max_server_version(run_on_req): version = run_on_req.get('maxServerVersion') if version: max_ver = tuple(int(elt) for elt in version.split('.')) return client_context.version <= max_ver return True def should_run_on(self, scenario_def): run_on = scenario_def.get('runOn', []) if not run_on: # Always run these tests. return True for req in run_on: if (self.valid_topology(req) and self.min_server_version(req) and self.max_server_version(req)): return True return False def ensure_run_on(self, scenario_def, method): """Test modifier that enforces a 'runOn' on a test case.""" return client_context._require( lambda: self.should_run_on(scenario_def), "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" return scenario_def['tests'] def create_tests(self): for dirpath, _, filenames in os.walk(self.test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes # dates. opts = json_util.JSONOptions(tz_aware=False) scenario_def = ScenarioDict( json_util.loads(scenario_stream.read(), json_options=opts)) test_type = os.path.splitext(filename)[0] # Construct test from scenario. for test_def in self.tests(scenario_def): test_name = 'test_%s_%s_%s' % ( dirname, test_type.replace("-", "_").replace('.', '_'), str(test_def['description'].replace(" ", "_").replace( '.', '_'))) new_test = self._create_test( scenario_def, test_def, test_name) new_test = self._ensure_min_max_server_version( scenario_def, new_test) new_test = self.ensure_run_on( scenario_def, new_test) new_test.__name__ = test_name setattr(self._test_class, new_test.__name__, new_test) def _connection_string(h, authenticate): if h.startswith("mongodb://"): return h elif client_context.auth_enabled and authenticate: return "mongodb://%s:%s@%s" % (db_user, db_pwd, str(h)) else: return "mongodb://%s" % (str(h),) def _mongo_client(host, port, authenticate=True, directConnection=False, **kwargs): """Create a new client over SSL/TLS if necessary.""" host = host or client_context.host port = port or client_context.port client_options = client_context.default_client_options.copy() if client_context.replica_set_name and not directConnection: client_options['replicaSet'] = client_context.replica_set_name client_options.update(kwargs) client = MongoClient(_connection_string(host, authenticate), port, **client_options) return client def single_client_noauth(h=None, p=None, **kwargs): """Make a direct connection. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) def single_client(h=None, p=None, **kwargs): """Make a direct connection, and authenticate if necessary.""" return _mongo_client(h, p, directConnection=True, **kwargs) def rs_client_noauth(h=None, p=None, **kwargs): """Connect to the replica set. Don't authenticate.""" return _mongo_client(h, p, authenticate=False, **kwargs) def rs_client(h=None, p=None, **kwargs): """Connect to the replica set and authenticate if necessary.""" return _mongo_client(h, p, **kwargs) def rs_or_single_client_noauth(h=None, p=None, **kwargs): """Connect to the replica set if there is one, otherwise the standalone. Like rs_or_single_client, but does not authenticate. """ return _mongo_client(h, p, authenticate=False, **kwargs) def rs_or_single_client(h=None, p=None, **kwargs): """Connect to the replica set if there is one, otherwise the standalone. Authenticates if necessary. """ return _mongo_client(h, p, **kwargs) def ensure_all_connected(client): """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a non-replica set client. Depending on the use-case, the caller may need to clear any event listeners that are configured on the client. """ ismaster = client.admin.command("isMaster") if 'setName' not in ismaster: raise ConfigurationError("cluster is not a replica set") target_host_list = set(ismaster['hosts']) connected_host_list = set([ismaster['me']]) admindb = client.get_database('admin') # Run isMaster until we have connected to each host at least once. while connected_host_list != target_host_list: ismaster = admindb.command("isMaster", read_preference=ReadPreference.SECONDARY) connected_host_list.update([ismaster["me"]]) def one(s): """Get one element of a set""" return next(iter(s)) def oid_generated_on_process(oid): """Makes a determination as to whether the given ObjectId was generated by the current process, based on the 5-byte random number in the ObjectId. """ return ObjectId._random() == oid.binary[4:9] def delay(sec): return '''function() { sleep(%f * 1000); return true; }''' % sec def get_command_line(client): command_line = client.admin.command('getCmdLineOpts') assert command_line['ok'] == 1, "getCmdLineOpts() failed" return command_line def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() def camel_to_upper_camel(camel): return camel[0].upper() + camel[1:] def camel_to_snake_args(arguments): for arg_name in list(arguments): c2s = camel_to_snake(arg_name) arguments[c2s] = arguments.pop(arg_name) return arguments def parse_collection_options(opts): if 'readPreference' in opts: opts['read_preference'] = parse_read_preference( opts.pop('readPreference')) if 'writeConcern' in opts: opts['write_concern'] = WriteConcern( **dict(opts.pop('writeConcern'))) if 'readConcern' in opts: opts['read_concern'] = ReadConcern( **dict(opts.pop('readConcern'))) return opts def server_started_with_option(client, cmdline_opt, config_opt): """Check if the server was started with a particular option. :Parameters: - `cmdline_opt`: The command line option (i.e. --nojournal) - `config_opt`: The config file option (i.e. nojournal) """ command_line = get_command_line(client) if 'parsed' in command_line: parsed = command_line['parsed'] if config_opt in parsed: return parsed[config_opt] argv = command_line['argv'] return cmdline_opt in argv def server_started_with_auth(client): try: command_line = get_command_line(client) except OperationFailure as e: msg = e.details.get('errmsg', '') if e.code == 13 or 'unauthorized' in msg or 'login' in msg: # Unauthorized. return True raise # MongoDB >= 2.0 if 'parsed' in command_line: parsed = command_line['parsed'] # MongoDB >= 2.6 if 'security' in parsed: security = parsed['security'] # >= rc3 if 'authorization' in security: return security['authorization'] == 'enabled' # < rc3 return security.get('auth', False) or bool(security.get('keyFile')) return parsed.get('auth', False) or bool(parsed.get('keyFile')) # Legacy argv = command_line['argv'] return '--auth' in argv or '--keyFile' in argv def server_started_with_nojournal(client): command_line = get_command_line(client) # MongoDB 2.6. if 'parsed' in command_line: parsed = command_line['parsed'] if 'storage' in parsed: storage = parsed['storage'] if 'journal' in storage: return not storage['journal']['enabled'] return server_started_with_option(client, '--nojournal', 'nojournal') def server_is_master_with_slave(client): command_line = get_command_line(client) if 'parsed' in command_line: return command_line['parsed'].get('master', False) return '--master' in command_line['argv'] def drop_collections(db): # Drop all non-system collections in this database. for coll in db.list_collection_names( filter={"name": {"$regex": r"^(?!system\.)"}}): db.drop_collection(coll) def remove_all_users(db): db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w}) def joinall(threads): """Join threads with a 5-minute timeout, assert joins succeeded""" for t in threads: t.join(300) assert not t.is_alive(), "Thread %s hung" % t def connected(client): """Convenience to wait for a newly-constructed client to connect.""" with warnings.catch_warnings(): # Ignore warning that "ismaster" is always routed to primary even # if client's read preference isn't PRIMARY. warnings.simplefilter("ignore", UserWarning) client.admin.command('ismaster') # Force connection. return client def wait_until(predicate, success_description, timeout=10): """Wait up to 10 seconds (by default) for predicate to be true. E.g.: wait_until(lambda: client.primary == ('a', 1), 'connect to the primary') If the lambda-expression isn't true after 10 seconds, we raise AssertionError("Didn't ever connect to the primary"). Returns the predicate's first true value. """ start = time.time() interval = min(float(timeout)/100, 0.1) while True: retval = predicate() if retval: return retval if time.time() - start > timeout: raise AssertionError("Didn't ever %s" % success_description) time.sleep(interval) def repl_set_step_down(client, **kwargs): """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" cmd = SON([('replSetStepDown', 1)]) cmd.update(kwargs) # Unfreeze a secondary to ensure a speedy election. client.admin.command( 'replSetFreeze', 0, read_preference=ReadPreference.SECONDARY) client.admin.command(cmd) def is_mongos(client): res = client.admin.command('ismaster') return res.get('msg', '') == 'isdbgrid' def assertRaisesExactly(cls, fn, *args, **kwargs): """ Unlike the standard assertRaises, this checks that a function raises a specific class of exception, and not a subclass. E.g., check that MongoClient() raises ConnectionFailure but not its subclass, AutoReconnect. """ try: fn(*args, **kwargs) except Exception as e: assert e.__class__ == cls, "got %s, expected %s" % ( e.__class__.__name__, cls.__name__) else: raise AssertionError("%s not raised" % cls) @contextlib.contextmanager def _ignore_deprecations(): with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) yield def ignore_deprecations(wrapped=None): """A context manager or a decorator.""" if wrapped: @functools.wraps(wrapped) def wrapper(*args, **kwargs): with _ignore_deprecations(): return wrapped(*args, **kwargs) return wrapper else: return _ignore_deprecations() class DeprecationFilter(object): def __init__(self, action="ignore"): """Start filtering deprecations.""" self.warn_context = warnings.catch_warnings() self.warn_context.__enter__() warnings.simplefilter(action, DeprecationWarning) def stop(self): """Stop filtering deprecations.""" self.warn_context.__exit__() self.warn_context = None def get_pool(client): """Get the standalone, primary, or mongos pool.""" topology = client._get_topology() server = topology.select_server(writable_server_selector) return server.pool def get_pools(client): """Get all pools.""" return [ server.pool for server in client._get_topology().select_servers(any_server_selector)] # Constants for run_threads and lazy_client_trial. NTRIALS = 5 NTHREADS = 10 def run_threads(collection, target): """Run a target function in many threads. target is a function taking a Collection and an integer. """ threads = [] for i in range(NTHREADS): bound_target = partial(target, collection, i) threads.append(threading.Thread(target=bound_target)) for t in threads: t.start() for t in threads: t.join(60) assert not t.is_alive() @contextlib.contextmanager def frequent_thread_switches(): """Make concurrency bugs more likely to manifest.""" interval = None if not sys.platform.startswith('java'): if hasattr(sys, 'getswitchinterval'): interval = sys.getswitchinterval() sys.setswitchinterval(1e-6) else: interval = sys.getcheckinterval() sys.setcheckinterval(1) try: yield finally: if not sys.platform.startswith('java'): if hasattr(sys, 'setswitchinterval'): sys.setswitchinterval(interval) else: sys.setcheckinterval(interval) def lazy_client_trial(reset, target, test, get_client): """Test concurrent operations on a lazily-connecting client. `reset` takes a collection and resets it for the next trial. `target` takes a lazily-connecting collection and an index from 0 to NTHREADS, and performs some operation, e.g. an insert. `test` takes the lazily-connecting collection and asserts a post-condition to prove `target` succeeded. """ collection = client_context.client.pymongo_test.test with frequent_thread_switches(): for i in range(NTRIALS): reset(collection) lazy_client = get_client() lazy_collection = lazy_client.pymongo_test.test run_threads(lazy_collection, target) test(lazy_collection) def gevent_monkey_patched(): """Check if gevent's monkey patching is active.""" # In Python 3.6 importing gevent.socket raises an ImportWarning. with warnings.catch_warnings(): warnings.simplefilter("ignore", ImportWarning) try: import socket import gevent.socket return socket.socket is gevent.socket.socket except ImportError: return False def eventlet_monkey_patched(): """Check if eventlet's monkey patching is active.""" try: import threading import eventlet return (threading.current_thread.__module__ == 'eventlet.green.threading') except ImportError: return False def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() def disable_replication(client): """Disable replication on all secondaries, requires MongoDB 3.2.""" for host, port in client.secondaries: secondary = single_client(host, port) secondary.admin.command('configureFailPoint', 'stopReplProducer', mode='alwaysOn') def enable_replication(client): """Enable replication on all secondaries, requires MongoDB 3.2.""" for host, port in client.secondaries: secondary = single_client(host, port) secondary.admin.command('configureFailPoint', 'stopReplProducer', mode='off') class ExceptionCatchingThread(threading.Thread): """A thread that stores any exception encountered from run().""" def __init__(self, *args, **kwargs): self.exc = None super(ExceptionCatchingThread, self).__init__(*args, **kwargs) def run(self): try: super(ExceptionCatchingThread, self).run() except BaseException as exc: self.exc = exc raise def parse_read_preference(pref): # Make first letter lowercase to match read_pref's modes. mode_string = pref.get('mode', 'primary') mode_string = mode_string[:1].lower() + mode_string[1:] mode = read_preferences.read_pref_mode_from_name(mode_string) max_staleness = pref.get('maxStalenessSeconds', -1) tag_sets = pref.get('tag_sets') return read_preferences.make_read_preference( mode, tag_sets=tag_sets, max_staleness=max_staleness) def server_name_to_type(name): """Convert a ServerType name to the corresponding value. For SDAM tests.""" # Special case, some tests in the spec include the PossiblePrimary # type, but only single-threaded drivers need that type. We call # possible primaries Unknown. if name == 'PossiblePrimary': return SERVER_TYPE.Unknown return getattr(SERVER_TYPE, name) def cat_files(dest, *sources): """Cat multiple files into dest.""" with open(dest, 'wb') as fdst: for src in sources: with open(src, 'rb') as fsrc: shutil.copyfileobj(fsrc, fdst) @contextlib.contextmanager def assertion_context(msg): """A context manager that adds info to an assertion failure.""" try: yield except AssertionError as exc: msg = '%s (%s)' % (exc, msg) py3compat.reraise(type(exc), msg, sys.exc_info()[2]) pymongo-3.11.0/test/utils_selection_tests.py000066400000000000000000000224251374256237000212620ustar00rootroot00000000000000# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for testing Server Selection and Max Staleness.""" import datetime import os import sys sys.path[0:0] = [""] from bson import json_util from pymongo.common import clean_node, HEARTBEAT_FREQUENCY from pymongo.errors import AutoReconnect, ConfigurationError from pymongo.ismaster import IsMaster from pymongo.server_description import ServerDescription from pymongo.settings import TopologySettings from pymongo.server_selectors import writable_server_selector from pymongo.topology import Topology from test import unittest from test.utils import MockPool, parse_read_preference class MockMonitor(object): def __init__(self, server_description, topology, pool, topology_settings): pass def cancel_check(self): pass def open(self): pass def request_check(self): pass def close(self): pass def get_addresses(server_list): seeds = [] hosts = [] for server in server_list: seeds.append(clean_node(server['address'])) hosts.append(server['address']) return seeds, hosts def make_last_write_date(server): epoch = datetime.datetime.utcfromtimestamp(0) millis = server.get('lastWrite', {}).get('lastWriteDate') if millis: diff = ((millis % 1000) + 1000) % 1000 seconds = (millis - diff) / 1000 micros = diff * 1000 return epoch + datetime.timedelta( seconds=seconds, microseconds=micros) else: # "Unknown" server. return epoch def make_server_description(server, hosts): """Make a ServerDescription from server info in a JSON test.""" server_type = server['type'] if server_type == "Unknown": return ServerDescription(clean_node(server['address']), IsMaster({})) ismaster_response = {'ok': True, 'hosts': hosts} if server_type != "Standalone" and server_type != "Mongos": ismaster_response['setName'] = "rs" if server_type == "RSPrimary": ismaster_response['ismaster'] = True elif server_type == "RSSecondary": ismaster_response['secondary'] = True elif server_type == "Mongos": ismaster_response['msg'] = 'isdbgrid' ismaster_response['lastWrite'] = { 'lastWriteDate': make_last_write_date(server) } for field in 'maxWireVersion', 'tags', 'idleWritePeriodMillis': if field in server: ismaster_response[field] = server[field] ismaster_response.setdefault('maxWireVersion', 6) # Sets _last_update_time to now. sd = ServerDescription(clean_node(server['address']), IsMaster(ismaster_response), round_trip_time=server['avg_rtt_ms'] / 1000.0) if 'lastUpdateTime' in server: sd._last_update_time = server['lastUpdateTime'] / 1000.0 # ms to sec. return sd def get_topology_type_name(scenario_def): td = scenario_def['topology_description'] name = td['type'] if name == 'Unknown': # PyMongo never starts a topology in type Unknown. return 'Sharded' if len(td['servers']) > 1 else 'Single' else: return name def get_topology_settings_dict(**kwargs): settings = dict( monitor_class=MockMonitor, heartbeat_frequency=HEARTBEAT_FREQUENCY, pool_class=MockPool ) settings.update(kwargs) return settings def create_test(scenario_def): def run_scenario(self): # Initialize topologies. if 'heartbeatFrequencyMS' in scenario_def: frequency = int(scenario_def['heartbeatFrequencyMS']) / 1000.0 else: frequency = HEARTBEAT_FREQUENCY seeds, hosts = get_addresses( scenario_def['topology_description']['servers']) settings = get_topology_settings_dict( heartbeat_frequency=frequency, seeds=seeds ) # "Eligible servers" is defined in the server selection spec as # the set of servers matching both the ReadPreference's mode # and tag sets. top_latency = Topology(TopologySettings(**settings)) top_latency.open() # "In latency window" is defined in the server selection # spec as the subset of suitable_servers that falls within the # allowable latency window. settings['local_threshold_ms'] = 1000000 top_suitable = Topology(TopologySettings(**settings)) top_suitable.open() # Update topologies with server descriptions. for server in scenario_def['topology_description']['servers']: server_description = make_server_description(server, hosts) top_suitable.on_change(server_description) top_latency.on_change(server_description) # Create server selector. if scenario_def.get("operation") == "write": pref = writable_server_selector else: # Make first letter lowercase to match read_pref's modes. pref_def = scenario_def['read_preference'] if scenario_def.get('error'): with self.assertRaises((ConfigurationError, ValueError)): # Error can be raised when making Read Pref or selecting. pref = parse_read_preference(pref_def) top_latency.select_server(pref) return pref = parse_read_preference(pref_def) # Select servers. if not scenario_def.get('suitable_servers'): with self.assertRaises(AutoReconnect): top_suitable.select_server(pref, server_selection_timeout=0) return if not scenario_def['in_latency_window']: with self.assertRaises(AutoReconnect): top_latency.select_server(pref, server_selection_timeout=0) return actual_suitable_s = top_suitable.select_servers( pref, server_selection_timeout=0) actual_latency_s = top_latency.select_servers( pref, server_selection_timeout=0) expected_suitable_servers = {} for server in scenario_def['suitable_servers']: server_description = make_server_description(server, hosts) expected_suitable_servers[server['address']] = server_description actual_suitable_servers = {} for s in actual_suitable_s: actual_suitable_servers["%s:%d" % (s.description.address[0], s.description.address[1])] = s.description self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) for k, actual in actual_suitable_servers.items(): expected = expected_suitable_servers[k] self.assertEqual(expected.address, actual.address) self.assertEqual(expected.server_type, actual.server_type) self.assertEqual(expected.round_trip_time, actual.round_trip_time) self.assertEqual(expected.tags, actual.tags) self.assertEqual(expected.all_hosts, actual.all_hosts) expected_latency_servers = {} for server in scenario_def['in_latency_window']: server_description = make_server_description(server, hosts) expected_latency_servers[server['address']] = server_description actual_latency_servers = {} for s in actual_latency_s: actual_latency_servers["%s:%d" % (s.description.address[0], s.description.address[1])] = s.description self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) for k, actual in actual_latency_servers.items(): expected = expected_latency_servers[k] self.assertEqual(expected.address, actual.address) self.assertEqual(expected.server_type, actual.server_type) self.assertEqual(expected.round_trip_time, actual.round_trip_time) self.assertEqual(expected.tags, actual.tags) self.assertEqual(expected.all_hosts, actual.all_hosts) return run_scenario def create_selection_tests(test_dir): class TestAllScenarios(unittest.TestCase): pass for dirpath, _, filenames in os.walk(test_dir): dirname = os.path.split(dirpath) dirname = os.path.split(dirname[-2])[-1] + '_' + dirname[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json_util.loads(scenario_stream.read()) # Construct test from scenario. new_test = create_test(scenario_def) test_name = 'test_%s_%s' % ( dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) return TestAllScenarios pymongo-3.11.0/test/utils_spec_runner.py000066400000000000000000000724171374256237000204040ustar00rootroot00000000000000# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for testing driver specs.""" import copy import threading from bson import decode, encode from bson.binary import Binary, STANDARD from bson.codec_options import CodecOptions from bson.int64 import Int64 from bson.py3compat import iteritems, abc, string_type, text_type from bson.son import SON from gridfs import GridFSBucket from pymongo import (client_session, helpers, operations) from pymongo.command_cursor import CommandCursor from pymongo.cursor import Cursor from pymongo.errors import (BulkWriteError, OperationFailure, PyMongoError) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import _WriteResult, BulkWriteResult from pymongo.write_concern import WriteConcern from test import (client_context, client_knobs, IntegrationTest, unittest) from test.utils import (camel_to_snake, camel_to_snake_args, camel_to_upper_camel, CompareType, CMAPListener, OvertCommandListener, parse_read_preference, rs_client, ServerAndTopologyEventListener, HeartbeatEventListener) class SpecRunnerThread(threading.Thread): def __init__(self, name): super(SpecRunnerThread, self).__init__() self.name = name self.exc = None self.setDaemon(True) self.cond = threading.Condition() self.ops = [] self.stopped = False def schedule(self, work): self.ops.append(work) with self.cond: self.cond.notify() def stop(self): self.stopped = True with self.cond: self.cond.notify() def run(self): while not self.stopped or self.ops: if not self. ops: with self.cond: self.cond.wait(10) if self.ops: try: work = self.ops.pop(0) work() except Exception as exc: self.exc = exc self.stop() class SpecRunner(IntegrationTest): @classmethod def setUpClass(cls): super(SpecRunner, cls).setUpClass() cls.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @classmethod def tearDownClass(cls): cls.knobs.disable() super(SpecRunner, cls).tearDownClass() def setUp(self): super(SpecRunner, self).setUp() self.targets = {} self.listener = None self.pool_listener = None self.server_listener = None self.maxDiff = None def _set_fail_point(self, client, command_args): cmd = SON([('configureFailPoint', 'failCommand')]) cmd.update(command_args) client.admin.command(cmd) def set_fail_point(self, command_args): cmd = SON([('configureFailPoint', 'failCommand')]) cmd.update(command_args) clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: self._set_fail_point(client, cmd) def targeted_fail_point(self, session, fail_point): """Run the targetedFailPoint test operation. Enable the fail point on the session's pinned mongos. """ clients = {c.address: c for c in self.mongos_clients} client = clients[session._pinned_address] self._set_fail_point(client, fail_point) self.addCleanup(self.set_fail_point, {'mode': 'off'}) def assert_session_pinned(self, session): """Run the assertSessionPinned test operation. Assert that the given session is pinned. """ self.assertIsNotNone(session._transaction.pinned_address) def assert_session_unpinned(self, session): """Run the assertSessionUnpinned test operation. Assert that the given session is not pinned. """ self.assertIsNone(session._pinned_address) self.assertIsNone(session._transaction.pinned_address) def assert_collection_exists(self, database, collection): """Run the assertCollectionExists test operation.""" db = self.client[database] self.assertIn(collection, db.list_collection_names()) def assert_collection_not_exists(self, database, collection): """Run the assertCollectionNotExists test operation.""" db = self.client[database] self.assertNotIn(collection, db.list_collection_names()) def assert_index_exists(self, database, collection, index): """Run the assertIndexExists test operation.""" coll = self.client[database][collection] self.assertIn(index, [doc['name'] for doc in coll.list_indexes()]) def assert_index_not_exists(self, database, collection, index): """Run the assertIndexNotExists test operation.""" coll = self.client[database][collection] self.assertNotIn(index, [doc['name'] for doc in coll.list_indexes()]) def assertErrorLabelsContain(self, exc, expected_labels): labels = [l for l in expected_labels if exc.has_error_label(l)] self.assertEqual(labels, expected_labels) def assertErrorLabelsOmit(self, exc, omit_labels): for label in omit_labels: self.assertFalse( exc.has_error_label(label), msg='error labels should not contain %s' % (label,)) def kill_all_sessions(self): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: try: client.admin.command('killAllSessions', []) except OperationFailure: # "operation was interrupted" by killing the command's # own session. pass def check_command_result(self, expected_result, result): # Only compare the keys in the expected result. filtered_result = {} for key in expected_result: try: filtered_result[key] = result[key] except KeyError: pass self.assertEqual(filtered_result, expected_result) # TODO: factor the following function with test_crud.py. def check_result(self, expected_result, result): if isinstance(result, _WriteResult): for res in expected_result: prop = camel_to_snake(res) # SPEC-869: Only BulkWriteResult has upserted_count. if (prop == "upserted_count" and not isinstance(result, BulkWriteResult)): if result.upserted_id is not None: upserted_count = 1 else: upserted_count = 0 self.assertEqual(upserted_count, expected_result[res], prop) elif prop == "inserted_ids": # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): self.assertEqual(len(expected_result[res]), result.inserted_count) else: # InsertManyResult may be compared to [id1] from the # crud spec or {"0": id1} from the retryable write spec. ids = expected_result[res] if isinstance(ids, dict): ids = [ids[str(i)] for i in range(len(ids))] self.assertEqual(ids, result.inserted_ids, prop) elif prop == "upserted_ids": # Convert indexes from strings to integers. ids = expected_result[res] expected_ids = {} for str_index in ids: expected_ids[int(str_index)] = ids[str_index] self.assertEqual(expected_ids, result.upserted_ids, prop) else: self.assertEqual( getattr(result, prop), expected_result[res], prop) return True else: self.assertEqual(result, expected_result) def get_object_name(self, op): """Allow subclasses to override handling of 'object' Transaction spec says 'object' is required. """ return op['object'] @staticmethod def parse_options(opts): if 'readPreference' in opts: opts['read_preference'] = parse_read_preference( opts.pop('readPreference')) if 'writeConcern' in opts: opts['write_concern'] = WriteConcern( **dict(opts.pop('writeConcern'))) if 'readConcern' in opts: opts['read_concern'] = ReadConcern( **dict(opts.pop('readConcern'))) if 'maxTimeMS' in opts: opts['max_time_ms'] = opts.pop('maxTimeMS') if 'maxCommitTimeMS' in opts: opts['max_commit_time_ms'] = opts.pop('maxCommitTimeMS') if 'hint' in opts: hint = opts.pop('hint') if not isinstance(hint, string_type): hint = list(iteritems(hint)) opts['hint'] = hint # Properly format 'hint' arguments for the Bulk API tests. if 'requests' in opts: reqs = opts.pop('requests') for req in reqs: args = req.pop('arguments') if 'hint' in args: hint = args.pop('hint') if not isinstance(hint, string_type): hint = list(iteritems(hint)) args['hint'] = hint req['arguments'] = args opts['requests'] = reqs return dict(opts) def run_operation(self, sessions, collection, operation): original_collection = collection name = camel_to_snake(operation['name']) if name == 'run_command': name = 'command' elif name == 'download_by_name': name = 'open_download_stream_by_name' elif name == 'download': name = 'open_download_stream' database = collection.database collection = database.get_collection(collection.name) if 'collectionOptions' in operation: collection = collection.with_options( **self.parse_options(operation['collectionOptions'])) object_name = self.get_object_name(operation) if object_name == 'gridfsbucket': # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). obj = GridFSBucket( database, bucket_name=collection.name, disable_md5=True) else: objects = { 'client': database.client, 'database': database, 'collection': collection, 'testRunner': self } objects.update(sessions) obj = objects[object_name] # Combine arguments with options and handle special cases. arguments = operation.get('arguments', {}) arguments.update(arguments.pop("options", {})) self.parse_options(arguments) cmd = getattr(obj, name) for arg_name in list(arguments): c2s = camel_to_snake(arg_name) # PyMongo accepts sort as list of tuples. if arg_name == "sort": sort_dict = arguments[arg_name] arguments[arg_name] = list(iteritems(sort_dict)) # Named "key" instead not fieldName. if arg_name == "fieldName": arguments["key"] = arguments.pop(arg_name) # Aggregate uses "batchSize", while find uses batch_size. elif ((arg_name == "batchSize" or arg_name == "allowDiskUse") and name == "aggregate"): continue # Requires boolean returnDocument. elif arg_name == "returnDocument": arguments[c2s] = arguments.pop(arg_name) == "After" elif c2s == "requests": # Parse each request into a bulk write model. requests = [] for request in arguments["requests"]: bulk_model = camel_to_upper_camel(request["name"]) bulk_class = getattr(operations, bulk_model) bulk_arguments = camel_to_snake_args(request["arguments"]) requests.append(bulk_class(**dict(bulk_arguments))) arguments["requests"] = requests elif arg_name == "session": arguments['session'] = sessions[arguments['session']] elif (name in ('command', 'run_admin_command') and arg_name == 'command'): # Ensure the first key is the command name. ordered_command = SON([(operation['command_name'], 1)]) ordered_command.update(arguments['command']) arguments['command'] = ordered_command elif name == 'open_download_stream' and arg_name == 'id': arguments['file_id'] = arguments.pop(arg_name) elif name != 'find' and c2s == 'max_time_ms': # find is the only method that accepts snake_case max_time_ms. # All other methods take kwargs which must use the server's # camelCase maxTimeMS. See PYTHON-1855. arguments['maxTimeMS'] = arguments.pop('max_time_ms') elif name == 'with_transaction' and arg_name == 'callback': callback_ops = arguments[arg_name]['operations'] arguments['callback'] = lambda _: self.run_operations( sessions, original_collection, copy.deepcopy(callback_ops), in_with_transaction=True) elif name == 'drop_collection' and arg_name == 'collection': arguments['name_or_collection'] = arguments.pop(arg_name) elif name == 'create_collection' and arg_name == 'collection': arguments['name'] = arguments.pop(arg_name) elif name == 'create_index' and arg_name == 'keys': arguments['keys'] = list(arguments.pop(arg_name).items()) elif name == 'drop_index' and arg_name == 'name': arguments['index_or_name'] = arguments.pop(arg_name) else: arguments[c2s] = arguments.pop(arg_name) if name == 'run_on_thread': args = {'sessions': sessions, 'collection': collection} args.update(arguments) arguments = args result = cmd(**dict(arguments)) if name == "aggregate": if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: # Read from the primary to ensure causal consistency. out = collection.database.get_collection( arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY) return out.find() if name == "map_reduce": if isinstance(result, dict) and 'results' in result: return result['results'] if 'download' in name: result = Binary(result.read()) if isinstance(result, Cursor) or isinstance(result, CommandCursor): return list(result) return result def allowable_errors(self, op): """Allow encryption spec to override expected error classes.""" return (PyMongoError,) def _run_op(self, sessions, collection, op, in_with_transaction): expected_result = op.get('result') if expect_error(op): with self.assertRaises(self.allowable_errors(op), msg=op['name']) as context: self.run_operation(sessions, collection, op.copy()) if expect_error_message(expected_result): if isinstance(context.exception, BulkWriteError): errmsg = str(context.exception.details).lower() else: errmsg = str(context.exception).lower() self.assertIn(expected_result['errorContains'].lower(), errmsg) if expect_error_code(expected_result): self.assertEqual(expected_result['errorCodeName'], context.exception.details.get('codeName')) if expect_error_labels_contain(expected_result): self.assertErrorLabelsContain( context.exception, expected_result['errorLabelsContain']) if expect_error_labels_omit(expected_result): self.assertErrorLabelsOmit( context.exception, expected_result['errorLabelsOmit']) # Reraise the exception if we're in the with_transaction # callback. if in_with_transaction: raise context.exception else: result = self.run_operation(sessions, collection, op.copy()) if 'result' in op: if op['name'] == 'runCommand': self.check_command_result(expected_result, result) else: self.check_result(expected_result, result) def run_operations(self, sessions, collection, ops, in_with_transaction=False): for op in ops: self._run_op(sessions, collection, op, in_with_transaction) # TODO: factor with test_command_monitoring.py def check_events(self, test, listener, session_ids): res = listener.results if not len(test['expectations']): return # Give a nicer message when there are missing or extra events cmds = decode_raw([event.command for event in res['started']]) self.assertEqual( len(res['started']), len(test['expectations']), cmds) for i, expectation in enumerate(test['expectations']): event_type = next(iter(expectation)) event = res['started'][i] # The tests substitute 42 for any number other than 0. if (event.command_name == 'getMore' and event.command['getMore']): event.command['getMore'] = Int64(42) elif event.command_name == 'killCursors': event.command['cursors'] = [Int64(42)] elif event.command_name == 'update': # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into expectations. updates = expectation[event_type]['command']['updates'] for update in updates: update.setdefault('upsert', False) update.setdefault('multi', False) # Replace afterClusterTime: 42 with actual afterClusterTime. expected_cmd = expectation[event_type]['command'] expected_read_concern = expected_cmd.get('readConcern') if expected_read_concern is not None: time = expected_read_concern.get('afterClusterTime') if time == 42: actual_time = event.command.get( 'readConcern', {}).get('afterClusterTime') if actual_time is not None: expected_read_concern['afterClusterTime'] = actual_time recovery_token = expected_cmd.get('recoveryToken') if recovery_token == 42: expected_cmd['recoveryToken'] = CompareType(dict) # Replace lsid with a name like "session0" to match test. if 'lsid' in event.command: for name, lsid in session_ids.items(): if event.command['lsid'] == lsid: event.command['lsid'] = name break for attr, expected in expectation[event_type].items(): actual = getattr(event, attr) expected = wrap_types(expected) if isinstance(expected, dict): for key, val in expected.items(): if val is None: if key in actual: self.fail("Unexpected key [%s] in %r" % ( key, actual)) elif key not in actual: self.fail("Expected key [%s] in %r" % ( key, actual)) else: self.assertEqual(val, decode_raw(actual[key]), "Key [%s] in %s" % (key, actual)) else: self.assertEqual(actual, expected) def maybe_skip_scenario(self, test): if test.get('skipReason'): raise unittest.SkipTest(test.get('skipReason')) def get_scenario_db_name(self, scenario_def): """Allow subclasses to override a test's database name.""" return scenario_def['database_name'] def get_scenario_coll_name(self, scenario_def): """Allow subclasses to override a test's collection name.""" return scenario_def['collection_name'] def get_outcome_coll_name(self, outcome, collection): """Allow subclasses to override outcome collection.""" return collection.name def run_test_ops(self, sessions, collection, test): """Added to allow retryable writes spec to override a test's operation.""" self.run_operations(sessions, collection, test['operations']) def parse_client_options(self, opts): """Allow encryption spec to override a clientOptions parsing.""" # Convert test['clientOptions'] to dict to avoid a Jython bug using # "**" with ScenarioDict. return dict(opts) def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) db = client_context.client.get_database( db_name, write_concern=WriteConcern(w='majority')) coll = db[coll_name] coll.drop() db.create_collection(coll_name) if scenario_def['data']: # Load data. coll.insert_many(scenario_def['data']) def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) # Kill all sessions before and after each test to prevent an open # transaction (from a test failure) from blocking collection/database # operations during test set up and tear down. self.kill_all_sessions() self.addCleanup(self.kill_all_sessions) self.setup_scenario(scenario_def) database_name = self.get_scenario_db_name(scenario_def) collection_name = self.get_scenario_coll_name(scenario_def) # SPEC-1245 workaround StaleDbVersion on distinct for c in self.mongos_clients: c[database_name][collection_name].distinct("x") # Configure the fail point before creating the client. if 'failPoint' in test: fp = test['failPoint'] self.set_fail_point(fp) self.addCleanup(self.set_fail_point, { 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) listener = OvertCommandListener() pool_listener = CMAPListener() server_listener = ServerAndTopologyEventListener() # Create a new client, to avoid interference from pooled sessions. client_options = self.parse_client_options(test['clientOptions']) # MMAPv1 does not support retryable writes. if (client_options.get('retryWrites') is True and client_context.storage_engine == 'mmapv1'): self.skipTest("MMAPv1 does not support retryWrites=True") use_multi_mongos = test['useMultipleMongoses'] if client_context.is_mongos and use_multi_mongos: client = rs_client( client_context.mongos_seeds(), event_listeners=[listener, pool_listener, server_listener], **client_options) else: client = rs_client( event_listeners=[listener, pool_listener, server_listener], **client_options) self.scenario_client = client self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener # Close the client explicitly to avoid having too many threads open. self.addCleanup(client.close) # Create session0 and session1. sessions = {} session_ids = {} for i in range(2): # Don't attempt to create sessions if they are not supported by # the running server version. if not client_context.sessions_enabled: break session_name = 'session%d' % i opts = camel_to_snake_args(test['sessionOptions'][session_name]) if 'default_transaction_options' in opts: txn_opts = self.parse_options( opts['default_transaction_options']) txn_opts = client_session.TransactionOptions(**txn_opts) opts['default_transaction_options'] = txn_opts s = client.start_session(**dict(opts)) sessions[session_name] = s # Store lsid so we can access it after end_session, in check_events. session_ids[session_name] = s.session_id self.addCleanup(end_sessions, sessions) collection = client[database_name][collection_name] self.run_test_ops(sessions, collection, test) end_sessions(sessions) self.check_events(test, listener, session_ids) # Disable fail points. if 'failPoint' in test: fp = test['failPoint'] self.set_fail_point({ 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) # Assert final state is expected. outcome = test['outcome'] expected_c = outcome.get('collection') if expected_c is not None: outcome_coll_name = self.get_outcome_coll_name( outcome, collection) # Read from the primary with local read concern to ensure causal # consistency. outcome_coll = client_context.client[ collection.database.name].get_collection( outcome_coll_name, read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern('local')) actual_data = list(outcome_coll.find(sort=[('_id', 1)])) # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. self.assertEqual(wrap_types(expected_c['data']), actual_data) def expect_any_error(op): if isinstance(op, dict): return op.get('error') return False def expect_error_message(expected_result): if isinstance(expected_result, dict): return isinstance(expected_result['errorContains'], text_type) return False def expect_error_code(expected_result): if isinstance(expected_result, dict): return expected_result['errorCodeName'] return False def expect_error_labels_contain(expected_result): if isinstance(expected_result, dict): return expected_result['errorLabelsContain'] return False def expect_error_labels_omit(expected_result): if isinstance(expected_result, dict): return expected_result['errorLabelsOmit'] return False def expect_error(op): expected_result = op.get('result') return (expect_any_error(op) or expect_error_message(expected_result) or expect_error_code(expected_result) or expect_error_labels_contain(expected_result) or expect_error_labels_omit(expected_result)) def end_sessions(sessions): for s in sessions.values(): # Aborts the transaction if it's open. s.end_session() OPTS = CodecOptions(document_class=dict, uuid_representation=STANDARD) def decode_raw(val): """Decode RawBSONDocuments in the given container.""" if isinstance(val, (list, abc.Mapping)): return decode(encode({'v': val}, codec_options=OPTS), OPTS)['v'] return val TYPES = { 'binData': Binary, 'long': Int64, } def wrap_types(val): """Support $$type assertion in command results.""" if isinstance(val, list): return [wrap_types(v) for v in val] if isinstance(val, abc.Mapping): typ = val.get('$$type') if typ: return CompareType(TYPES[typ]) d = {} for key in val: d[key] = wrap_types(val[key]) return d return val pymongo-3.11.0/test/version.py000066400000000000000000000056621374256237000163240ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Some tools for running tests based on MongoDB server version.""" class Version(tuple): def __new__(cls, *version): padded_version = cls._padded(version, 4) return super(Version, cls).__new__(cls, tuple(padded_version)) @classmethod def _padded(cls, iter, length, padding=0): l = list(iter) if len(l) < length: for _ in range(length - len(l)): l.append(padding) return l @classmethod def from_string(cls, version_string): mod = 0 bump_patch_level = False if version_string.endswith("+"): version_string = version_string[0:-1] mod = 1 elif version_string.endswith("-pre-"): version_string = version_string[0:-5] mod = -1 elif version_string.endswith("-"): version_string = version_string[0:-1] mod = -1 # Deal with '-rcX' substrings if '-rc' in version_string: version_string = version_string[0:version_string.find('-rc')] mod = -1 # Deal with git describe generated substrings elif '-' in version_string: version_string = version_string[0:version_string.find('-')] mod = -1 bump_patch_level = True version = [int(part) for part in version_string.split(".")] version = cls._padded(version, 3) # Make from_string and from_version_array agree. For example: # MongoDB Enterprise > db.runCommand('buildInfo').versionArray # [ 3, 2, 1, -100 ] # MongoDB Enterprise > db.runCommand('buildInfo').version # 3.2.0-97-g1ef94fe if bump_patch_level: version[-1] += 1 version.append(mod) return Version(*version) @classmethod def from_version_array(cls, version_array): version = list(version_array) if version[-1] < 0: version[-1] = -1 version = cls._padded(version, 3) return Version(*version) @classmethod def from_client(cls, client): info = client.server_info() if 'versionArray' in info: return cls.from_version_array(info['versionArray']) return cls.from_string(info['version']) def at_least(self, *other_version): return self >= Version(*other_version) def __str__(self): return ".".join(map(str, self)) pymongo-3.11.0/tools/000077500000000000000000000000001374256237000144355ustar00rootroot00000000000000pymongo-3.11.0/tools/README.rst000066400000000000000000000001171374256237000161230ustar00rootroot00000000000000Tools ===== This directory contains tools for use with the ``pymongo`` module. pymongo-3.11.0/tools/benchmark.py000066400000000000000000000135411374256237000167450ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MongoDB benchmarking suite.""" from __future__ import print_function import time import sys sys.path[0:0] = [""] import datetime from pymongo import mongo_client from pymongo import ASCENDING trials = 2 per_trial = 5000 batch_size = 100 small = {} medium = {"integer": 5, "number": 5.05, "boolean": False, "array": ["test", "benchmark"] } # this is similar to the benchmark data posted to the user list large = {"base_url": "http://www.example.com/test-me", "total_word_count": 6743, "access_time": datetime.datetime.utcnow(), "meta_tags": {"description": "i am a long description string", "author": "Holly Man", "dynamically_created_meta_tag": "who know\n what" }, "page_structure": {"counted_tags": 3450, "no_of_js_attached": 10, "no_of_images": 6 }, "harvested_words": ["10gen", "web", "open", "source", "application", "paas", "platform-as-a-service", "technology", "helps", "developers", "focus", "building", "mongodb", "mongo"] * 20 } def setup_insert(db, collection, object): db.drop_collection(collection) def insert(db, collection, object): for i in range(per_trial): to_insert = object.copy() to_insert["x"] = i db[collection].insert(to_insert) def insert_batch(db, collection, object): for i in range(per_trial / batch_size): db[collection].insert([object] * batch_size) def find_one(db, collection, x): for _ in range(per_trial): db[collection].find_one({"x": x}) def find(db, collection, x): for _ in range(per_trial): for _ in db[collection].find({"x": x}): pass def timed(name, function, args=[], setup=None): times = [] for _ in range(trials): if setup: setup(*args) start = time.time() function(*args) times.append(time.time() - start) best_time = min(times) print("{0:s}{1:d}".format(name + (60 - len(name)) * ".", per_trial / best_time)) return best_time def main(): c = mongo_client.MongoClient(connectTimeoutMS=60*1000) # jack up timeout c.drop_database("benchmark") db = c.benchmark timed("insert (small, no index)", insert, [db, 'small_none', small], setup_insert) timed("insert (medium, no index)", insert, [db, 'medium_none', medium], setup_insert) timed("insert (large, no index)", insert, [db, 'large_none', large], setup_insert) db.small_index.create_index("x", ASCENDING) timed("insert (small, indexed)", insert, [db, 'small_index', small]) db.medium_index.create_index("x", ASCENDING) timed("insert (medium, indexed)", insert, [db, 'medium_index', medium]) db.large_index.create_index("x", ASCENDING) timed("insert (large, indexed)", insert, [db, 'large_index', large]) timed("batch insert (small, no index)", insert_batch, [db, 'small_bulk', small], setup_insert) timed("batch insert (medium, no index)", insert_batch, [db, 'medium_bulk', medium], setup_insert) timed("batch insert (large, no index)", insert_batch, [db, 'large_bulk', large], setup_insert) timed("find_one (small, no index)", find_one, [db, 'small_none', per_trial / 2]) timed("find_one (medium, no index)", find_one, [db, 'medium_none', per_trial / 2]) timed("find_one (large, no index)", find_one, [db, 'large_none', per_trial / 2]) timed("find_one (small, indexed)", find_one, [db, 'small_index', per_trial / 2]) timed("find_one (medium, indexed)", find_one, [db, 'medium_index', per_trial / 2]) timed("find_one (large, indexed)", find_one, [db, 'large_index', per_trial / 2]) timed("find (small, no index)", find, [db, 'small_none', per_trial / 2]) timed("find (medium, no index)", find, [db, 'medium_none', per_trial / 2]) timed("find (large, no index)", find, [db, 'large_none', per_trial / 2]) timed("find (small, indexed)", find, [db, 'small_index', per_trial / 2]) timed("find (medium, indexed)", find, [db, 'medium_index', per_trial / 2]) timed("find (large, indexed)", find, [db, 'large_index', per_trial / 2]) # timed("find range (small, no index)", find, # [db, 'small_none', # {"$gt": per_trial / 4, "$lt": 3 * per_trial / 4}]) # timed("find range (medium, no index)", find, # [db, 'medium_none', # {"$gt": per_trial / 4, "$lt": 3 * per_trial / 4}]) # timed("find range (large, no index)", find, # [db, 'large_none', # {"$gt": per_trial / 4, "$lt": 3 * per_trial / 4}]) timed("find range (small, indexed)", find, [db, 'small_index', {"$gt": per_trial / 2, "$lt": per_trial / 2 + batch_size}]) timed("find range (medium, indexed)", find, [db, 'medium_index', {"$gt": per_trial / 2, "$lt": per_trial / 2 + batch_size}]) timed("find range (large, indexed)", find, [db, 'large_index', {"$gt": per_trial / 2, "$lt": per_trial / 2 + batch_size}]) if __name__ == "__main__": # cProfile.run("main()") main() pymongo-3.11.0/tools/clean.py000066400000000000000000000021301374256237000160650ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Clean up script for build artifacts. Only really intended to be used by internal build scripts. """ import os import sys try: os.remove("pymongo/_cmessage.so") os.remove("bson/_cbson.so") except: pass try: os.remove("pymongo/_cmessage.pyd") os.remove("bson/_cbson.pyd") except: pass try: from pymongo import _cmessage sys.exit("could still import _cmessage") except ImportError: pass try: from bson import _cbson sys.exit("could still import _cbson") except ImportError: pass pymongo-3.11.0/tools/fail_if_no_c.py000066400000000000000000000015121374256237000173750ustar00rootroot00000000000000# Copyright 2009-2015 MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fail if the C extension module doesn't exist. Only really intended to be used by internal build scripts. """ import sys sys.path[0:0] = [""] import bson import pymongo if not pymongo.has_c() or not bson.has_c(): sys.exit("could not load C extensions") pymongo-3.11.0/tools/ocsptest.py000066400000000000000000000036761374256237000166670ustar00rootroot00000000000000# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); you # may not use this file except in compliance with the License. You # may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. import argparse import logging import socket from ssl import CERT_REQUIRED from pymongo.pyopenssl_context import SSLContext from pymongo.ssl_support import get_ssl_context # Enable logs in this format: # 2020-06-08 23:49:35,982 DEBUG ocsp_support Peer did not staple an OCSP response FORMAT = '%(asctime)s %(levelname)s %(module)s %(message)s' logging.basicConfig(format=FORMAT, level=logging.DEBUG) def check_ocsp(host, port, capath): ctx = get_ssl_context( None, # certfile None, # keyfile None, # passphrase capath, CERT_REQUIRED, None, # crlfile True, # match_hostname True) # check_ocsp_endpoint # Ensure we're using pyOpenSSL. assert isinstance(ctx, SSLContext) s = socket.socket() s.connect((host, port)) try: s = ctx.wrap_socket(s, server_hostname=host) finally: s.close() def main(): parser = argparse.ArgumentParser( description='Debug OCSP') parser.add_argument( '--host', type=str, required=True, help="Host to connect to") parser.add_argument( '-p', '--port', type=int, default=443, help="Port to connect to") parser.add_argument( '--ca_file', type=str, default=None, help="CA file for host") args = parser.parse_args() check_ocsp(args.host, args.port, args.ca_file) if __name__ == '__main__': main()