pax_global_header00006660000000000000000000000064145656750150014530gustar00rootroot0000000000000052 comment=0c2a6e20f52b9678d8f3361040bf5a2a7a0c08a6 pytds-1.15.0/000077500000000000000000000000001456567501500127575ustar00rootroot00000000000000pytds-1.15.0/.gitignore000066400000000000000000000003421456567501500147460ustar00rootroot00000000000000*.pyc *.py~ *.swp build dist MANIFEST python_tds.egg-info .coverage coverage.xml test.sh env/ env2/ env3/ prof/ RELEASE-VERSION .DS_Store .idea/ docs/_build .tox/ /.test-cache/ /.cache/ /.pytest_cache/ /tests/.connection.json pytds-1.15.0/.pylintrc000066400000000000000000000001641456567501500146250ustar00rootroot00000000000000[MAIN] # At this point I prefer to have if/else even when there is a return in the if branch disable=no-else-returnpytds-1.15.0/.readthedocs.yaml000066400000000000000000000010731456567501500162070ustar00rootroot00000000000000# .readthedocs.yaml # Read the Docs configuration file # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Set the version of Python and other tools you might need build: os: ubuntu-22.04 tools: python: "3.11" # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # We recommend specifying your dependencies to enable reproducible builds: # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html python: install: - requirements: docs/requirements.txt pytds-1.15.0/.travis.yml000066400000000000000000000005631456567501500150740ustar00rootroot00000000000000language: python python: - "2.7" - "3.5" - "3.6" - "3.7" install: - python --version - "python -c \"import struct; print(struct.calcsize('P') * 8)\"" - pip install -e . - pip install -r test_requirements.txt script: - pytest -v --junitxml=junit-results.xml --cov=./ - codecov - python profiling/profile_smp.py - python profiling/profile_reader.py pytds-1.15.0/LICENSE.txt000066400000000000000000000020741456567501500146050ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2014 Mikhail Denisenko Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. pytds-1.15.0/MANIFEST.in000066400000000000000000000001121456567501500145070ustar00rootroot00000000000000include requirements.txt test_requirements.txt RELEASE-VERSION version.py pytds-1.15.0/README.rst000066400000000000000000000043071456567501500144520ustar00rootroot00000000000000pytds ===== .. image:: https://secure.travis-ci.org/denisenkom/pytds.png?branch=master :target: https://travis-ci.org/denisenkom/pytds .. image:: https://ci.appveyor.com/api/projects/status/a5h4y29063crqtet?svg=true :target: https://ci.appveyor.com/project/denisenkom/pytds .. image:: http://img.shields.io/pypi/v/python-tds.svg :target: https://pypi.python.org/pypi/python-tds/ .. image:: https://codecov.io/gh/denisenkom/pytds/branch/master/graph/badge.svg :target: https://codecov.io/gh/denisenkom/pytds `Python DBAPI`_ driver for MSSQL using pure Python TDS (Tabular Data Stream) protocol implementation. Doesn't depend on ADO or FreeTDS. Can be used on any platform, including Linux, MacOS, Windows. It can be used with https://pypi.python.org/pypi/django-sqlserver as a Django database backend. Features -------- * Fully supports new MSSQL 2008 date types: datetime2, date, time, datetimeoffset * MARS * Bulk insert * Table-valued parameters * TLS connection encryption * Kerberos support on non-Windows platforms (requires kerberos package) Installation ------------ To install run this command: .. code-block:: bash $ pip install python-tds If you want to use TLS you should also install pyOpenSSL package: .. code-block:: bash $ pip install pyOpenSSL For a better performance install bitarray package too: .. code-block:: bash $ pip install bitarray To use Kerberos on non-Windows platforms (experimental) install kerberos package: .. code-block:: bash $ pip install kerberos Documentation ------------- Documentation is available at https://python-tds.readthedocs.io/en/latest/. Example ------- To connect to database do .. code-block:: python import pytds with pytds.connect('server', 'database', 'user', 'password') as conn: with conn.cursor() as cur: cur.execute("select 1") cur.fetchall() To enable TLS you should also provide cafile parameter which should be a file name containing trusted CAs in PEM format. For detailed documentation of connection parameters see: `pytds.connect`_ .. _Python DBAPI: http://legacy.python.org/dev/peps/pep-0249/ .. _pytds.connect: https://python-tds.readthedocs.io/en/latest/pytds.html#pytds.connect pytds-1.15.0/RELEASE.rst000066400000000000000000000007531456567501500145760ustar00rootroot00000000000000How to release new version to pypi ================================== Make sure you don't have any local changes and that you are on the right branch by running: git status Verify that Travis CI tests are passing for current branch. Tag current commit by running: git tag -a Check build: python setup.py sdist Push to github: git push && git push --tags Install twine: python3 -m pip install --user --upgrade twine Upload to pypi: python3 -m twine upload dist/* pytds-1.15.0/appveyor.yml000066400000000000000000000027111456567501500153500ustar00rootroot00000000000000version: 1.0.{build} os: Windows Server 2012 R2 environment: INAPPVEYOR: 1 HOST: localhost SQLUSER: sa SQLPASSWORD: Password12! DATABASE: test matrix: - PYTHON: "C:\\Python38" SQLINSTANCE: SQL2016 - PYTHON: "C:\\Python38-x64" SQLINSTANCE: SQL2016 - PYTHON: "C:\\Python38-x64" SQLINSTANCE: SQL2014 - PYTHON: "C:\\Python38-x64" SQLINSTANCE: SQL2012SP1 - PYTHON: "C:\\Python38-x64" SQLINSTANCE: SQL2008R2SP2 install: - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - python --version - "python -c \"import struct; print(struct.calcsize('P') * 8)\"" - pip install -e . - pip install -r test_requirements.txt build_script: - python setup.py sdist before_test: # setup SQL Server - ps: | $instanceName = $env:SQLINSTANCE Start-Service "MSSQL`$$instanceName" Start-Service "SQLBrowser" - sqlcmd -S "(local)\%SQLINSTANCE%" -Q "Use [master]; CREATE DATABASE test; ALTER DATABASE test SET READ_COMMITTED_SNAPSHOT ON; ALTER DATABASE test SET ALLOW_SNAPSHOT_ISOLATION ON" - sqlcmd -S "(local)\%SQLINSTANCE%" -h -1 -Q "set nocount on; Select @@version" test_script: - mypy src - ruff check src - pytest -v --junitxml=junit-results.xml --cov=./ - ps: | $wc = New-Object 'System.Net.WebClient' $url = "https://ci.appveyor.com/api/testresults/junit/$env:APPVEYOR_JOB_ID" $wc.UploadFile($url, (Resolve-Path .\junit-results.xml)); Write-Output $url - codecov pytds-1.15.0/ci.bat000066400000000000000000000006651456567501500140510ustar00rootroot00000000000000%PYTHONHOME%\scripts\virtualenv env set PYTHONHOME= env\scripts\pip install coverage --use-mirrors env\scripts\pip install -r requirements.txt --use-mirrors env\scripts\pip install -r test_requirements.txt --use-mirrors env\scripts\pip install -r test_requirements26.txt --use-mirrors env\scripts\nosetests --with-coverage --cover-erase --cover-package=pytds --cover-xml --cover-xml-file=coverage.xml --with-xunit --xunit-file=xunit.xml pytds-1.15.0/ci.sh000077500000000000000000000006231456567501500137120ustar00rootroot00000000000000#!/bin/bash set -e $PYTHONHOME/bin/virtualenv env . env/bin/activate pip install coverage --use-mirrors pip install -r requirements.txt --use-mirrors pip install -r test_requirements.txt --use-mirrors pip install -r test_requirements26.txt --use-mirrors || true nosetests --with-coverage --cover-erase --cover-package=pytds --cover-xml --cover-xml-file=coverage.xml --with-xunit --xunit-file=xunit.xml pytds-1.15.0/conftest.py000066400000000000000000000002061456567501500151540ustar00rootroot00000000000000def pytest_configure(config): plugin = config.pluginmanager.getplugin("mypy") plugin.mypy_argv.append("--check-untyped-defs") pytds-1.15.0/docs/000077500000000000000000000000001456567501500137075ustar00rootroot00000000000000pytds-1.15.0/docs/Makefile000066400000000000000000000151721456567501500153550ustar00rootroot00000000000000# Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) endif # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " html to make standalone HTML files" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " xml to make Docutils-native XML files" @echo " pseudoxml to make pseudoxml-XML files for display purposes" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: rm -rf $(BUILDDIR)/* html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/python-tds.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/python-tds.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/python-tds" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/python-tds" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." latexpdfja: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through platex and dvipdfmx..." $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." xml: $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml @echo @echo "Build finished. The XML files are in $(BUILDDIR)/xml." pseudoxml: $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml @echo @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." pytds-1.15.0/docs/conf.py000066400000000000000000000202261456567501500152100ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # python-tds documentation build configuration file, created by # sphinx-quickstart on Sun Apr 28 23:59:50 2013. # # This file is execfile()d with the current directory set to its containing dir. # # Note that not all possible configuration values are present in this # autogenerated file. # # All configuration values have a default; values that are commented out # serve to show the default. import sys, os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. pkg_root = os.path.normpath(os.path.join(os.path.dirname(__file__), "..")) sys.path.insert(0, os.path.join(pkg_root, "src")) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.doctest", "sphinx.ext.todo", "sphinx.ext.coverage", "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # The suffix of source filenames. source_suffix = ".rst" # The encoding of source files. # source_encoding = 'utf-8-sig' # The master toctree document. master_doc = "index" # General information about the project. project = "python-tds" copyright = "2013, Mikhail Denisenko" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = "1.6" # The full version, including alpha/beta/rc tags. release = "1.6" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: # today = '' # Else, today_fmt is used as the format for a strftime call. # today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. # default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. # add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). # add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. # show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. # keep_warnings = False # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. # html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". # html_title = None # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. # html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. # html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. # html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. # html_use_smartypants = True # Custom sidebar templates, maps document names to template names. # html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. # html_additional_pages = {} # If false, no module index is generated. # html_domain_indices = True # If false, no index is generated. # html_use_index = True # If true, the index is split into individual pages for each letter. # html_split_index = False # If true, links to the reST sources are added to the pages. # html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. # html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). # html_file_suffix = None # Output file base name for HTML help builder. htmlhelp_basename = "python-tdsdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ( "index", "python-tds.tex", "python-tds Documentation", "Mikhail Denisenko", "manual", ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. # latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. # latex_use_parts = False # If true, show page references after internal links. # latex_show_pagerefs = False # If true, show URL addresses after external links. # latex_show_urls = False # Documents to append as an appendix to all manuals. # latex_appendices = [] # If false, no module index is generated. # latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ("index", "python-tds", "python-tds Documentation", ["Mikhail Denisenko"], 1) ] # If true, show URL addresses after external links. # man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ( "index", "python-tds", "python-tds Documentation", "Mikhail Denisenko", "python-tds", "One line description of project.", "Miscellaneous", ), ] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] # If false, no module index is generated. # texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. # texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. # texinfo_no_detailmenu = False pytds-1.15.0/docs/extensions.rst000066400000000000000000000016441456567501500166450ustar00rootroot00000000000000`pytds.extensions` -- Extensions to the DB API ============================================== .. module:: pytds.extensions .. _isolation-level-constants: Isolation level constants ------------------------- .. data:: ISOLATION_LEVEL_READ_UNCOMMITTED Transaction can read uncommitted data .. data:: ISOLATION_LEVEL_READ_COMMITTED Transaction can read only committed data, will block on attempt to read modified uncommitted data .. data:: ISOLATION_LEVEL_REPEATABLE_READ Transaction will place lock on read records, other transactions will block trying to modify such records .. data:: ISOLATION_LEVEL_SERIALIZABLE Transaction will lock tables to prevent other transactions from inserting new data that would match selected recordsets .. data:: ISOLATION_LEVEL_SNAPSHOT Allows non-blocking consistent reads on a snapshot for transaction without blocking other transactions changes pytds-1.15.0/docs/index.rst000066400000000000000000000070041456567501500155510ustar00rootroot00000000000000.. python-tds documentation master file, created by sphinx-quickstart on Sun Apr 28 23:59:50 2013. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Pytds - Microsoft SQL Server database adapter for Python ======================================================== Pytds is the top to bottom pure Python TDS implementation, that means cross-platform, and no dependency on ADO or FreeTDS. It supports large parameters (>4000 characters), MARS, timezones, new date types (datetime2, date, time, datetimeoffset). Even though it is implemented in Python performance is comparable to ADO and FreeTDS bindings. It also supports Python 3. .. rubric:: Contents .. toctree:: :maxdepth: 2 pytds extensions Connection to Mirrored Servers ============================== When MSSQL server is setup with mirroring you should connect to it using two parameters of :func:`pytds.connect`, one parameter is ``server`` this should be a main server and parameter ``failover_partner`` should be a mirror server. See also `MSDN article `_. Table Valued Parameters ======================= Here is example of using TVP: .. code-block:: py with conn.cursor() as cur: cur.execute('CREATE TYPE dbo.CategoryTableType AS TABLE ( CategoryID int, CategoryName nvarchar(50) )') conn.commit() tvp = pytds.TableValuedParam(type_name='dbo.CategoryTableType', rows=rows_gen()) cur.execute('SELECT * FROM %s', (tvp,)) Using Binary Parameters ======================= To use a parameter that is of a binary or varbinary type, you need to wrap the value with pytds.Binary(). This function accepts bytes objects so be sure to convert buffers or file-like objects to bytes first. Examples of wrapping various kinds of bytes representations: .. code-block:: py pytds.Binary(b'') pytds.Binary(b'\x00\x01\x02') pytds.Binary(b'x' * 9000) An example of how you might store an image from a file in a varbinary(MAX) field: .. code-block:: py image=Image.open(image_path) with io.BytesIO() as output: image.save(output, format="jpeg") image_data = pytds.Binary(output.getvalue()) with pytds.connect(dns='your connection info') as conn: with conn.cursor() as cur: cur.execute("insert into table_name (text_field, binary_field) values (%s, %s)", (image_name, image_data)) conn.commit() Testing ======= To run tests you need to have tox installed. Also you would want to have different versions of Python, you can use pyenv to install those. At a minimun you should set HOST environment variable to point to your SQL server, e.g.: .. code-block:: bash export HOST=mysqlserver it could also specify SQL server named instance, e.g.: .. code-block:: bash export HOST=mysqlserver\\myinstance By default tests will use SQL server integrated authentication using user sa with password sa and database test. You can specify different user name, password, database with SQLUSER, SQLPASSWORD, DATABASE environment variables. To enable testing NTLM authentication you should specify NTLM_USER and NTLM_PASSWORD environment variables. Once environment variables are setup you can run tests by running command: .. code-block:: bash tox Test configuration stored in tox.ini file at the root of the repository. Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` pytds-1.15.0/docs/install.rst000066400000000000000000000001621456567501500161060ustar00rootroot00000000000000Installation ============ As easy as: :: pip install python-tds All requirements are automatically installed. pytds-1.15.0/docs/make.bat000066400000000000000000000150651456567501500153230ustar00rootroot00000000000000@ECHO OFF REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set BUILDDIR=_build set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . set I18NSPHINXOPTS=%SPHINXOPTS% . if NOT "%PAPER%" == "" ( set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% ) if "%1" == "" goto help if "%1" == "help" ( :help echo.Please use `make ^` where ^ is one of echo. html to make standalone HTML files echo. dirhtml to make HTML files named index.html in directories echo. singlehtml to make a single large HTML file echo. pickle to make pickle files echo. json to make JSON files echo. htmlhelp to make HTML files and a HTML help project echo. qthelp to make HTML files and a qthelp project echo. devhelp to make HTML files and a Devhelp project echo. epub to make an epub echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter echo. text to make text files echo. man to make manual pages echo. texinfo to make Texinfo files echo. gettext to make PO message catalogs echo. changes to make an overview over all changed/added/deprecated items echo. xml to make Docutils-native XML files echo. pseudoxml to make pseudoxml-XML files for display purposes echo. linkcheck to check all external links for integrity echo. doctest to run all doctests embedded in the documentation if enabled goto end ) if "%1" == "clean" ( for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i del /q /s %BUILDDIR%\* goto end ) %SPHINXBUILD% 2> nul if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) if "%1" == "html" ( %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/html. goto end ) if "%1" == "dirhtml" ( %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. goto end ) if "%1" == "singlehtml" ( %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml if errorlevel 1 exit /b 1 echo. echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. goto end ) if "%1" == "pickle" ( %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the pickle files. goto end ) if "%1" == "json" ( %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can process the JSON files. goto end ) if "%1" == "htmlhelp" ( %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run HTML Help Workshop with the ^ .hhp project file in %BUILDDIR%/htmlhelp. goto end ) if "%1" == "qthelp" ( %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp if errorlevel 1 exit /b 1 echo. echo.Build finished; now you can run "qcollectiongenerator" with the ^ .qhcp project file in %BUILDDIR%/qthelp, like this: echo.^> qcollectiongenerator %BUILDDIR%\qthelp\python-tds.qhcp echo.To view the help file: echo.^> assistant -collectionFile %BUILDDIR%\qthelp\python-tds.ghc goto end ) if "%1" == "devhelp" ( %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp if errorlevel 1 exit /b 1 echo. echo.Build finished. goto end ) if "%1" == "epub" ( %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub if errorlevel 1 exit /b 1 echo. echo.Build finished. The epub file is in %BUILDDIR%/epub. goto end ) if "%1" == "latex" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex if errorlevel 1 exit /b 1 echo. echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdf" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "latexpdfja" ( %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex cd %BUILDDIR%/latex make all-pdf-ja cd %BUILDDIR%/.. echo. echo.Build finished; the PDF files are in %BUILDDIR%/latex. goto end ) if "%1" == "text" ( %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text if errorlevel 1 exit /b 1 echo. echo.Build finished. The text files are in %BUILDDIR%/text. goto end ) if "%1" == "man" ( %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man if errorlevel 1 exit /b 1 echo. echo.Build finished. The manual pages are in %BUILDDIR%/man. goto end ) if "%1" == "texinfo" ( %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo if errorlevel 1 exit /b 1 echo. echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. goto end ) if "%1" == "gettext" ( %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale if errorlevel 1 exit /b 1 echo. echo.Build finished. The message catalogs are in %BUILDDIR%/locale. goto end ) if "%1" == "changes" ( %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes if errorlevel 1 exit /b 1 echo. echo.The overview file is in %BUILDDIR%/changes. goto end ) if "%1" == "linkcheck" ( %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck if errorlevel 1 exit /b 1 echo. echo.Link check complete; look for any errors in the above output ^ or in %BUILDDIR%/linkcheck/output.txt. goto end ) if "%1" == "doctest" ( %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest if errorlevel 1 exit /b 1 echo. echo.Testing of doctests in the sources finished, look at the ^ results in %BUILDDIR%/doctest/output.txt. goto end ) if "%1" == "xml" ( %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml if errorlevel 1 exit /b 1 echo. echo.Build finished. The XML files are in %BUILDDIR%/xml. goto end ) if "%1" == "pseudoxml" ( %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml if errorlevel 1 exit /b 1 echo. echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. goto end ) :end pytds-1.15.0/docs/pytds.rst000066400000000000000000000011461456567501500156060ustar00rootroot00000000000000`pytds` -- main module ---------------------- .. automodule:: pytds :members: `pytds.login` -- various login mechanisms, e.g. NTLM, Negotiate, SSPI ---------------------------------------------------------------------- .. automodule:: pytds.login :members: `pytds.tds_base` -- Internal classes ------------------------------------------ .. automodule:: pytds.tds_base :members: `pytds.tds_types` -- Column type classes ------------------------------------------ .. automodule:: pytds.tds_types :members: `pytds.tz` -- timezones ----------------------- .. automodule:: pytds.tz :members: pytds-1.15.0/docs/requirements.txt000066400000000000000000000000001456567501500171610ustar00rootroot00000000000000pytds-1.15.0/profiling/000077500000000000000000000000001456567501500147505ustar00rootroot00000000000000pytds-1.15.0/profiling/profile_reader.py000066400000000000000000000025041456567501500203050ustar00rootroot00000000000000import struct import cProfile import pstats import pytds.tds_socket BUFSIZE = 4096 HEADER = struct.Struct(">BBHHBx") class Sock: def __init__(self): self._read_pos = 0 self._buf = bytearray(b"\x00" * BUFSIZE) HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) def sendall(self, data, flags=0): pass def recv_into(self, buffer, size=0): if size == 0: size = len(buffer) if self._read_pos >= BUFSIZE: HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) self._read_pos = 0 to_read = min(size, BUFSIZE - self._read_pos) buffer[:to_read] = self._buf[self._read_pos : self._read_pos + to_read] return to_read def recv(self, size): if self._read_pos >= len(self._buf): HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) self._read_pos = 0 res = self._buf[self._read_pos : self._read_pos + size] self._read_pos += len(res) return res def close(self): pass class Session: def __init__(self): self._transport = Sock() sess = Session() rdr = pytds.tds._TdsReader(sess) pr = cProfile.Profile() pr.enable() for _ in range(50000): rdr.recv(BUFSIZE) pr.disable() sortby = "tottime" ps = pstats.Stats(pr).sort_stats(sortby) ps.print_stats() pytds-1.15.0/profiling/profile_smp.log000066400000000000000000000011401456567501500177660ustar00rootroot00000000000000This file is a log of runs of profile_*.py SMP Profiling: It is used to measure performance improvements in SMP subsystem. Dec 11 2017 duration 0.7 sec at commit c1659058eb5ca35ed86ee366f7508a035d280ca4 Dec 12 2017 duration 0.6 sec at commit d6a3859d3ebcbafcf9335a0f4c6ee47673af2461 Dec 13 2017 duration 0.66 sec at commit 21297d4 Dec 14 2017 duration 0.74 sec after commit 9257a6b (improved TDS reader by 50%) TDS Reader profiling: Dec 13 2017 duration 0.64 sec at commit d255ce4 Dec 14 2017 duration 0.3 sec after commit 9257a6b Dec 17 2017 duration 0.35 sec after commit 24d3f97 (fix timeout test) pytds-1.15.0/profiling/profile_smp.py000066400000000000000000000023521456567501500176430ustar00rootroot00000000000000import struct import cProfile import pstats import io import pytds.smp transport = None bufsize = 512 smp_header = struct.Struct("= len(self._buf): self._seq += 1 smp_header.pack_into(self._buf, 0, 0x53, 0x8, 0, bufsize, self._seq, 4) self._read_pos = 0 res = self._buf[self._read_pos : self._read_pos + size] self._read_pos += len(res) return res def close(self): pass sock = Sock() mgr = pytds.smp.SmpManager(transport=sock) sess = mgr.create_session() pr = cProfile.Profile() pr.enable() buf = bytearray(b"\x00" * bufsize) for _ in range(50000): sess.recv_into(buf) pr.disable() sortby = "tottime" ps = pstats.Stats(pr).sort_stats(sortby) ps.print_stats() pytds-1.15.0/pytest.ini000066400000000000000000000002121456567501500150030ustar00rootroot00000000000000[pytest] log_level=DEBUG log_format=%(created)f %(filename)s %(lineno)d %(levelname)s %(message)s log_date_format=%H:%M:%S #addopts=--mypypytds-1.15.0/requirements.txt000066400000000000000000000000001456567501500162310ustar00rootroot00000000000000pytds-1.15.0/setup.py000066400000000000000000000017031456567501500144720ustar00rootroot00000000000000import os import setuptools from setuptools import setup import version requirements = list( open(os.path.join(os.path.dirname(__file__), "requirements.txt"), "r").readlines() ) print(setuptools.find_packages("src")) setup( name="python-tds", version=version.get_git_version(), description="Python DBAPI driver for MSSQL using pure Python TDS (Tabular Data Stream) protocol implementation", author="Mikhail Denisenko", author_email="denisenkom@gmail.com", url="https://github.com/denisenkom/pytds", license="MIT", packages=["pytds"], package_dir={"": "src"}, classifiers=[ "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3.3", "Programming Language :: Python :: 3.4", "Programming Language :: Python :: 3.5", ], zip_safe=True, install_requires=requirements, ) pytds-1.15.0/shell.py000066400000000000000000000022261456567501500144420ustar00rootroot00000000000000# simple interactive shell for MSSQL server import pytds import os def main(): conn = pytds.connect( dsn=os.getenv("HOST", "localhost"), user=os.getenv("SQLUSER", "sa"), password=os.getenv("SQLPASSWORD"), cafile="/Users/denisenk/opensource/pytds/ca.pem", enc_login_only=True, ) while True: try: sql = input("sql> ") except KeyboardInterrupt: return with conn.cursor() as cursor: try: cursor.execute(sql) except pytds.ProgrammingError as e: print("Error: " + str(e)) else: for _, msg in cursor.messages: print(msg.text) if cursor.description: print("\t".join(col[0] for col in cursor.description)) print("-" * 80) count = 0 for row in cursor: print("\t".join(str(col) for col in row)) count += 1 print("-" * 80) print("Returned {} rows".format(count)) print() main() pytds-1.15.0/src/000077500000000000000000000000001456567501500135465ustar00rootroot00000000000000pytds-1.15.0/src/pytds/000077500000000000000000000000001456567501500147115ustar00rootroot00000000000000pytds-1.15.0/src/pytds/__init__.py000066400000000000000000000464611456567501500170350ustar00rootroot00000000000000"""DB-SIG compliant module for communicating with MS SQL servers""" from __future__ import annotations from collections import deque import datetime import os import socket import time import uuid import warnings from typing import Any from pytds.tds_types import TzInfoFactoryType from . import lcid from . import connection_pool import pytds.tz from .connection import MarsConnection, NonMarsConnection, Connection from .cursor import Cursor # noqa: F401 # export for backward compatibility from .login import KerberosAuth, SspiAuth, AuthProtocol # noqa: F401 # export for backward compatibility from .row_strategies import ( tuple_row_strategy, list_row_strategy, # noqa: F401 # export for backward compatibility dict_row_strategy, namedtuple_row_strategy, # noqa: F401 # export for backward compatibility recordtype_row_strategy, # noqa: F401 # export for backward compatibility RowStrategy, ) from .tds_socket import _TdsSocket from . import instance_browser_client from . import tds_base from . import utils from . import login as pytds_login from .tds_base import ( Error, # noqa: F401 # export for backward compatibility LoginError, # noqa: F401 # export for backward compatibility DatabaseError, # noqa: F401 # export for backward compatibility ProgrammingError, # noqa: F401 # export for backward compatibility IntegrityError, # noqa: F401 # export for backward compatibility DataError, # noqa: F401 # export for backward compatibility InternalError, # noqa: F401 # export for backward compatibility InterfaceError, # noqa: F401 # export for backward compatibility TimeoutError, # noqa: F401 # export for backward compatibility OperationalError, # noqa: F401 # export for backward compatibility NotSupportedError, # noqa: F401 # export for backward compatibility Warning, # noqa: F401 # export for backward compatibility ClosedConnectionError, # noqa: F401 # export for backward compatibility Column, # noqa: F401 # export for backward compatibility PreLoginEnc, # noqa: F401 # export for backward compatibility ) from .tds_types import TableValuedParam, Binary # noqa: F401 # export for backward compatibility from .tds_base import ( ROWID, # noqa: F401 # export for backward compatibility DECIMAL, # noqa: F401 # export for backward compatibility STRING, # noqa: F401 # export for backward compatibility BINARY, # noqa: F401 # export for backward compatibility NUMBER, # noqa: F401 # export for backward compatibility DATETIME, # noqa: F401 # export for backward compatibility INTEGER, # noqa: F401 # export for backward compatibility REAL, # noqa: F401 # export for backward compatibility XML, # noqa: F401 # export for backward compatibility output, # noqa: F401 # export for backward compatibility default, # noqa: F401 # export for backward compatibility ) from . import tls from .tds_base import logger __author__ = "Mikhail Denisenko " try: __version__ = utils.package_version("python-tds") except Exception: __version__ = "DEV" intversion = utils.ver_to_int(__version__) #: Compliant with DB SIG 2.0 apilevel = "2.0" #: Module may be shared, but not connections threadsafety = 1 #: This module uses extended python format codes paramstyle = "pyformat" # map to servers deques, used to store active/passive servers # between calls to connect function # deques are used because they can be rotated _servers_deques: dict[ tuple[tuple[tuple[str, int | None, str], ...], str | None], deque[tuple[Any, int | None, str]], ] = {} def _get_servers_deque( servers: tuple[tuple[str, int | None, str], ...], database: str | None ) -> deque[tuple[Any, int | None, str]]: """Returns deque of servers for given tuple of servers and database name. This deque have active server at the begining, if first server is not accessible at the moment the deque will be rotated, second server will be moved to the first position, thirt to the second position etc, and previously first server will be moved to the last position. This allows to remember last successful server between calls to connect function. """ key = (servers, database) if key not in _servers_deques: _servers_deques[key] = deque(servers) return _servers_deques[key] def connect( dsn: str | None = None, database: str | None = None, user: str | None = None, password: str | None = None, timeout: float | None = None, login_timeout: float = 15, as_dict: bool | None = None, appname: str | None = None, port: int | None = None, tds_version: int = tds_base.TDS74, autocommit: bool = False, blocksize: int = 4096, use_mars: bool = False, auth: AuthProtocol | None = None, readonly: bool = False, load_balancer: tds_base.LoadBalancer | None = None, use_tz: datetime.tzinfo | None = None, bytes_to_unicode: bool = True, row_strategy: RowStrategy | None = None, failover_partner: str | None = None, server: str | None = None, cafile: str | None = None, sock: socket.socket | None = None, validate_host: bool = True, enc_login_only: bool = False, disable_connect_retry: bool = False, pooling: bool = False, use_sso: bool = False, isolation_level: int = 0, ): """ Opens connection to the database :keyword dsn: SQL server host and instance: [\\] :type dsn: string :keyword failover_partner: secondary database host, used if primary is not accessible :type failover_partner: string :keyword database: the database to initially connect to :type database: string :keyword user: database user to connect as :type user: string :keyword password: user's password :type password: string :keyword timeout: query timeout in seconds, default 0 (no timeout) :type timeout: int :keyword login_timeout: timeout for connection and login in seconds, default 15 :type login_timeout: int :keyword as_dict: whether rows should be returned as dictionaries instead of tuples. :type as_dict: boolean :keyword appname: Set the application name to use for the connection :type appname: string :keyword port: the TCP port to use to connect to the server :type port: int :keyword tds_version: Maximum TDS version to use, should only be used for testing :type tds_version: int :keyword autocommit: Enable or disable database level autocommit :type autocommit: bool :keyword blocksize: Size of block for the TDS protocol, usually should not be used :type blocksize: int :keyword use_mars: Enable or disable MARS :type use_mars: bool :keyword auth: An instance of authentication method class, e.g. Ntlm or Sspi :keyword readonly: Allows to enable read-only mode for connection, only supported by MSSQL 2012, earlier versions will ignore this parameter :type readonly: bool :keyword load_balancer: An instance of load balancer class to use, if not provided will not use load balancer :keyword use_tz: Provides timezone for naive database times, if not provided date and time will be returned in naive format :keyword bytes_to_unicode: If true single byte database strings will be converted to unicode Python strings, otherwise will return strings as ``bytes`` without conversion. :type bytes_to_unicode: bool :keyword row_strategy: strategy used to create rows, determines type of returned rows, can be custom or one of: :func:`tuple_row_strategy`, :func:`list_row_strategy`, :func:`dict_row_strategy`, :func:`namedtuple_row_strategy`, :func:`recordtype_row_strategy` :type row_strategy: function of list of column names returning row factory :keyword cafile: Name of the file containing trusted CAs in PEM format, if provided will enable TLS :type cafile: str :keyword validate_host: Host name validation during TLS connection is enabled by default, if you disable it you will be vulnerable to MitM type of attack. :type validate_host: bool :keyword enc_login_only: Allows you to scope TLS encryption only to an authentication portion. This means that anyone who can observe traffic on your network will be able to see all your SQL requests and potentially modify them. :type enc_login_only: bool :keyword use_sso: Enables SSO login, e.g. Kerberos using SSPI on Windows and kerberos package on other platforms. Cannot be used together with auth parameter. :returns: An instance of :class:`Connection` """ if use_sso and auth: raise ValueError("use_sso cannot be used with auth parameter defined") login = tds_base._TdsLogin() login.client_host_name = socket.gethostname()[:128] login.library = "Python TDS Library" login.user_name = user or "" login.password = password or "" login.app_name = appname or "pytds" login.port = port login.language = "" # use database default login.attach_db_file = "" login.tds_version = tds_version if tds_version < tds_base.TDS70: raise ValueError("This TDS version is not supported") login.database = database or "" login.bulk_copy = False login.client_lcid = lcid.LANGID_ENGLISH_US login.use_mars = use_mars login.pid = os.getpid() login.change_password = "" login.client_id = uuid.getnode() # client mac address login.cafile = cafile login.validate_host = validate_host login.enc_login_only = enc_login_only if cafile: if not tls.OPENSSL_AVAILABLE: raise ValueError( "You are trying to use encryption but pyOpenSSL does not work, you probably " "need to install it first" ) login.tls_ctx = tls.create_context(cafile) if login.enc_login_only: login.enc_flag = PreLoginEnc.ENCRYPT_OFF else: login.enc_flag = PreLoginEnc.ENCRYPT_ON else: login.tls_ctx = None login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP if use_tz: login.client_tz = use_tz else: login.client_tz = pytds.tz.local # that will set: # ANSI_DEFAULTS to ON, # IMPLICIT_TRANSACTIONS to OFF, # TEXTSIZE to 0x7FFFFFFF (2GB) (TDS 7.2 and below), TEXTSIZE to infinite (introduced in TDS 7.3), # and ROWCOUNT to infinite login.option_flag2 = tds_base.TDS_ODBC_ON login.connect_timeout = login_timeout login.query_timeout = timeout login.blocksize = blocksize login.readonly = readonly login.load_balancer = load_balancer login.bytes_to_unicode = bytes_to_unicode if server and dsn: raise ValueError("Both server and dsn shouldn't be specified") if server: warnings.warn( "server parameter is deprecated, use dsn instead", DeprecationWarning ) dsn = server if load_balancer and failover_partner: raise ValueError( "Both load_balancer and failover_partner shoudln't be specified" ) servers: list[tuple[str, int | None]] = [] if load_balancer: servers += ((srv, None) for srv in load_balancer.choose()) else: servers += [(dsn or "localhost", port)] if failover_partner: servers.append((failover_partner, port)) parsed_servers: list[tuple[str, int | None, str]] = [] for srv, instance_port in servers: host, instance = utils.parse_server(srv) if instance and instance_port: raise ValueError("Both instance and port shouldn't be specified") parsed_servers.append((host, instance_port, instance)) if use_sso: spn = f"MSSQLSvc@{parsed_servers[0][0]}:{parsed_servers[0][1]}" try: login.auth = pytds_login.SspiAuth(spn=spn) except ImportError: login.auth = pytds_login.KerberosAuth(spn) else: login.auth = auth login.servers = _get_servers_deque(tuple(parsed_servers), database) # unique connection identifier used to pool connection key = ( dsn, login.user_name, login.app_name, login.tds_version, login.database, login.client_lcid, login.use_mars, login.cafile, login.blocksize, login.readonly, login.bytes_to_unicode, login.auth, login.client_tz, autocommit, ) tzinfo_factory = None if use_tz is None else pytds.tz.FixedOffsetTimezone assert ( row_strategy is None or as_dict is None ), "Both row_startegy and as_dict were specified, you should use either one or another" if as_dict: row_strategy = dict_row_strategy elif row_strategy is not None: row_strategy = row_strategy else: row_strategy = tuple_row_strategy # default row strategy if disable_connect_retry: first_try_time = login.connect_timeout else: first_try_time = login.connect_timeout * 0.08 def attempt(attempt_timeout: float) -> Connection: if pooling: res = connection_pool.connection_pool.take(key) if res is not None: tds_socket, sess = res sess.callproc("sp_reset_connection", []) tds_socket._row_strategy = row_strategy if tds_socket.mars_enabled: return MarsConnection( pooling=pooling, key=key, tds_socket=tds_socket, ) else: return NonMarsConnection( pooling=pooling, key=key, tds_socket=tds_socket, ) host, port, instance = login.servers[0] return _connect( login=login, host=host, port=port, instance=instance, timeout=attempt_timeout, pooling=pooling, key=key, autocommit=autocommit, isolation_level=isolation_level, tzinfo_factory=tzinfo_factory, sock=sock, use_tz=use_tz, row_strategy=row_strategy, ) def ex_handler(ex: Exception) -> None: if isinstance(ex, LoginError): raise ex elif isinstance(ex, BrokenPipeError): # Allow to retry when BrokenPipeError is received pass elif isinstance(ex, OperationalError): # if there are more than one message this means # that the login was successful, like in the # case when database is not accessible # mssql returns 2 messages: # 1) Cannot open database "" requested by the login. The login failed. # 2) Login failed for user '' # in this case we want to retry if ex.msg_no in ( 18456, # login failed 18486, # account is locked 18487, # password expired 18488, # password should be changed 18452, # login from untrusted domain ): raise ex else: raise ex return utils.exponential_backoff( work=attempt, ex_handler=ex_handler, max_time_sec=login.connect_timeout, first_attempt_time_sec=first_try_time, ) def _connect( login: tds_base._TdsLogin, host: str, port: int | None, instance: str, timeout: float, pooling: bool, key: connection_pool.PoolKeyType, autocommit: bool, isolation_level: int, tzinfo_factory: TzInfoFactoryType | None, sock: socket.socket | None, use_tz: datetime.tzinfo | None, row_strategy: RowStrategy, ) -> Connection: """ Establish physical connection and login. """ try: login.server_name = host login.instance_name = instance resolved_port = instance_browser_client.resolve_instance_port( server=host, port=port, instance=instance, timeout=timeout ) if not sock: logger.info("Opening socket to %s:%d", host, resolved_port) sock = socket.create_connection((host, resolved_port), timeout) except Exception as e: raise LoginError(f"Cannot connect to server '{host}': {e}") from e try: sock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) # default keep alive should be 30 seconds according to spec: # https://msdn.microsoft.com/en-us/library/dd341108.aspx sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 30) sock.settimeout(timeout) tds_socket = _TdsSocket( sock=sock, tzinfo_factory=tzinfo_factory, use_tz=use_tz, row_strategy=row_strategy, autocommit=autocommit, login=login, isolation_level=isolation_level, ) logger.info("Performing login on the connection") route = tds_socket.login() if route is not None: logger.info( "Connection was rerouted to %s:%d", route["server"], route["port"] ) sock.close() ### Change SPN once route exists if isinstance(login.auth, pytds_login.SspiAuth): route_spn = f"MSSQLSvc@{host}:{port}" login.auth = pytds_login.SspiAuth( user_name=login.user_name, password=login.password, server_name=host, port=port, spn=route_spn, ) return _connect( login=login, host=route["server"], port=route["port"], instance=instance, timeout=timeout, pooling=pooling, key=key, autocommit=autocommit, isolation_level=isolation_level, tzinfo_factory=tzinfo_factory, use_tz=use_tz, row_strategy=row_strategy, sock=None, ) if not autocommit: tds_socket.main_session.begin_tran() sock.settimeout(login.query_timeout) if tds_socket.mars_enabled: return MarsConnection( pooling=pooling, key=key, tds_socket=tds_socket, ) else: return NonMarsConnection( pooling=pooling, key=key, tds_socket=tds_socket, ) except Exception: sock.close() raise def Date(year: int, month: int, day: int) -> datetime.date: return datetime.date(year, month, day) def DateFromTicks(ticks: float) -> datetime.date: return datetime.date.fromtimestamp(ticks) def Time( hour: int, minute: int, second: int, microsecond: int = 0, tzinfo: datetime.tzinfo | None = None, ) -> datetime.time: return datetime.time(hour, minute, second, microsecond, tzinfo) def TimeFromTicks(ticks: float) -> datetime.time: return Time(*time.localtime(ticks)[3:6]) def Timestamp( year: int, month: int, day: int, hour: int, minute: int, second: int, microseconds: int = 0, tzinfo: datetime.tzinfo | None = None, ) -> datetime.datetime: return datetime.datetime( year, month, day, hour, minute, second, microseconds, tzinfo ) def TimestampFromTicks(ticks: float) -> datetime.datetime: return datetime.datetime.fromtimestamp(ticks) pytds-1.15.0/src/pytds/collate.py000066400000000000000000000236361456567501500167200ustar00rootroot00000000000000import codecs import struct TDS_CHARSET_ISO_8859_1 = 1 TDS_CHARSET_CP1251 = 2 TDS_CHARSET_CP1252 = 3 TDS_CHARSET_UCS_2LE = 4 TDS_CHARSET_UNICODE = 5 ucs2_codec = codecs.lookup("utf_16_le") def sortid2charset(sort_id): sql_collate = sort_id # # The table from the MSQLServer reference "Windows Collation Designators" # and from " NLS Information for Microsoft Windows XP" # if sql_collate in ( 30, # SQL_Latin1_General_CP437_BIN 31, # SQL_Latin1_General_CP437_CS_AS 32, # SQL_Latin1_General_CP437_CI_AS 33, # SQL_Latin1_General_Pref_CP437_CI_AS 34, ): # SQL_Latin1_General_CP437_CI_AI return "CP437" elif sql_collate in ( 40, # SQL_Latin1_General_CP850_BIN 41, # SQL_Latin1_General_CP850_CS_AS 42, # SQL_Latin1_General_CP850_CI_AS 43, # SQL_Latin1_General_Pref_CP850_CI_AS 44, # SQL_Latin1_General_CP850_CI_AI 49, # SQL_1xCompat_CP850_CI_AS 55, # SQL_AltDiction_CP850_CS_AS 56, # SQL_AltDiction_Pref_CP850_CI_AS 57, # SQL_AltDiction_CP850_CI_AI 58, # SQL_Scandinavian_Pref_CP850_CI_AS 59, # SQL_Scandinavian_CP850_CS_AS 60, # SQL_Scandinavian_CP850_CI_AS 61, ): # SQL_AltDiction_CP850_CI_AS return "CP850" elif sql_collate in ( 80, # SQL_Latin1_General_1250_BIN 81, # SQL_Latin1_General_CP1250_CS_AS 82, # SQL_Latin1_General_Cp1250_CI_AS_KI_WI 83, # SQL_Czech_Cp1250_CS_AS_KI_WI 84, # SQL_Czech_Cp1250_CI_AS_KI_WI 85, # SQL_Hungarian_Cp1250_CS_AS_KI_WI 86, # SQL_Hungarian_Cp1250_CI_AS_KI_WI 87, # SQL_Polish_Cp1250_CS_AS_KI_WI 88, # SQL_Polish_Cp1250_CI_AS_KI_WI 89, # SQL_Romanian_Cp1250_CS_AS_KI_WI 90, # SQL_Romanian_Cp1250_CI_AS_KI_WI 91, # SQL_Croatian_Cp1250_CS_AS_KI_WI 92, # SQL_Croatian_Cp1250_CI_AS_KI_WI 93, # SQL_Slovak_Cp1250_CS_AS_KI_WI 94, # SQL_Slovak_Cp1250_CI_AS_KI_WI 95, # SQL_Slovenian_Cp1250_CS_AS_KI_WI 96, # SQL_Slovenian_Cp1250_CI_AS_KI_WI ): return "CP1250" elif sql_collate in ( 104, # SQL_Latin1_General_1251_BIN 105, # SQL_Latin1_General_CP1251_CS_AS 106, # SQL_Latin1_General_CP1251_CI_AS 107, # SQL_Ukrainian_Cp1251_CS_AS_KI_WI 108, # SQL_Ukrainian_Cp1251_CI_AS_KI_WI ): return "CP1251" elif sql_collate in ( 51, # SQL_Latin1_General_Cp1_CS_AS_KI_WI 52, # SQL_Latin1_General_Cp1_CI_AS_KI_WI 53, # SQL_Latin1_General_Pref_Cp1_CI_AS_KI_WI 54, # SQL_Latin1_General_Cp1_CI_AI_KI_WI 183, # SQL_Danish_Pref_Cp1_CI_AS_KI_WI 184, # SQL_SwedishPhone_Pref_Cp1_CI_AS_KI_WI 185, # SQL_SwedishStd_Pref_Cp1_CI_AS_KI_WI 186, # SQL_Icelandic_Pref_Cp1_CI_AS_KI_WI ): return "CP1252" elif sql_collate in ( 112, # SQL_Latin1_General_1253_BIN 113, # SQL_Latin1_General_CP1253_CS_AS 114, # SQL_Latin1_General_CP1253_CI_AS 120, # SQL_MixDiction_CP1253_CS_AS 121, # SQL_AltDiction_CP1253_CS_AS 122, # SQL_AltDiction2_CP1253_CS_AS 124, # SQL_Latin1_General_CP1253_CI_AI ): return "CP1253" elif sql_collate in ( 128, # SQL_Latin1_General_1254_BIN 129, # SQL_Latin1_General_Cp1254_CS_AS_KI_WI 130, # SQL_Latin1_General_Cp1254_CI_AS_KI_WI ): return "CP1254" elif sql_collate in ( 136, # SQL_Latin1_General_1255_BIN 137, # SQL_Latin1_General_CP1255_CS_AS 138, # SQL_Latin1_General_CP1255_CI_AS ): return "CP1255" elif sql_collate in ( 144, # SQL_Latin1_General_1256_BIN 145, # SQL_Latin1_General_CP1256_CS_AS 146, # SQL_Latin1_General_CP1256_CI_AS ): return "CP1256" elif sql_collate in ( 152, # SQL_Latin1_General_1257_BIN 153, # SQL_Latin1_General_CP1257_CS_AS 154, # SQL_Latin1_General_CP1257_CI_AS 155, # SQL_Estonian_Cp1257_CS_AS_KI_WI 156, # SQL_Estonian_Cp1257_CI_AS_KI_WI 157, # SQL_Latvian_Cp1257_CS_AS_KI_WI 158, # SQL_Latvian_Cp1257_CI_AS_KI_WI 159, # SQL_Lithuanian_Cp1257_CS_AS_KI_WI 160, # SQL_Lithuanian_Cp1257_CI_AS_KI_WI ): return "CP1257" else: raise Exception("Invalid collation: 0x%X" % (sql_collate,)) def lcid2charset(lcid): if lcid in ( 0x405, 0x40E, # 0x1040e 0x415, 0x418, 0x41A, 0x41B, 0x41C, 0x424, # 0x81a, seem wrong in XP table TODO check 0x104E, ): return "CP1250" elif lcid in ( 0x402, 0x419, 0x422, 0x423, 0x42F, 0x43F, 0x440, 0x444, 0x450, 0x81A, # ?? 0x82C, 0x843, 0xC1A, ): return "CP1251" elif lcid in ( 0x1007, 0x1009, 0x100A, 0x100C, 0x1407, 0x1409, 0x140A, 0x140C, 0x1809, 0x180A, 0x180C, 0x1C09, 0x1C0A, 0x2009, 0x200A, 0x2409, 0x240A, 0x2809, 0x280A, 0x2C09, 0x2C0A, 0x3009, 0x300A, 0x3409, 0x340A, 0x380A, 0x3C0A, 0x400A, 0x403, 0x406, 0x407, # 0x10407 0x409, 0x40A, 0x40B, 0x40C, 0x40F, 0x410, 0x413, 0x414, 0x416, 0x41D, 0x421, 0x42D, 0x436, 0x437, # 0x10437 0x438, # 0x439, ??? Unicode only 0x43E, 0x440A, 0x441, 0x456, 0x480A, 0x4C0A, 0x500A, 0x807, 0x809, 0x80A, 0x80C, 0x810, 0x813, 0x814, 0x816, 0x81D, 0x83E, 0xC07, 0xC09, 0xC0A, 0xC0C, ): return "CP1252" elif lcid == 0x408: return "CP1253" elif lcid in (0x41F, 0x42C, 0x443): return "CP1254" elif lcid == 0x40D: return "CP1255" elif lcid in ( 0x1001, 0x1401, 0x1801, 0x1C01, 0x2001, 0x2401, 0x2801, 0x2C01, 0x3001, 0x3401, 0x3801, 0x3C01, 0x4001, 0x401, 0x420, 0x429, 0x801, 0xC01, ): return "CP1256" elif lcid in (0x425, 0x426, 0x427, 0x827): # ?? return "CP1257" elif lcid == 0x42A: return "CP1258" elif lcid == 0x41E: return "CP874" elif lcid == 0x411: # 0x10411 return "CP932" elif lcid in (0x1004, 0x804): # 0x20804 return "CP936" elif lcid == 0x412: # 0x10412 return "CP949" elif lcid in ( 0x1404, 0x404, # 0x30404 0xC04, ): return "CP950" else: return "CP1252" class Collation(object): _coll_struct = struct.Struct("> 26 return cls( lcid=lcid, ignore_case=ignore_case, ignore_accent=ignore_accent, ignore_width=ignore_width, ignore_kana=ignore_kana, binary=binary, binary2=binary2, version=version, sort_id=sort_id, ) def pack(self): lump = 0 lump |= self.lcid & 0xFFFFF lump |= (self.version << 26) & 0xF0000000 if self.ignore_case: lump |= self.f_ignore_case if self.ignore_accent: lump |= self.f_ignore_accent if self.ignore_width: lump |= self.f_ignore_width if self.ignore_kana: lump |= self.f_ignore_kana if self.binary: lump |= self.f_binary if self.binary2: lump |= self.f_binary2 return self._coll_struct.pack(lump, self.sort_id) def get_charset(self): if self.sort_id: return sortid2charset(self.sort_id) else: return lcid2charset(self.lcid) def get_codec(self): return codecs.lookup(self.get_charset()) # TODO: define __repr__ and __unicode__ raw_collation = Collation(0, 0, 0, 0, 0, 0, 0, 0, 0) pytds-1.15.0/src/pytds/connection.py000066400000000000000000000217351456567501500174320ustar00rootroot00000000000000""" This module implements DBAPI connection classes for both MARS and non-MARS variants """ from __future__ import annotations import typing import warnings import weakref from . import tds_base from .tds_socket import _TdsSocket from . import row_strategies from .tds_base import logger from . import connection_pool if typing.TYPE_CHECKING: from .cursor import Cursor, NonMarsCursor, _MarsCursor class Connection(typing.Protocol): """ This class defines interface for connection object according to DBAPI specification. This interface is implemented by MARS and non-MARS connection classes. """ @property def autocommit(self) -> bool: ... @autocommit.setter def autocommit(self, value: bool) -> None: ... @property def isolation_level(self) -> int: ... @isolation_level.setter def isolation_level(self, level: int) -> None: ... def __enter__(self) -> BaseConnection: ... def __exit__(self, *args) -> None: ... def commit(self) -> None: ... def rollback(self) -> None: ... def close(self) -> None: ... @property def mars_enabled(self) -> bool: ... def cursor(self) -> Cursor: ... class BaseConnection(Connection): """ Base connection class. It implements most of the common logic for MARS and non-MARS connection classes. """ _connection_closed_exception = tds_base.InterfaceError("Connection closed") def __init__( self, pooling: bool, key: connection_pool.PoolKeyType, tds_socket: _TdsSocket, ) -> None: # _tds_socket is set to None when connection is closed self._tds_socket: _TdsSocket | None = tds_socket self._key = key self._pooling = pooling # references to all cursors opened from connection # those references used to close cursors when connection is closed self._cursors: weakref.WeakSet[Cursor] = weakref.WeakSet() @property def as_dict(self) -> bool: """ Instructs all cursors this connection creates to return results as a dictionary rather than a tuple. """ if not self._tds_socket: raise self._connection_closed_exception return ( self._tds_socket.main_session.row_strategy == row_strategies.dict_row_strategy ) @as_dict.setter def as_dict(self, value: bool) -> None: warnings.warn( "setting as_dict property on the active connection, instead create connection with needed row_strategy", DeprecationWarning, ) if not self._tds_socket: raise self._connection_closed_exception if value: self._tds_socket.main_session.row_strategy = ( row_strategies.dict_row_strategy ) else: self._tds_socket.main_session.row_strategy = ( row_strategies.tuple_row_strategy ) @property def autocommit_state(self) -> bool: """ An alias for `autocommit`, provided for compatibility with pymssql """ if not self._tds_socket: raise self._connection_closed_exception return self._tds_socket.main_session.autocommit def set_autocommit(self, value: bool) -> None: """An alias for `autocommit`, provided for compatibility with ADO dbapi""" if not self._tds_socket: raise self._connection_closed_exception self._tds_socket.main_session.autocommit = value @property def autocommit(self) -> bool: """ The current state of autocommit on the connection. """ if not self._tds_socket: raise self._connection_closed_exception return self._tds_socket.main_session.autocommit @autocommit.setter def autocommit(self, value: bool) -> None: if not self._tds_socket: raise self._connection_closed_exception self._tds_socket.main_session.autocommit = value @property def isolation_level(self) -> int: """Isolation level for transactions, for possible values see :ref:`isolation-level-constants` .. seealso:: `SET TRANSACTION ISOLATION LEVEL`__ in MSSQL documentation .. __: http://msdn.microsoft.com/en-us/library/ms173763.aspx """ if not self._tds_socket: raise self._connection_closed_exception return self._tds_socket.main_session.isolation_level @isolation_level.setter def isolation_level(self, level: int) -> None: if not self._tds_socket: raise self._connection_closed_exception self._tds_socket.main_session.isolation_level = level @property def tds_version(self) -> int: """ Version of the TDS protocol that is being used by this connection """ if not self._tds_socket: raise self._connection_closed_exception return self._tds_socket.tds_version @property def product_version(self): """ Version of the MSSQL server """ if not self._tds_socket: raise self._connection_closed_exception return self._tds_socket.product_version def __enter__(self) -> BaseConnection: return self def __exit__(self, *args) -> None: self.close() def commit(self) -> None: """ Commit transaction which is currently in progress. """ if not self._tds_socket: raise self._connection_closed_exception # Setting cont to True to start new transaction # after current transaction is rolled back self._tds_socket.main_session.commit(cont=True) def rollback(self) -> None: """ Roll back transaction which is currently in progress. """ if self._tds_socket: # Setting cont to True to start new transaction # after current transaction is rolled back self._tds_socket.main_session.rollback(cont=True) def close(self) -> None: """Close connection to an MS SQL Server. This function tries to close the connection and free all memory used. It can be called more than once in a row. No exception is raised in this case. """ if self._tds_socket: logger.debug("Closing connection") if self._pooling: connection_pool.connection_pool.add( self._key, (self._tds_socket, self._tds_socket.main_session) ) else: self._tds_socket.close() logger.debug("Closing all cursors which were opened by connection") for cursor in self._cursors: cursor.close() self._tds_socket = None class MarsConnection(BaseConnection): """ MARS connection class, this object is created by calling :func:`connect` with use_mars parameter set to False. """ def __init__( self, pooling: bool, key: connection_pool.PoolKeyType, tds_socket: _TdsSocket, ): super().__init__(pooling=pooling, key=key, tds_socket=tds_socket) @property def mars_enabled(self) -> bool: return True def cursor(self) -> _MarsCursor: """ Return cursor object that can be used to make queries and fetch results from the database. """ from .cursor import _MarsCursor if not self._tds_socket: raise self._connection_closed_exception cursor = _MarsCursor( connection=self, session=self._tds_socket.create_session(), ) self._cursors.add(cursor) return cursor def close(self): if self._tds_socket: self._tds_socket.close_all_mars_sessions() super().close() class NonMarsConnection(BaseConnection): """ Non-MARS connection class, this object should be created by calling :func:`connect` with use_mars parameter set to False. """ def __init__( self, pooling: bool, key: connection_pool.PoolKeyType, tds_socket: _TdsSocket, ): super().__init__(pooling=pooling, key=key, tds_socket=tds_socket) self._active_cursor: NonMarsCursor | None = None @property def mars_enabled(self) -> bool: return False def cursor(self) -> NonMarsCursor: """ Return cursor object that can be used to make queries and fetch results from the database. """ from .cursor import NonMarsCursor if not self._tds_socket: raise self._connection_closed_exception # Only one cursor can be active at any given time if self._active_cursor: self._active_cursor.cancel() self._active_cursor.close() cursor = NonMarsCursor( connection=self, session=self._tds_socket.main_session, ) self._active_cursor = cursor self._cursors.add(cursor) return cursor pytds-1.15.0/src/pytds/connection_pool.py000066400000000000000000000020161456567501500204520ustar00rootroot00000000000000from __future__ import annotations import datetime from typing import Optional, Union, Tuple from pytds.tds_base import AuthProtocol from pytds.tds_socket import _TdsSocket, _TdsSession PoolKeyType = Tuple[ Optional[str], Optional[str], Optional[str], int, Optional[str], int, bool, Optional[str], int, bool, bool, Union[AuthProtocol, None], datetime.tzinfo, bool, ] class ConnectionPool: def __init__(self, max_pool_size: int = 100, min_pool_size: int = 0): self._max_pool_size = max_pool_size self._pool: dict[PoolKeyType, list[tuple[_TdsSocket, _TdsSession]]] = {} def add(self, key: PoolKeyType, conn: tuple[_TdsSocket, _TdsSession]) -> None: self._pool.setdefault(key, []).append(conn) def take(self, key: PoolKeyType) -> tuple[_TdsSocket, _TdsSession] | None: conns = self._pool.get(key, []) if len(conns) > 0: return conns.pop() else: return None connection_pool = ConnectionPool() pytds-1.15.0/src/pytds/cursor.py000066400000000000000000000555451456567501500166160ustar00rootroot00000000000000""" This module implements DBAPI cursor classes for MARS and non-MARS """ from __future__ import annotations import collections import csv import typing import warnings from collections.abc import Iterable import pytds from pytds.connection import Connection, MarsConnection, NonMarsConnection from pytds.tds_types import NVarCharType, TzInfoFactoryType from pytds.tds_socket import _TdsSession from pytds import tds_base from .tds_base import logger class Cursor(typing.Protocol, Iterable): """ This class defines an interface for cursor classes. It is implemented by MARS and non-MARS cursor classes. """ def __enter__(self) -> Cursor: ... def __exit__(self, *args) -> None: ... def get_proc_outputs(self) -> list[typing.Any]: ... def callproc( self, procname: tds_base.InternalProc | str, parameters: dict[str, typing.Any] | tuple[typing.Any, ...] = (), ) -> list[typing.Any]: ... @property def return_value(self) -> int | None: ... @property def spid(self) -> int: ... @property def connection(self) -> Connection | None: ... def get_proc_return_status(self) -> int | None: ... def cancel(self) -> None: ... def close(self) -> None: ... def execute( self, operation: str, params: list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] | None = (), ) -> Cursor: ... def executemany( self, operation: str, params_seq: Iterable[ list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] ], ) -> None: ... def execute_scalar( self, query_string: str, params: list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] | None = None, ) -> typing.Any: ... def nextset(self) -> bool | None: ... @property def rowcount(self) -> int: ... @property def description(self): ... def set_stream(self, column_idx: int, stream) -> None: ... @property def messages( self ) -> ( list[ tuple[ typing.Type, tds_base.IntegrityError | tds_base.ProgrammingError | tds_base.OperationalError, ] ] | None ): ... @property def native_description(self): ... def fetchone(self) -> typing.Any: ... def fetchmany(self, size=None) -> list[typing.Any]: ... def fetchall(self) -> list[typing.Any]: ... @staticmethod def setinputsizes(sizes=None) -> None: ... @staticmethod def setoutputsize(size=None, column=0) -> None: ... def copy_to( self, file: Iterable[str] | None = None, table_or_view: str | None = None, sep: str = "\t", columns: Iterable[tds_base.Column | str] | None = None, check_constraints: bool = False, fire_triggers: bool = False, keep_nulls: bool = False, kb_per_batch: int | None = None, rows_per_batch: int | None = None, order: str | None = None, tablock: bool = False, schema: str | None = None, null_string: str | None = None, data: Iterable[tuple[typing.Any, ...]] | None = None, ): ... class BaseCursor(Cursor, collections.abc.Iterator): """ This class represents a base database cursor, which is used to issue queries and fetch results from a database connection. There are two actual cursor classes: one for MARS connections and one for non-MARS connections. """ _cursor_closed_exception = tds_base.InterfaceError("Cursor is closed") def __init__(self, connection: Connection, session: _TdsSession): self.arraysize = 1 # Null value in _session means cursor was closed self._session: _TdsSession | None = session # Keeping strong reference to connection to prevent connection from being garbage collected # while there are active cursors self._connection: Connection | None = connection @property def connection(self) -> Connection | None: warnings.warn( "connection property is deprecated on the cursor object and will be removed in future releases", DeprecationWarning, ) return self._connection def __enter__(self) -> BaseCursor: return self def __exit__(self, *args) -> None: self.close() def __iter__(self) -> BaseCursor: """ Return self to make cursors compatibile with Python iteration protocol. """ return self def get_proc_outputs(self) -> list[typing.Any]: """ If stored procedure has result sets and OUTPUT parameters use this method after you processed all result sets to get values of the OUTPUT parameters. :return: A list of output parameter values. """ if self._session is None: raise self._cursor_closed_exception return self._session.get_proc_outputs() def callproc( self, procname: tds_base.InternalProc | str, parameters: dict[str, typing.Any] | tuple[typing.Any, ...] = (), ) -> list[typing.Any]: """ Call a stored procedure with the given name. :param procname: The name of the procedure to call :type procname: str :keyword parameters: The optional parameters for the procedure :type parameters: sequence Note: If stored procedure has OUTPUT parameters and result sets this method will not return values for OUTPUT parameters, you should call get_proc_outputs to get values for OUTPUT parameters. """ if self._session is None: raise self._cursor_closed_exception return self._session.callproc(procname, parameters) @property def return_value(self) -> int | None: """Alias to :func:`get_proc_return_status`""" return self.get_proc_return_status() @property def spid(self) -> int: """MSSQL Server's session ID (SPID) It can be used to correlate connections between client and server logs. """ if self._session is None: raise self._cursor_closed_exception return self._session._spid def _get_tzinfo_factory(self) -> TzInfoFactoryType | None: if self._session is None: raise self._cursor_closed_exception return self._session.tzinfo_factory def _set_tzinfo_factory(self, tzinfo_factory: TzInfoFactoryType | None) -> None: if self._session is None: raise self._cursor_closed_exception self._session.tzinfo_factory = tzinfo_factory tzinfo_factory = property(_get_tzinfo_factory, _set_tzinfo_factory) def get_proc_return_status(self) -> int | None: """Last executed stored procedure's return value Returns integer value returned by `RETURN` statement from last executed stored procedure. If no value was not returned or no stored procedure was executed return `None`. """ if self._session is None: return None return self._session.get_proc_return_status() def cancel(self) -> None: """Cancel currently executing statement or stored procedure call""" if self._session is None: return self._session.cancel_if_pending() def close(self) -> None: """ Closes the cursor. The cursor is unusable from this point. """ logger.debug("Closing cursor") self._session = None self._connection = None T = typing.TypeVar("T") def execute( self, operation: str, params: list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] | None = (), ) -> BaseCursor: """Execute an SQL query Optionally query can be executed with parameters. To make parametrized query use `%s` in the query to denote a parameter and pass a tuple with parameter values, e.g.: .. code-block:: execute("select %s, %s", (1,2)) This will execute query replacing first `%s` with first parameter value - 1, and second `%s` with second parameter value -2. Another option is to use named parameters with passing a dictionary, e.g.: .. code-block:: execute("select %(param1)s, %(param2)s", {param1=1, param2=2}) Both those ways of passing parameters is safe from SQL injection attacks. This function does not return results of the execution. Use :func:`fetchone` or similar to fetch results. """ if self._session is None: raise self._cursor_closed_exception self._session.execute(operation, params) # for compatibility with pyodbc return self def executemany( self, operation: str, params_seq: Iterable[ list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] ], ) -> None: """ Execute same SQL query multiple times for each parameter set in the `params_seq` list. """ if self._session is None: raise self._cursor_closed_exception self._session.executemany(operation=operation, params_seq=params_seq) def execute_scalar( self, query_string: str, params: list[typing.Any] | tuple[typing.Any, ...] | dict[str, typing.Any] | None = None, ) -> typing.Any: """ This method executes SQL query then returns first column of first row or the result. Query can be parametrized, see :func:`execute` method for details. This method is useful if you want just a single value, as in: .. code-block:: conn.execute_scalar('SELECT COUNT(*) FROM employees') This method works in the same way as ``iter(conn).next()[0]``. Remaining rows, if any, can still be iterated after calling this method. """ if self._session is None: raise self._cursor_closed_exception return self._session.execute_scalar(query_string, params) def nextset(self) -> bool | None: """Move to next recordset in batch statement, all rows of current recordset are discarded if present. :returns: true if successful or ``None`` when there are no more recordsets """ if self._session is None: raise self._cursor_closed_exception return self._session.next_set() @property def rowcount(self) -> int: """Number of rows affected by previous statement :returns: -1 if this information was not supplied by the server """ if self._session is None: return -1 return self._session.rows_affected @property def description(self): """Cursor description, see http://legacy.python.org/dev/peps/pep-0249/#description""" if self._session is None: return None res = self._session.res_info if res: return res.description else: return None def set_stream(self, column_idx: int, stream) -> None: """ This function can be used to efficiently receive values which can be very large, e.g. `TEXT`, `VARCHAR(MAX)`, `VARBINARY(MAX)`. When streaming is not enabled, values are loaded to memory as they are received from server and once entire row is loaded, it is returned. With this function streaming receiver can be specified via `stream` parameter which will receive chunks of the data as they are received. For each received chunk driver will call stream's write method. For example this can be used to save value of a field into a file, or to proces value as it is being received. For string fields chunks are represented as unicode strings. For binary fields chunks are represented as `bytes` strings. Example usage: .. code-block:: cursor.execute("select N'very large field'") cursor.set_stream(0, StringIO()) row = cursor.fetchone() # now row[0] contains instance of a StringIO object which was gradually # filled with output from server for first column. :param column_idx: Zero based index of a column for which to setup streaming receiver :type column_idx: int :param stream: Stream object that will be receiving chunks of data via it's `write` method. """ if self._session is None: raise self._cursor_closed_exception res_info = self._session.res_info if not res_info: raise ValueError("No result set is active") if len(res_info.columns) <= column_idx or column_idx < 0: raise ValueError("Invalid value for column_idx") res_info.columns[column_idx].serializer.set_chunk_handler( pytds.tds_types._StreamChunkedHandler(stream) ) @property def messages( self ) -> ( list[ tuple[ typing.Type, tds_base.IntegrityError | tds_base.ProgrammingError | tds_base.OperationalError, ] ] | None ): """Messages generated by server, see http://legacy.python.org/dev/peps/pep-0249/#cursor-messages""" if self._session: result = [] for msg in self._session.messages: ex = tds_base._create_exception_by_message(msg) result.append((type(ex), ex)) return result else: return None @property def native_description(self): """todo document""" if self._session is None: return None res = self._session.res_info if res: return res.native_descr else: return None def fetchone(self) -> typing.Any: """Fetch next row. Returns row using currently configured factory, or ``None`` if there are no more rows """ if self._session is None: raise self._cursor_closed_exception return self._session.fetchone() def fetchmany(self, size=None) -> list[typing.Any]: """Fetch next N rows :param size: Maximum number of rows to return, default value is cursor.arraysize :returns: List of rows """ if self._session is None: raise self._cursor_closed_exception if size is None: size = self.arraysize rows = [] for _ in range(size): row = self.fetchone() if not row: break rows.append(row) return rows def fetchall(self) -> list[typing.Any]: """Fetch all remaining rows Do not use this if you expect large number of rows returned by the server, since this method will load all rows into memory. It is more efficient to load and process rows by iterating over them. """ if self._session is None: raise self._cursor_closed_exception return list(row for row in self) def __next__(self) -> typing.Any: row = self.fetchone() if row is None: raise StopIteration return row @staticmethod def setinputsizes(sizes=None) -> None: """ This method does nothing, as permitted by DB-API specification. """ pass @staticmethod def setoutputsize(size=None, column=0) -> None: """ This method does nothing, as permitted by DB-API specification. """ pass def copy_to( self, file: Iterable[str] | None = None, table_or_view: str | None = None, sep: str = "\t", columns: Iterable[tds_base.Column | str] | None = None, check_constraints: bool = False, fire_triggers: bool = False, keep_nulls: bool = False, kb_per_batch: int | None = None, rows_per_batch: int | None = None, order: str | None = None, tablock: bool = False, schema: str | None = None, null_string: str | None = None, data: Iterable[collections.abc.Sequence[typing.Any]] | None = None, ): """*Experimental*. Efficiently load data to database from file using ``BULK INSERT`` operation :param file: Source file-like object, should be in csv format. Specify either this or data, not both. :param table_or_view: Destination table or view in the database :type table_or_view: str Optional parameters: :keyword sep: Separator used in csv file :type sep: str :keyword columns: List of :class:`pytds.tds_base.Column` objects or column names in target table to insert to. SQL Server will do some conversions, so these may not have to match the actual table definition exactly. If not provided will insert into all columns assuming nvarchar(4000) NULL for all columns. If only the column name is provided, the type is assumed to be nvarchar(4000) NULL. If rows are given with file, you cannot specify non-string data types. If rows are given with data, the values must be a type supported by the serializer for the column in tds_types. :type columns: list :keyword check_constraints: Check table constraints for incoming data :type check_constraints: bool :keyword fire_triggers: Enable or disable triggers for table :type fire_triggers: bool :keyword keep_nulls: If enabled null values inserted as-is, instead of inserting default value for column :type keep_nulls: bool :keyword kb_per_batch: Kilobytes per batch can be used to optimize performance, see MSSQL server documentation for details :type kb_per_batch: int :keyword rows_per_batch: Rows per batch can be used to optimize performance, see MSSQL server documentation for details :type rows_per_batch: int :keyword order: The ordering of the data in source table. List of columns with ASC or DESC suffix. E.g. ``['order_id ASC', 'name DESC']`` Can be used to optimize performance, see MSSQL server documentation for details :type order: list :keyword tablock: Enable or disable table lock for the duration of bulk load :keyword schema: Name of schema for table or view, if not specified default schema will be used :keyword null_string: String that should be interpreted as a NULL when reading the CSV file. Has no meaning if using data instead of file. :keyword data: The data to insert as an iterable of rows, which are iterables of values. Specify either data parameter or file parameter but not both. """ if self._session is None: raise self._cursor_closed_exception # conn = self._conn() rows: Iterable[collections.abc.Sequence[typing.Any]] if data is None: if file is None: raise ValueError("No data was specified via file or data parameter") reader = csv.reader(file, delimiter=sep) if null_string is not None: def _convert_null_strings(csv_reader): for row in csv_reader: yield [r if r != null_string else None for r in row] reader = _convert_null_strings(reader) rows = reader else: rows = data obj_name = tds_base.tds_quote_id(table_or_view) if schema: obj_name = f"{tds_base.tds_quote_id(schema)}.{obj_name}" if columns: metadata = [] for column in columns: if isinstance(column, tds_base.Column): metadata.append(column) else: metadata.append( tds_base.Column( name=column, type=NVarCharType(size=4000), flags=tds_base.Column.fNullable, ) ) else: self.execute(f"select top 1 * from {obj_name} where 1<>1") metadata = [ tds_base.Column( name=col[0], type=NVarCharType(size=4000), flags=tds_base.Column.fNullable if col[6] else 0, ) for col in self.description ] col_defs = ",".join( f"{tds_base.tds_quote_id(col.column_name)} {col.type.get_declaration()}" for col in metadata ) with_opts = [] if check_constraints: with_opts.append("CHECK_CONSTRAINTS") if fire_triggers: with_opts.append("FIRE_TRIGGERS") if keep_nulls: with_opts.append("KEEP_NULLS") if kb_per_batch: with_opts.append("KILOBYTES_PER_BATCH = {0}".format(kb_per_batch)) if rows_per_batch: with_opts.append("ROWS_PER_BATCH = {0}".format(rows_per_batch)) if order: with_opts.append("ORDER({0})".format(",".join(order))) if tablock: with_opts.append("TABLOCK") with_part = "" if with_opts: with_part = "WITH ({0})".format(",".join(with_opts)) operation = "INSERT BULK {0}({1}) {2}".format(obj_name, col_defs, with_part) self.execute(operation) self._session.submit_bulk(metadata, rows) self._session.process_simple_request() class NonMarsCursor(BaseCursor): """ This class represents a non-MARS database cursor, which is used to issue queries and fetch results from a database connection. Non-MARS connections allow only one cursor to be active at a given time. """ def __init__(self, connection: NonMarsConnection, session: _TdsSession): super().__init__(connection=connection, session=session) class _MarsCursor(BaseCursor): """ This class represents a MARS database cursor, which is used to issue queries and fetch results from a database connection. MARS connections allow multiple cursors to be active at the same time. """ def __init__(self, connection: MarsConnection, session: _TdsSession): super().__init__( connection=connection, session=session, ) @property def spid(self) -> int: # not thread safe for connection return self.execute_scalar("select @@SPID") def close(self) -> None: """ Closes the cursor. The cursor is unusable from this point. """ logger.debug("Closing MARS cursor") if self._session is not None: self._session.close() self._session = None self._connection = None pytds-1.15.0/src/pytds/extensions.py000066400000000000000000000012411456567501500174600ustar00rootroot00000000000000#: Transaction can read uncommitted data ISOLATION_LEVEL_READ_UNCOMMITTED = 1 #: Transaction can read only committed data, will block on attempt #: to read modified uncommitted data ISOLATION_LEVEL_READ_COMMITTED = 2 #: Transaction will place lock on read records, other transactions #: will block trying to modify such records ISOLATION_LEVEL_REPEATABLE_READ = 3 #: Transaction will lock tables to prevent other transactions #: from inserting new data that would match selected recordsets ISOLATION_LEVEL_SERIALIZABLE = 4 #: Allows non-blocking consistent reads on a snapshot for transaction without #: blocking other transactions changes ISOLATION_LEVEL_SNAPSHOT = 5 pytds-1.15.0/src/pytds/instance_browser_client.py000066400000000000000000000052361456567501500221760ustar00rootroot00000000000000""" This module implements client interface for MSSQL server browser which provides information about MSSQL server instances running on the host via UDP socket at port 1434. """ from __future__ import annotations import socket import typing from . import tds_base from .tds_base import logger def parse_instances_response(msg: bytes) -> dict[str, dict[str, str]] | None: """ Parses instances response as received from MSSQL server browser endpoint """ name: str | None = None if len(msg) > 3 and tds_base.my_ord(msg[0]) == 5: tokens = msg[3:].decode("ascii").split(";") results: dict[str, dict[str, str]] = {} instdict: dict[str, str] = {} got_name = False for token in tokens: if got_name and name: instdict[name] = token got_name = False else: name = token if not name: if not instdict: break results[instdict["InstanceName"].upper()] = instdict instdict = {} continue got_name = True return results return None def tds7_get_instances( ip_addr: typing.Any, timeout: float = 5 ) -> dict[str, dict[str, str]] | None: """ Get MSSQL instances information from instance browser service endpoint. Returns a dictionary keyed by instance name of dictionaries of instances information. """ with socket.socket(type=socket.SOCK_DGRAM) as s: s.settimeout(timeout) # send the request s.sendto(b"\x03", (ip_addr, 1434)) msg = s.recv(16 * 1024 - 1) # got data, read and parse return parse_instances_response(msg) def resolve_instance_port( server: typing.Any, port: int | None, instance: str, timeout: float = 5 ) -> int: """ Resolve MSSQL server instance's port, if instance name is provided and port not provided """ if instance and not port: logger.info("querying %s for list of instances", server) instances = tds7_get_instances(server, timeout=timeout) if not instances: raise RuntimeError( "Querying list of instances failed, returned value has invalid format" ) if instance not in instances: raise tds_base.LoginError( f"Instance {instance} not found on server {server}" ) instdict = instances[instance] if "tcp" not in instdict: raise tds_base.LoginError( f"Instance {instance} doen't have tcp connections enabled" ) port = int(instdict["tcp"]) return port or 1433 pytds-1.15.0/src/pytds/lcid.py000066400000000000000000000457241456567501500162120ustar00rootroot00000000000000# Copyright 2010 Michael Murr # # This file is part of LibForensics. # # LibForensics is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # LibForensics is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with LibForensics. If not, see . """Constants for Locale IDs. (LCIDs)""" __docformat__ = "restructuredtext en" __all__ = [ "LANGID_AFRIKAANS", "LANGID_ALBANIAN", "LANGID_AMHARIC", "LANGID_ARABIC", "LANGID_ARABIC_ALGERIA", "LANGID_ARABIC_BAHRAIN", "LANGID_ARABIC_EGYPT", "LANGID_ARABIC_IRAQ", "LANGID_ARABIC_JORDAN", "LANGID_ARABIC_KUWAIT", "LANGID_ARABIC_LEBANON", "LANGID_ARABIC_LIBYA", "LANGID_ARABIC_MOROCCO", "LANGID_ARABIC_OMAN", "LANGID_ARABIC_QATAR", "LANGID_ARABIC_SYRIA", "LANGID_ARABIC_TUNISIA", "LANGID_ARABIC_UAE", "LANGID_ARABIC_YEMEN", "LANGID_ARMENIAN", "LANGID_ASSAMESE", "LANGID_AZERI_CYRILLIC", "LANGID_AZERI_LATIN", "LANGID_BASQUE", "LANGID_BELGIAN_DUTCH", "LANGID_BELGIAN_FRENCH", "LANGID_BENGALI", "LANGID_BULGARIAN", "LANGID_BURMESE", "LANGID_BYELORUSSIAN", "LANGID_CATALAN", "LANGID_CHEROKEE", "LANGID_CHINESE_HONG_KONG_SAR", "LANGID_CHINESE_MACAO_SAR", "LANGID_CHINESE_SINGAPORE", "LANGID_CROATIAN", "LANGID_CZECH", "LANGID_DANISH", "LANGID_DIVEHI", "LANGID_DUTCH", "LANGID_EDO", "LANGID_ENGLISH_AUS", "LANGID_ENGLISH_BELIZE", "LANGID_ENGLISH_CANADIAN", "LANGID_ENGLISH_CARIBBEAN", "LANGID_ENGLISH_INDONESIA", "LANGID_ENGLISH_IRELAND", "LANGID_ENGLISH_JAMAICA", "LANGID_ENGLISH_NEW_ZEALAND", "LANGID_ENGLISH_PHILIPPINES", "LANGID_ENGLISH_SOUTH_AFRICA", "LANGID_ENGLISH_TRINIDAD_TOBAGO", "LANGID_ENGLISH_UK", "LANGID_ENGLISH_US", "LANGID_ENGLISH_ZIMBABWE", "LANGID_ESTONIAN", "LANGID_FAEROESE", "LANGID_FILIPINO", "LANGID_FINNISH", "LANGID_FRENCH", "LANGID_FRENCH_CAMEROON", "LANGID_FRENCH_CANADIAN", "LANGID_FRENCH_CONGO_D_R_C", "LANGID_FRENCH_COTED_IVOIRE", "LANGID_FRENCH_HAITI", "LANGID_FRENCH_LUXEMBOURG", "LANGID_FRENCH_MALI", "LANGID_FRENCH_MONACO", "LANGID_FRENCH_MOROCCO", "LANGID_FRENCH_REUNION", "LANGID_FRENCH_SENEGAL", "LANGID_FRENCH_WEST_INDIES", "LANGID_FRISIAN_NETHERLANDS", "LANGID_FULFULDE", "LANGID_GAELIC_IRELAND", "LANGID_GAELIC_SCOTLAND", "LANGID_GALICIAN", "LANGID_GEORGIAN", "LANGID_GERMAN", "LANGID_GERMAN_AUSTRIA", "LANGID_GERMAN_LIECHTENSTEIN", "LANGID_GERMAN_LUXEMBOURG", "LANGID_GREEK", "LANGID_GUARANI", "LANGID_GUJARATI", "LANGID_HAUSA", "LANGID_HAWAIIAN", "LANGID_HEBREW", "LANGID_HINDI", "LANGID_HUNGARIAN", "LANGID_IBIBIO", "LANGID_ICELANDIC", "LANGID_IGBO", "LANGID_INDONESIAN", "LANGID_INUKTITUT", "LANGID_ITALIAN", "LANGID_JAPANESE", "LANGID_KANNADA", "LANGID_KANURI", "LANGID_KASHMIRI", "LANGID_KAZAKH", "LANGID_KHMER", "LANGID_KIRGHIZ", "LANGID_KONKANI", "LANGID_KOREAN", "LANGID_KYRGYZ", "LANGID_LANGUAGE_NONE", "LANGID_LAO", "LANGID_LATIN", "LANGID_LATVIAN", "LANGID_LITHUANIAN", "LANGID_MACEDONIAN_FYROM", "LANGID_MALAYALAM", "LANGID_MALAYSIAN", "LANGID_MALAY_BRUNEI_DARUSSALAM", "LANGID_MALTESE", "LANGID_MANIPURI", "LANGID_MARATHI", "LANGID_MEXICAN_SPANISH", "LANGID_MONGOLIAN", "LANGID_NEPALI", "LANGID_NORWEGIAN_BOKMOL", "LANGID_NORWEGIAN_NYNORSK", "LANGID_NO_PROOFING", "LANGID_ORIYA", "LANGID_OROMO", "LANGID_PASHTO", "LANGID_PERSIAN", "LANGID_POLISH", "LANGID_PORTUGUESE", "LANGID_PORTUGUESE_BRAZIL", "LANGID_PUNJABI", "LANGID_RHAETO_ROMANIC", "LANGID_ROMANIAN", "LANGID_ROMANIAN_MOLDOVA", "LANGID_RUSSIAN", "LANGID_RUSSIAN_MOLDOVA", "LANGID_SAMI_LAPPISH", "LANGID_SANSKRIT", "LANGID_SERBIAN_CYRILLIC", "LANGID_SERBIAN_LATIN", "LANGID_SESOTHO", "LANGID_SIMPLIFIED_CHINESE", "LANGID_SINDHI", "LANGID_SINDHI_PAKISTAN", "LANGID_SINHALESE", "LANGID_SLOVAK", "LANGID_SLOVENIAN", "LANGID_SOMALI", "LANGID_SORBIAN", "LANGID_SPANISH", "LANGID_SPANISH_ARGENTINA", "LANGID_SPANISH_BOLIVIA", "LANGID_SPANISH_CHILE", "LANGID_SPANISH_COLOMBIA", "LANGID_SPANISH_COSTA_RICA", "LANGID_SPANISH_DOMINICAN_REPUBLIC", "LANGID_SPANISH_ECUADOR", "LANGID_SPANISH_EL_SALVADOR", "LANGID_SPANISH_GUATEMALA", "LANGID_SPANISH_HONDURAS", "LANGID_SPANISH_MODERN_SORT", "LANGID_SPANISH_NICARAGUA", "LANGID_SPANISH_PANAMA", "LANGID_SPANISH_PARAGUAY", "LANGID_SPANISH_PERU", "LANGID_SPANISH_PUERTO_RICO", "LANGID_SPANISH_URUGUAY", "LANGID_SPANISH_VENEZUELA", "LANGID_SUTU", "LANGID_SWAHILI", "LANGID_SWEDISH", "LANGID_SWEDISH_FINLAND", "LANGID_SWISS_FRENCH", "LANGID_SWISS_GERMAN", "LANGID_SWISS_ITALIAN", "LANGID_SYRIAC", "LANGID_TAJIK", "LANGID_TAMAZIGHT", "LANGID_TAMAZIGHT_LATIN", "LANGID_TAMIL", "LANGID_TATAR", "LANGID_TELUGU", "LANGID_THAI", "LANGID_TIBETAN", "LANGID_TIGRIGNA_ERITREA", "LANGID_TIGRIGNA_ETHIOPIC", "LANGID_TRADITIONAL_CHINESE", "LANGID_TSONGA", "LANGID_TSWANA", "LANGID_TURKISH", "LANGID_TURKMEN", "LANGID_UKRAINIAN", "LANGID_URDU", "LANGID_UZBEK_CYRILLIC", "LANGID_UZBEK_LATIN", "LANGID_VENDA", "LANGID_VIETNAMESE", "LANGID_WELSH", "LANGID_XHOSA", "LANGID_YI", "LANGID_YIDDISH", "LANGID_YORUBA", "LANGID_ZULU", "lang_id_names", ] LANGID_AFRIKAANS = 1078 LANGID_ALBANIAN = 1052 LANGID_AMHARIC = 1118 LANGID_ARABIC = 1025 LANGID_ARABIC_ALGERIA = 5121 LANGID_ARABIC_BAHRAIN = 15361 LANGID_ARABIC_EGYPT = 3073 LANGID_ARABIC_IRAQ = 2049 LANGID_ARABIC_JORDAN = 11265 LANGID_ARABIC_KUWAIT = 13313 LANGID_ARABIC_LEBANON = 12289 LANGID_ARABIC_LIBYA = 4097 LANGID_ARABIC_MOROCCO = 6145 LANGID_ARABIC_OMAN = 8193 LANGID_ARABIC_QATAR = 16385 LANGID_ARABIC_SYRIA = 10241 LANGID_ARABIC_TUNISIA = 7169 LANGID_ARABIC_UAE = 14337 LANGID_ARABIC_YEMEN = 9217 LANGID_ARMENIAN = 1067 LANGID_ASSAMESE = 1101 LANGID_AZERI_CYRILLIC = 2092 LANGID_AZERI_LATIN = 1068 LANGID_BASQUE = 1069 LANGID_BELGIAN_DUTCH = 2067 LANGID_BELGIAN_FRENCH = 2060 LANGID_BENGALI = 1093 LANGID_BULGARIAN = 1026 LANGID_BURMESE = 1109 LANGID_BYELORUSSIAN = 1059 LANGID_CATALAN = 1027 LANGID_CHEROKEE = 1116 LANGID_CHINESE_HONG_KONG_SAR = 3076 LANGID_CHINESE_MACAO_SAR = 5124 LANGID_CHINESE_SINGAPORE = 4100 LANGID_CROATIAN = 1050 LANGID_CZECH = 1029 LANGID_DANISH = 1030 LANGID_DIVEHI = 1125 LANGID_DUTCH = 1043 LANGID_EDO = 1126 LANGID_ENGLISH_AUS = 3081 LANGID_ENGLISH_BELIZE = 10249 LANGID_ENGLISH_CANADIAN = 4105 LANGID_ENGLISH_CARIBBEAN = 9225 LANGID_ENGLISH_INDONESIA = 14345 LANGID_ENGLISH_IRELAND = 6153 LANGID_ENGLISH_JAMAICA = 8201 LANGID_ENGLISH_NEW_ZEALAND = 5129 LANGID_ENGLISH_PHILIPPINES = 13321 LANGID_ENGLISH_SOUTH_AFRICA = 7177 LANGID_ENGLISH_TRINIDAD_TOBAGO = 11273 LANGID_ENGLISH_UK = 2057 LANGID_ENGLISH_US = 1033 LANGID_ENGLISH_ZIMBABWE = 12297 LANGID_ESTONIAN = 1061 LANGID_FAEROESE = 1080 LANGID_FILIPINO = 1124 LANGID_FINNISH = 1035 LANGID_FRENCH = 1036 LANGID_FRENCH_CAMEROON = 11276 LANGID_FRENCH_CANADIAN = 3084 LANGID_FRENCH_CONGO_D_R_C = 9228 LANGID_FRENCH_COTED_IVOIRE = 12300 LANGID_FRENCH_HAITI = 15372 LANGID_FRENCH_LUXEMBOURG = 5132 LANGID_FRENCH_MALI = 13324 LANGID_FRENCH_MONACO = 6156 LANGID_FRENCH_MOROCCO = 14348 LANGID_FRENCH_REUNION = 8204 LANGID_FRENCH_SENEGAL = 10252 LANGID_FRENCH_WEST_INDIES = 7180 LANGID_FRISIAN_NETHERLANDS = 1122 LANGID_FULFULDE = 1127 LANGID_GAELIC_IRELAND = 2108 LANGID_GAELIC_SCOTLAND = 1084 LANGID_GALICIAN = 1110 LANGID_GEORGIAN = 1079 LANGID_GERMAN = 1031 LANGID_GERMAN_AUSTRIA = 3079 LANGID_GERMAN_LIECHTENSTEIN = 5127 LANGID_GERMAN_LUXEMBOURG = 4103 LANGID_GREEK = 1032 LANGID_GUARANI = 1140 LANGID_GUJARATI = 1095 LANGID_HAUSA = 1128 LANGID_HAWAIIAN = 1141 LANGID_HEBREW = 1037 LANGID_HINDI = 1081 LANGID_HUNGARIAN = 1038 LANGID_IBIBIO = 1129 LANGID_ICELANDIC = 1039 LANGID_IGBO = 1136 LANGID_INDONESIAN = 1057 LANGID_INUKTITUT = 1117 LANGID_ITALIAN = 1040 LANGID_JAPANESE = 1041 LANGID_KANNADA = 1099 LANGID_KANURI = 1137 LANGID_KASHMIRI = 1120 LANGID_KAZAKH = 1087 LANGID_KHMER = 1107 LANGID_KIRGHIZ = 1088 LANGID_KONKANI = 1111 LANGID_KOREAN = 1042 LANGID_KYRGYZ = 1088 LANGID_LANGUAGE_NONE = 0 LANGID_LAO = 1108 LANGID_LATIN = 1142 LANGID_LATVIAN = 1062 LANGID_LITHUANIAN = 1063 LANGID_MACEDONIAN_FYROM = 1071 LANGID_MALAYALAM = 1100 LANGID_MALAY_BRUNEI_DARUSSALAM = 2110 LANGID_MALAYSIAN = 1086 LANGID_MALTESE = 1082 LANGID_MANIPURI = 1112 LANGID_MARATHI = 1102 LANGID_MEXICAN_SPANISH = 2058 LANGID_MONGOLIAN = 1104 LANGID_NEPALI = 1121 LANGID_NO_PROOFING = 1024 LANGID_NORWEGIAN_BOKMOL = 1044 LANGID_NORWEGIAN_NYNORSK = 2068 LANGID_ORIYA = 1096 LANGID_OROMO = 1138 LANGID_PASHTO = 1123 LANGID_PERSIAN = 1065 LANGID_POLISH = 1045 LANGID_PORTUGUESE = 2070 LANGID_PORTUGUESE_BRAZIL = 1046 LANGID_PUNJABI = 1094 LANGID_RHAETO_ROMANIC = 1047 LANGID_ROMANIAN = 1048 LANGID_ROMANIAN_MOLDOVA = 2072 LANGID_RUSSIAN = 1049 LANGID_RUSSIAN_MOLDOVA = 2073 LANGID_SAMI_LAPPISH = 1083 LANGID_SANSKRIT = 1103 LANGID_SERBIAN_CYRILLIC = 3098 LANGID_SERBIAN_LATIN = 2074 LANGID_SESOTHO = 1072 LANGID_SIMPLIFIED_CHINESE = 2052 LANGID_SINDHI = 1113 LANGID_SINDHI_PAKISTAN = 2137 LANGID_SINHALESE = 1115 LANGID_SLOVAK = 1051 LANGID_SLOVENIAN = 1060 LANGID_SOMALI = 1143 LANGID_SORBIAN = 1070 LANGID_SPANISH = 1034 LANGID_SPANISH_ARGENTINA = 11274 LANGID_SPANISH_BOLIVIA = 16394 LANGID_SPANISH_CHILE = 13322 LANGID_SPANISH_COLOMBIA = 9226 LANGID_SPANISH_COSTA_RICA = 5130 LANGID_SPANISH_DOMINICAN_REPUBLIC = 7178 LANGID_SPANISH_ECUADOR = 12298 LANGID_SPANISH_EL_SALVADOR = 17418 LANGID_SPANISH_GUATEMALA = 4106 LANGID_SPANISH_HONDURAS = 18442 LANGID_SPANISH_MODERN_SORT = 3082 LANGID_SPANISH_NICARAGUA = 19466 LANGID_SPANISH_PANAMA = 6154 LANGID_SPANISH_PARAGUAY = 15370 LANGID_SPANISH_PERU = 10250 LANGID_SPANISH_PUERTO_RICO = 20490 LANGID_SPANISH_URUGUAY = 14346 LANGID_SPANISH_VENEZUELA = 8202 LANGID_SUTU = 1072 LANGID_SWAHILI = 1089 LANGID_SWEDISH = 1053 LANGID_SWEDISH_FINLAND = 2077 LANGID_SWISS_FRENCH = 4108 LANGID_SWISS_GERMAN = 2055 LANGID_SWISS_ITALIAN = 2064 LANGID_SYRIAC = 1114 LANGID_TAJIK = 1064 LANGID_TAMAZIGHT = 1119 LANGID_TAMAZIGHT_LATIN = 2143 LANGID_TAMIL = 1097 LANGID_TATAR = 1092 LANGID_TELUGU = 1098 LANGID_THAI = 1054 LANGID_TIBETAN = 1105 LANGID_TIGRIGNA_ERITREA = 2163 LANGID_TIGRIGNA_ETHIOPIC = 1139 LANGID_TRADITIONAL_CHINESE = 1028 LANGID_TSONGA = 1073 LANGID_TSWANA = 1074 LANGID_TURKISH = 1055 LANGID_TURKMEN = 1090 LANGID_UKRAINIAN = 1058 LANGID_URDU = 1056 LANGID_UZBEK_CYRILLIC = 2115 LANGID_UZBEK_LATIN = 1091 LANGID_VENDA = 1075 LANGID_VIETNAMESE = 1066 LANGID_WELSH = 1106 LANGID_XHOSA = 1076 LANGID_YI = 1144 LANGID_YIDDISH = 1085 LANGID_YORUBA = 1130 LANGID_ZULU = 1077 lang_id_names = { LANGID_AFRIKAANS: "African", LANGID_ALBANIAN: "Albanian", LANGID_AMHARIC: "Amharic", LANGID_ARABIC: "Arabic", LANGID_ARABIC_ALGERIA: "Arabic Algerian", LANGID_ARABIC_BAHRAIN: "Arabic Bahraini", LANGID_ARABIC_EGYPT: "Arabic Egyptian", LANGID_ARABIC_IRAQ: "Arabic Iraqi", LANGID_ARABIC_JORDAN: "Arabic Jordanian", LANGID_ARABIC_KUWAIT: "Arabic Kuwaiti", LANGID_ARABIC_LEBANON: "Arabic Lebanese", LANGID_ARABIC_LIBYA: "Arabic Libyan", LANGID_ARABIC_MOROCCO: "Arabic Moroccan", LANGID_ARABIC_OMAN: "Arabic Omani", LANGID_ARABIC_QATAR: "Arabic Qatari", LANGID_ARABIC_SYRIA: "Arabic Syrian", LANGID_ARABIC_TUNISIA: "Arabic Tunisian", LANGID_ARABIC_UAE: "Arabic United Arab Emirates", LANGID_ARABIC_YEMEN: "Arabic Yemeni", LANGID_ARMENIAN: "Armenian", LANGID_ASSAMESE: "Assamese", LANGID_AZERI_CYRILLIC: "Azeri Cyrillic", LANGID_AZERI_LATIN: "Azeri Latin", LANGID_BASQUE: "Basque", LANGID_BELGIAN_DUTCH: "Belgian Dutch", LANGID_BELGIAN_FRENCH: "Belgian French", LANGID_BENGALI: "Bengali", LANGID_BULGARIAN: "Bulgarian", LANGID_BURMESE: "Burmese", LANGID_BYELORUSSIAN: "Byelorussian", LANGID_CATALAN: "Catalan", LANGID_CHEROKEE: "Cherokee", LANGID_CHINESE_HONG_KONG_SAR: "Chinese Hong Kong SAR", LANGID_CHINESE_MACAO_SAR: "Chinese Macao SAR", LANGID_CHINESE_SINGAPORE: "Chinese Singapore", LANGID_CROATIAN: "Croatian", LANGID_CZECH: "Czech", LANGID_DANISH: "Danish", LANGID_DIVEHI: "Divehi", LANGID_DUTCH: "Dutch", LANGID_EDO: "Edo", LANGID_ENGLISH_AUS: "Australian English", LANGID_ENGLISH_BELIZE: "Belize English", LANGID_ENGLISH_CANADIAN: "Canadian English", LANGID_ENGLISH_CARIBBEAN: "Caribbean English", LANGID_ENGLISH_INDONESIA: "Indonesian English", LANGID_ENGLISH_IRELAND: "Irish English", LANGID_ENGLISH_JAMAICA: "Jamaican English", LANGID_ENGLISH_NEW_ZEALAND: "New Zealand English", LANGID_ENGLISH_PHILIPPINES: "Filipino English", LANGID_ENGLISH_SOUTH_AFRICA: "South African English", LANGID_ENGLISH_TRINIDAD_TOBAGO: "Tobago Trinidad English", LANGID_ENGLISH_UK: "United Kingdom English", LANGID_ENGLISH_US: "United States English", LANGID_ENGLISH_ZIMBABWE: "Zimbabwe English", LANGID_ESTONIAN: "Estonian", LANGID_FAEROESE: "Faeroese", LANGID_FILIPINO: "Filipino", LANGID_FINNISH: "Finnish", LANGID_FRENCH: "French", LANGID_FRENCH_CAMEROON: "French Cameroon", LANGID_FRENCH_CANADIAN: "French Canadian", LANGID_FRENCH_CONGO_D_R_C: "French (Congo (DRC))", LANGID_FRENCH_COTED_IVOIRE: "French Cote d'Ivoire", LANGID_FRENCH_HAITI: "French Haiti", LANGID_FRENCH_LUXEMBOURG: "French Luxembourg", LANGID_FRENCH_MALI: "French Mali", LANGID_FRENCH_MONACO: "French Monaco", LANGID_FRENCH_MOROCCO: "French Morocco", LANGID_FRENCH_REUNION: "French Reunion", LANGID_FRENCH_SENEGAL: "French Senegal", LANGID_FRENCH_WEST_INDIES: "French West Indies", LANGID_FRISIAN_NETHERLANDS: "Frisian Netherlands", LANGID_FULFULDE: "Fulfulde", LANGID_GAELIC_IRELAND: "Gaelic Irish", LANGID_GAELIC_SCOTLAND: "Gaelic Scottish", LANGID_GALICIAN: "Galician", LANGID_GEORGIAN: "Georgian", LANGID_GERMAN: "German", LANGID_GERMAN_AUSTRIA: "German Austrian", LANGID_GERMAN_LIECHTENSTEIN: "German Liechtenstein", LANGID_GERMAN_LUXEMBOURG: "German Luxembourg", LANGID_GREEK: "Greek", LANGID_GUARANI: "Guarani", LANGID_GUJARATI: "Gujarati", LANGID_HAUSA: "Hausa", LANGID_HAWAIIAN: "Hawaiian", LANGID_HEBREW: "Hebrew", LANGID_HINDI: "Hindi", LANGID_HUNGARIAN: "Hungarian", LANGID_IBIBIO: "Ibibio", LANGID_ICELANDIC: "Icelandic", LANGID_IGBO: "Igbo", LANGID_INDONESIAN: "Indonesian", LANGID_INUKTITUT: "Inuktitut", LANGID_ITALIAN: "Italian", LANGID_JAPANESE: "Japanese", LANGID_KANNADA: "Kannada", LANGID_KANURI: "Kanuri", LANGID_KASHMIRI: "Kashmiri", LANGID_KAZAKH: "Kazakh", LANGID_KHMER: "Khmer", LANGID_KIRGHIZ: "Kirghiz", LANGID_KONKANI: "Konkani", LANGID_KOREAN: "Korean", LANGID_KYRGYZ: "Kyrgyz", LANGID_LANGUAGE_NONE: "No specified", LANGID_LAO: "Lao", LANGID_LATIN: "Latin", LANGID_LATVIAN: "Latvian", LANGID_LITHUANIAN: "Lithuanian", LANGID_MACEDONIAN_FYROM: "Macedonian (FYROM)", LANGID_MALAYALAM: "Malayalam", LANGID_MALAY_BRUNEI_DARUSSALAM: "Malay Brunei Darussalam", LANGID_MALAYSIAN: "Malaysian", LANGID_MALTESE: "Maltese", LANGID_MANIPURI: "Manipuri", LANGID_MARATHI: "Marathi", LANGID_MEXICAN_SPANISH: "Mexican Spanish", LANGID_MONGOLIAN: "Mongolian", LANGID_NEPALI: "Nepali", LANGID_NO_PROOFING: "Disables proofing", LANGID_NORWEGIAN_BOKMOL: "Norwegian Bokmol", LANGID_NORWEGIAN_NYNORSK: "Norwegian Nynorsk", LANGID_ORIYA: "Oriya", LANGID_OROMO: "Oromo", LANGID_PASHTO: "Pashto", LANGID_PERSIAN: "Persian", LANGID_POLISH: "Polish", LANGID_PORTUGUESE: "Portuguese", LANGID_PORTUGUESE_BRAZIL: "Portuguese (Brazil)", LANGID_PUNJABI: "Punjabi", LANGID_RHAETO_ROMANIC: "Rhaeto Romanic", LANGID_ROMANIAN: "Romanian", LANGID_ROMANIAN_MOLDOVA: "Romanian Moldova", LANGID_RUSSIAN: "Russian", LANGID_RUSSIAN_MOLDOVA: "Russian Moldova", LANGID_SAMI_LAPPISH: "Sami Lappish", LANGID_SANSKRIT: "Sanskrit", LANGID_SERBIAN_CYRILLIC: "Serbian Cyrillic", LANGID_SERBIAN_LATIN: "Serbian Latin", LANGID_SESOTHO: "Sesotho", LANGID_SIMPLIFIED_CHINESE: "Simplified Chinese", LANGID_SINDHI: "Sindhi", LANGID_SINDHI_PAKISTAN: "Sindhi (Pakistan)", LANGID_SINHALESE: "Sinhalese", LANGID_SLOVAK: "Slovakian", LANGID_SLOVENIAN: "Slovenian", LANGID_SOMALI: "Somali", LANGID_SORBIAN: "Sorbian", LANGID_SPANISH: "Spanish", LANGID_SPANISH_ARGENTINA: "Spanish Argentina", LANGID_SPANISH_BOLIVIA: "Spanish Bolivian", LANGID_SPANISH_CHILE: "Spanish Chilean", LANGID_SPANISH_COLOMBIA: "Spanish Colombian", LANGID_SPANISH_COSTA_RICA: "Spanish Costa Rican", LANGID_SPANISH_DOMINICAN_REPUBLIC: "Spanish Dominican Republic", LANGID_SPANISH_ECUADOR: "Spanish Ecuadorian", LANGID_SPANISH_EL_SALVADOR: "Spanish El Salvadorian", LANGID_SPANISH_GUATEMALA: "Spanish Guatemala", LANGID_SPANISH_HONDURAS: "Spanish Honduran", LANGID_SPANISH_MODERN_SORT: "Spanish Modern Sort", LANGID_SPANISH_NICARAGUA: "Spanish Nicaraguan", LANGID_SPANISH_PANAMA: "Spanish Panamanian", LANGID_SPANISH_PARAGUAY: "Spanish Paraguayan", LANGID_SPANISH_PERU: "Spanish Peruvian", LANGID_SPANISH_PUERTO_RICO: "Spanish Puerto Rican", LANGID_SPANISH_URUGUAY: "Spanish Uruguayan", LANGID_SPANISH_VENEZUELA: "Spanish Venezuelan", LANGID_SUTU: "Sutu", LANGID_SWAHILI: "Swahili", LANGID_SWEDISH: "Swedish", LANGID_SWEDISH_FINLAND: "Swedish Finnish", LANGID_SWISS_FRENCH: "Swiss French", LANGID_SWISS_GERMAN: "Swiss German", LANGID_SWISS_ITALIAN: "Swiss Italian", LANGID_SYRIAC: "Syriac", LANGID_TAJIK: "Tajik", LANGID_TAMAZIGHT: "Tamazight", LANGID_TAMAZIGHT_LATIN: "Tamazight Latin", LANGID_TAMIL: "Tamil", LANGID_TATAR: "Tatar", LANGID_TELUGU: "Telugu", LANGID_THAI: "Thai", LANGID_TIBETAN: "Tibetan", LANGID_TIGRIGNA_ERITREA: "Tigrigna Eritrea", LANGID_TIGRIGNA_ETHIOPIC: "Tigrigna Ethiopic", LANGID_TRADITIONAL_CHINESE: "Traditional Chinese", LANGID_TSONGA: "Tsonga", LANGID_TSWANA: "Tswana", LANGID_TURKISH: "Turkish", LANGID_TURKMEN: "Turkmen", LANGID_UKRAINIAN: "Ukrainian", LANGID_URDU: "Urdu", LANGID_UZBEK_CYRILLIC: "Uzbek Cyrillic", LANGID_UZBEK_LATIN: "Uzbek Latin", LANGID_VENDA: "Venda", LANGID_VIETNAMESE: "Vietnamese", LANGID_WELSH: "Welsh", LANGID_XHOSA: "Xhosa", LANGID_YI: "Yi", LANGID_YIDDISH: "Yiddish", LANGID_YORUBA: "Yoruba", LANGID_ZULU: "Zulu", } pytds-1.15.0/src/pytds/login.py000066400000000000000000000161721456567501500164020ustar00rootroot00000000000000# vim: set fileencoding=utf8 : """ .. module:: login :platform: Unix, Windows, MacOSX :synopsis: Login classes .. moduleauthor:: Mikhail Denisenko """ from __future__ import annotations import base64 import ctypes import logging import socket from pytds.tds_base import AuthProtocol logger = logging.getLogger(__name__) class SspiAuth(AuthProtocol): """SSPI authentication :platform: Windows Required parameters are server_name and port or spn :keyword user_name: User name, if not provided current security context will be used :type user_name: str :keyword password: User password, if not provided current security context will be used :type password: str :keyword server_name: MSSQL server host name :type server_name: str :keyword port: MSSQL server port :type port: int :keyword spn: Service name :type spn: str """ def __init__( self, user_name: str = "", password: str = "", server_name: str = "", port: int | None = None, spn: str | None = None, ) -> None: from . import sspi # parse username/password informations if "\\" in user_name: domain, user_name = user_name.split("\\") else: domain = "" if domain and user_name: self._identity = sspi.make_winnt_identity(domain, user_name, password) else: self._identity = None # build SPN if spn: self._sname = spn else: primary_host_name, _, _ = socket.gethostbyname_ex(server_name) self._sname = f"MSSQLSvc/{primary_host_name}:{port}" # using Negotiate system will use proper protocol (either NTLM or Kerberos) self._cred = sspi.SspiCredentials( package="Negotiate", use=sspi.SECPKG_CRED_OUTBOUND, identity=self._identity ) self._flags = ( sspi.ISC_REQ_CONFIDENTIALITY | sspi.ISC_REQ_REPLAY_DETECT | sspi.ISC_REQ_CONNECTION ) self._ctx = None def create_packet(self) -> bytes: from . import sspi buf = ctypes.create_string_buffer(4096) ctx, status, bufs = self._cred.create_context( flags=self._flags, byte_ordering="network", target_name=self._sname, output_buffers=[(sspi.SECBUFFER_TOKEN, buf)], ) self._ctx = ctx if status == sspi.Status.SEC_I_COMPLETE_AND_CONTINUE: ctx.complete_auth_token(bufs) return bufs[0][1] def handle_next(self, packet: bytes) -> bytes | None: from . import sspi if self._ctx: buf = ctypes.create_string_buffer(4096) status, buffers = self._ctx.next( flags=self._flags, byte_ordering="network", target_name=self._sname, input_buffers=[(sspi.SECBUFFER_TOKEN, packet)], output_buffers=[(sspi.SECBUFFER_TOKEN, buf)], ) return buffers[0][1] else: return None def close(self) -> None: if self._ctx: self._ctx.close() class NtlmAuth(AuthProtocol): """ This class is deprecated since `ntlm-auth` package, on which it depends, is deprecated. Instead use :class:`.SpnegoAuth`. NTLM authentication, uses Python implementation (ntlm-auth) For more information about NTLM authentication see https://github.com/jborean93/ntlm-auth :param user_name: User name :type user_name: str :param password: User password :type password: str :param ntlm_compatibility: NTLM compatibility level, default is 3(NTLMv2) :type ntlm_compatibility: int """ def __init__(self, user_name: str, password: str, ntlm_compatibility: int = 3) -> None: self._user_name = user_name if "\\" in user_name: domain, self._user = user_name.split("\\", 1) self._domain = domain.upper() else: self._domain = "WORKSPACE" self._user = user_name self._password = password self._workstation = socket.gethostname().upper() try: from ntlm_auth.ntlm import NtlmContext # type: ignore # fix later except ImportError: raise ImportError( "To use NTLM authentication you need to install ntlm-auth module" ) self._ntlm_context = NtlmContext( self._user, self._password, self._domain, self._workstation, ntlm_compatibility=ntlm_compatibility, ) def create_packet(self) -> bytes: return self._ntlm_context.step() def handle_next(self, packet: bytes) -> bytes | None: return self._ntlm_context.step(packet) def close(self) -> None: pass class SpnegoAuth(AuthProtocol): """Authentication using Negotiate protocol, uses implementation provided pyspnego package Takes same parameters as spnego.client function. """ def __init__(self, *args, **kwargs) -> None: try: import spnego except ImportError: raise ImportError( "To use spnego authentication you need to install pyspnego package" ) self._context = spnego.client(*args, **kwargs) def create_packet(self) -> bytes: result = self._context.step() if not result: raise RuntimeError("spnego did not create initial packet") return result def handle_next(self, packet: bytes) -> bytes | None: return self._context.step(packet) def close(self) -> None: pass class KerberosAuth(AuthProtocol): def __init__(self, server_principal: str) -> None: try: import kerberos # type: ignore # fix later except ImportError: import winkerberos as kerberos # type: ignore # fix later self._kerberos = kerberos res, context = kerberos.authGSSClientInit(server_principal) if res < 0: raise RuntimeError(f"authGSSClientInit failed with code {res}") logger.info("Initialized GSS context") self._context = context def create_packet(self) -> bytes: res = self._kerberos.authGSSClientStep(self._context, "") if res < 0: raise RuntimeError(f"authGSSClientStep failed with code {res}") data = self._kerberos.authGSSClientResponse(self._context) logger.info("created first client GSS packet %s", data) return base64.b64decode(data) def handle_next(self, packet: bytes) -> bytes | None: res = self._kerberos.authGSSClientStep( self._context, base64.b64encode(packet).decode("ascii") ) if res < 0: raise RuntimeError(f"authGSSClientStep failed with code {res}") if res == self._kerberos.AUTH_GSS_COMPLETE: logger.info("GSS authentication completed") return b"" else: data = self._kerberos.authGSSClientResponse(self._context) logger.info("created client GSS packet %s", data) return base64.b64decode(data) def close(self) -> None: pass pytds-1.15.0/src/pytds/row_strategies.py000066400000000000000000000067451456567501500203400ustar00rootroot00000000000000""" This module implements various row strategies. E.g. row strategy that generated dictionaries or named tuples for rows. """ import collections import keyword import re from typing import Iterable, Callable, Any, Tuple, NamedTuple, Dict, List # RowGenerator is a callable which takes a list of column values and # returns an object representing that row RowGenerator = Callable[[Iterable[Any]], Any] # RowStrategy is a callable that takes a list of column names # and returns a row generator RowStrategy = Callable[[Iterable[str]], RowGenerator] def tuple_row_strategy( column_names: Iterable[str] ) -> Callable[[Iterable[Any]], Tuple[Any, ...]]: """Tuple row strategy, rows returned as tuples, default""" return tuple def list_row_strategy( column_names: Iterable[str] ) -> Callable[[Iterable[Any]], List[Any]]: """List row strategy, rows returned as lists""" return list def dict_row_strategy( column_names: Iterable[str] ) -> Callable[[Iterable[Any]], Dict[str, Any]]: """Dict row strategy, rows returned as dictionaries""" # replace empty column names with indices column_names = [(name or str(idx)) for idx, name in enumerate(column_names)] def row_factory(row: Iterable[Any]) -> Dict[str, Any]: return dict(zip(column_names, row)) return row_factory def is_valid_identifier(name: str) -> bool: """Returns true if given name can be used as an identifier in Python, otherwise returns false.""" return bool( name and re.match("^[_A-Za-z][_a-zA-Z0-9]*$", name) and not keyword.iskeyword(name) ) def namedtuple_row_strategy( column_names: Iterable[str] ) -> Callable[[Iterable[Any]], NamedTuple]: """Namedtuple row strategy, rows returned as named tuples Column names that are not valid Python identifiers will be replaced with col_ """ # replace empty column names with placeholders clean_column_names = [ name if is_valid_identifier(name) else f"col{idx}_" for idx, name in enumerate(column_names) ] row_class = collections.namedtuple("Row", clean_column_names) # type: ignore # needs fixing def row_factory(row: Iterable[Any]) -> NamedTuple: return row_class(*row) return row_factory def recordtype_row_strategy( column_names: Iterable[str] ) -> Callable[[Iterable[Any]], Any]: """Recordtype row strategy, rows returned as recordtypes Column names that are not valid Python identifiers will be replaced with col_ """ try: from namedlist import namedlist as recordtype # type: ignore # needs fixing # optional dependency except ImportError: from recordtype import recordtype # type: ignore # needs fixing # optional dependency # replace empty column names with placeholders column_names = [ name if is_valid_identifier(name) else "col%s_" % idx for idx, name in enumerate(column_names) ] recordtype_row_class = recordtype("Row", column_names) # custom extension class that supports indexing class Row(recordtype_row_class): # type: ignore # needs fixing def __getitem__(self, index): if isinstance(index, slice): return tuple(getattr(self, x) for x in self.__slots__[index]) return getattr(self, self.__slots__[index]) def __setitem__(self, index, value): setattr(self, self.__slots__[index], value) def row_factory(row: Iterable[Any]) -> Row: return Row(*row) return row_factory pytds-1.15.0/src/pytds/smp.py000066400000000000000000000305551456567501500160720ustar00rootroot00000000000000""" This file implements Session Multiplex Protocol used by MARS connections Protocol documentation https://msdn.microsoft.com/en-us/library/cc219643.aspx """ from __future__ import annotations import struct import logging import threading import socket import errno from typing import Dict, Tuple from . import tds_base try: from bitarray import bitarray # type: ignore # fix typing later except ImportError: class BitArray(list): def __init__(self, size: int): super(BitArray, self).__init__() self[:] = [False] * size def setall(self, val: bool) -> None: for i in range(len(self)): self[i] = val bitarray = BitArray from .tds_base import Error, skipall, TransportProtocol logger = logging.getLogger(__name__) SMP_HEADER = struct.Struct(" int | None: return self._state def close(self) -> None: self._mgr.close_smp_session(self) def sendall(self, data: bytes, flags: int = 0) -> None: self._mgr.send_packet(self, data) def _recv_internal(self, size: int) -> Tuple[int, int]: if not self._curr_buf[self._curr_buf_pos :]: self._curr_buf = self._mgr.recv_packet(self) self._curr_buf_pos = 0 if not self._curr_buf: return 0, 0 to_read = min(size, len(self._curr_buf) - self._curr_buf_pos) offset = self._curr_buf_pos self._curr_buf_pos += to_read return offset, to_read def recv_into( self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0 ) -> int: if size == 0: size = len(buffer) offset, to_read = self._recv_internal(size) buffer[:to_read] = self._curr_buf[offset : offset + to_read] return to_read def recv(self, size: int) -> bytes: offset, to_read = self._recv_internal(size) return self._curr_buf[offset : offset + to_read] def is_connected(self) -> bool: return self._state == SessionState.SESSION_ESTABLISHED def gettimeout(self) -> float | None: return self._mgr._transport.gettimeout() def settimeout(self, timeout: float | None) -> None: self._mgr._transport.settimeout(timeout) class PacketTypes: SYN = 0x1 ACK = 0x2 FIN = 0x4 DATA = 0x8 # @staticmethod # def type_to_str(t): # if t == PacketTypes.SYN: # return 'SYN' # elif t == PacketTypes.ACK: # return 'ACK' # elif t == PacketTypes.DATA: # return 'DATA' # elif t == PacketTypes.FIN: # return 'FIN' class SessionState: SESSION_ESTABLISHED = 1 CLOSED = 2 FIN_SENT = 3 FIN_RECEIVED = 4 @staticmethod def to_str(st: int) -> str: if st == SessionState.SESSION_ESTABLISHED: return "SESSION ESTABLISHED" elif st == SessionState.CLOSED: return "CLOSED" elif st == SessionState.FIN_SENT: return "FIN SENT" elif st == SessionState.FIN_RECEIVED: return "FIN RECEIVED" else: raise RuntimeError(f"invalid session state: {st}") class SmpManager: def __init__(self, transport: TransportProtocol, max_sessions: int = 2**16): self._transport = transport self._sessions: Dict[int, _SmpSession] = {} self._used_ids_ba = bitarray(max_sessions) self._used_ids_ba.setall(False) self._lock = threading.RLock() self._hdr_buf = memoryview(bytearray(b"\x00" * SMP_HEADER.size)) def __repr__(self): return "".format(self._sessions) def create_session(self) -> _SmpSession: try: session_id = self._used_ids_ba.index(False) except ValueError: raise Error( "Can't create more MARS sessions, close some sessions and try again" ) session = _SmpSession(self, session_id) with self._lock: self._sessions[session_id] = session self._used_ids_ba[session_id] = True hdr = SMP_HEADER.pack( SMP_ID, PacketTypes.SYN, session_id, SMP_HEADER.size, 0, session.high_water_for_recv, ) self._transport.sendall(hdr) session._state = SessionState.SESSION_ESTABLISHED return session def close_all_sessions(self, keep): for sess in list(self._sessions.values()): if sess is not keep: self.close_smp_session(sess) def close_smp_session(self, session: _SmpSession) -> None: if session._state in (SessionState.CLOSED, SessionState.FIN_SENT): return elif session._state == SessionState.SESSION_ESTABLISHED: with self._lock: hdr = SMP_HEADER.pack( SMP_ID, PacketTypes.FIN, session.session_id, SMP_HEADER.size, session.seq_num_for_send, session.high_water_for_recv, ) session._state = SessionState.FIN_SENT try: self._transport.sendall(hdr) self.recv_packet(session) except (socket.error, OSError) as ex: if ex.errno in (errno.ECONNRESET, errno.EPIPE): session._state = SessionState.CLOSED else: raise ex def send_queued_packets(self, session: _SmpSession) -> None: with self._lock: while ( session.send_queue and session.seq_num_for_send < session.high_water_for_send ): data = session.send_queue.pop(0) self.send_packet(session, data) @staticmethod def _add_one_wrap(val: int) -> int: return 0 if val == 2**32 - 1 else val + 1 def send_packet(self, session: _SmpSession, data: bytes) -> None: with self._lock: if ( session._state == SessionState.CLOSED or session._state == SessionState.FIN_SENT ): raise Error("Stream closed") if session.seq_num_for_send < session.high_water_for_send: size = SMP_HEADER.size + len(data) seq_num = self._add_one_wrap(session.seq_num_for_send) hdr = SMP_HEADER.pack( SMP_ID, PacketTypes.DATA, session.session_id, size, seq_num, session.high_water_for_recv, ) session._last_high_water_for_recv = session.high_water_for_recv self._transport.sendall(hdr + data) session.seq_num_for_send = self._add_one_wrap(session.seq_num_for_send) else: session.send_queue.append(data) self._read_smp_message() def recv_packet(self, session: _SmpSession) -> bytes: with self._lock: if session._state == SessionState.CLOSED: return b"" while not session.recv_queue: self._read_smp_message() if session._state in (SessionState.CLOSED, SessionState.FIN_RECEIVED): return b"" session.high_water_for_recv = self._add_one_wrap( session.high_water_for_recv ) if session.high_water_for_recv - session._last_high_water_for_recv >= 2: hdr = SMP_HEADER.pack( SMP_ID, PacketTypes.ACK, session.session_id, SMP_HEADER.size, session.seq_num_for_send, session.high_water_for_recv, ) self._transport.sendall(hdr) session._last_high_water_for_recv = session.high_water_for_recv return session.recv_queue.pop(0) def _bad_stm(self, message: str) -> None: self.close() raise Error(message) def _read_smp_message(self) -> None: # caller should acquire lock before calling this function buf_pos = 0 while buf_pos < SMP_HEADER.size: read = self._transport.recv_into(self._hdr_buf[buf_pos:]) buf_pos += read if read == 0: self._bad_stm("Unexpected EOF while reading SMP header") smid, flags, sid, length, seq_num, wnd = SMP_HEADER.unpack(self._hdr_buf) if smid != SMP_ID: self._bad_stm("Invalid SMP packet signature") try: session = self._sessions[sid] except KeyError: self._bad_stm("Invalid SMP packet session id") if wnd < session.high_water_for_send: self._bad_stm("Invalid WNDW in packet from server") if seq_num > session.high_water_for_recv: self._bad_stm("Invalid SEQNUM in packet from server") if length < SMP_HEADER.size: self._bad_stm("Invalid LENGTH in packet from server") session._last_recv_seq_num = seq_num if flags == PacketTypes.DATA: if session._state == SessionState.SESSION_ESTABLISHED: if seq_num != self._add_one_wrap(session._seq_num_for_recv): self._bad_stm("Invalid SEQNUM in DATA packet from server") session._seq_num_for_recv = seq_num remains = length - SMP_HEADER.size while remains: data = self._transport.recv(remains) session.recv_queue.append(data) remains -= len(data) if wnd > session.high_water_for_send: session.high_water_for_send = wnd self.send_queued_packets(session) elif session._state == SessionState.FIN_SENT: skipall(self._transport, length - SMP_HEADER.size) else: self._bad_stm("Unexpected DATA packet from server") elif flags == PacketTypes.ACK: if session._state in (SessionState.FIN_RECEIVED, SessionState.CLOSED): self._bad_stm("Unexpected ACK packet from server") if seq_num != session._seq_num_for_recv: self._bad_stm("Invalid SEQNUM in ACK packet from server") session.high_water_for_send = wnd self.send_queued_packets(session) elif flags == PacketTypes.FIN: assert session._state in ( SessionState.SESSION_ESTABLISHED, SessionState.FIN_SENT, SessionState.FIN_RECEIVED, ) if session._state == SessionState.SESSION_ESTABLISHED: session._state = SessionState.FIN_RECEIVED elif session._state == SessionState.FIN_SENT: session._state = SessionState.CLOSED del self._sessions[session.session_id] self._used_ids_ba[session.session_id] = False elif session._state == SessionState.FIN_RECEIVED: self._bad_stm("Unexpected FIN packet from server") elif flags == PacketTypes.SYN: self._bad_stm("Unexpected SYN packet from server") else: self._bad_stm("Unexpected FLAGS in packet from server") def close(self) -> None: self._transport.close() def transport_closed(self) -> None: for session in self._sessions.values(): session._state = SessionState.CLOSED pytds-1.15.0/src/pytds/sspi.py000066400000000000000000000400121456567501500162360ustar00rootroot00000000000000""" This module implements wrapper for Windows SSPI API """ import logging from ctypes import ( # type: ignore # needs fixing c_ulong, c_ushort, c_void_p, c_ulonglong, POINTER, Structure, c_wchar_p, WINFUNCTYPE, windll, byref, cast, ) logger = logging.getLogger(__name__) class Status(object): SEC_E_OK = 0 SEC_I_CONTINUE_NEEDED = 0x00090312 SEC_I_COMPLETE_AND_CONTINUE = 0x00090314 SEC_I_INCOMPLETE_CREDENTIALS = 0x00090320 SEC_E_INSUFFICIENT_MEMORY = 0x80090300 - 0x100000000 SEC_E_INVALID_HANDLE = 0x80090301 - 0x100000000 SEC_E_UNSUPPORTED_FUNCTION = 0x80090302 - 0x100000000 SEC_E_INTERNAL_ERROR = 0x80090304 - 0x100000000 SEC_E_SECPKG_NOT_FOUND = 0x80090305 - 0x100000000 SEC_E_NOT_OWNER = 0x80090306 - 0x100000000 SEC_E_INVALID_TOKEN = 0x80090308 - 0x100000000 SEC_E_NO_IMPERSONATION = 0x8009030B - 0x100000000 SEC_E_LOGON_DENIED = 0x8009030C - 0x100000000 SEC_E_UNKNOWN_CREDENTIALS = 0x8009030D - 0x100000000 SEC_E_NO_CREDENTIALS = 0x8009030E - 0x100000000 SEC_E_OUT_OF_SEQUENCE = 0x80090310 - 0x100000000 SEC_E_NO_AUTHENTICATING_AUTHORITY = 0x80090311 - 0x100000000 SEC_E_BUFFER_TOO_SMALL = 0x80090321 - 0x100000000 SEC_E_WRONG_PRINCIPAL = 0x80090322 - 0x100000000 SEC_E_ALGORITHM_MISMATCH = 0x80090331 - 0x100000000 @classmethod def getname(cls, value): for name in dir(cls): if name.startswith("SEC_E_") and getattr(cls, name) == value: return name return "unknown value {0:x}".format(0x100000000 + value) # define SECBUFFER_EMPTY 0 // Undefined, replaced by provider # define SECBUFFER_DATA 1 // Packet data SECBUFFER_TOKEN = 2 # define SECBUFFER_PKG_PARAMS 3 // Package specific parameters # define SECBUFFER_MISSING 4 // Missing Data indicator # define SECBUFFER_EXTRA 5 // Extra data # define SECBUFFER_STREAM_TRAILER 6 // Security Trailer # define SECBUFFER_STREAM_HEADER 7 // Security Header # define SECBUFFER_NEGOTIATION_INFO 8 // Hints from the negotiation pkg # define SECBUFFER_PADDING 9 // non-data padding # define SECBUFFER_STREAM 10 // whole encrypted message # define SECBUFFER_MECHLIST 11 # define SECBUFFER_MECHLIST_SIGNATURE 12 # define SECBUFFER_TARGET 13 // obsolete # define SECBUFFER_CHANNEL_BINDINGS 14 # define SECBUFFER_CHANGE_PASS_RESPONSE 15 # define SECBUFFER_TARGET_HOST 16 # define SECBUFFER_ALERT 17 SECPKG_CRED_INBOUND = 0x00000001 SECPKG_CRED_OUTBOUND = 0x00000002 SECPKG_CRED_BOTH = 0x00000003 SECPKG_CRED_DEFAULT = 0x00000004 SECPKG_CRED_RESERVED = 0xF0000000 SECBUFFER_VERSION = 0 # define ISC_REQ_DELEGATE 0x00000001 # define ISC_REQ_MUTUAL_AUTH 0x00000002 ISC_REQ_REPLAY_DETECT = 4 # define ISC_REQ_SEQUENCE_DETECT 0x00000008 ISC_REQ_CONFIDENTIALITY = 0x10 ISC_REQ_USE_SESSION_KEY = 0x00000020 ISC_REQ_PROMPT_FOR_CREDS = 0x00000040 ISC_REQ_USE_SUPPLIED_CREDS = 0x00000080 ISC_REQ_ALLOCATE_MEMORY = 0x00000100 ISC_REQ_USE_DCE_STYLE = 0x00000200 ISC_REQ_DATAGRAM = 0x00000400 ISC_REQ_CONNECTION = 0x00000800 # define ISC_REQ_CALL_LEVEL 0x00001000 # define ISC_REQ_FRAGMENT_SUPPLIED 0x00002000 # define ISC_REQ_EXTENDED_ERROR 0x00004000 # define ISC_REQ_STREAM 0x00008000 # define ISC_REQ_INTEGRITY 0x00010000 # define ISC_REQ_IDENTIFY 0x00020000 # define ISC_REQ_NULL_SESSION 0x00040000 # define ISC_REQ_MANUAL_CRED_VALIDATION 0x00080000 # define ISC_REQ_RESERVED1 0x00100000 # define ISC_REQ_FRAGMENT_TO_FIT 0x00200000 # // This exists only in Windows Vista and greater # define ISC_REQ_FORWARD_CREDENTIALS 0x00400000 # define ISC_REQ_NO_INTEGRITY 0x00800000 // honored only by SPNEGO # define ISC_REQ_USE_HTTP_STYLE 0x01000000 # define ISC_REQ_UNVERIFIED_TARGET_NAME 0x20000000 # define ISC_REQ_CONFIDENTIALITY_ONLY 0x40000000 // honored by SPNEGO/Kerberos SECURITY_NETWORK_DREP = 0 SECURITY_NATIVE_DREP = 0x10 SECPKG_CRED_ATTR_NAMES = 1 ULONG = c_ulong USHORT = c_ushort PULONG = POINTER(ULONG) PVOID = c_void_p TimeStamp = c_ulonglong PTimeStamp = POINTER(c_ulonglong) PLUID = POINTER(c_ulonglong) class SecHandle(Structure): _fields_ = [ ("lower", c_void_p), ("upper", c_void_p), ] PSecHandle = POINTER(SecHandle) CredHandle = SecHandle PCredHandle = PSecHandle PCtxtHandle = PSecHandle class SecBuffer(Structure): _fields_ = [ ("cbBuffer", ULONG), ("BufferType", ULONG), ("pvBuffer", PVOID), ] PSecBuffer = POINTER(SecBuffer) class SecBufferDesc(Structure): _fields_ = [ ("ulVersion", ULONG), ("cBuffers", ULONG), ("pBuffers", PSecBuffer), ] PSecBufferDesc = POINTER(SecBufferDesc) class SEC_WINNT_AUTH_IDENTITY(Structure): _fields_ = [ ("User", c_wchar_p), ("UserLength", c_ulong), ("Domain", c_wchar_p), ("DomainLength", c_ulong), ("Password", c_wchar_p), ("PasswordLength", c_ulong), ("Flags", c_ulong), ] class SecPkgInfo(Structure): _fields_ = [ ("fCapabilities", ULONG), ("wVersion", USHORT), ("wRPCID", USHORT), ("cbMaxToken", ULONG), ("Name", c_wchar_p), ("Comment", c_wchar_p), ] PSecPkgInfo = POINTER(SecPkgInfo) class SecPkgCredentials_Names(Structure): _fields_ = [("UserName", c_wchar_p)] def ret_val(value): if value < 0: raise Exception("SSPI Error {0}".format(Status.getname(value))) return value ENUMERATE_SECURITY_PACKAGES_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing POINTER(c_ulong), POINTER(POINTER(SecPkgInfo)), ) ACQUIRE_CREDENTIALS_HANDLE_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing c_wchar_p, # principal c_wchar_p, # package ULONG, # fCredentialUse PLUID, # pvLogonID PVOID, # pAuthData PVOID, # pGetKeyFn PVOID, # pvGetKeyArgument PCredHandle, # phCredential PTimeStamp, # ptsExpiry ) FREE_CREDENTIALS_HANDLE_FN = WINFUNCTYPE(ret_val, POINTER(SecHandle)) # type: ignore # needs fixing INITIALIZE_SECURITY_CONTEXT_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing PCredHandle, PCtxtHandle, # phContext, c_wchar_p, # pszTargetName, ULONG, # fContextReq, ULONG, # Reserved1, ULONG, # TargetDataRep, PSecBufferDesc, # pInput, ULONG, # Reserved2, PCtxtHandle, # phNewContext, PSecBufferDesc, # pOutput, PULONG, # pfContextAttr, PTimeStamp, # ptsExpiry ) COMPLETE_AUTH_TOKEN_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing PCtxtHandle, # phContext PSecBufferDesc, # pToken ) FREE_CONTEXT_BUFFER_FN = WINFUNCTYPE(ret_val, PVOID) # type: ignore # needs fixing QUERY_CREDENTIAL_ATTRIBUTES_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing PCredHandle, # cred ULONG, # attribute PVOID, # out buffer ) ACCEPT_SECURITY_CONTEXT_FN = PVOID DELETE_SECURITY_CONTEXT_FN = WINFUNCTYPE(ret_val, PCtxtHandle) # type: ignore # needs fixing APPLY_CONTROL_TOKEN_FN = PVOID QUERY_CONTEXT_ATTRIBUTES_FN = PVOID IMPERSONATE_SECURITY_CONTEXT_FN = PVOID REVERT_SECURITY_CONTEXT_FN = PVOID MAKE_SIGNATURE_FN = PVOID VERIFY_SIGNATURE_FN = PVOID QUERY_SECURITY_PACKAGE_INFO_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing c_wchar_p, # package name POINTER(PSecPkgInfo), ) EXPORT_SECURITY_CONTEXT_FN = PVOID IMPORT_SECURITY_CONTEXT_FN = PVOID ADD_CREDENTIALS_FN = PVOID QUERY_SECURITY_CONTEXT_TOKEN_FN = PVOID ENCRYPT_MESSAGE_FN = PVOID DECRYPT_MESSAGE_FN = PVOID SET_CONTEXT_ATTRIBUTES_FN = PVOID class SECURITY_FUNCTION_TABLE(Structure): _fields_ = [ ("dwVersion", c_ulong), ("EnumerateSecurityPackages", ENUMERATE_SECURITY_PACKAGES_FN), ("QueryCredentialsAttributes", QUERY_CREDENTIAL_ATTRIBUTES_FN), ("AcquireCredentialsHandle", ACQUIRE_CREDENTIALS_HANDLE_FN), ("FreeCredentialsHandle", FREE_CREDENTIALS_HANDLE_FN), ("Reserved2", c_void_p), ("InitializeSecurityContext", INITIALIZE_SECURITY_CONTEXT_FN), ("AcceptSecurityContext", ACCEPT_SECURITY_CONTEXT_FN), ("CompleteAuthToken", COMPLETE_AUTH_TOKEN_FN), ("DeleteSecurityContext", DELETE_SECURITY_CONTEXT_FN), ("ApplyControlToken", APPLY_CONTROL_TOKEN_FN), ("QueryContextAttributes", QUERY_CONTEXT_ATTRIBUTES_FN), ("ImpersonateSecurityContext", IMPERSONATE_SECURITY_CONTEXT_FN), ("RevertSecurityContext", REVERT_SECURITY_CONTEXT_FN), ("MakeSignature", MAKE_SIGNATURE_FN), ("VerifySignature", VERIFY_SIGNATURE_FN), ("FreeContextBuffer", FREE_CONTEXT_BUFFER_FN), ("QuerySecurityPackageInfo", QUERY_SECURITY_PACKAGE_INFO_FN), ("Reserved3", c_void_p), ("Reserved4", c_void_p), ("ExportSecurityContext", EXPORT_SECURITY_CONTEXT_FN), ("ImportSecurityContext", IMPORT_SECURITY_CONTEXT_FN), ("AddCredentials", ADD_CREDENTIALS_FN), ("Reserved8", c_void_p), ("QuerySecurityContextToken", QUERY_SECURITY_CONTEXT_TOKEN_FN), ("EncryptMessage", ENCRYPT_MESSAGE_FN), ("DecryptMessage", DECRYPT_MESSAGE_FN), ("SetContextAttributes", SET_CONTEXT_ATTRIBUTES_FN), ] _PInitSecurityInterface = WINFUNCTYPE(POINTER(SECURITY_FUNCTION_TABLE)) InitSecurityInterface = _PInitSecurityInterface( ("InitSecurityInterfaceW", windll.secur32) ) sec_fn = InitSecurityInterface() if not sec_fn: raise Exception("InitSecurityInterface failed") sec_fn = sec_fn.contents class _SecContext(object): def __init__(self, cred: "SspiCredentials") -> None: self._cred = cred self._handle = SecHandle() self._ts = TimeStamp() self._attrs = ULONG() def close(self) -> None: if self._handle.lower and self._handle.upper: sec_fn.DeleteSecurityContext(self._handle) self._handle.lower = self._handle.upper = 0 def __del__(self) -> None: self.close() def complete_auth_token(self, bufs): sec_fn.CompleteAuthToken(byref(self._handle), byref(_make_buffers_desc(bufs))) def next( self, flags, target_name=None, byte_ordering="network", input_buffers=None, output_buffers=None, ): input_buffers_desc = ( _make_buffers_desc(input_buffers) if input_buffers else None ) output_buffers_desc = ( _make_buffers_desc(output_buffers) if output_buffers else None ) status = sec_fn.InitializeSecurityContext( byref(self._cred._handle), byref(self._handle), target_name, flags, 0, SECURITY_NETWORK_DREP if byte_ordering == "network" else SECURITY_NATIVE_DREP, byref(input_buffers_desc) if input_buffers_desc else None, 0, byref(self._handle), byref(output_buffers_desc) if input_buffers_desc else None, byref(self._attrs), byref(self._ts), ) result_buffers = [] for i, (type, buf) in enumerate(output_buffers): buf = buf[: output_buffers_desc.pBuffers[i].cbBuffer] result_buffers.append((type, buf)) return status, result_buffers class SspiCredentials(object): def __init__(self, package, use, identity=None): self._handle = SecHandle() self._ts = TimeStamp() logger.debug("Acquiring credentials handle") sec_fn.AcquireCredentialsHandle( None, package, use, None, byref(identity) if identity and identity.Domain else None, None, None, byref(self._handle), byref(self._ts), ) def close(self): if self._handle and (self._handle.lower or self._handle.upper): logger.debug("Releasing credentials handle") sec_fn.FreeCredentialsHandle(byref(self._handle)) self._handle = None def __del__(self): self.close() def query_user_name(self): names = SecPkgCredentials_Names() try: sec_fn.QueryCredentialsAttributes( byref(self._handle), SECPKG_CRED_ATTR_NAMES, byref(names) ) user_name = str(names.UserName) finally: p = c_wchar_p.from_buffer(names, SecPkgCredentials_Names.UserName.offset) sec_fn.FreeContextBuffer(p) return user_name def create_context( self, flags: int, target_name=None, byte_ordering="network", input_buffers=None, output_buffers=None, ): if self._handle is None: raise RuntimeError("Using closed SspiCredentials object") ctx = _SecContext(cred=self) input_buffers_desc = ( _make_buffers_desc(input_buffers) if input_buffers else None ) output_buffers_desc = ( _make_buffers_desc(output_buffers) if output_buffers else None ) logger.debug("Initializing security context") status = sec_fn.InitializeSecurityContext( byref(self._handle), None, target_name, flags, 0, SECURITY_NETWORK_DREP if byte_ordering == "network" else SECURITY_NATIVE_DREP, byref(input_buffers_desc) if input_buffers_desc else None, 0, byref(ctx._handle), byref(output_buffers_desc) if output_buffers_desc else None, byref(ctx._attrs), byref(ctx._ts), ) result_buffers = [] for i, (type, buf) in enumerate(output_buffers): buf = buf[: output_buffers_desc.pBuffers[i].cbBuffer] result_buffers.append((type, buf)) return ctx, status, result_buffers def _make_buffers_desc(buffers): desc = SecBufferDesc() desc.ulVersion = SECBUFFER_VERSION bufs_array = (SecBuffer * len(buffers))() for i, (type, buf) in enumerate(buffers): bufs_array[i].BufferType = type bufs_array[i].cbBuffer = len(buf) bufs_array[i].pvBuffer = cast(buf, PVOID) desc.pBuffers = bufs_array desc.cBuffers = len(buffers) return desc def make_winnt_identity(domain, user_name, password): identity = SEC_WINNT_AUTH_IDENTITY() identity.Flags = 2 # SEC_WINNT_AUTH_IDENTITY_UNICODE identity.Password = password identity.PasswordLength = len(password) identity.Domain = domain identity.DomainLength = len(domain) identity.User = user_name identity.UserLength = len(user_name) return identity # class SspiSecBuffer(object): # def __init__(self, type, buflen=4096): # self._buf = create_string_buffer(int(buflen)) # self._desc = SecBuffer() # self._desc.cbBuffer = buflen # self._desc.BufferType = type # self._desc.pvBuffer = cast(self._buf, PVOID) # # class SspiSecBuffers(object): # def __init__(self): # self._desc = SecBufferDesc() # self._desc.ulVersion = SECBUFFER_VERSION # self._descrs = (SecBuffer * 8)() # self._desc.pBuffers = self._descrs # # def append(self, buf): # if len(self._descrs) <= self._desc.cBuffers: # newdescrs = (SecBuffer * (len(self._descrs) * 2))(*self._descrs) # self._descrs = newdescrs # self._desc.pBuffers = newdescrs # self._descrs[self._desc.cBuffers] = buf._desc # self._desc.cBuffers += 1 def enum_security_packages(): num = ULONG() infos = POINTER(SecPkgInfo)() sec_fn.EnumerateSecurityPackages(byref(num), byref(infos)) try: return [ { "caps": infos[i].fCapabilities, "version": infos[i].wVersion, "rpcid": infos[i].wRPCID, "max_token": infos[i].cbMaxToken, "name": infos[i].Name, "comment": infos[i].Comment, } for i in range(num.value) ] finally: sec_fn.FreeContextBuffer(infos) pytds-1.15.0/src/pytds/tds.py000066400000000000000000000004551456567501500160610ustar00rootroot00000000000000""" This module provides backward compatibility """ # _token_map is needed by sqlalchemy_pytds connector from .tds_session import ( _token_map, # noqa: F401 # _token_map is needed by sqlalchemy_pytds connector ) from . import tds_base # noqa: F401 # this is needed by sqlalchemy_pytds connector pytds-1.15.0/src/pytds/tds_base.py000066400000000000000000000627001456567501500170540ustar00rootroot00000000000000""" .. module:: tds_base :platform: Unix, Windows, MacOSX :synopsis: Various internal stuff .. moduleauthor:: Mikhail Denisenko """ from __future__ import annotations import datetime import logging import socket import struct import typing from collections import deque from typing import Protocol, Iterable, TypedDict, Tuple, Any import pytds from pytds.collate import ucs2_codec logger = logging.getLogger("pytds") # tds protocol versions TDS70 = 0x70000000 TDS71 = 0x71000000 TDS71rev1 = 0x71000001 TDS72 = 0x72090002 TDS73A = 0x730A0003 TDS73 = TDS73A TDS73B = 0x730B0003 TDS74 = 0x74000004 if typing.TYPE_CHECKING: from pytds.tds_session import _TdsSession def IS_TDS7_PLUS(x: _TdsSession): return x.tds_version >= TDS70 def IS_TDS71_PLUS(x: _TdsSession): return x.tds_version >= TDS71 def IS_TDS72_PLUS(x: _TdsSession): return x.tds_version >= TDS72 def IS_TDS73_PLUS(x: _TdsSession): return x.tds_version >= TDS73A # https://msdn.microsoft.com/en-us/library/dd304214.aspx class PacketType: QUERY = 1 OLDLOGIN = 2 RPC = 3 REPLY = 4 CANCEL = 6 BULK = 7 FEDAUTHTOKEN = 8 TRANS = 14 # transaction management LOGIN = 16 AUTH = 17 PRELOGIN = 18 # mssql login options flags # option_flag1_values TDS_BYTE_ORDER_X86 = 0 TDS_CHARSET_ASCII = 0 TDS_DUMPLOAD_ON = 0 TDS_FLOAT_IEEE_754 = 0 TDS_INIT_DB_WARN = 0 TDS_SET_LANG_OFF = 0 TDS_USE_DB_SILENT = 0 TDS_BYTE_ORDER_68000 = 0x01 TDS_CHARSET_EBDDIC = 0x02 TDS_FLOAT_VAX = 0x04 TDS_FLOAT_ND5000 = 0x08 TDS_DUMPLOAD_OFF = 0x10 # prevent BCP TDS_USE_DB_NOTIFY = 0x20 TDS_INIT_DB_FATAL = 0x40 TDS_SET_LANG_ON = 0x80 # enum option_flag2_values TDS_INIT_LANG_WARN = 0 TDS_INTEGRATED_SECURTY_OFF = 0 TDS_ODBC_OFF = 0 TDS_USER_NORMAL = 0 # SQL Server login TDS_INIT_LANG_REQUIRED = 0x01 TDS_ODBC_ON = 0x02 TDS_TRANSACTION_BOUNDARY71 = 0x04 # removed in TDS 7.2 TDS_CACHE_CONNECT71 = 0x08 # removed in TDS 7.2 TDS_USER_SERVER = 0x10 # reserved TDS_USER_REMUSER = 0x20 # DQ login TDS_USER_SQLREPL = 0x40 # replication login TDS_INTEGRATED_SECURITY_ON = 0x80 # enum option_flag3_values TDS 7.3+ TDS_RESTRICTED_COLLATION = 0 TDS_CHANGE_PASSWORD = 0x01 TDS_SEND_YUKON_BINARY_XML = 0x02 TDS_REQUEST_USER_INSTANCE = 0x04 TDS_UNKNOWN_COLLATION_HANDLING = 0x08 TDS_ANY_COLLATION = 0x10 TDS5_PARAMFMT2_TOKEN = 32 # 0x20 TDS_LANGUAGE_TOKEN = 33 # 0x20 TDS 5.0 only TDS_ORDERBY2_TOKEN = 34 # 0x22 TDS_ROWFMT2_TOKEN = 97 # 0x61 TDS 5.0 only TDS_LOGOUT_TOKEN = 113 # 0x71 TDS 5.0 only? TDS_RETURNSTATUS_TOKEN = 121 # 0x79 TDS_PROCID_TOKEN = 124 # 0x7C TDS 4.2 only TDS7_RESULT_TOKEN = 129 # 0x81 TDS 7.0 only TDS7_COMPUTE_RESULT_TOKEN = 136 # 0x88 TDS 7.0 only TDS_COLNAME_TOKEN = 160 # 0xA0 TDS 4.2 only TDS_COLFMT_TOKEN = 161 # 0xA1 TDS 4.2 only TDS_DYNAMIC2_TOKEN = 163 # 0xA3 TDS_TABNAME_TOKEN = 164 # 0xA4 TDS_COLINFO_TOKEN = 165 # 0xA5 TDS_OPTIONCMD_TOKEN = 166 # 0xA6 TDS_COMPUTE_NAMES_TOKEN = 167 # 0xA7 TDS_COMPUTE_RESULT_TOKEN = 168 # 0xA8 TDS_ORDERBY_TOKEN = 169 # 0xA9 TDS_ERROR_TOKEN = 170 # 0xAA TDS_INFO_TOKEN = 171 # 0xAB TDS_PARAM_TOKEN = 172 # 0xAC TDS_LOGINACK_TOKEN = 173 # 0xAD TDS_CONTROL_TOKEN = 174 # 0xAE TDS_ROW_TOKEN = 209 # 0xD1 TDS_NBC_ROW_TOKEN = 210 # 0xD2 as of TDS 7.3.B TDS_CMP_ROW_TOKEN = 211 # 0xD3 TDS5_PARAMS_TOKEN = 215 # 0xD7 TDS 5.0 only TDS_CAPABILITY_TOKEN = 226 # 0xE2 TDS_ENVCHANGE_TOKEN = 227 # 0xE3 TDS_DBRPC_TOKEN = 230 # 0xE6 TDS5_DYNAMIC_TOKEN = 231 # 0xE7 TDS 5.0 only TDS5_PARAMFMT_TOKEN = 236 # 0xEC TDS 5.0 only TDS_AUTH_TOKEN = 237 # 0xED TDS 7.0 only TDS_RESULT_TOKEN = 238 # 0xEE TDS_DONE_TOKEN = 253 # 0xFD TDS_DONEPROC_TOKEN = 254 # 0xFE TDS_DONEINPROC_TOKEN = 255 # 0xFF # CURSOR support: TDS 5.0 only TDS_CURCLOSE_TOKEN = 128 # 0x80 TDS 5.0 only TDS_CURDELETE_TOKEN = 129 # 0x81 TDS 5.0 only TDS_CURFETCH_TOKEN = 130 # 0x82 TDS 5.0 only TDS_CURINFO_TOKEN = 131 # 0x83 TDS 5.0 only TDS_CUROPEN_TOKEN = 132 # 0x84 TDS 5.0 only TDS_CURDECLARE_TOKEN = 134 # 0x86 TDS 5.0 only # environment type field TDS_ENV_DATABASE = 1 TDS_ENV_LANG = 2 TDS_ENV_CHARSET = 3 TDS_ENV_PACKSIZE = 4 TDS_ENV_LCID = 5 TDS_ENV_UNICODE_DATA_SORT_COMP_FLAGS = 6 TDS_ENV_SQLCOLLATION = 7 TDS_ENV_BEGINTRANS = 8 TDS_ENV_COMMITTRANS = 9 TDS_ENV_ROLLBACKTRANS = 10 TDS_ENV_ENLIST_DTC_TRANS = 11 TDS_ENV_DEFECT_TRANS = 12 TDS_ENV_DB_MIRRORING_PARTNER = 13 TDS_ENV_PROMOTE_TRANS = 15 TDS_ENV_TRANS_MANAGER_ADDR = 16 TDS_ENV_TRANS_ENDED = 17 TDS_ENV_RESET_COMPLETION_ACK = 18 TDS_ENV_INSTANCE_INFO = 19 TDS_ENV_ROUTING = 20 # Microsoft internal stored procedure id's TDS_SP_CURSOR = 1 TDS_SP_CURSOROPEN = 2 TDS_SP_CURSORPREPARE = 3 TDS_SP_CURSOREXECUTE = 4 TDS_SP_CURSORPREPEXEC = 5 TDS_SP_CURSORUNPREPARE = 6 TDS_SP_CURSORFETCH = 7 TDS_SP_CURSOROPTION = 8 TDS_SP_CURSORCLOSE = 9 TDS_SP_EXECUTESQL = 10 TDS_SP_PREPARE = 11 TDS_SP_EXECUTE = 12 TDS_SP_PREPEXEC = 13 TDS_SP_PREPEXECRPC = 14 TDS_SP_UNPREPARE = 15 # Flags returned in TDS_DONE token TDS_DONE_FINAL = 0 TDS_DONE_MORE_RESULTS = 0x01 # more results follow TDS_DONE_ERROR = 0x02 # error occurred TDS_DONE_INXACT = 0x04 # transaction in progress TDS_DONE_PROC = 0x08 # results are from a stored procedure TDS_DONE_COUNT = 0x10 # count field in packet is valid TDS_DONE_CANCELLED = 0x20 # acknowledging an attention command (usually a cancel) TDS_DONE_EVENT = 0x40 # part of an event notification. TDS_DONE_SRVERROR = 0x100 # SQL server server error SYBVOID = 31 # 0x1F IMAGETYPE = SYBIMAGE = 34 # 0x22 TEXTTYPE = SYBTEXT = 35 # 0x23 SYBVARBINARY = 37 # 0x25 INTNTYPE = SYBINTN = 38 # 0x26 SYBVARCHAR = 39 # 0x27 BINARYTYPE = SYBBINARY = 45 # 0x2D SYBCHAR = 47 # 0x2F INT1TYPE = SYBINT1 = 48 # 0x30 BITTYPE = SYBBIT = 50 # 0x32 INT2TYPE = SYBINT2 = 52 # 0x34 INT4TYPE = SYBINT4 = 56 # 0x38 DATETIM4TYPE = SYBDATETIME4 = 58 # 0x3A FLT4TYPE = SYBREAL = 59 # 0x3B MONEYTYPE = SYBMONEY = 60 # 0x3C DATETIMETYPE = SYBDATETIME = 61 # 0x3D FLT8TYPE = SYBFLT8 = 62 # 0x3E NTEXTTYPE = SYBNTEXT = 99 # 0x63 SYBNVARCHAR = 103 # 0x67 BITNTYPE = SYBBITN = 104 # 0x68 NUMERICNTYPE = SYBNUMERIC = 108 # 0x6C DECIMALNTYPE = SYBDECIMAL = 106 # 0x6A FLTNTYPE = SYBFLTN = 109 # 0x6D MONEYNTYPE = SYBMONEYN = 110 # 0x6E DATETIMNTYPE = SYBDATETIMN = 111 # 0x6F MONEY4TYPE = SYBMONEY4 = 122 # 0x7A INT8TYPE = SYBINT8 = 127 # 0x7F BIGCHARTYPE = XSYBCHAR = 175 # 0xAF BIGVARCHRTYPE = XSYBVARCHAR = 167 # 0xA7 NVARCHARTYPE = XSYBNVARCHAR = 231 # 0xE7 NCHARTYPE = XSYBNCHAR = 239 # 0xEF BIGVARBINTYPE = XSYBVARBINARY = 165 # 0xA5 BIGBINARYTYPE = XSYBBINARY = 173 # 0xAD GUIDTYPE = SYBUNIQUE = 36 # 0x24 SSVARIANTTYPE = SYBVARIANT = 98 # 0x62 UDTTYPE = SYBMSUDT = 240 # 0xF0 XMLTYPE = SYBMSXML = 241 # 0xF1 TVPTYPE = 243 # 0xF3 DATENTYPE = SYBMSDATE = 40 # 0x28 TIMENTYPE = SYBMSTIME = 41 # 0x29 DATETIME2NTYPE = SYBMSDATETIME2 = 42 # 0x2a DATETIMEOFFSETNTYPE = SYBMSDATETIMEOFFSET = 43 # 0x2b # TDS type flag TDS_FSQLTYPE_SQL_DFLT = 0x00 TDS_FSQLTYPE_SQL_TSQL = 0x01 TDS_FOLEDB = 0x10 TDS_FREADONLY_INTENT = 0x20 # # Sybase only types # SYBLONGBINARY = 225 # 0xE1 SYBUINT1 = 64 # 0x40 SYBUINT2 = 65 # 0x41 SYBUINT4 = 66 # 0x42 SYBUINT8 = 67 # 0x43 SYBBLOB = 36 # 0x24 SYBBOUNDARY = 104 # 0x68 SYBDATE = 49 # 0x31 SYBDATEN = 123 # 0x7B SYB5INT8 = 191 # 0xBF SYBINTERVAL = 46 # 0x2E SYBLONGCHAR = 175 # 0xAF SYBSENSITIVITY = 103 # 0x67 SYBSINT1 = 176 # 0xB0 SYBTIME = 51 # 0x33 SYBTIMEN = 147 # 0x93 SYBUINTN = 68 # 0x44 SYBUNITEXT = 174 # 0xAE SYBXML = 163 # 0xA3 TDS_UT_TIMESTAMP = 80 # compute operator SYBAOPCNT = 0x4B SYBAOPCNTU = 0x4C SYBAOPSUM = 0x4D SYBAOPSUMU = 0x4E SYBAOPAVG = 0x4F SYBAOPAVGU = 0x50 SYBAOPMIN = 0x51 SYBAOPMAX = 0x52 # mssql2k compute operator SYBAOPCNT_BIG = 0x09 SYBAOPSTDEV = 0x30 SYBAOPSTDEVP = 0x31 SYBAOPVAR = 0x32 SYBAOPVARP = 0x33 SYBAOPCHECKSUM_AGG = 0x72 # param flags fByRefValue = 1 fDefaultValue = 2 TDS_IDLE = 0 TDS_QUERYING = 1 TDS_PENDING = 2 TDS_READING = 3 TDS_DEAD = 4 state_names = ["IDLE", "QUERYING", "PENDING", "READING", "DEAD"] TDS_ENCRYPTION_OFF = 0 TDS_ENCRYPTION_REQUEST = 1 TDS_ENCRYPTION_REQUIRE = 2 class PreLoginToken: """ PRELOGIN token option identifiers, corresponds to PL_OPTION_TOKEN in the spec. Spec link: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868 """ VERSION = 0 ENCRYPTION = 1 INSTOPT = 2 THREADID = 3 MARS = 4 TRACEID = 5 FEDAUTHREQUIRED = 6 NONCEOPT = 7 TERMINATOR = 0xFF class PreLoginEnc: """ PRELOGIN encryption parameter. Spec link: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868 """ ENCRYPT_OFF = 0 # Encryption available but off ENCRYPT_ON = 1 # Encryption available and on ENCRYPT_NOT_SUP = 2 # Encryption not available ENCRYPT_REQ = 3 # Encryption required PLP_MARKER = 0xFFFF PLP_NULL = 0xFFFFFFFFFFFFFFFF PLP_UNKNOWN = 0xFFFFFFFFFFFFFFFE TDS_NO_COUNT = -1 TVP_NULL_TOKEN = 0xFFFF # TVP COLUMN FLAGS TVP_COLUMN_DEFAULT_FLAG = 0x200 TVP_END_TOKEN = 0x00 TVP_ROW_TOKEN = 0x01 TVP_ORDER_UNIQUE_TOKEN = 0x10 TVP_COLUMN_ORDERING_TOKEN = 0x11 class CommonEqualityMixin(object): def __eq__(self, other): return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) def iterdecode(iterable, codec): """Uses an incremental decoder to decode each chunk of string in iterable. This function is a generator. :param iterable: Iterable object which yields raw data to be decoded. :param codec: An instance of a codec which will be used for decoding. """ decoder = codec.incrementaldecoder() for chunk in iterable: yield decoder.decode(chunk) yield decoder.decode(b"", True) def force_unicode(s): """ Convert input into a string. If input is a byte array, it will be decoded using UTF8 decoder. """ if isinstance(s, bytes): try: return s.decode("utf8") except UnicodeDecodeError as e: raise DatabaseError(e) elif isinstance(s, str): return s else: return str(s) def tds_quote_id(ident): """Quote an identifier according to MSSQL rules :param ident: identifier to quote :returns: Quoted identifier """ return "[{0}]".format(ident.replace("]", "]]")) # store a tuple of programming error codes prog_errors = ( 102, # syntax error 207, # invalid column name 208, # invalid object name 2812, # unknown procedure 4104, # multi-part identifier could not be bound ) # store a tuple of integrity error codes integrity_errors = ( 515, # NULL insert 547, # FK related 2601, # violate unique index 2627, # violate UNIQUE KEY constraint ) def my_ord(val): return val def join_bytearrays(ba): return b"".join(ba) # exception hierarchy class Warning(Exception): pass class Error(Exception): """ Base class for all error classes, except TimeoutError """ pass TimeoutError = socket.timeout class InterfaceError(Error): """ TODO add documentation """ pass class DatabaseError(Error): """ This error is raised when MSSQL server returns an error which includes error number """ def __init__(self, msg: str, exc: typing.Any | None = None): super().__init__(msg, exc) self.msg_no = 0 self.text = msg self.srvname = "" self.procname = "" self.number = 0 self.severity = 0 self.state = 0 self.line = 0 @property def message(self): if self.procname: return ( "SQL Server message %d, severity %d, state %d, " "procedure %s, line %d:\n%s" % ( self.number, self.severity, self.state, self.procname, self.line, self.text, ) ) else: return "SQL Server message %d, severity %d, state %d, " "line %d:\n%s" % ( self.number, self.severity, self.state, self.line, self.text, ) class ClosedConnectionError(InterfaceError): """ This error is raised when MSSQL server closes connection. """ def __init__(self): super(ClosedConnectionError, self).__init__("Server closed connection") class DataError(Error): """ This error is raised when input parameter contains data which cannot be converted to acceptable data type. """ pass class OperationalError(DatabaseError): """ TODO add documentation """ pass class LoginError(OperationalError): """ This error is raised if provided login credentials are invalid """ pass class IntegrityError(DatabaseError): """ TODO add documentation """ pass class InternalError(DatabaseError): """ TODO add documentation """ pass class ProgrammingError(DatabaseError): """ TODO add documentation """ pass class NotSupportedError(DatabaseError): """ TODO add documentation """ pass # DB-API type definitions class DBAPITypeObject: """ TODO add documentation """ def __init__(self, *values): self.values = set(values) def __eq__(self, other): return other in self.values def __cmp__(self, other): if other in self.values: return 0 if other < self.values: return 1 else: return -1 # standard dbapi type objects STRING = DBAPITypeObject( SYBVARCHAR, SYBCHAR, SYBTEXT, XSYBNVARCHAR, XSYBNCHAR, SYBNTEXT, XSYBVARCHAR, XSYBCHAR, SYBMSXML, ) BINARY = DBAPITypeObject(SYBIMAGE, SYBBINARY, SYBVARBINARY, XSYBVARBINARY, XSYBBINARY) NUMBER = DBAPITypeObject( SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN, SYBREAL, SYBFLT8, SYBFLTN, ) DATETIME = DBAPITypeObject(SYBDATETIME, SYBDATETIME4, SYBDATETIMN) DECIMAL = DBAPITypeObject(SYBMONEY, SYBMONEY4, SYBMONEYN, SYBNUMERIC, SYBDECIMAL) ROWID = DBAPITypeObject() # non-standard, but useful type objects INTEGER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN) REAL = DBAPITypeObject(SYBREAL, SYBFLT8, SYBFLTN) XML = DBAPITypeObject(SYBMSXML) class InternalProc(object): """ TODO add documentation """ def __init__(self, proc_id, name): self.proc_id = proc_id self.name = name def __unicode__(self): return self.name SP_EXECUTESQL = InternalProc(TDS_SP_EXECUTESQL, "sp_executesql") SP_PREPARE = InternalProc(TDS_SP_PREPARE, "sp_prepare") SP_EXECUTE = InternalProc(TDS_SP_EXECUTE, "sp_execute") def skipall(stm, size): """Skips exactly size bytes in stm If EOF is reached before size bytes are skipped will raise :class:`ClosedConnectionError` :param stm: Stream to skip bytes in, should have read method this read method can return less than requested number of bytes. :param size: Number of bytes to skip. """ res = stm.recv(size) if len(res) == size: return elif len(res) == 0: raise ClosedConnectionError() left = size - len(res) while left: buf = stm.recv(left) if len(buf) == 0: raise ClosedConnectionError() left -= len(buf) def read_chunks(stm, size): """Reads exactly size bytes from stm and produces chunks May call stm.read multiple times until required number of bytes is read. If EOF is reached before size bytes are read will raise :class:`ClosedConnectionError` :param stm: Stream to read bytes from, should have read method, this read method can return less than requested number of bytes. :param size: Number of bytes to read. """ if size == 0: yield b"" return res = stm.recv(size) if len(res) == 0: raise ClosedConnectionError() yield res left = size - len(res) while left: buf = stm.recv(left) if len(buf) == 0: raise ClosedConnectionError() yield buf left -= len(buf) def readall(stm, size): """Reads exactly size bytes from stm May call stm.read multiple times until required number of bytes read. If EOF is reached before size bytes are read will raise :class:`ClosedConnectionError` :param stm: Stream to read bytes from, should have read method this read method can return less than requested number of bytes. :param size: Number of bytes to read. :returns: Bytes buffer of exactly given size. """ return join_bytearrays(read_chunks(stm, size)) def readall_fast(stm, size): """ Slightly faster version of readall, it reads no more than two chunks. Meaning that it can only be used to read small data that doesn't span more that two packets. :param stm: Stream to read from, should have read method. :param size: Number of bytes to read. :return: """ buf, offset = stm.read_fast(size) if len(buf) - offset < size: # slow case buf = buf[offset:] buf += stm.recv(size - len(buf)) return buf, 0 return buf, offset def total_seconds(td): """Total number of seconds in timedelta object Python 2.6 doesn't have total_seconds method, this function provides a backport """ return td.days * 24 * 60 * 60 + td.seconds class Param: """ Describes typed parameter. Can be used to explicitly specify type of the parameter in the parametrized query. :param name: Optional name of the parameter :type name: str :param type: Type of the parameter, e.g. :class:`pytds.tds_types.IntType` """ def __init__(self, name: str = "", type=None, value=None, flags: int = 0): self.name = name self.type = type self.value = value self.flags = flags class Column(CommonEqualityMixin): """ Describes table column. Can be used to define schema for bulk insert. Following flags can be used for columns in `flags` parameter: * :const:`.fNullable` - column can contain `NULL` values * :const:`.fCaseSen` - column is case-sensitive * :const:`.fReadWrite` - TODO document * :const:`.fIdentity` - TODO document * :const:`.fComputed` - TODO document :param name: Name of the column :type name: str :param type: Type of a column, e.g. :class:`pytds.tds_types.IntType` :param flags: Combination of flags for the column, multiple flags can be combined using binary or operator. Possible flags are described above. """ fNullable = 1 fCaseSen = 2 fReadWrite = 8 fIdentity = 0x10 fComputed = 0x20 def __init__(self, name="", type=None, flags=fNullable, value=None): self.char_codec = None self.column_name = name self.column_usertype = 0 self.flags = flags self.type = type self.value = value self.serializer = None def __repr__(self): val = self.value if isinstance(val, bytes) and len(self.value) > 100: val = self.value[:100] + b"... len is " + str(len(val)).encode("ascii") if isinstance(val, str) and len(self.value) > 100: val = self.value[:100] + "... len is " + str(len(val)) return ( "".format( repr(self.column_name), repr(self.type), repr(val), repr(self.flags), repr(self.column_usertype), repr(self.char_codec), ) ) def choose_serializer(self, type_factory, collation): """ Chooses appropriate data type serializer for column's data type. """ return type_factory.serializer_by_type(sql_type=self.type, collation=collation) class TransportProtocol(Protocol): """ This protocol mimics socket protocol """ # def is_connected(self) -> bool: # ... def close(self) -> None: ... def gettimeout(self) -> float | None: ... def settimeout(self, timeout: float | None) -> None: ... def sendall(self, buf: bytes, flags: int = 0) -> None: ... def recv(self, size: int) -> bytes: ... def recv_into( self, buf: bytearray | memoryview, size: int = 0, flags: int = 0 ) -> int: ... class LoadBalancer(Protocol): def choose(self) -> Iterable[str]: ... class AuthProtocol(Protocol): def create_packet(self) -> bytes: ... def handle_next(self, packet: bytes) -> bytes | None: ... def close(self) -> None: ... # packet header # https://msdn.microsoft.com/en-us/library/dd340948.aspx _header = struct.Struct(">BBHHBx") _byte = struct.Struct("B") _smallint_le = struct.Struct("h") _usmallint_le = struct.Struct("H") _int_le = struct.Struct("l") _uint_le = struct.Struct("L") _int8_le = struct.Struct("q") _uint8_le = struct.Struct("Q") logging_enabled = False # stored procedure output parameter class output: @property def type(self): """ This is either the sql type declaration or python type instance of the parameter. """ return self._type @property def value(self): """ This is the value of the parameter. """ return self._value def __init__(self, value: Any = None, param_type=None): """Creates procedure output parameter. :param param_type: either sql type declaration or python type :param value: value to pass into procedure """ if param_type is None: if value is None or value is default: raise ValueError("Output type cannot be autodetected") elif isinstance(param_type, type) and value is not None: if value is not default and not isinstance(value, param_type): raise ValueError( "value should match param_type, value is {}, param_type is '{}'".format( repr(value), param_type.__name__ ) ) self._type = param_type self._value = value class _Default: pass default = _Default() def tds7_crypt_pass(password: str) -> bytearray: """Mangle password according to tds rules :param password: Password str :returns: Byte-string with encoded password """ encoded = bytearray(ucs2_codec.encode(password)[0]) for i, ch in enumerate(encoded): encoded[i] = ((ch << 4) & 0xFF | (ch >> 4)) ^ 0xA5 return encoded class _TdsLogin: def __init__(self) -> None: self.client_host_name = "" self.library = "" self.server_name = "" self.instance_name = "" self.user_name = "" self.password = "" self.app_name = "" self.port: int | None = None self.language = "" self.attach_db_file = "" self.tds_version = TDS74 self.database = "" self.bulk_copy = False self.client_lcid = 0 self.use_mars = False self.pid = 0 self.change_password = "" self.client_id = 0 self.cafile: str | None = None self.validate_host = True self.enc_login_only = False self.enc_flag = 0 self.tls_ctx = None self.client_tz: datetime.tzinfo = pytds.tz.local self.option_flag2 = 0 self.connect_timeout = 0.0 self.query_timeout: float | None = None self.blocksize = 4096 self.readonly = False self.load_balancer: LoadBalancer | None = None self.bytes_to_unicode = False self.auth: AuthProtocol | None = None self.servers: deque[Tuple[Any, int | None, str]] = deque() self.server_enc_flag = 0 class _TdsEnv: def __init__(self): self.database = None self.language = None self.charset = None self.autocommit = False # Transaction isolation level self.isolation_level = 0 def _create_exception_by_message( msg: Message, custom_error_msg: str | None = None ) -> ProgrammingError | IntegrityError | OperationalError: msg_no = msg["msgno"] if custom_error_msg is not None: error_msg = custom_error_msg else: error_msg = msg["message"] ex: ProgrammingError | IntegrityError | OperationalError if msg_no in prog_errors: ex = ProgrammingError(error_msg) elif msg_no in integrity_errors: ex = IntegrityError(error_msg) else: ex = OperationalError(error_msg) ex.msg_no = msg["msgno"] ex.text = msg["message"] ex.srvname = msg["server"] ex.procname = msg["proc_name"] ex.number = msg["msgno"] ex.severity = msg["severity"] ex.state = msg["state"] ex.line = msg["line_number"] return ex class Message(TypedDict): marker: int msgno: int state: int severity: int sql_state: int | None priv_msg_type: int message: str server: str proc_name: str line_number: int class Route(TypedDict): server: str port: int class _Results(object): def __init__(self) -> None: self.columns: list[Column] = [] self.row_count = 0 self.description: tuple[tuple[str, Any, None, int, int, int, int], ...] = () pytds-1.15.0/src/pytds/tds_reader.py000066400000000000000000000165241456567501500174070ustar00rootroot00000000000000""" This module implements TdsReader class """ from __future__ import annotations import struct import typing from typing import Tuple, Any from pytds import tds_base from pytds.collate import Collation, ucs2_codec from pytds.tds_base import ( readall, readall_fast, _header, _int_le, _uint_be, _uint_le, _uint8_le, _int8_le, _byte, _smallint_le, _usmallint_le, ) if typing.TYPE_CHECKING: from pytds.tds_session import _TdsSession class ResponseMetadata: """ This class represents response metadata extracted from first response packet. This includes response type and session ID """ def __init__(self): self.type = 0 self.spid = 0 class _TdsReader: """TDS stream reader Provides stream-like interface for TDS packeted stream. Also provides convinience methods to decode primitive data like different kinds of integers etc. """ def __init__( self, tds_session: _TdsSession, transport: tds_base.TransportProtocol, bufsize: int = 4096, ): self._buf = bytearray(b"\x00" * bufsize) self._bufview = memoryview(self._buf) self._pos = len(self._buf) # position in the buffer self._have = 0 # number of bytes read from packet self._size = 0 # size of current packet self._transport = transport self._session = tds_session self._type: int | None = None # value of status field from packet header # 0 - means not last packet # 1 - means last packet self._status = 1 self._spid = 0 @property def session(self): return self._session def set_block_size(self, size: int) -> None: self._buf = bytearray(b"\x00" * size) self._bufview = memoryview(self._buf) def get_block_size(self) -> int: return len(self._buf) @property def packet_type(self) -> int | None: """Type of current packet Possible values are TDS_QUERY, TDS_LOGIN, etc. """ return self._type def stream_finished(self) -> bool: """ Verifies whether the current response packet stream has finished reading. If the function returns True, it indicates that you should invoke the begin_response method to initiate the reading of the next stream. :return: """ if self._pos >= self._size: return self._status == 1 else: return False def read_fast(self, size: int) -> Tuple[bytes, int]: """Faster version of read Instead of returning sliced buffer it returns reference to internal buffer and the offset to this buffer. :param size: Number of bytes to read :returns: Tuple of bytes buffer, and offset in this buffer """ # Current response stream finished if self._pos >= self._size: if self._status == 1: return b"", 0 self._read_packet() offset = self._pos to_read = min(size, self._size - self._pos) self._pos += to_read return self._buf, offset def recv(self, size: int) -> bytes: if self._pos >= self._size: # Current response stream finished if self._status == 1: return b"" self._read_packet() offset = self._pos to_read = min(size, self._size - self._pos) self._pos += to_read return self._buf[offset : offset + to_read] def unpack(self, struc: struct.Struct) -> Tuple[Any, ...]: """Unpacks given structure from stream :param struc: A struct.Struct instance :returns: Result of unpacking """ buf, offset = readall_fast(self, struc.size) return struc.unpack_from(buf, offset) def get_byte(self) -> int: """Reads one byte from stream""" return self.unpack(_byte)[0] def get_smallint(self) -> int: """Reads 16bit signed integer from the stream""" return self.unpack(_smallint_le)[0] def get_usmallint(self) -> int: """Reads 16bit unsigned integer from the stream""" return self.unpack(_usmallint_le)[0] def get_int(self) -> int: """Reads 32bit signed integer from the stream""" return self.unpack(_int_le)[0] def get_uint(self) -> int: """Reads 32bit unsigned integer from the stream""" return self.unpack(_uint_le)[0] def get_uint_be(self) -> int: """Reads 32bit unsigned big-endian integer from the stream""" return self.unpack(_uint_be)[0] def get_uint8(self) -> int: """Reads 64bit unsigned integer from the stream""" return self.unpack(_uint8_le)[0] def get_int8(self) -> int: """Reads 64bit signed integer from the stream""" return self.unpack(_int8_le)[0] def read_ucs2(self, num_chars: int) -> str: """Reads num_chars UCS2 string from the stream""" buf = readall(self, num_chars * 2) return ucs2_codec.decode(buf)[0] def read_str(self, size: int, codec) -> str: """Reads byte string from the stream and decodes it :param size: Size of string in bytes :param codec: Instance of codec to decode string :returns: Unicode string """ return codec.decode(readall(self, size))[0] def get_collation(self) -> Collation: """Reads :class:`Collation` object from stream""" buf = readall(self, Collation.wire_size) return Collation.unpack(buf) def begin_response(self) -> ResponseMetadata: """ This method should be called first before reading anything. It will read first response packet and return its metadata, after that read methods can be called to read contents of the response packet stream until it ends. """ if self._status != 1 or self._pos < self._size: raise RuntimeError( "begin_response was called before previous response was fully consumed" ) self._read_packet() res = ResponseMetadata() res.type = self._type res.spid = self._spid return res def _read_packet(self) -> None: """Reads next TDS packet from the underlying transport Can only be called when transport's read pointer is at the beginning of the packet. """ pos = 0 while pos < _header.size: received = self._transport.recv_into( self._bufview[pos:], _header.size - pos ) if received == 0: raise tds_base.ClosedConnectionError() pos += received self._pos = _header.size self._type, self._status, self._size, self._spid, _ = _header.unpack_from( self._bufview, 0 ) self._have = pos while pos < self._size: received = self._transport.recv_into(self._bufview[pos:], self._size - pos) if received == 0: raise tds_base.ClosedConnectionError() pos += received self._have += received def read_whole_packet(self) -> bytes: """Reads single packet and returns bytes payload of the packet Can only be called when transport's read pointer is at the beginning of the packet. """ # self._read_packet() return readall(self, self._size - _header.size) pytds-1.15.0/src/pytds/tds_session.py000066400000000000000000001773261456567501500176400ustar00rootroot00000000000000""" This module implements TdsSession class """ from __future__ import annotations import codecs import collections.abc import contextlib import datetime import struct import typing import warnings from typing import Callable, Iterable, Any, List from pytds import tds_base, tds_types from pytds.collate import lcid2charset, raw_collation from pytds.tds_base import ( readall, skipall, PreLoginToken, PreLoginEnc, Message, logging_enabled, _create_exception_by_message, output, default, _TdsLogin, tds7_crypt_pass, logger, _Results, _TdsEnv, ) from pytds.tds_reader import _TdsReader, ResponseMetadata from pytds.tds_writer import _TdsWriter from pytds.row_strategies import list_row_strategy, RowStrategy, RowGenerator if typing.TYPE_CHECKING: from pytds.tds_socket import _TdsSocket class _TdsSession: """TDS session This class has the following responsibilities: * Track state of a single TDS session if MARS enabled there could be multiple TDS sessions within one connection. * Provides API to send requests and receive responses * Does serialization of requests and deserialization of responses """ def __init__( self, tds: _TdsSocket, transport: tds_base.TransportProtocol, tzinfo_factory: tds_types.TzInfoFactoryType | None, env: _TdsEnv, bufsize: int, row_strategy: RowStrategy = list_row_strategy, ): self.out_pos = 8 self.res_info: _Results | None = None self.in_cancel = False self.wire_mtx = None self.param_info = None self.has_status = False self.ret_status: int | None = None self.skipped_to_status = False self._transport = transport self._reader = _TdsReader( transport=transport, bufsize=bufsize, tds_session=self ) self._writer = _TdsWriter( transport=transport, bufsize=bufsize, tds_session=self ) self.in_buf_max = 0 self.state = tds_base.TDS_IDLE self._tds = tds self.messages: list[Message] = [] self.rows_affected = -1 self.use_tz = tds.use_tz self._spid = 0 self.tzinfo_factory = tzinfo_factory self.more_rows = False self.done_flags = 0 self.internal_sp_called = 0 self.output_params: dict[int, tds_base.Column] = {} self.authentication: tds_base.AuthProtocol | None = None self.return_value_index = 0 self._out_params_indexes: list[int] = [] self.row: list[Any] | None = None self.end_marker = 0 self._row_strategy = row_strategy self._env = env self._row_convertor: RowGenerator = list @property def autocommit(self): return self._env.autocommit @autocommit.setter def autocommit(self, value: bool): if self._env.autocommit != value: if value: if self._tds.tds72_transaction: self.rollback(cont=False) else: self.begin_tran() self._env.autocommit = value @property def isolation_level(self): return self._env.isolation_level @isolation_level.setter def isolation_level(self, value: int): """ Set transaction isolation level. Will roll back current transaction if it has different isolation level. """ if self._env.isolation_level != value: if self._tds.tds72_transaction: # Setting cont=False to delay reopening of new transaction until # next command execution in case isolation_level changes again self.rollback(cont=False) self._env.isolation_level = value @property def row_strategy(self) -> Callable[[Iterable[str]], Callable[[Iterable[Any]], Any]]: return self._row_strategy @row_strategy.setter def row_strategy( self, value: Callable[[Iterable[str]], Callable[[Iterable[Any]], Any]] ) -> None: warnings.warn( "Changing row_strategy on live connection is now deprecated, you should set it when creating new connection", DeprecationWarning, ) self._row_strategy = value def log_response_message(self, msg): # logging is disabled by default if logging_enabled: logger.info("[%d] %s", self._spid, msg) def __repr__(self): fmt = "<_TdsSession state={} tds={} messages={} rows_affected={} use_tz={} spid={} in_cancel={}>" res = fmt.format( repr(self.state), repr(self._tds), repr(self.messages), repr(self.rows_affected), repr(self.use_tz), repr(self._spid), self.in_cancel, ) return res def raise_db_exception(self) -> None: """Raises exception from last server message This function will skip messages: The statement has been terminated """ if not self.messages: raise tds_base.Error("Request failed, server didn't send error message") msg = None while True: msg = self.messages[-1] if msg["msgno"] == 3621: # the statement has been terminated self.messages = self.messages[:-1] else: break error_msg = " ".join(m["message"] for m in self.messages) ex = _create_exception_by_message(msg, error_msg) raise ex def get_type_info(self, curcol): """Reads TYPE_INFO structure (http://msdn.microsoft.com/en-us/library/dd358284.aspx) :param curcol: An instance of :class:`Column` that will receive read information """ r = self._reader # User defined data type of the column if tds_base.IS_TDS72_PLUS(self): user_type = r.get_uint() else: user_type = r.get_usmallint() curcol.column_usertype = user_type curcol.flags = r.get_usmallint() # Flags type_id = r.get_byte() serializer_class = self._tds.type_factory.get_type_serializer(type_id) curcol.serializer = serializer_class.from_stream(r) def tds7_process_result(self): """Reads and processes COLMETADATA stream This stream contains a list of returned columns. Stream format link: http://msdn.microsoft.com/en-us/library/dd357363.aspx """ self.log_response_message("got COLMETADATA") r = self._reader # read number of columns and allocate the columns structure num_cols = r.get_smallint() # This can be a DUMMY results token from a cursor fetch if num_cols == -1: return self.param_info = None self.has_status = False self.ret_status = None self.skipped_to_status = False self.rows_affected = tds_base.TDS_NO_COUNT self.more_rows = True self.row = [None] * num_cols self.res_info = info = _Results() # # loop through the columns populating COLINFO struct from # server response # header_tuple = [] for col in range(num_cols): curcol = tds_base.Column() info.columns.append(curcol) self.get_type_info(curcol) curcol.column_name = r.read_ucs2(r.get_byte()) precision = curcol.serializer.precision scale = curcol.serializer.scale size = curcol.serializer.size header_tuple.append( ( curcol.column_name, curcol.serializer.get_typeid(), None, size, precision, scale, curcol.flags & tds_base.Column.fNullable, ) ) info.description = tuple(header_tuple) self._setup_row_factory() return info def process_param(self): """Reads and processes RETURNVALUE stream. This stream is used to send OUTPUT parameters from RPC to client. Stream format url: http://msdn.microsoft.com/en-us/library/dd303881.aspx """ self.log_response_message("got RETURNVALUE message") r = self._reader if tds_base.IS_TDS72_PLUS(self): ordinal = r.get_usmallint() else: r.get_usmallint() # ignore size ordinal = self._out_params_indexes[self.return_value_index] name = r.read_ucs2(r.get_byte()) r.get_byte() # 1 - OUTPUT of sp, 2 - result of udf param = tds_base.Column() param.column_name = name self.get_type_info(param) param.value = param.serializer.read(r) self.output_params[ordinal] = param self.return_value_index += 1 def process_cancel(self): """ Process the incoming token stream until it finds an end token DONE with the cancel flag set. At that point the connection should be ready to handle a new query. In case when no cancel request is pending this function does nothing. """ self.log_response_message("got CANCEL message") # silly cases, nothing to do if not self.in_cancel: return while True: while not self._reader.stream_finished(): token_id = self.get_token_id() self.process_token(token_id) if not self.in_cancel: return self.begin_response() def process_msg(self, marker: int) -> None: """Reads and processes ERROR/INFO streams Stream formats: - ERROR: http://msdn.microsoft.com/en-us/library/dd304156.aspx - INFO: http://msdn.microsoft.com/en-us/library/dd303398.aspx :param marker: TDS_ERROR_TOKEN or TDS_INFO_TOKEN """ self.log_response_message("got ERROR/INFO message") r = self._reader r.get_smallint() # size msg: Message = { "marker": marker, "msgno": r.get_int(), "state": r.get_byte(), "severity": r.get_byte(), "sql_state": None, "priv_msg_type": 0, "message": "", "server": "", "proc_name": "", "line_number": 0, } if marker == tds_base.TDS_INFO_TOKEN: msg["priv_msg_type"] = 0 elif marker == tds_base.TDS_ERROR_TOKEN: msg["priv_msg_type"] = 1 else: logger.error('tds_process_msg() called with unknown marker "%d"', marker) msg["message"] = r.read_ucs2(r.get_smallint()) # server name msg["server"] = r.read_ucs2(r.get_byte()) # stored proc name if available msg["proc_name"] = r.read_ucs2(r.get_byte()) msg["line_number"] = ( r.get_int() if tds_base.IS_TDS72_PLUS(self) else r.get_smallint() ) # in case extended error data is sent, we just try to discard it # special case self.messages.append(msg) def process_row(self): """Reads and handles ROW stream. This stream contains list of values of one returned row. Stream format url: http://msdn.microsoft.com/en-us/library/dd357254.aspx """ self.log_response_message("got ROW message") r = self._reader info = self.res_info info.row_count += 1 for i, curcol in enumerate(info.columns): curcol.value = self.row[i] = curcol.serializer.read(r) def process_nbcrow(self): """Reads and handles NBCROW stream. This stream contains list of values of one returned row in a compressed way, introduced in TDS 7.3.B Stream format url: http://msdn.microsoft.com/en-us/library/dd304783.aspx """ self.log_response_message("got NBCROW message") r = self._reader info = self.res_info if not info: self.bad_stream("got row without info") assert len(info.columns) > 0 info.row_count += 1 # reading bitarray for nulls, 1 represent null values for # corresponding fields nbc = readall(r, (len(info.columns) + 7) // 8) for i, curcol in enumerate(info.columns): if tds_base.my_ord(nbc[i // 8]) & (1 << (i % 8)): value = None else: value = curcol.serializer.read(r) self.row[i] = value def process_orderby(self): """Reads and processes ORDER stream Used to inform client by which column dataset is ordered. Stream format url: http://msdn.microsoft.com/en-us/library/dd303317.aspx """ r = self._reader skipall(r, r.get_smallint()) def process_end(self, marker): """Reads and processes DONE/DONEINPROC/DONEPROC streams Stream format urls: - DONE: http://msdn.microsoft.com/en-us/library/dd340421.aspx - DONEINPROC: http://msdn.microsoft.com/en-us/library/dd340553.aspx - DONEPROC: http://msdn.microsoft.com/en-us/library/dd340753.aspx :param marker: Can be TDS_DONE_TOKEN or TDS_DONEINPROC_TOKEN or TDS_DONEPROC_TOKEN """ code_to_str = { tds_base.TDS_DONE_TOKEN: "DONE", tds_base.TDS_DONEINPROC_TOKEN: "DONEINPROC", tds_base.TDS_DONEPROC_TOKEN: "DONEPROC", } self.end_marker = marker self.more_rows = False r = self._reader status = r.get_usmallint() r.get_usmallint() # cur_cmd more_results = status & tds_base.TDS_DONE_MORE_RESULTS != 0 was_cancelled = status & tds_base.TDS_DONE_CANCELLED != 0 done_count_valid = status & tds_base.TDS_DONE_COUNT != 0 if self.res_info: self.res_info.more_results = more_results rows_affected = r.get_int8() if tds_base.IS_TDS72_PLUS(self) else r.get_int() self.log_response_message( "got {} message, more_res={}, cancelled={}, rows_affected={}".format( code_to_str[marker], more_results, was_cancelled, rows_affected ) ) if was_cancelled or (not more_results and not self.in_cancel): self.in_cancel = False self.set_state(tds_base.TDS_IDLE) if done_count_valid: self.rows_affected = rows_affected else: self.rows_affected = -1 self.done_flags = status if ( self.done_flags & tds_base.TDS_DONE_ERROR and not was_cancelled and not self.in_cancel ): self.raise_db_exception() def _ensure_transaction(self) -> None: if not self._env.autocommit and not self._tds.tds72_transaction: self.begin_tran() def process_env_chg(self): """Reads and processes ENVCHANGE stream. Stream info url: http://msdn.microsoft.com/en-us/library/dd303449.aspx """ self.log_response_message("got ENVCHANGE message") r = self._reader size = r.get_smallint() type_id = r.get_byte() if type_id == tds_base.TDS_ENV_SQLCOLLATION: size = r.get_byte() self.conn.collation = r.get_collation() logger.info("switched collation to %s", self.conn.collation) skipall(r, size - 5) # discard old one skipall(r, r.get_byte()) elif type_id == tds_base.TDS_ENV_BEGINTRANS: size = r.get_byte() assert size == 8 self.conn.tds72_transaction = r.get_uint8() # old val, should be 0 skipall(r, r.get_byte()) elif ( type_id == tds_base.TDS_ENV_COMMITTRANS or type_id == tds_base.TDS_ENV_ROLLBACKTRANS ): self.conn.tds72_transaction = 0 # new val, should be 0 skipall(r, r.get_byte()) # old val, should have previous transaction id skipall(r, r.get_byte()) elif type_id == tds_base.TDS_ENV_PACKSIZE: newval = r.read_ucs2(r.get_byte()) r.read_ucs2(r.get_byte()) new_block_size = int(newval) if new_block_size >= 512: # Is possible to have a shrink if server limits packet # size more than what we specified # # Reallocate buffer if possible (strange values from server or out of memory) use older buffer */ self._writer.bufsize = new_block_size elif type_id == tds_base.TDS_ENV_DATABASE: newval = r.read_ucs2(r.get_byte()) logger.info("switched to database %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.database = newval elif type_id == tds_base.TDS_ENV_LANG: newval = r.read_ucs2(r.get_byte()) logger.info("switched language to %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.language = newval elif type_id == tds_base.TDS_ENV_CHARSET: newval = r.read_ucs2(r.get_byte()) logger.info("switched charset to %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.charset = newval remap = {"iso_1": "iso8859-1"} self.conn.server_codec = codecs.lookup(remap.get(newval, newval)) elif type_id == tds_base.TDS_ENV_DB_MIRRORING_PARTNER: newval = r.read_ucs2(r.get_byte()) logger.info("got mirroring partner %s", newval) r.read_ucs2(r.get_byte()) elif type_id == tds_base.TDS_ENV_LCID: lcid = int(r.read_ucs2(r.get_byte())) logger.info("switched lcid to %s", lcid) self.conn.server_codec = codecs.lookup(lcid2charset(lcid)) r.read_ucs2(r.get_byte()) elif type_id == tds_base.TDS_ENV_UNICODE_DATA_SORT_COMP_FLAGS: r.read_ucs2(r.get_byte()) comp_flags = r.read_ucs2(r.get_byte()) self.conn.comp_flags = comp_flags elif type_id == 20: # routing r.get_usmallint() protocol = r.get_byte() protocol_property = r.get_usmallint() alt_server = r.read_ucs2(r.get_usmallint()) logger.info( "got routing info proto=%d proto_prop=%d alt_srv=%s", protocol, protocol_property, alt_server, ) self.conn.route = { "server": alt_server, "port": protocol_property, } # OLDVALUE = 0x00, 0x00 r.get_usmallint() else: logger.warning("unknown env type: %d, skipping", type_id) # discard byte values, not still supported skipall(r, size - 1) def process_auth(self) -> None: """Reads and processes SSPI stream. Stream info: http://msdn.microsoft.com/en-us/library/dd302844.aspx """ r = self._reader w = self._writer pdu_size = r.get_smallint() if not self.authentication: raise tds_base.Error("Got unexpected token") packet = self.authentication.handle_next(readall(r, pdu_size)) if packet: w.write(packet) w.flush() def is_connected(self) -> bool: """ :return: True if transport is connected """ return self._transport.is_connected() # type: ignore # needs fixing def bad_stream(self, msg) -> None: """Called when input stream contains unexpected data. Will close stream and raise :class:`InterfaceError` :param msg: Message for InterfaceError exception. :return: Never returns, always raises exception. """ self.close() raise tds_base.InterfaceError(msg) @property def tds_version(self) -> int: """Returns integer encoded current TDS protocol version""" return self._tds.tds_version @property def conn(self) -> _TdsSocket: """Reference to owning :class:`_TdsSocket`""" return self._tds def close(self) -> None: self._transport.close() def set_state(self, state: int) -> int: """Switches state of the TDS session. It also does state transitions checks. :param state: New state, one of TDS_PENDING/TDS_READING/TDS_IDLE/TDS_DEAD/TDS_QUERING """ prior_state = self.state if state == prior_state: return state if state == tds_base.TDS_PENDING: if prior_state in (tds_base.TDS_READING, tds_base.TDS_QUERYING): self.state = tds_base.TDS_PENDING else: raise tds_base.InterfaceError( "logic error: cannot chage query state from {0} to {1}".format( tds_base.state_names[prior_state], tds_base.state_names[state] ) ) elif state == tds_base.TDS_READING: # transition to READING are valid only from PENDING if self.state != tds_base.TDS_PENDING: raise tds_base.InterfaceError( "logic error: cannot change query state from {0} to {1}".format( tds_base.state_names[prior_state], tds_base.state_names[state] ) ) else: self.state = state elif state == tds_base.TDS_IDLE: if prior_state == tds_base.TDS_DEAD: raise tds_base.InterfaceError( "logic error: cannot change query state from {0} to {1}".format( tds_base.state_names[prior_state], tds_base.state_names[state] ) ) self.state = state elif state == tds_base.TDS_DEAD: self.state = state elif state == tds_base.TDS_QUERYING: if self.state == tds_base.TDS_DEAD: raise tds_base.InterfaceError( "logic error: cannot change query state from {0} to {1}".format( tds_base.state_names[prior_state], tds_base.state_names[state] ) ) elif self.state != tds_base.TDS_IDLE: raise tds_base.InterfaceError( "logic error: cannot change query state from {0} to {1}".format( tds_base.state_names[prior_state], tds_base.state_names[state] ) ) else: self.rows_affected = tds_base.TDS_NO_COUNT self.internal_sp_called = 0 self.state = state else: assert False return self.state @contextlib.contextmanager def querying_context(self, packet_type: int) -> typing.Iterator[None]: """Context manager for querying. Sets state to TDS_QUERYING, and reverts it to TDS_IDLE if exception happens inside managed block, and to TDS_PENDING if managed block succeeds and flushes buffer. """ if self.set_state(tds_base.TDS_QUERYING) != tds_base.TDS_QUERYING: raise tds_base.Error("Couldn't switch to state") self._writer.begin_packet(packet_type) try: yield except: if self.state != tds_base.TDS_DEAD: self.set_state(tds_base.TDS_IDLE) raise else: self.set_state(tds_base.TDS_PENDING) self._writer.flush() def make_param(self, name: str, value: Any) -> tds_base.Param: """Generates instance of :class:`Param` from value and name Value can also be of a special types: - An instance of :class:`Param`, in which case it is just returned. - An instance of :class:`output`, in which case parameter will become an output parameter. - A singleton :var:`default`, in which case default value will be passed into a stored proc. :param name: Name of the parameter, will populate column_name property of returned column. :param value: Value of the parameter, also used to guess the type of parameter. :return: An instance of :class:`Column` """ if isinstance(value, tds_base.Param): value.name = name return value if isinstance(value, tds_base.Column): warnings.warn( "Usage of Column class as parameter is deprecated, use Param class instead", DeprecationWarning, ) return tds_base.Param( name=name, type=value.type, value=value.value, ) param_type = None param_flags = 0 if isinstance(value, output): param_flags |= tds_base.fByRefValue if isinstance(value.type, str): param_type = tds_types.sql_type_by_declaration(value.type) elif value.type: param_type = self.conn.type_inferrer.from_class(value.type) value = value.value if value is default: param_flags |= tds_base.fDefaultValue value = None param_value = value if param_type is None: param_type = self.conn.type_inferrer.from_value(value) param = tds_base.Param( name=name, type=param_type, flags=param_flags, value=param_value ) return param def _convert_params( self, parameters: dict[str, Any] | typing.Iterable[Any] ) -> List[tds_base.Param]: """Converts a dict of list of parameters into a list of :class:`Column` instances. :param parameters: Can be a list of parameter values, or a dict of parameter names to values. :return: A list of :class:`Column` instances. """ if isinstance(parameters, dict): return [self.make_param(name, value) for name, value in parameters.items()] else: params = [] for parameter in parameters: params.append(self.make_param("", parameter)) return params def cancel_if_pending(self) -> None: """Cancels current pending request. Does nothing if no request is pending, otherwise sends cancel request, and waits for response. """ if self.state == tds_base.TDS_IDLE: return if not self.in_cancel: self.put_cancel() self.process_cancel() def submit_rpc( self, rpc_name: tds_base.InternalProc | str, params: List[tds_base.Param], flags: int = 0, ) -> None: """Sends an RPC request. This call will transition session into pending state. If some operation is currently pending on the session, it will be cancelled before sending this request. Spec: http://msdn.microsoft.com/en-us/library/dd357576.aspx :param rpc_name: Name of the RPC to call, can be an instance of :class:`InternalProc` :param params: Stored proc parameters, should be a list of :class:`Column` instances. :param flags: See spec for possible flags. """ logger.info("Sending RPC %s flags=%d", rpc_name, flags) self.messages = [] self.output_params = {} self.cancel_if_pending() self.res_info = None w = self._writer with self.querying_context(tds_base.PacketType.RPC): if tds_base.IS_TDS72_PLUS(self): self._start_query() if tds_base.IS_TDS71_PLUS(self) and isinstance( rpc_name, tds_base.InternalProc ): w.put_smallint(-1) w.put_smallint(rpc_name.proc_id) else: if isinstance(rpc_name, tds_base.InternalProc): proc_name = rpc_name.name else: proc_name = rpc_name w.put_smallint(len(proc_name)) w.write_ucs2(proc_name) # # TODO support flags # bit 0 (1 as flag) in TDS7/TDS5 is "recompile" # bit 1 (2 as flag) in TDS7+ is "no metadata" bit this will prevent sending of column infos # w.put_usmallint(flags) self._out_params_indexes = [] for i, param in enumerate(params): if param.flags & tds_base.fByRefValue: self._out_params_indexes.append(i) w.put_byte(len(param.name)) w.write_ucs2(param.name) # # TODO support other flags (use defaul null/no metadata) # bit 1 (2 as flag) in TDS7+ is "default value" bit # (what's the meaning of "default value" ?) # w.put_byte(param.flags) # TYPE_INFO structure: https://msdn.microsoft.com/en-us/library/dd358284.aspx serializer = self._tds.type_factory.serializer_by_type( sql_type=param.type, collation=self._tds.collation or raw_collation ) type_id = serializer.type w.put_byte(type_id) serializer.write_info(w) serializer.write(w, param.value) def _setup_row_factory(self) -> None: self._row_convertor = list if self.res_info: column_names = [col[0] for col in self.res_info.description] self._row_convertor = self._row_strategy(column_names) def callproc( self, procname: tds_base.InternalProc | str, parameters: dict[str, Any] | typing.Iterable[Any], ) -> list[Any]: self._ensure_transaction() results = list(parameters) conv_parameters = self._convert_params(parameters) self.submit_rpc(procname, conv_parameters, 0) self.begin_response() self.process_rpc() for key, param in self.output_params.items(): results[key] = param.value return results def get_proc_outputs(self) -> list[Any]: """ If stored procedure has result sets and OUTPUT parameters use this method after you processed all result sets to get values of the OUTPUT parameters. :return: A list of output parameter values. """ self.complete_rpc() results = [None] * len(self.output_params.items()) for key, param in self.output_params.items(): results[key] = param.value return results def get_proc_return_status(self) -> int | None: """Last executed stored procedure's return value Returns integer value returned by `RETURN` statement from last executed stored procedure. If no value was not returned or no stored procedure was executed return `None`. """ if not self.has_status: self.find_return_status() return self.ret_status if self.has_status else None def executemany( self, operation: str, params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]], ) -> None: """ Execute same SQL query multiple times for each parameter set in the `params_seq` list. """ counts = [] for params in params_seq: self.execute(operation, params) if self.rows_affected != -1: counts.append(self.rows_affected) if counts: self.rows_affected = sum(counts) def execute( self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, ) -> None: self._ensure_transaction() if params: named_params = {} if isinstance(params, (list, tuple)): names = [] pid = 1 for val in params: if val is None: names.append("NULL") else: name = f"@P{pid}" names.append(name) named_params[name] = val pid += 1 if len(names) == 1: operation = operation % names[0] else: operation = operation % tuple(names) elif isinstance(params, dict): # rename parameters rename: dict[str, Any] = {} pid = 1 for name, value in params.items(): if value is None: rename[name] = "NULL" else: mssql_name = f"@P{pid}" rename[name] = mssql_name named_params[mssql_name] = value pid += 1 operation = operation % rename if named_params: list_named_params = self._convert_params(named_params) param_definition = ",".join( f"{p.name} {p.type.get_declaration()}" for p in list_named_params ) self.submit_rpc( tds_base.SP_EXECUTESQL, [ self.make_param("", operation), self.make_param("", param_definition), ] + list_named_params, 0, ) else: self.submit_plain_query(operation) else: self.submit_plain_query(operation) self.begin_response() self.find_result_or_done() def execute_scalar( self, query_string: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, ) -> Any: """ This method executes SQL query then returns first column of first row or the result. Query can be parametrized, see :func:`execute` method for details. This method is useful if you want just a single value, as in: .. code-block:: conn.execute_scalar('SELECT COUNT(*) FROM employees') This method works in the same way as ``iter(conn).next()[0]``. Remaining rows, if any, can still be iterated after calling this method. """ self.execute(operation=query_string, params=params) row = self._fetchone() if not row: return None return row[0] def submit_plain_query(self, operation: str) -> None: """Sends a plain query to server. This call will transition session into pending state. If some operation is currently pending on the session, it will be cancelled before sending this request. Spec: http://msdn.microsoft.com/en-us/library/dd358575.aspx :param operation: A string representing sql statement. """ self.messages = [] self.cancel_if_pending() self.res_info = None logger.info("Sending query %s", operation[:100]) w = self._writer with self.querying_context(tds_base.PacketType.QUERY): if tds_base.IS_TDS72_PLUS(self): self._start_query() w.write_ucs2(operation) def submit_bulk( self, metadata: list[tds_base.Column], rows: Iterable[collections.abc.Sequence[Any]], ) -> None: """Sends insert bulk command. Spec: http://msdn.microsoft.com/en-us/library/dd358082.aspx :param metadata: A list of :class:`Column` instances. :param rows: A collection of rows, each row is a collection of values. :return: """ logger.info("Sending INSERT BULK") num_cols = len(metadata) w = self._writer serializers = [] with self.querying_context(tds_base.PacketType.BULK): w.put_byte(tds_base.TDS7_RESULT_TOKEN) w.put_usmallint(num_cols) for col in metadata: if tds_base.IS_TDS72_PLUS(self): w.put_uint(col.column_usertype) else: w.put_usmallint(col.column_usertype) w.put_usmallint(col.flags) serializer = col.choose_serializer( type_factory=self._tds.type_factory, collation=self._tds.collation, ) type_id = serializer.type w.put_byte(type_id) serializers.append(serializer) serializer.write_info(w) w.put_byte(len(col.column_name)) w.write_ucs2(col.column_name) for row in rows: w.put_byte(tds_base.TDS_ROW_TOKEN) for i, col in enumerate(metadata): serializers[i].write(w, row[i]) # https://msdn.microsoft.com/en-us/library/dd340421.aspx w.put_byte(tds_base.TDS_DONE_TOKEN) w.put_usmallint(tds_base.TDS_DONE_FINAL) w.put_usmallint(0) # curcmd # row count if tds_base.IS_TDS72_PLUS(self): w.put_int8(0) else: w.put_int(0) def put_cancel(self) -> None: """Sends a cancel request to the server. Switches connection to IN_CANCEL state. """ logger.info("Sending CANCEL") self._writer.begin_packet(tds_base.PacketType.CANCEL) self._writer.flush() self.in_cancel = True _begin_tran_struct_72 = struct.Struct(" None: logger.info("Sending BEGIN TRAN il=%x", self._env.isolation_level) self.submit_begin_tran(isolation_level=self._env.isolation_level) self.process_simple_request() def submit_begin_tran(self, isolation_level: int = 0) -> None: if tds_base.IS_TDS72_PLUS(self): self.messages = [] self.cancel_if_pending() w = self._writer with self.querying_context(tds_base.PacketType.TRANS): self._start_query() w.pack( self._begin_tran_struct_72, 5, # TM_BEGIN_XACT isolation_level, 0, # new transaction name ) else: self.submit_plain_query("BEGIN TRANSACTION") self.conn.tds72_transaction = 1 _commit_rollback_tran_struct72_hdr = struct.Struct(" None: """ Rollback current transaction if it exists. If `cont` parameter is set to true, new transaction will start immediately after current transaction is rolled back """ if self._env.autocommit: return # if not self._conn or not self._conn.is_connected(): # return if not self._tds.tds72_transaction: return logger.info("Sending ROLLBACK TRAN") self.submit_rollback(cont, isolation_level=self._env.isolation_level) prev_timeout = self._tds.sock.gettimeout() self._tds.sock.settimeout(None) try: self.process_simple_request() finally: self._tds.sock.settimeout(prev_timeout) def submit_rollback(self, cont: bool, isolation_level: int = 0) -> None: """ Send transaction rollback request. If `cont` parameter is set to true, new transaction will start immediately after current transaction is rolled back """ if tds_base.IS_TDS72_PLUS(self): self.messages = [] self.cancel_if_pending() w = self._writer with self.querying_context(tds_base.PacketType.TRANS): self._start_query() flags = 0 if cont: flags |= 1 w.pack( self._commit_rollback_tran_struct72_hdr, 8, # TM_ROLLBACK_XACT 0, # transaction name flags, ) if cont: w.pack( self._continue_tran_struct72, isolation_level, 0, # new transaction name ) else: self.submit_plain_query( "IF @@TRANCOUNT > 0 ROLLBACK BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 ROLLBACK" ) self.conn.tds72_transaction = 1 if cont else 0 def commit(self, cont: bool) -> None: if self._env.autocommit: return if not self._tds.tds72_transaction: return logger.info("Sending COMMIT TRAN") self.submit_commit(cont, isolation_level=self._env.isolation_level) prev_timeout = self._tds.sock.gettimeout() self._tds.sock.settimeout(None) try: self.process_simple_request() finally: self._tds.sock.settimeout(prev_timeout) def submit_commit(self, cont: bool, isolation_level: int = 0) -> None: if tds_base.IS_TDS72_PLUS(self): self.messages = [] self.cancel_if_pending() w = self._writer with self.querying_context(tds_base.PacketType.TRANS): self._start_query() flags = 0 if cont: flags |= 1 w.pack( self._commit_rollback_tran_struct72_hdr, 7, # TM_COMMIT_XACT 0, # transaction name flags, ) if cont: w.pack( self._continue_tran_struct72, isolation_level, 0, # new transaction name ) else: self.submit_plain_query( "IF @@TRANCOUNT > 0 COMMIT BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 COMMIT" ) self.conn.tds72_transaction = 1 if cont else 0 _tds72_query_start = struct.Struct(" None: w = self._writer w.pack( _TdsSession._tds72_query_start, 0x16, # total length 0x12, # length 2, # type self.conn.tds72_transaction, 1, # request count ) def send_prelogin(self, login: _TdsLogin) -> None: from . import intversion # https://msdn.microsoft.com/en-us/library/dd357559.aspx instance_name = login.instance_name or "MSSQLServer" instance_name_encoded = instance_name.encode("ascii") if len(instance_name_encoded) > 65490: raise ValueError("Instance name is too long") if tds_base.IS_TDS72_PLUS(self): start_pos = 26 buf = struct.pack( b">BHHBHHBHHBHHBHHB", # netlib version PreLoginToken.VERSION, start_pos, 6, # encryption PreLoginToken.ENCRYPTION, start_pos + 6, 1, # instance PreLoginToken.INSTOPT, start_pos + 6 + 1, len(instance_name_encoded) + 1, # thread id PreLoginToken.THREADID, start_pos + 6 + 1 + len(instance_name_encoded) + 1, 4, # MARS enabled PreLoginToken.MARS, start_pos + 6 + 1 + len(instance_name_encoded) + 1 + 4, 1, # end PreLoginToken.TERMINATOR, ) else: start_pos = 21 buf = struct.pack( b">BHHBHHBHHBHHB", # netlib version PreLoginToken.VERSION, start_pos, 6, # encryption PreLoginToken.ENCRYPTION, start_pos + 6, 1, # instance PreLoginToken.INSTOPT, start_pos + 6 + 1, len(instance_name_encoded) + 1, # thread id PreLoginToken.THREADID, start_pos + 6 + 1 + len(instance_name_encoded) + 1, 4, # end PreLoginToken.TERMINATOR, ) assert start_pos == len(buf) w = self._writer w.begin_packet(tds_base.PacketType.PRELOGIN) w.write(buf) w.put_uint_be(intversion) w.put_usmallint_be(0) # build number # encryption flag w.put_byte(login.enc_flag) w.write(instance_name_encoded) w.put_byte(0) # zero terminate instance_name w.put_int(0) # TODO: change this to thread id attribs: dict[str, str | int | bool] = { "lib_ver": f"{intversion:x}", "enc_flag": f"{login.enc_flag:x}", "inst_name": instance_name, } if tds_base.IS_TDS72_PLUS(self): # MARS (1 enabled) w.put_byte(1 if login.use_mars else 0) attribs["mars"] = login.use_mars logger.info( "Sending PRELOGIN %s", " ".join(f"{n}={v!r}" for n, v in attribs.items()) ) w.flush() def begin_response(self) -> ResponseMetadata: """Begins reading next response from server. If timeout happens during reading of first packet will send cancellation message. """ try: return self._reader.begin_response() except tds_base.TimeoutError: self.put_cancel() raise def process_prelogin(self, login: _TdsLogin) -> None: # https://msdn.microsoft.com/en-us/library/dd357559.aspx resp_header = self.begin_response() p = self._reader.read_whole_packet() size = len(p) if size <= 0 or resp_header.type != tds_base.PacketType.REPLY: self.bad_stream( "Invalid packet type: {0}, expected REPLY(4)".format( self._reader.packet_type ) ) self.parse_prelogin(octets=p, login=login) def parse_prelogin(self, octets: bytes, login: _TdsLogin) -> None: from . import tls # https://msdn.microsoft.com/en-us/library/dd357559.aspx size = len(octets) p = octets # default 2, no certificate, no encryptption crypt_flag = 2 i = 0 byte_struct = struct.Struct("B") off_len_struct = struct.Struct(">HH") prod_version_struct = struct.Struct(">LH") while True: if i >= size: self.bad_stream("Invalid size of PRELOGIN structure") (type_id,) = byte_struct.unpack_from(p, i) if type_id == PreLoginToken.TERMINATOR: break if i + 4 > size: self.bad_stream("Invalid size of PRELOGIN structure") off, length = off_len_struct.unpack_from(p, i + 1) if off > size or off + length > size: self.bad_stream("Invalid offset in PRELOGIN structure") if type_id == PreLoginToken.VERSION: self.conn.server_library_version = prod_version_struct.unpack_from( p, off ) elif type_id == PreLoginToken.ENCRYPTION and length >= 1: (crypt_flag,) = byte_struct.unpack_from(p, off) elif type_id == PreLoginToken.MARS: self.conn._mars_enabled = bool(byte_struct.unpack_from(p, off)[0]) elif type_id == PreLoginToken.INSTOPT: # ignore instance name mismatch pass i += 5 logger.info( "Got PRELOGIN response crypt=%x mars=%d", crypt_flag, self.conn._mars_enabled, ) # if server do not has certificate do normal login login.server_enc_flag = crypt_flag if crypt_flag == PreLoginEnc.ENCRYPT_OFF: if login.enc_flag == PreLoginEnc.ENCRYPT_ON: self.bad_stream("Server returned unexpected ENCRYPT_ON value") else: # encrypt login packet only tls.establish_channel(self) elif crypt_flag == PreLoginEnc.ENCRYPT_ON: # encrypt entire connection tls.establish_channel(self) elif crypt_flag == PreLoginEnc.ENCRYPT_REQ: if login.enc_flag == PreLoginEnc.ENCRYPT_NOT_SUP: # connection terminated by server and client raise tds_base.Error( "Client does not have encryption enabled but it is required by server, " "enable encryption and try connecting again" ) else: # encrypt entire connection tls.establish_channel(self) elif crypt_flag == PreLoginEnc.ENCRYPT_NOT_SUP: if login.enc_flag == PreLoginEnc.ENCRYPT_ON: # connection terminated by server and client raise tds_base.Error( "You requested encryption but it is not supported by server" ) # do not encrypt anything else: self.bad_stream( "Unexpected value of enc_flag returned by server: {}".format(crypt_flag) ) def tds7_send_login(self, login: _TdsLogin) -> None: # https://msdn.microsoft.com/en-us/library/dd304019.aspx option_flag2 = login.option_flag2 user_name = login.user_name if len(user_name) > 128: raise ValueError("User name should be no longer that 128 characters") if len(login.password) > 128: raise ValueError("Password should be not longer than 128 characters") if len(login.change_password) > 128: raise ValueError("Password should be not longer than 128 characters") if len(login.client_host_name) > 128: raise ValueError("Host name should be not longer than 128 characters") if len(login.app_name) > 128: raise ValueError("App name should be not longer than 128 characters") if len(login.server_name) > 128: raise ValueError("Server name should be not longer than 128 characters") if len(login.database) > 128: raise ValueError("Database name should be not longer than 128 characters") if len(login.language) > 128: raise ValueError("Language should be not longer than 128 characters") if len(login.attach_db_file) > 260: raise ValueError("File path should be not longer than 260 characters") w = self._writer w.begin_packet(tds_base.PacketType.LOGIN) self.authentication = None current_pos = 86 + 8 if tds_base.IS_TDS72_PLUS(self) else 86 client_host_name = login.client_host_name login.client_host_name = client_host_name packet_size = ( current_pos + ( len(client_host_name) + len(login.app_name) + len(login.server_name) + len(login.library) + len(login.language) + len(login.database) ) * 2 ) if login.auth: self.authentication = login.auth auth_packet = login.auth.create_packet() packet_size += len(auth_packet) else: auth_packet = b"" packet_size += (len(user_name) + len(login.password)) * 2 w.put_int(packet_size) w.put_uint(login.tds_version) w.put_int(login.blocksize) from . import intversion w.put_uint(intversion) w.put_int(login.pid) w.put_uint(0) # connection id option_flag1 = ( tds_base.TDS_SET_LANG_ON | tds_base.TDS_USE_DB_NOTIFY | tds_base.TDS_INIT_DB_FATAL ) if not login.bulk_copy: option_flag1 |= tds_base.TDS_DUMPLOAD_OFF w.put_byte(option_flag1) if self.authentication: option_flag2 |= tds_base.TDS_INTEGRATED_SECURITY_ON w.put_byte(option_flag2) type_flags = 0 if login.readonly: type_flags |= tds_base.TDS_FREADONLY_INTENT w.put_byte(type_flags) option_flag3 = tds_base.TDS_UNKNOWN_COLLATION_HANDLING w.put_byte(option_flag3 if tds_base.IS_TDS73_PLUS(self) else 0) mins_fix = ( int( ( login.client_tz.utcoffset(datetime.datetime.now()) or datetime.timedelta() ).total_seconds() ) // 60 ) logger.info( "Sending LOGIN tds_ver=%x bufsz=%d pid=%d opt1=%x opt2=%x opt3=%x cli_tz=%d cli_lcid=%s " "cli_host=%s lang=%s db=%s", login.tds_version, w.bufsize, login.pid, option_flag1, option_flag2, option_flag3, mins_fix, login.client_lcid, client_host_name, login.language, login.database, ) w.put_int(mins_fix) w.put_int(login.client_lcid) w.put_smallint(current_pos) w.put_smallint(len(client_host_name)) current_pos += len(client_host_name) * 2 if self.authentication: w.put_smallint(0) w.put_smallint(0) w.put_smallint(0) w.put_smallint(0) else: w.put_smallint(current_pos) w.put_smallint(len(user_name)) current_pos += len(user_name) * 2 w.put_smallint(current_pos) w.put_smallint(len(login.password)) current_pos += len(login.password) * 2 w.put_smallint(current_pos) w.put_smallint(len(login.app_name)) current_pos += len(login.app_name) * 2 # server name w.put_smallint(current_pos) w.put_smallint(len(login.server_name)) current_pos += len(login.server_name) * 2 # reserved w.put_smallint(0) w.put_smallint(0) # library name w.put_smallint(current_pos) w.put_smallint(len(login.library)) current_pos += len(login.library) * 2 # language w.put_smallint(current_pos) w.put_smallint(len(login.language)) current_pos += len(login.language) * 2 # database name w.put_smallint(current_pos) w.put_smallint(len(login.database)) current_pos += len(login.database) * 2 # ClientID client_id = struct.pack(">Q", login.client_id)[2:] w.write(client_id) # authentication w.put_smallint(current_pos) w.put_smallint(len(auth_packet)) current_pos += len(auth_packet) # db file w.put_smallint(current_pos) w.put_smallint(len(login.attach_db_file)) current_pos += len(login.attach_db_file) * 2 if tds_base.IS_TDS72_PLUS(self): # new password w.put_smallint(current_pos) w.put_smallint(len(login.change_password)) # sspi long w.put_int(0) w.write_ucs2(client_host_name) if not self.authentication: w.write_ucs2(user_name) w.write(tds7_crypt_pass(login.password)) w.write_ucs2(login.app_name) w.write_ucs2(login.server_name) w.write_ucs2(login.library) w.write_ucs2(login.language) w.write_ucs2(login.database) if self.authentication: w.write(auth_packet) w.write_ucs2(login.attach_db_file) w.write_ucs2(login.change_password) w.flush() _SERVER_TO_CLIENT_MAPPING = { 0x07000000: tds_base.TDS70, 0x07010000: tds_base.TDS71, 0x71000001: tds_base.TDS71rev1, tds_base.TDS72: tds_base.TDS72, tds_base.TDS73A: tds_base.TDS73A, tds_base.TDS73B: tds_base.TDS73B, tds_base.TDS74: tds_base.TDS74, } def process_login_tokens(self) -> bool: r = self._reader succeed = False while True: # When handling login requests that involve special mechanisms such as SSPI, # it's crucial to be aware that multiple response streams may be generated. # Therefore, it becomes necessary to iterate through these streams during # the response processing phase. if r.stream_finished(): r.begin_response() marker = r.get_byte() if marker == tds_base.TDS_LOGINACK_TOKEN: # https://msdn.microsoft.com/en-us/library/dd340651.aspx succeed = True size = r.get_smallint() r.get_byte() # interface version = r.get_uint_be() self.conn.tds_version = self._SERVER_TO_CLIENT_MAPPING.get( version, version ) if not tds_base.IS_TDS7_PLUS(self): self.bad_stream("Only TDS 7.0 and higher are supported") # get server product name # ignore product name length, some servers seem to set it incorrectly r.get_byte() size -= 10 self.conn.product_name = r.read_ucs2(size // 2) product_version = r.get_uint_be() logger.info( "Got LOGINACK tds_ver=%x srv_name=%s srv_ver=%x", self.conn.tds_version, self.conn.product_name, product_version, ) # MSSQL 6.5 and 7.0 seem to return strange values for this # using TDS 4.2, something like 5F 06 32 FF for 6.50 self.conn.product_version = product_version if self.authentication: self.authentication.close() self.authentication = None else: self.process_token(marker) if marker == tds_base.TDS_DONE_TOKEN: break return succeed def process_returnstatus(self) -> None: self.log_response_message("got RETURNSTATUS message") self.ret_status = self._reader.get_int() self.has_status = True def process_token(self, marker: int) -> Any: handler = _token_map.get(marker) if not handler: self.bad_stream(f"Invalid TDS marker: {marker}({marker:x})") return return handler(self) def get_token_id(self) -> int: self.set_state(tds_base.TDS_READING) try: marker = self._reader.get_byte() except tds_base.TimeoutError: self.set_state(tds_base.TDS_PENDING) raise except: self._tds.close() raise return marker def process_simple_request(self) -> None: self.begin_response() while True: marker = self.get_token_id() if marker in ( tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN, ): self.process_end(marker) if not self.done_flags & tds_base.TDS_DONE_MORE_RESULTS: return else: self.process_token(marker) def next_set(self) -> bool | None: while self.more_rows: self.next_row() if self.state == tds_base.TDS_IDLE: return False if self.find_result_or_done(): return True return None def fetchone(self) -> Any | None: row = self._fetchone() if row is None: return None else: return self._row_convertor(row) def _fetchone(self) -> list[Any] | None: if self.res_info is None: raise tds_base.ProgrammingError( "Previous statement didn't produce any results" ) if self.skipped_to_status: raise tds_base.ProgrammingError( "Unable to fetch any rows after accessing return_status" ) if not self.next_row(): return None return self.row def next_row(self) -> bool: if not self.more_rows: return False while True: marker = self.get_token_id() if marker in (tds_base.TDS_ROW_TOKEN, tds_base.TDS_NBC_ROW_TOKEN): self.process_token(marker) return True elif marker in ( tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN, ): self.process_end(marker) return False else: self.process_token(marker) def find_result_or_done(self) -> bool: self.done_flags = 0 while True: marker = self.get_token_id() if marker == tds_base.TDS7_RESULT_TOKEN: self.process_token(marker) return True elif marker in ( tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN, ): self.process_end(marker) if self.done_flags & tds_base.TDS_DONE_MORE_RESULTS: if self.done_flags & tds_base.TDS_DONE_COUNT: return True else: return False else: self.process_token(marker) def process_rpc(self) -> bool: self.done_flags = 0 self.return_value_index = 0 while True: marker = self.get_token_id() if marker == tds_base.TDS7_RESULT_TOKEN: self.process_token(marker) return True elif marker in (tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN): self.process_end(marker) if ( self.done_flags & tds_base.TDS_DONE_MORE_RESULTS and not self.done_flags & tds_base.TDS_DONE_COUNT ): # skip results that don't event have rowcount continue return False else: self.process_token(marker) def complete_rpc(self) -> None: # go through all result sets while self.next_set(): pass def find_return_status(self) -> None: self.skipped_to_status = True while True: marker = self.get_token_id() self.process_token(marker) if marker == tds_base.TDS_RETURNSTATUS_TOKEN: return def process_tabname(self): """ Processes TABNAME token Ref: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/140e3348-da08-409a-b6c3-f0fc9cee2d6e """ r = self._reader total_length = r.get_smallint() if not tds_base.IS_TDS71_PLUS(self): r.get_smallint() # name length tds_base.skipall(r, total_length) def process_colinfo(self): """ Process COLNAME token Ref: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/aa8466c5-ca3d-48ca-a638-7c1becebe754 """ r = self._reader total_length = r.get_smallint() tds_base.skipall(r, total_length) _token_map = { tds_base.TDS_AUTH_TOKEN: _TdsSession.process_auth, tds_base.TDS_ENVCHANGE_TOKEN: _TdsSession.process_env_chg, tds_base.TDS_DONE_TOKEN: lambda self: self.process_end(tds_base.TDS_DONE_TOKEN), tds_base.TDS_DONEPROC_TOKEN: lambda self: self.process_end( tds_base.TDS_DONEPROC_TOKEN ), tds_base.TDS_DONEINPROC_TOKEN: lambda self: self.process_end( tds_base.TDS_DONEINPROC_TOKEN ), tds_base.TDS_ERROR_TOKEN: lambda self: self.process_msg(tds_base.TDS_ERROR_TOKEN), tds_base.TDS_INFO_TOKEN: lambda self: self.process_msg(tds_base.TDS_INFO_TOKEN), tds_base.TDS_CAPABILITY_TOKEN: lambda self: self.process_msg( tds_base.TDS_CAPABILITY_TOKEN ), tds_base.TDS_PARAM_TOKEN: lambda self: self.process_param(), tds_base.TDS7_RESULT_TOKEN: lambda self: self.tds7_process_result(), tds_base.TDS_ROW_TOKEN: lambda self: self.process_row(), tds_base.TDS_NBC_ROW_TOKEN: lambda self: self.process_nbcrow(), tds_base.TDS_ORDERBY_TOKEN: lambda self: self.process_orderby(), tds_base.TDS_RETURNSTATUS_TOKEN: lambda self: self.process_returnstatus(), tds_base.TDS_TABNAME_TOKEN: lambda self: self.process_tabname(), tds_base.TDS_COLINFO_TOKEN: lambda self: self.process_colinfo(), } pytds-1.15.0/src/pytds/tds_socket.py000066400000000000000000000136141456567501500174320ustar00rootroot00000000000000""" This module implements TdsSocket class """ from __future__ import annotations import logging import datetime from . import tds_base from . import tds_types from . import tls from .tds_base import PreLoginEnc, _TdsEnv, _TdsLogin, Route from .row_strategies import list_row_strategy from .smp import SmpManager # _token_map is needed by sqlalchemy_pytds connector from .tds_session import ( _TdsSession, ) logger = logging.getLogger(__name__) class _TdsSocket: """ This class represents root TDS connection if MARS is used it can have multiple sessions represented by _TdsSession class if MARS is not used it would have single _TdsSession instance """ def __init__( self, sock: tds_base.TransportProtocol, login: _TdsLogin, tzinfo_factory: tds_types.TzInfoFactoryType | None = None, row_strategy=list_row_strategy, use_tz: datetime.tzinfo | None = None, autocommit=False, isolation_level=0, ): self._is_connected = False self.env = _TdsEnv() self.env.isolation_level = isolation_level self.collation = None self.tds72_transaction = 0 self._mars_enabled = False self.sock = sock self.bufsize = login.blocksize self.use_tz = use_tz self.tds_version = login.tds_version self.type_factory = tds_types.SerializerFactory(self.tds_version) self._tzinfo_factory = tzinfo_factory self._smp_manager: SmpManager | None = None self._main_session = _TdsSession( tds=self, transport=sock, tzinfo_factory=tzinfo_factory, row_strategy=row_strategy, env=self.env, # initially we use fixed bufsize # it may be updated later if server specifies different block size bufsize=4096, ) self._login = login self.route: Route | None = None self._row_strategy = row_strategy self.env.autocommit = autocommit self.query_timeout = login.query_timeout self.type_inferrer = tds_types.TdsTypeInferrer( type_factory=self.type_factory, collation=self.collation, bytes_to_unicode=self._login.bytes_to_unicode, allow_tz=not self.use_tz, ) self.server_library_version = (0, 0) self.product_name = "" self.product_version = 0 def __repr__(self) -> str: fmt = "<_TdsSocket tran={} mars={} tds_version={} use_tz={}>" return fmt.format( self.tds72_transaction, self._mars_enabled, self.tds_version, self.use_tz ) def login(self) -> Route | None: self._login.server_enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP if tds_base.IS_TDS71_PLUS(self._main_session): self._main_session.send_prelogin(self._login) self._main_session.process_prelogin(self._login) self._main_session.tds7_send_login(self._login) if self._login.server_enc_flag == PreLoginEnc.ENCRYPT_OFF: tls.revert_to_clear(self._main_session) self._main_session.begin_response() if not self._main_session.process_login_tokens(): self._main_session.raise_db_exception() if self.route is not None: return self.route # update block size if server returned different one if ( self._main_session._writer.bufsize != self._main_session._reader.get_block_size() ): self._main_session._reader.set_block_size( self._main_session._writer.bufsize ) self.type_factory = tds_types.SerializerFactory(self.tds_version) self.type_inferrer = tds_types.TdsTypeInferrer( type_factory=self.type_factory, collation=self.collation, bytes_to_unicode=self._login.bytes_to_unicode, allow_tz=not self.use_tz, ) if self._mars_enabled: self._smp_manager = SmpManager(self.sock) self._main_session = _TdsSession( tds=self, bufsize=self.bufsize, transport=self._smp_manager.create_session(), tzinfo_factory=self._tzinfo_factory, row_strategy=self._row_strategy, env=self.env, ) self._is_connected = True q = [] if self._login.database and self.env.database != self._login.database: q.append("use " + tds_base.tds_quote_id(self._login.database)) if q: self._main_session.submit_plain_query("".join(q)) self._main_session.process_simple_request() return None @property def mars_enabled(self) -> bool: return self._mars_enabled @property def main_session(self) -> _TdsSession: return self._main_session def create_session(self) -> _TdsSession: if not self._smp_manager: raise RuntimeError( "Calling create_session on a non-MARS connection does not work" ) return _TdsSession( tds=self, transport=self._smp_manager.create_session(), tzinfo_factory=self._tzinfo_factory, row_strategy=self._row_strategy, bufsize=self.bufsize, env=self.env, ) def is_connected(self) -> bool: return self._is_connected def close(self) -> None: self._is_connected = False if self.sock is not None: self.sock.close() if self._smp_manager: self._smp_manager.transport_closed() self._main_session.state = tds_base.TDS_DEAD if self._main_session.authentication: self._main_session.authentication.close() self._main_session.authentication = None def close_all_mars_sessions(self) -> None: if self._smp_manager: self._smp_manager.close_all_sessions(keep=self.main_session._transport) pytds-1.15.0/src/pytds/tds_types.py000066400000000000000000002460031456567501500173060ustar00rootroot00000000000000""" This module implements various data types supported by Microsoft SQL Server """ from __future__ import annotations import itertools import datetime import decimal import struct import re import uuid import functools from io import StringIO, BytesIO from typing import Callable from pytds.tds_base import read_chunks from . import tds_base from .collate import ucs2_codec, raw_collation from . import tz _flt4_struct = struct.Struct("f") _flt8_struct = struct.Struct("d") _utc = tz.utc TzInfoFactoryType = Callable[[int], datetime.tzinfo] def _applytz(dt, tzinfo): if not tzinfo: return dt dt = dt.replace(tzinfo=tzinfo) return dt def _decode_num(buf): """Decodes little-endian integer from buffer Buffer can be of any size """ return functools.reduce( lambda acc, val: acc * 256 + tds_base.my_ord(val), reversed(buf), 0 ) class PlpReader(object): """Partially length prefixed reader Spec: http://msdn.microsoft.com/en-us/library/dd340469.aspx """ def __init__(self, r): """ :param r: An instance of :class:`_TdsReader` """ self._rdr = r size = r.get_uint8() self._size = size def is_null(self): """ :return: True if stored value is NULL """ return self._size == tds_base.PLP_NULL def is_unknown_len(self): """ :return: True if total size is unknown upfront """ return self._size == tds_base.PLP_UNKNOWN def size(self): """ :return: Total size in bytes if is_uknown_len and is_null are both False """ return self._size def chunks(self): """Generates chunks from stream, each chunk is an instace of bytes.""" if self.is_null(): return total = 0 while True: chunk_len = self._rdr.get_uint() if chunk_len == 0: if not self.is_unknown_len() and total != self._size: msg = ( "PLP actual length (%d) doesn't match reported length (%d)" % (total, self._size) ) self._rdr.session.bad_stream(msg) return total += chunk_len left = chunk_len while left: buf = self._rdr.recv(left) yield buf left -= len(buf) class _StreamChunkedHandler(object): def __init__(self, stream): self.stream = stream def add_chunk(self, val): self.stream.write(val) def end(self): return self.stream class _DefaultChunkedHandler(object): def __init__(self, stream): self.stream = stream def add_chunk(self, val): self.stream.write(val) def end(self): value = self.stream.getvalue() self.stream.seek(0) self.stream.truncate() return value def __eq__(self, other): return self.stream.getvalue() == other.stream.getvalue() def __ne__(self, other): return not self.__eq__(other) class SqlTypeMetaclass(tds_base.CommonEqualityMixin): def __repr__(self): return "".format(self.get_declaration()) def get_declaration(self): raise NotImplementedError() class ImageType(SqlTypeMetaclass): def get_declaration(self): return "IMAGE" class BinaryType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "BINARY({})".format(self._size) class VarBinaryType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "VARBINARY({})".format(self._size) class VarBinaryMaxType(SqlTypeMetaclass): def get_declaration(self): return "VARBINARY(MAX)" class CharType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "CHAR({})".format(self._size) class VarCharType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "VARCHAR({})".format(self._size) class VarCharMaxType(SqlTypeMetaclass): def get_declaration(self): return "VARCHAR(MAX)" class NCharType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "NCHAR({})".format(self._size) class NVarCharType(SqlTypeMetaclass): def __init__(self, size=30): self._size = size @property def size(self): return self._size def get_declaration(self): return "NVARCHAR({})".format(self._size) class NVarCharMaxType(SqlTypeMetaclass): def get_declaration(self): return "NVARCHAR(MAX)" class TextType(SqlTypeMetaclass): def get_declaration(self): return "TEXT" class NTextType(SqlTypeMetaclass): def get_declaration(self): return "NTEXT" class XmlType(SqlTypeMetaclass): def get_declaration(self): return "XML" class SmallMoneyType(SqlTypeMetaclass): def get_declaration(self): return "SMALLMONEY" class MoneyType(SqlTypeMetaclass): def get_declaration(self): return "MONEY" class DecimalType(SqlTypeMetaclass): def __init__(self, precision=18, scale=0): self._precision = precision self._scale = scale @classmethod def from_value(cls, value): if not (-(10**38) + 1 <= value <= 10**38 - 1): raise tds_base.DataError("Decimal value is out of range") with decimal.localcontext() as context: context.prec = 38 value = value.normalize() _, digits, exp = value.as_tuple() if exp > 0: scale = 0 prec = len(digits) + exp else: scale = -exp prec = max(len(digits), scale) return cls(precision=prec, scale=scale) @property def precision(self): return self._precision @property def scale(self): return self._scale def get_declaration(self): return "DECIMAL({}, {})".format(self._precision, self._scale) class UniqueIdentifierType(SqlTypeMetaclass): def get_declaration(self): return "UNIQUEIDENTIFIER" class VariantType(SqlTypeMetaclass): def get_declaration(self): return "SQL_VARIANT" class SqlValueMetaclass(tds_base.CommonEqualityMixin): pass class BaseTypeSerializer(tds_base.CommonEqualityMixin): """Base type for TDS data types. All TDS types should derive from it. In addition actual types should provide the following: - type - class variable storing type identifier """ type = 0 def __init__(self, precision=None, scale=None, size=None): self._precision = precision self._scale = scale self._size = size @property def precision(self): return self._precision @property def scale(self): return self._scale @property def size(self): return self._size def get_typeid(self): """Returns type identifier of type.""" return self.type @classmethod def from_stream(cls, r): """Class method that reads and returns a type instance. :param r: An instance of :class:`_TdsReader` to read type from. Should be implemented in actual types. """ raise NotImplementedError def write_info(self, w): """Writes type info into w stream. :param w: An instance of :class:`_TdsWriter` to write into. Should be symmetrical to from_stream method. Should be implemented in actual types. """ raise NotImplementedError def write(self, w, value): """Writes type's value into stream :param w: An instance of :class:`_TdsWriter` to write into. :param value: A value to be stored, should be compatible with the type Should be implemented in actual types. """ raise NotImplementedError def read(self, r): """Reads value from the stream. :param r: An instance of :class:`_TdsReader` to read value from. :return: A read value. Should be implemented in actual types. """ raise NotImplementedError def set_chunk_handler(self, chunk_handler): raise ValueError("Column type does not support chunk handler") class BasePrimitiveTypeSerializer(BaseTypeSerializer): """Base type for primitive TDS data types. Primitive type is a fixed size type with no type arguments. All primitive TDS types should derive from it. In addition actual types should provide the following: - type - class variable storing type identifier - declaration - class variable storing name of sql type - isntance - class variable storing instance of class """ def write(self, w, value): raise NotImplementedError def read(self, r): raise NotImplementedError instance: BaseTypeSerializer | None = None @classmethod def from_stream(cls, r): return cls.instance def write_info(self, w): pass class BaseTypeSerializerN(BaseTypeSerializer): """Base type for nullable TDS data types. All nullable TDS types should derive from it. In addition actual types should provide the following: - type - class variable storing type identifier - subtypes - class variable storing dict {subtype_size: subtype_instance} """ subtypes: dict[int, BaseTypeSerializer] = {} def __init__(self, size): super(BaseTypeSerializerN, self).__init__(size=size) assert size in self.subtypes self._current_subtype = self.subtypes[size] def get_typeid(self): return self._current_subtype.get_typeid() @classmethod def from_stream(cls, r): size = r.get_byte() if size not in cls.subtypes: raise tds_base.InterfaceError("Invalid %s size" % cls.type, size) return cls(size) def write_info(self, w): w.put_byte(self.size) def read(self, r): size = r.get_byte() if size == 0: return None if size not in self.subtypes: raise r.session.bad_stream("Invalid %s size" % self.type, size) return self.subtypes[size].read(r) def write(self, w, val): if val is None: w.put_byte(0) return w.put_byte(self.size) self._current_subtype.write(w, val) class BitType(SqlTypeMetaclass): type = tds_base.SYBBITN def get_declaration(self): return "BIT" class TinyIntType(SqlTypeMetaclass): type = tds_base.SYBINTN size = 1 def get_declaration(self): return "TINYINT" class SmallIntType(SqlTypeMetaclass): type = tds_base.SYBINTN size = 2 def get_declaration(self): return "SMALLINT" class IntType(SqlTypeMetaclass): """ Integer type, corresponds to `INT `_ type in the MSSQL server. """ type = tds_base.SYBINTN size = 4 def get_declaration(self): return "INT" class BigIntType(SqlTypeMetaclass): type = tds_base.SYBINTN size = 8 def get_declaration(self): return "BIGINT" class RealType(SqlTypeMetaclass): def get_declaration(self): return "REAL" class FloatType(SqlTypeMetaclass): def get_declaration(self): return "FLOAT" class BitSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBBIT declaration = "BIT" def write(self, w, value): w.put_byte(1 if value else 0) def read(self, r): return bool(r.get_byte()) BitSerializer.instance = bit_serializer = BitSerializer() class BitNSerializer(BaseTypeSerializerN): type = tds_base.SYBBITN subtypes = {1: bit_serializer} def __init__(self, typ): super(BitNSerializer, self).__init__(size=1) self._typ = typ def __repr__(self): return "BitNSerializer({})".format(self._typ) # BitNSerializer.instance = BitNSerializer(BitType()) class TinyIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT1 declaration = "TINYINT" def write(self, w, val): w.put_byte(val) def read(self, r): return r.get_byte() TinyIntSerializer.instance = tiny_int_serializer = TinyIntSerializer() class SmallIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT2 declaration = "SMALLINT" def write(self, w, val): w.put_smallint(val) def read(self, r): return r.get_smallint() SmallIntSerializer.instance = small_int_serializer = SmallIntSerializer() class IntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT4 declaration = "INT" def write(self, w, val): w.put_int(val) def read(self, r): return r.get_int() IntSerializer.instance = int_serializer = IntSerializer() class BigIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT8 declaration = "BIGINT" def write(self, w, val): w.put_int8(val) def read(self, r): return r.get_int8() BigIntSerializer.instance = big_int_serializer = BigIntSerializer() class IntNSerializer(BaseTypeSerializerN): type = tds_base.SYBINTN subtypes = { 1: tiny_int_serializer, 2: small_int_serializer, 4: int_serializer, 8: big_int_serializer, } type_by_size = { 1: TinyIntType(), 2: SmallIntType(), 4: IntType(), 8: BigIntType(), } def __init__(self, typ): super(IntNSerializer, self).__init__(size=typ.size) self._typ = typ @classmethod def from_stream(cls, r): size = r.get_byte() if size not in cls.subtypes: raise tds_base.InterfaceError("Invalid %s size" % cls.type, size) return cls(cls.type_by_size[size]) def __repr__(self): return "IntN({})".format(self.size) class RealSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBREAL declaration = "REAL" def write(self, w, val): w.pack(_flt4_struct, val) def read(self, r): return r.unpack(_flt4_struct)[0] RealSerializer.instance = real_serializer = RealSerializer() class FloatSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBFLT8 declaration = "FLOAT" def write(self, w, val): w.pack(_flt8_struct, val) def read(self, r): return r.unpack(_flt8_struct)[0] FloatSerializer.instance = float_serializer = FloatSerializer() class FloatNSerializer(BaseTypeSerializerN): type = tds_base.SYBFLTN subtypes = { 4: real_serializer, 8: float_serializer, } class VarChar(SqlValueMetaclass): def __init__(self, val, collation=raw_collation): self._val = val self._collation = collation @property def collation(self): return self._collation @property def val(self): return self._val def __str__(self): return self._val class VarChar70Serializer(BaseTypeSerializer): type = tds_base.XSYBVARCHAR def __init__(self, size, collation=raw_collation, codec=None): super(VarChar70Serializer, self).__init__(size=size) self._collation = collation if codec: self._codec = codec else: self._codec = collation.get_codec() @classmethod def from_stream(cls, r): size = r.get_smallint() return cls(size, codec=r.session.conn.server_codec) def write_info(self, w): w.put_smallint(self.size) def write(self, w, val): if val is None: w.put_smallint(-1) else: if w._tds._tds._login.bytes_to_unicode: val = tds_base.force_unicode(val) if isinstance(val, str): val, _ = self._codec.encode(val) w.put_smallint(len(val)) w.write(val) def read(self, r): size = r.get_smallint() if size < 0: return None if r._session._tds._login.bytes_to_unicode: return r.read_str(size, self._codec) else: return tds_base.readall(r, size) class VarChar71Serializer(VarChar70Serializer): @classmethod def from_stream(cls, r): size = r.get_smallint() collation = r.get_collation() return cls(size, collation) def write_info(self, w): super(VarChar71Serializer, self).write_info(w) w.put_collation(self._collation) class VarChar72Serializer(VarChar71Serializer): @classmethod def from_stream(cls, r): size = r.get_usmallint() collation = r.get_collation() if size == 0xFFFF: return VarCharMaxSerializer(collation) return cls(size, collation) class VarCharMaxSerializer(VarChar72Serializer): def __init__(self, collation=raw_collation): super(VarChar72Serializer, self).__init__(0, collation) self._chunk_handler = None def write_info(self, w): w.put_usmallint(tds_base.PLP_MARKER) w.put_collation(self._collation) def write(self, w, val): if val is None: w.put_uint8(tds_base.PLP_NULL) else: if w._tds._tds._login.bytes_to_unicode: val = tds_base.force_unicode(val) if isinstance(val, str): val, _ = self._codec.encode(val) # Putting the actual length here causes an error when bulk inserting: # # While reading current row from host, a premature end-of-message # was encountered--an incoming data stream was interrupted when # the server expected to see more data. The host program may have # terminated. Ensure that you are using a supported client # application programming interface (API). # # See https://github.com/tediousjs/tedious/issues/197 # It is not known why this happens, but Microsoft's bcp tool # uses PLP_UNKNOWN for varchar(max) as well. w.put_uint8(tds_base.PLP_UNKNOWN) if len(val) > 0: w.put_uint(len(val)) w.write(val) w.put_uint(0) def read(self, r): login = r._session._tds._login r = PlpReader(r) if r.is_null(): return None if self._chunk_handler is None: if login.bytes_to_unicode: self._chunk_handler = _DefaultChunkedHandler(StringIO()) else: self._chunk_handler = _DefaultChunkedHandler(BytesIO()) if login.bytes_to_unicode: for chunk in tds_base.iterdecode(r.chunks(), self._codec): self._chunk_handler.add_chunk(chunk) else: for chunk in r.chunks(): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class NVarChar70Serializer(BaseTypeSerializer): type = tds_base.XSYBNVARCHAR def __init__(self, size, collation=raw_collation): super(NVarChar70Serializer, self).__init__(size=size) self._collation = collation @classmethod def from_stream(cls, r): size = r.get_usmallint() return cls(size / 2) def write_info(self, w): w.put_usmallint(self.size * 2) def write(self, w, val): if val is None: w.put_usmallint(0xFFFF) else: if isinstance(val, bytes): val = tds_base.force_unicode(val) buf, _ = ucs2_codec.encode(val) length = len(buf) w.put_usmallint(length) w.write(buf) def read(self, r): size = r.get_usmallint() if size == 0xFFFF: return None return r.read_str(size, ucs2_codec) class NVarChar71Serializer(NVarChar70Serializer): @classmethod def from_stream(cls, r): size = r.get_usmallint() collation = r.get_collation() return cls(size / 2, collation) def write_info(self, w): super(NVarChar71Serializer, self).write_info(w) w.put_collation(self._collation) class NVarChar72Serializer(NVarChar71Serializer): @classmethod def from_stream(cls, r): size = r.get_usmallint() collation = r.get_collation() if size == 0xFFFF: return NVarCharMaxSerializer(collation=collation) return cls(size / 2, collation=collation) class NVarCharMaxSerializer(NVarChar72Serializer): def __init__(self, collation=raw_collation): super(NVarCharMaxSerializer, self).__init__(size=-1, collation=collation) self._chunk_handler = _DefaultChunkedHandler(StringIO()) def __repr__(self): return "NVarCharMax(s={},c={})".format(self.size, repr(self._collation)) def get_typeid(self): return tds_base.SYBNTEXT def write_info(self, w): w.put_usmallint(tds_base.PLP_MARKER) w.put_collation(self._collation) def write(self, w, val): if val is None: w.put_uint8(tds_base.PLP_NULL) else: if isinstance(val, bytes): val = tds_base.force_unicode(val) val, _ = ucs2_codec.encode(val) # Putting the actual length here causes an error when bulk inserting: # # While reading current row from host, a premature end-of-message # was encountered--an incoming data stream was interrupted when # the server expected to see more data. The host program may have # terminated. Ensure that you are using a supported client # application programming interface (API). # # See https://github.com/tediousjs/tedious/issues/197 # It is not known why this happens, but Microsoft's bcp tool # uses PLP_UNKNOWN for nvarchar(max) as well. w.put_uint8(tds_base.PLP_UNKNOWN) if len(val) > 0: w.put_uint(len(val)) w.write(val) w.put_uint(0) def read(self, r): r = PlpReader(r) if r.is_null(): return None for chunk in tds_base.iterdecode(r.chunks(), ucs2_codec): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class XmlSerializer(NVarCharMaxSerializer): type = tds_base.SYBMSXML declaration = "XML" def __init__(self, schema=None): super(XmlSerializer, self).__init__(0) self._schema = schema or {} def __repr__(self): return "XmlSerializer(schema={})".format(repr(self._schema)) def get_typeid(self): return self.type @classmethod def from_stream(cls, r): has_schema = r.get_byte() schema = {} if has_schema: schema["dbname"] = r.read_ucs2(r.get_byte()) schema["owner"] = r.read_ucs2(r.get_byte()) schema["collection"] = r.read_ucs2(r.get_smallint()) return cls(schema) def write_info(self, w): if self._schema: w.put_byte(1) w.put_byte(len(self._schema["dbname"])) w.write_ucs2(self._schema["dbname"]) w.put_byte(len(self._schema["owner"])) w.write_ucs2(self._schema["owner"]) w.put_usmallint(len(self._schema["collection"])) w.write_ucs2(self._schema["collection"]) else: w.put_byte(0) class Text70Serializer(BaseTypeSerializer): type = tds_base.SYBTEXT declaration = "TEXT" def __init__(self, size=0, table_name="", collation=raw_collation, codec=None): super(Text70Serializer, self).__init__(size=size) self._table_name = table_name self._collation = collation if codec: self._codec = codec else: self._codec = collation.get_codec() self._chunk_handler = None def __repr__(self): return "Text70(size={},table_name={},codec={})".format( self.size, self._table_name, self._codec ) @classmethod def from_stream(cls, r): size = r.get_int() table_name = r.read_ucs2(r.get_smallint()) return cls(size, table_name, codec=r.session.conn.server_codec) def write_info(self, w): w.put_int(self.size) def write(self, w, val): if val is None: w.put_int(-1) else: if w._tds._tds._login.bytes_to_unicode: val = tds_base.force_unicode(val) if isinstance(val, str): val, _ = self._codec.encode(val) w.put_int(len(val)) w.write(val) def read(self, r): size = r.get_byte() if size == 0: return None tds_base.readall(r, size) # textptr tds_base.readall(r, 8) # timestamp colsize = r.get_int() if self._chunk_handler is None: if r._session._tds._login.bytes_to_unicode: self._chunk_handler = _DefaultChunkedHandler(StringIO()) else: self._chunk_handler = _DefaultChunkedHandler(BytesIO()) if r._session._tds._login.bytes_to_unicode: for chunk in tds_base.iterdecode(read_chunks(r, colsize), self._codec): self._chunk_handler.add_chunk(chunk) else: for chunk in read_chunks(r, colsize): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class Text71Serializer(Text70Serializer): def __repr__(self): return "Text71(size={}, table_name={}, collation={})".format( self.size, self._table_name, repr(self._collation) ) @classmethod def from_stream(cls, r): size = r.get_int() collation = r.get_collation() table_name = r.read_ucs2(r.get_smallint()) return cls(size, table_name, collation) def write_info(self, w): w.put_int(self.size) w.put_collation(self._collation) class Text72Serializer(Text71Serializer): def __init__(self, size=0, table_name_parts=(), collation=raw_collation): super(Text72Serializer, self).__init__( size=size, table_name=".".join(table_name_parts), collation=collation ) self._table_name_parts = table_name_parts @classmethod def from_stream(cls, r): size = r.get_int() collation = r.get_collation() num_parts = r.get_byte() parts = [] for _ in range(num_parts): parts.append(r.read_ucs2(r.get_smallint())) return cls(size, parts, collation) class NText70Serializer(BaseTypeSerializer): type = tds_base.SYBNTEXT declaration = "NTEXT" def __init__(self, size=0, table_name="", collation=raw_collation): super(NText70Serializer, self).__init__(size=size) self._collation = collation self._table_name = table_name self._chunk_handler = _DefaultChunkedHandler(StringIO()) def __repr__(self): return "NText70(size={}, table_name={})".format(self.size, self._table_name) @classmethod def from_stream(cls, r): size = r.get_int() table_name = r.read_ucs2(r.get_smallint()) return cls(size, table_name) def read(self, r): textptr_size = r.get_byte() if textptr_size == 0: return None tds_base.readall(r, textptr_size) # textptr tds_base.readall(r, 8) # timestamp colsize = r.get_int() for chunk in tds_base.iterdecode(read_chunks(r, colsize), ucs2_codec): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() def write_info(self, w): w.put_int(self.size * 2) def write(self, w, val): if val is None: w.put_int(-1) else: w.put_int(len(val) * 2) w.write_ucs2(val) def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class NText71Serializer(NText70Serializer): def __repr__(self): return "NText71(size={}, table_name={}, collation={})".format( self.size, self._table_name, repr(self._collation) ) @classmethod def from_stream(cls, r): size = r.get_int() collation = r.get_collation() table_name = r.read_ucs2(r.get_smallint()) return cls(size, table_name, collation) def write_info(self, w): w.put_int(self.size) w.put_collation(self._collation) class NText72Serializer(NText71Serializer): def __init__(self, size=0, table_name_parts=(), collation=raw_collation): super(NText72Serializer, self).__init__(size=size, collation=collation) self._table_name_parts = table_name_parts def __repr__(self): return "NText72Serializer(s={},table_name={},coll={})".format( self.size, self._table_name_parts, self._collation ) @classmethod def from_stream(cls, r): size = r.get_int() collation = r.get_collation() num_parts = r.get_byte() parts = [] for _ in range(num_parts): parts.append(r.read_ucs2(r.get_smallint())) return cls(size, parts, collation) class Binary(bytes, SqlValueMetaclass): def __repr__(self): return "Binary({0})".format(super(Binary, self).__repr__()) class VarBinarySerializer(BaseTypeSerializer): type = tds_base.XSYBVARBINARY def __init__(self, size): super(VarBinarySerializer, self).__init__(size=size) def __repr__(self): return "VarBinary({})".format(self.size) @classmethod def from_stream(cls, r): size = r.get_usmallint() return cls(size) def write_info(self, w): w.put_usmallint(self.size) def write(self, w, val): if val is None: w.put_usmallint(0xFFFF) else: w.put_usmallint(len(val)) w.write(val) def read(self, r): size = r.get_usmallint() if size == 0xFFFF: return None return tds_base.readall(r, size) class VarBinarySerializer72(VarBinarySerializer): def __repr__(self): return "VarBinary72({})".format(self.size) @classmethod def from_stream(cls, r): size = r.get_usmallint() if size == 0xFFFF: return VarBinarySerializerMax() return cls(size) class VarBinarySerializerMax(VarBinarySerializer): def __init__(self): super(VarBinarySerializerMax, self).__init__(0) self._chunk_handler = _DefaultChunkedHandler(BytesIO()) def __repr__(self): return "VarBinaryMax()" def write_info(self, w): w.put_usmallint(tds_base.PLP_MARKER) def write(self, w, val): if val is None: w.put_uint8(tds_base.PLP_NULL) else: w.put_uint8(len(val)) if val: w.put_uint(len(val)) w.write(val) w.put_uint(0) def read(self, r): r = PlpReader(r) if r.is_null(): return None for chunk in r.chunks(): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class UDT72Serializer(BaseTypeSerializer): # Data type definition stream used for UDT_INFO in TYPE_INFO # https://msdn.microsoft.com/en-us/library/a57df60e-d0a6-4e7e-a2e5-ccacd277c673/ def __init__( self, max_byte_size, db_name, schema_name, type_name, assembly_qualified_name ): self.max_byte_size = max_byte_size self.db_name = db_name self.schema_name = schema_name self.type_name = type_name self.assembly_qualified_name = assembly_qualified_name super(UDT72Serializer, self).__init__() def __repr__(self): return ( "UDT72Serializer(max_byte_size={}, db_name={}, " "schema_name={}, type_name={}, " "assembly_qualified_name={})".format( *map( repr, ( self.max_byte_size, self.db_name, self.schema_name, self.type_name, self.assembly_qualified_name, ), ) ) ) @classmethod def from_stream(cls, r): # MAX_BYTE_SIZE max_byte_size = r.get_usmallint() assert max_byte_size == 0xFFFF or 1 < max_byte_size < 8000 # DB_NAME -- B_VARCHAR db_name = r.read_ucs2(r.get_byte()) # SCHEMA_NAME -- B_VARCHAR schema_name = r.read_ucs2(r.get_byte()) # TYPE_NAME -- B_VARCHAR type_name = r.read_ucs2(r.get_byte()) # UDT_METADATA -- # a US_VARCHAR (2 bytes length prefix) # containing ASSEMBLY_QUALIFIED_NAME assembly_qualified_name = r.read_ucs2(r.get_smallint()) return cls( max_byte_size, db_name, schema_name, type_name, assembly_qualified_name ) def read(self, r): r = PlpReader(r) if r.is_null(): return None return b"".join(r.chunks()) class UDT72SerializerMax(UDT72Serializer): def __init__(self, *args, **kwargs): super(UDT72SerializerMax, self).__init__(0, *args, **kwargs) class Image70Serializer(BaseTypeSerializer): type = tds_base.SYBIMAGE declaration = "IMAGE" def __init__(self, size=0, table_name=""): super(Image70Serializer, self).__init__(size=size) self._table_name = table_name self._chunk_handler = _DefaultChunkedHandler(BytesIO()) def __repr__(self): return "Image70(tn={},s={})".format(repr(self._table_name), self.size) @classmethod def from_stream(cls, r): size = r.get_int() table_name = r.read_ucs2(r.get_smallint()) return cls(size, table_name) def read(self, r): size = r.get_byte() if size == 16: # Jeff's hack tds_base.readall(r, 16) # textptr tds_base.readall(r, 8) # timestamp colsize = r.get_int() for chunk in read_chunks(r, colsize): self._chunk_handler.add_chunk(chunk) return self._chunk_handler.end() else: return None def write(self, w, val): if val is None: w.put_int(-1) return w.put_int(len(val)) w.write(val) def write_info(self, w): w.put_int(self.size) def set_chunk_handler(self, chunk_handler): self._chunk_handler = chunk_handler class Image72Serializer(Image70Serializer): def __init__(self, size=0, parts=()): super(Image72Serializer, self).__init__(size=size, table_name=".".join(parts)) self._parts = parts def __repr__(self): return "Image72(p={},s={})".format(self._parts, self.size) @classmethod def from_stream(cls, r): size = r.get_int() num_parts = r.get_byte() parts = [] for _ in range(num_parts): parts.append(r.read_ucs2(r.get_usmallint())) return Image72Serializer(size, parts) _datetime_base_date = datetime.datetime(1900, 1, 1) class SmallDateTimeType(SqlTypeMetaclass): def get_declaration(self): return "SMALLDATETIME" class DateTimeType(SqlTypeMetaclass): def get_declaration(self): return "DATETIME" class SmallDateTime(SqlValueMetaclass): """Corresponds to MSSQL smalldatetime""" def __init__(self, days, minutes): """ @param days: Days since 1900-01-01 @param minutes: Minutes since 00:00:00 """ self._days = days self._minutes = minutes @property def days(self): return self._days @property def minutes(self): return self._minutes def to_pydatetime(self): return _datetime_base_date + datetime.timedelta( days=self._days, minutes=self._minutes ) @classmethod def from_pydatetime(cls, dt): days = (dt - _datetime_base_date).days minutes = dt.hour * 60 + dt.minute return cls(days=days, minutes=minutes) class BaseDateTimeSerializer(BaseTypeSerializer): def write(self, w, value): raise NotImplementedError def write_info(self, w): raise NotImplementedError def read(self, r): raise NotImplementedError @classmethod def from_stream(cls, r): raise NotImplementedError class SmallDateTimeSerializer(BasePrimitiveTypeSerializer, BaseDateTimeSerializer): type = tds_base.SYBDATETIME4 declaration = "SMALLDATETIME" _struct = struct.Struct(" 0 assure no # core if for some bug it's 0... # 1, 5, 5, 5, 5, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 13, 13, 13, 13, 13, 13, 13, 13, 13, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, ] _info_struct = struct.Struct("BBB") def __init__(self, precision=18, scale=0): super(MsDecimalSerializer, self).__init__( precision=precision, scale=scale, size=self._bytes_per_prec[precision] ) if precision > 38: raise tds_base.DataError("Precision of decimal value is out of range") def __repr__(self): return "MsDecimal(scale={}, prec={})".format(self.scale, self.precision) @classmethod def from_value(cls, value): sql_type = DecimalType.from_value(value) return cls(scale=sql_type.scale, prec=sql_type.precision) @classmethod def from_stream(cls, r): size, prec, scale = r.unpack(cls._info_struct) return cls(scale=scale, precision=prec) def write_info(self, w): w.pack(self._info_struct, self.size, self.precision, self.scale) def write(self, w, value): with decimal.localcontext() as context: context.prec = 38 if value is None: w.put_byte(0) return if not isinstance(value, decimal.Decimal): value = decimal.Decimal(value) value = value.normalize() scale = self.scale size = self.size w.put_byte(size) val = value positive = 1 if val > 0 else 0 w.put_byte(positive) # sign if not positive: val *= -1 size -= 1 val *= 10**scale for i in range(size): w.put_byte(int(val % 256)) val //= 256 assert val == 0 def _decode(self, positive, buf): val = _decode_num(buf) val = decimal.Decimal(val) with decimal.localcontext() as ctx: ctx.prec = 38 if not positive: val *= -1 val /= 10**self._scale return val def read_fixed(self, r, size): positive = r.get_byte() buf = tds_base.readall(r, size - 1) return self._decode(positive, buf) def read(self, r): size = r.get_byte() if size <= 0: return None return self.read_fixed(r, size) class Money4Serializer(BasePrimitiveTypeSerializer): type = tds_base.SYBMONEY4 declaration = "SMALLMONEY" def read(self, r): return decimal.Decimal(r.get_int()) / 10000 def write(self, w, val): val = int(val * 10000) w.put_int(val) Money4Serializer.instance = money4_serializer = Money4Serializer() class Money8Serializer(BasePrimitiveTypeSerializer): type = tds_base.SYBMONEY declaration = "MONEY" _struct = struct.Struct(" 128: raise ValueError( "Schema part of TVP name should be no longer than 128 characters" ) if len(typ_name) > 128: raise ValueError( "Name part of TVP name should be no longer than 128 characters" ) if columns is not None: if len(columns) > 1024: raise ValueError("TVP cannot have more than 1024 columns") if len(columns) < 1: raise ValueError("TVP must have at least one column") self._typ_dbname = ( "" # dbname should always be empty string for TVP according to spec ) self._typ_schema = typ_schema self._typ_name = typ_name self._columns = columns def __repr__(self): return "TableType(s={},n={},cols={})".format( self._typ_schema, self._typ_name, repr(self._columns) ) def get_declaration(self): assert not self._typ_dbname if self._typ_schema: full_name = "{}.{}".format(self._typ_schema, self._typ_name) else: full_name = self._typ_name return "{} READONLY".format(full_name) @property def typ_schema(self): return self._typ_schema @property def typ_name(self): return self._typ_name @property def columns(self): return self._columns class TableValuedParam(SqlValueMetaclass): """ Used to represent a value of table-valued parameter """ def __init__(self, type_name=None, columns=None, rows=None): # parsing type name self._typ_schema = "" self._typ_name = "" if type_name: parts = type_name.split(".") if len(parts) > 2: raise ValueError( "Type name should consist of at most 2 parts, e.g. dbo.MyType" ) self._typ_name = parts[-1] if len(parts) > 1: self._typ_schema = parts[0] self._columns = columns self._rows = rows @property def typ_name(self): return self._typ_name @property def typ_schema(self): return self._typ_schema @property def columns(self): return self._columns @property def rows(self): return self._rows def is_null(self): return self._rows is None def peek_row(self): try: rows = iter(self._rows) except TypeError: raise tds_base.DataError("rows should be iterable") try: row = next(rows) except StopIteration: # no rows raise tds_base.DataError( "Cannot infer columns from rows for TVP because there are no rows" ) else: # put row back self._rows = itertools.chain([row], rows) return row class TableSerializer(BaseTypeSerializer): """ Used to serialize table valued parameters spec: https://msdn.microsoft.com/en-us/library/dd304813.aspx """ type = tds_base.TVPTYPE def read(self, r): """According to spec TDS does not support output TVP values""" raise NotImplementedError @classmethod def from_stream(cls, r): """According to spec TDS does not support output TVP values""" raise NotImplementedError def __init__(self, table_type, columns_serializers): super(TableSerializer, self).__init__() self._table_type = table_type self._columns_serializers = columns_serializers @property def table_type(self): return self._table_type def __repr__(self): return "TableSerializer(t={},c={})".format( repr(self._table_type), repr(self._columns_serializers) ) def write_info(self, w): """ Writes TVP_TYPENAME structure spec: https://msdn.microsoft.com/en-us/library/dd302994.aspx @param w: TdsWriter @return: """ w.write_b_varchar("") # db_name, should be empty w.write_b_varchar(self._table_type.typ_schema) w.write_b_varchar(self._table_type.typ_name) def write(self, w, val): """ Writes remaining part of TVP_TYPE_INFO structure, resuming from TVP_COLMETADATA specs: https://msdn.microsoft.com/en-us/library/dd302994.aspx https://msdn.microsoft.com/en-us/library/dd305261.aspx https://msdn.microsoft.com/en-us/library/dd303230.aspx @param w: TdsWriter @param val: TableValuedParam or None @return: """ if val.is_null(): w.put_usmallint(tds_base.TVP_NULL_TOKEN) else: columns = self._table_type.columns w.put_usmallint(len(columns)) for i, column in enumerate(columns): w.put_uint(column.column_usertype) w.put_usmallint(column.flags) # TYPE_INFO structure: https://msdn.microsoft.com/en-us/library/dd358284.aspx serializer = self._columns_serializers[i] type_id = serializer.type w.put_byte(type_id) serializer.write_info(w) w.write_b_varchar("") # ColName, must be empty in TVP according to spec # here can optionally send TVP_ORDER_UNIQUE and TVP_COLUMN_ORDERING # https://msdn.microsoft.com/en-us/library/dd305261.aspx # terminating optional metadata w.put_byte(tds_base.TVP_END_TOKEN) # now sending rows using TVP_ROW # https://msdn.microsoft.com/en-us/library/dd305261.aspx if val.rows: for row in val.rows: w.put_byte(tds_base.TVP_ROW_TOKEN) for i, col in enumerate(self._table_type.columns): if not col.flags & tds_base.TVP_COLUMN_DEFAULT_FLAG: self._columns_serializers[i].write(w, row[i]) # terminating rows w.put_byte(tds_base.TVP_END_TOKEN) _type_map = { tds_base.SYBINT1: TinyIntSerializer, tds_base.SYBINT2: SmallIntSerializer, tds_base.SYBINT4: IntSerializer, tds_base.SYBINT8: BigIntSerializer, tds_base.SYBINTN: IntNSerializer, tds_base.SYBBIT: BitSerializer, tds_base.SYBBITN: BitNSerializer, tds_base.SYBREAL: RealSerializer, tds_base.SYBFLT8: FloatSerializer, tds_base.SYBFLTN: FloatNSerializer, tds_base.SYBMONEY4: Money4Serializer, tds_base.SYBMONEY: Money8Serializer, tds_base.SYBMONEYN: MoneyNSerializer, tds_base.XSYBCHAR: VarChar70Serializer, tds_base.XSYBVARCHAR: VarChar70Serializer, tds_base.XSYBNCHAR: NVarChar70Serializer, tds_base.XSYBNVARCHAR: NVarChar70Serializer, tds_base.SYBTEXT: Text70Serializer, tds_base.SYBNTEXT: NText70Serializer, tds_base.SYBMSXML: XmlSerializer, tds_base.XSYBBINARY: VarBinarySerializer, tds_base.XSYBVARBINARY: VarBinarySerializer, tds_base.SYBIMAGE: Image70Serializer, tds_base.SYBNUMERIC: MsDecimalSerializer, tds_base.SYBDECIMAL: MsDecimalSerializer, tds_base.SYBVARIANT: VariantSerializer, tds_base.SYBMSDATE: MsDateSerializer, tds_base.SYBMSTIME: MsTimeSerializer, tds_base.SYBMSDATETIME2: DateTime2Serializer, tds_base.SYBMSDATETIMEOFFSET: DateTimeOffsetSerializer, tds_base.SYBDATETIME4: SmallDateTimeSerializer, tds_base.SYBDATETIME: DateTimeSerializer, tds_base.SYBDATETIMN: DateTimeNSerializer, tds_base.SYBUNIQUE: MsUniqueSerializer, } _type_map71 = _type_map.copy() _type_map71.update( { tds_base.XSYBCHAR: VarChar71Serializer, tds_base.XSYBNCHAR: NVarChar71Serializer, tds_base.XSYBVARCHAR: VarChar71Serializer, tds_base.XSYBNVARCHAR: NVarChar71Serializer, tds_base.SYBTEXT: Text71Serializer, tds_base.SYBNTEXT: NText71Serializer, } ) _type_map72 = _type_map.copy() _type_map72.update( { tds_base.XSYBCHAR: VarChar72Serializer, tds_base.XSYBNCHAR: NVarChar72Serializer, tds_base.XSYBVARCHAR: VarChar72Serializer, tds_base.XSYBNVARCHAR: NVarChar72Serializer, tds_base.SYBTEXT: Text72Serializer, tds_base.SYBNTEXT: NText72Serializer, tds_base.XSYBBINARY: VarBinarySerializer72, tds_base.XSYBVARBINARY: VarBinarySerializer72, tds_base.SYBIMAGE: Image72Serializer, tds_base.UDTTYPE: UDT72Serializer, } ) _type_map73 = _type_map72.copy() _type_map73.update( { tds_base.TVPTYPE: TableSerializer, } ) def sql_type_by_declaration(declaration): return _declarations_parser.parse(declaration) class SerializerFactory(object): """ Factory class for TDS data types """ def __init__(self, tds_ver): self._tds_ver = tds_ver if self._tds_ver >= tds_base.TDS73: self._type_map = _type_map73 elif self._tds_ver >= tds_base.TDS72: self._type_map = _type_map72 elif self._tds_ver >= tds_base.TDS71: self._type_map = _type_map71 else: self._type_map = _type_map def get_type_serializer(self, tds_type_id): type_class = self._type_map.get(tds_type_id) if not type_class: raise tds_base.InterfaceError("Invalid type id {}".format(tds_type_id)) return type_class def long_binary_type(self): if self._tds_ver >= tds_base.TDS72: return VarBinaryMaxType() else: return ImageType() def long_varchar_type(self): if self._tds_ver >= tds_base.TDS72: return VarCharMaxType() else: return TextType() def long_string_type(self): if self._tds_ver >= tds_base.TDS72: return NVarCharMaxType() else: return NTextType() def datetime(self, precision): if self._tds_ver >= tds_base.TDS72: return DateTime2Type(precision=precision) else: return DateTimeType() def has_datetime_with_tz(self): return self._tds_ver >= tds_base.TDS72 def datetime_with_tz(self, precision): if self._tds_ver >= tds_base.TDS72: return DateTimeOffsetType(precision=precision) else: raise tds_base.DataError( "Given TDS version does not support DATETIMEOFFSET type" ) def date(self): if self._tds_ver >= tds_base.TDS72: return DateType() else: return DateTimeType() def time(self, precision): if self._tds_ver >= tds_base.TDS72: return TimeType(precision=precision) else: raise tds_base.DataError("Given TDS version does not support TIME type") def serializer_by_declaration(self, declaration, connection): sql_type = sql_type_by_declaration(declaration) return self.serializer_by_type( sql_type=sql_type, collation=connection.collation ) def serializer_by_type(self, sql_type, collation=raw_collation): typ = sql_type if isinstance(typ, BitType): return BitNSerializer(typ) elif isinstance(typ, TinyIntType): return IntNSerializer(typ) elif isinstance(typ, SmallIntType): return IntNSerializer(typ) elif isinstance(typ, IntType): return IntNSerializer(typ) elif isinstance(typ, BigIntType): return IntNSerializer(typ) elif isinstance(typ, RealType): return FloatNSerializer(size=4) elif isinstance(typ, FloatType): return FloatNSerializer(size=8) elif isinstance(typ, SmallMoneyType): return self._type_map[tds_base.SYBMONEYN](size=4) elif isinstance(typ, MoneyType): return self._type_map[tds_base.SYBMONEYN](size=8) elif isinstance(typ, CharType): return self._type_map[tds_base.XSYBCHAR](size=typ.size, collation=collation) elif isinstance(typ, VarCharType): return self._type_map[tds_base.XSYBVARCHAR]( size=typ.size, collation=collation ) elif isinstance(typ, VarCharMaxType): return VarCharMaxSerializer(collation=collation) elif isinstance(typ, NCharType): return self._type_map[tds_base.XSYBNCHAR]( size=typ.size, collation=collation ) elif isinstance(typ, NVarCharType): return self._type_map[tds_base.XSYBNVARCHAR]( size=typ.size, collation=collation ) elif isinstance(typ, NVarCharMaxType): return NVarCharMaxSerializer(collation=collation) elif isinstance(typ, TextType): return self._type_map[tds_base.SYBTEXT](collation=collation) elif isinstance(typ, NTextType): return self._type_map[tds_base.SYBNTEXT](collation=collation) elif isinstance(typ, XmlType): return self._type_map[tds_base.SYBMSXML]() elif isinstance(typ, BinaryType): return self._type_map[tds_base.XSYBBINARY]() elif isinstance(typ, VarBinaryType): return self._type_map[tds_base.XSYBVARBINARY](size=typ.size) elif isinstance(typ, VarBinaryMaxType): return VarBinarySerializerMax() elif isinstance(typ, ImageType): return self._type_map[tds_base.SYBIMAGE]() elif isinstance(typ, DecimalType): return self._type_map[tds_base.SYBDECIMAL]( scale=typ.scale, precision=typ.precision ) elif isinstance(typ, VariantType): return self._type_map[tds_base.SYBVARIANT](size=0) elif isinstance(typ, SmallDateTimeType): return self._type_map[tds_base.SYBDATETIMN](size=4) elif isinstance(typ, DateTimeType): return self._type_map[tds_base.SYBDATETIMN](size=8) elif isinstance(typ, DateType): return self._type_map[tds_base.SYBMSDATE](typ) elif isinstance(typ, TimeType): return self._type_map[tds_base.SYBMSTIME](typ) elif isinstance(typ, DateTime2Type): return self._type_map[tds_base.SYBMSDATETIME2](typ) elif isinstance(typ, DateTimeOffsetType): return self._type_map[tds_base.SYBMSDATETIMEOFFSET](typ) elif isinstance(typ, UniqueIdentifierType): return self._type_map[tds_base.SYBUNIQUE]() elif isinstance(typ, TableType): columns_serializers = None if typ.columns is not None: columns_serializers = [ self.serializer_by_type(col.type) for col in typ.columns ] return TableSerializer( table_type=typ, columns_serializers=columns_serializers ) else: raise ValueError("Cannot map type {} to serializer.".format(typ)) class DeclarationsParser(object): def __init__(self): declaration_parsers = [ ("bit", BitType), ("tinyint", TinyIntType), ("smallint", SmallIntType), ("(?:int|integer)", IntType), ("bigint", BigIntType), ("real", RealType), ("(?:float|double precision)", FloatType), ("(?:char|character)", CharType), ( r"(?:char|character)\((\d+)\)", lambda size_str: CharType(size=int(size_str)), ), (r"(?:varchar|char(?:|acter)\s+varying)", VarCharType), ( r"(?:varchar|char(?:|acter)\s+varying)\((\d+)\)", lambda size_str: VarCharType(size=int(size_str)), ), (r"varchar\(max\)", VarCharMaxType), (r"(?:nchar|national\s+(?:char|character))", NCharType), ( r"(?:nchar|national\s+(?:char|character))\((\d+)\)", lambda size_str: NCharType(size=int(size_str)), ), (r"(?:nvarchar|national\s+(?:char|character)\s+varying)", NVarCharType), ( r"(?:nvarchar|national\s+(?:char|character)\s+varying)\((\d+)\)", lambda size_str: NVarCharType(size=int(size_str)), ), (r"nvarchar\(max\)", NVarCharMaxType), ("xml", XmlType), ("text", TextType), (r"(?:ntext|national\s+text)", NTextType), ("binary", BinaryType), (r"binary\((\d+)\)", lambda size_str: BinaryType(size=int(size_str))), ("(?:varbinary|binary varying)", VarBinaryType), ( r"(?:varbinary|binary varying)\((\d+)\)", lambda size_str: VarBinaryType(size=int(size_str)), ), (r"varbinary\(max\)", VarBinaryMaxType), ("image", ImageType), ("smalldatetime", SmallDateTimeType), ("datetime", DateTimeType), ("date", DateType), (r"time", TimeType), ( r"time\((\d+)\)", lambda precision_str: TimeType(precision=int(precision_str)), ), ("datetime2", DateTime2Type), ( r"datetime2\((\d+)\)", lambda precision_str: DateTime2Type(precision=int(precision_str)), ), ("datetimeoffset", DateTimeOffsetType), ( r"datetimeoffset\((\d+)\)", lambda precision_str: DateTimeOffsetType(precision=int(precision_str)), ), ("(?:decimal|dec|numeric)", DecimalType), ( r"(?:decimal|dec|numeric)\((\d+)\)", lambda precision_str: DecimalType(precision=int(precision_str)), ), ( r"(?:decimal|dec|numeric)\((\d+), ?(\d+)\)", lambda precision_str, scale_str: DecimalType( precision=int(precision_str), scale=int(scale_str) ), ), ("smallmoney", SmallMoneyType), ("money", MoneyType), ("uniqueidentifier", UniqueIdentifierType), ("sql_variant", VariantType), ] self._compiled = [ (re.compile(r"^" + regex + "$", re.IGNORECASE), constructor) for regex, constructor in declaration_parsers ] def parse(self, declaration): """ Parse sql type declaration, e.g. varchar(10) and return instance of corresponding type class, e.g. VarCharType(10) @param declaration: Sql declaration to parse, e.g. varchar(10) @return: instance of SqlTypeMetaclass """ declaration = declaration.strip() for regex, constructor in self._compiled: m = regex.match(declaration) if m: return constructor(*m.groups()) raise ValueError("Unable to parse type declaration", declaration) _declarations_parser = DeclarationsParser() class TdsTypeInferrer(object): def __init__( self, type_factory, collation=None, bytes_to_unicode=False, allow_tz=False ): """ Class used to do TDS type inference :param type_factory: Instance of TypeFactory :param collation: Collation to use for strings :param bytes_to_unicode: Treat bytes type as unicode string :param allow_tz: Allow usage of DATETIMEOFFSET type """ self._type_factory = type_factory self._collation = collation self._bytes_to_unicode = bytes_to_unicode self._allow_tz = allow_tz def from_value(self, value): """Function infers TDS type from Python value. :param value: value from which to infer TDS type :return: An instance of subclass of :class:`BaseType` """ if value is None: sql_type = NVarCharType(size=1) else: sql_type = self._from_class_value(value, type(value)) return sql_type def from_class(self, cls): """Function infers TDS type from Python class. :param cls: Class from which to infer type :return: An instance of subclass of :class:`BaseType` """ return self._from_class_value(None, cls) def _from_class_value(self, value, value_type): type_factory = self._type_factory bytes_to_unicode = self._bytes_to_unicode allow_tz = self._allow_tz if issubclass(value_type, bool): return BitType() elif issubclass(value_type, int): if value is None: return IntType() if -(2**31) <= value <= 2**31 - 1: return IntType() elif -(2**63) <= value <= 2**63 - 1: return BigIntType() elif -(10**38) + 1 <= value <= 10**38 - 1: return DecimalType(precision=38) else: return VarCharMaxType() elif issubclass(value_type, float): return FloatType() elif issubclass(value_type, Binary): if value is None or len(value) <= 8000: return VarBinaryType(size=8000) else: return type_factory.long_binary_type() elif issubclass(value_type, bytes): if bytes_to_unicode: return type_factory.long_string_type() else: return type_factory.long_varchar_type() elif issubclass(value_type, str): return type_factory.long_string_type() elif issubclass(value_type, datetime.datetime): if value and value.tzinfo and allow_tz: return type_factory.datetime_with_tz(precision=6) else: return type_factory.datetime(precision=6) elif issubclass(value_type, datetime.date): return type_factory.date() elif issubclass(value_type, datetime.time): return type_factory.time(precision=6) elif issubclass(value_type, decimal.Decimal): if value is None: return DecimalType() else: return DecimalType.from_value(value) elif issubclass(value_type, uuid.UUID): return UniqueIdentifierType() elif issubclass(value_type, TableValuedParam): columns = value.columns rows = value.rows if columns is None: # trying to auto detect columns using data from first row if rows is None: # rows are not present too, this means # entire tvp has value of NULL pass else: # use first row to infer types of columns row = value.peek_row() columns = [] try: cell_iter = iter(row) except TypeError: raise tds_base.DataError( "Each row in table should be an iterable" ) for cell in cell_iter: if isinstance(cell, TableValuedParam): raise tds_base.DataError( "TVP type cannot have nested TVP types" ) col_type = self.from_value(cell) col = tds_base.Column(type=col_type) columns.append(col) return TableType( typ_schema=value.typ_schema, typ_name=value.typ_name, columns=columns ) else: raise tds_base.DataError( "Cannot infer TDS type from Python value: {!r}".format(value) ) pytds-1.15.0/src/pytds/tds_writer.py000066400000000000000000000115661456567501500174620ustar00rootroot00000000000000""" This module implements TdsWriter class """ import struct from pytds import tds_base from pytds.collate import Collation, ucs2_codec from pytds.tds_base import ( _uint_be, _byte, _smallint_le, _usmallint_le, _usmallint_be, _int_le, _uint_le, _int8_le, _uint8_le, _header, ) class _TdsWriter: """TDS stream writer Handles splitting of incoming data into TDS packets according to TDS protocol. Provides convinience methods for writing primitive data types. """ def __init__( self, transport: tds_base.TransportProtocol, bufsize: int, tds_session ): self._transport = transport self._tds = tds_session self._pos = 0 self._buf = bytearray(bufsize) self._packet_no = 0 self._type = 0 @property def session(self): return self._tds @property def bufsize(self) -> int: """Size of the buffer""" return len(self._buf) @bufsize.setter def bufsize(self, bufsize: int) -> None: if len(self._buf) == bufsize: return if bufsize > len(self._buf): self._buf.extend(b"\0" * (bufsize - len(self._buf))) else: self._buf = self._buf[0:bufsize] def begin_packet(self, packet_type: int) -> None: """Starts new packet stream :param packet_type: Type of TDS stream, e.g. TDS_PRELOGIN, TDS_QUERY etc. """ self._type = packet_type self._pos = 8 def pack(self, struc: struct.Struct, *args) -> None: """Packs and writes structure into stream""" self.write(struc.pack(*args)) def put_byte(self, value: int) -> None: """Writes single byte into stream""" self.pack(_byte, value) def put_smallint(self, value: int) -> None: """Writes 16-bit signed integer into the stream""" self.pack(_smallint_le, value) def put_usmallint(self, value: int) -> None: """Writes 16-bit unsigned integer into the stream""" self.pack(_usmallint_le, value) def put_usmallint_be(self, value: int) -> None: """Writes 16-bit unsigned big-endian integer into the stream""" self.pack(_usmallint_be, value) def put_int(self, value: int) -> None: """Writes 32-bit signed integer into the stream""" self.pack(_int_le, value) def put_uint(self, value: int) -> None: """Writes 32-bit unsigned integer into the stream""" self.pack(_uint_le, value) def put_uint_be(self, value: int) -> None: """Writes 32-bit unsigned big-endian integer into the stream""" self.pack(_uint_be, value) def put_int8(self, value: int) -> None: """Writes 64-bit signed integer into the stream""" self.pack(_int8_le, value) def put_uint8(self, value: int) -> None: """Writes 64-bit unsigned integer into the stream""" self.pack(_uint8_le, value) def put_collation(self, collation: Collation) -> None: """Writes :class:`Collation` structure into the stream""" self.write(collation.pack()) def write(self, data: bytes) -> None: """Writes given bytes buffer into the stream Function returns only when entire buffer is written """ data_off = 0 while data_off < len(data): left = len(self._buf) - self._pos if left <= 0: self._write_packet(final=False) else: to_write = min(left, len(data) - data_off) self._buf[self._pos : self._pos + to_write] = data[ data_off : data_off + to_write ] self._pos += to_write data_off += to_write def write_b_varchar(self, s: str) -> None: self.put_byte(len(s)) self.write_ucs2(s) def write_ucs2(self, s: str) -> None: """Write string encoding it in UCS2 into stream""" self.write_string(s, ucs2_codec) def write_string(self, s: str, codec) -> None: """Write string encoding it with codec into stream""" for i in range(0, len(s), self.bufsize): chunk = s[i : i + self.bufsize] buf, consumed = codec.encode(chunk) assert consumed == len(chunk) self.write(buf) def flush(self) -> None: """Closes current packet stream""" return self._write_packet(final=True) def _write_packet(self, final: bool) -> None: """Writes single TDS packet into underlying transport. Data for the packet is taken from internal buffer. :param final: True means this is the final packet in substream. """ status = 1 if final else 0 _header.pack_into( self._buf, 0, self._type, status, self._pos, 0, self._packet_no ) self._packet_no = (self._packet_no + 1) % 256 self._transport.sendall(self._buf[: self._pos]) self._pos = 8 pytds-1.15.0/src/pytds/tls.py000066400000000000000000000165761456567501500161040ustar00rootroot00000000000000from __future__ import annotations import logging from typing import Any import typing try: import OpenSSL.SSL # type: ignore # needs fixing except ImportError: OPENSSL_AVAILABLE = False else: OPENSSL_AVAILABLE = True from . import tds_base BUFSIZE = 65536 logger = logging.getLogger(__name__) if typing.TYPE_CHECKING: from pytds.tds_session import _TdsSession class EncryptedSocket(tds_base.TransportProtocol): def __init__( self, transport: tds_base.TransportProtocol, tls_conn: OpenSSL.SSL.Connection ) -> None: super().__init__() self._transport = transport self._tls_conn = tls_conn def gettimeout(self) -> float | None: return self._transport.gettimeout() def settimeout(self, timeout: float | None) -> None: self._transport.settimeout(timeout) def sendall(self, data: Any, flags: int = 0) -> None: # TLS.Connection does not support bytearrays, need to convert to bytes first if isinstance(data, bytearray): data = bytes(data) self._tls_conn.sendall(data) buf = self._tls_conn.bio_read(BUFSIZE) self._transport.sendall(buf) # def send(self, data): # while True: # try: # return self._tls_conn.send(data) # except OpenSSL.SSL.WantWriteError: # buf = self._tls_conn.bio_read(BUFSIZE) # self._transport.sendall(buf) def recv_into( self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0 ) -> int: if size == 0: size = len(buffer) res = self.recv(size) buffer[0 : len(res)] = res return len(res) def recv(self, bufsize: int, flags: int = 0) -> bytes: while True: try: buf = self._tls_conn.bio_read(bufsize) except OpenSSL.SSL.WantReadError: pass else: self._transport.sendall(buf) try: return self._tls_conn.recv(bufsize) except OpenSSL.SSL.WantReadError: buf = self._transport.recv(BUFSIZE) if buf: self._tls_conn.bio_write(buf) else: return b"" def close(self) -> None: self._tls_conn.shutdown() self._transport.close() def shutdown(self, how: int = 0) -> None: self._tls_conn.shutdown() def verify_cb(conn, cert, err_num, err_depth, ret_code: int) -> bool: return ret_code == 1 def is_san_matching(san: str, host_name: str) -> bool: for item in san.split(','): dnsentry = item.strip().lstrip('DNS:').strip() # SANs are usually have form like: DNS:hostname if dnsentry == host_name: return True if ( dnsentry[0:2] == "*." ): # support for wildcards, but only at the first position afterstar_parts = dnsentry[2:] afterstar_parts_sname = ".".join( host_name.split(".")[1:] ) # remove first part of dns name if afterstar_parts == afterstar_parts_sname: return True return False def validate_host(cert, name: bytes) -> bool: """ Validates host name against certificate @param cert: Certificate returned by host @param name: Actual host name used for connection @return: Returns true if host name matches certificate """ cn = None for t, v in cert.get_subject().get_components(): if t == b"CN": cn = v break if cn == name: return True # checking SAN s_name = name.decode("ascii") for i in range(cert.get_extension_count()): ext = cert.get_extension(i) if ext.get_short_name() == b"subjectAltName": s = str(ext) if is_san_matching(s, s_name): return True # TODO check if wildcard is needed in CN as well return False def create_context(cafile: str) -> OpenSSL.SSL.Context: ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD) ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2) ctx.set_options(OpenSSL.SSL.OP_NO_SSLv3) ctx.set_verify(OpenSSL.SSL.VERIFY_PEER, verify_cb) # print("verify depth:", ctx.get_verify_depth()) # print("verify mode:", ctx.get_verify_mode()) # print("openssl version:", cryptography.hazmat.backends.openssl.backend.openssl_version_text()) ctx.load_verify_locations(cafile=cafile) return ctx # https://msdn.microsoft.com/en-us/library/dd357559.aspx def establish_channel(tds_sock: _TdsSession) -> None: w = tds_sock._writer r = tds_sock._reader login = tds_sock.conn._login bhost = login.server_name.encode("ascii") conn = OpenSSL.SSL.Connection(login.tls_ctx) conn.set_tlsext_host_name(bhost) # change connection to client mode conn.set_connect_state() logger.info("doing TLS handshake") while True: try: logger.debug("calling do_handshake") conn.do_handshake() except OpenSSL.SSL.WantReadError: logger.debug( "got WantReadError, getting data from the write end of the TLS connection buffer" ) try: req = conn.bio_read(BUFSIZE) except OpenSSL.SSL.WantReadError: # PyOpenSSL - https://github.com/pyca/pyopenssl/issues/887 logger.debug("got WantReadError again, waiting for response...") else: logger.debug( "sending %d bytes of the handshake data to the server", len(req) ) w.begin_packet(tds_base.PacketType.PRELOGIN) w.write(req) w.flush() logger.debug("receiving response from the server") resp_meta = r.begin_response() if resp_meta.type != tds_base.PacketType.PRELOGIN: raise tds_base.Error( f"Invalid packet type was received from server, expected PRELOGIN(18) got {resp_meta.type}" ) while not r.stream_finished(): resp = r.recv(4096) logger.debug( "adding %d bytes of the response into the TLS connection buffer", len(resp), ) conn.bio_write(resp) else: logger.info("TLS handshake is complete") if login.validate_host: if not validate_host(cert=conn.get_peer_certificate(), name=bhost): raise tds_base.Error( "Certificate does not match host name '{}'".format( login.server_name ) ) enc_sock = EncryptedSocket(transport=tds_sock.conn.sock, tls_conn=conn) tds_sock.conn.sock = enc_sock tds_sock._writer._transport = enc_sock tds_sock._reader._transport = enc_sock return def revert_to_clear(tds_sock: _TdsSession) -> None: """ Reverts connection back to non-encrypted mode Used when client sent ENCRYPT_OFF flag @param tds_sock: @return: """ enc_conn = tds_sock.conn.sock if isinstance(enc_conn, EncryptedSocket): clear_conn = enc_conn._transport enc_conn.shutdown() tds_sock.conn.sock = clear_conn tds_sock._writer._transport = clear_conn tds_sock._reader._transport = clear_conn pytds-1.15.0/src/pytds/tz.py000066400000000000000000000036411456567501500157240ustar00rootroot00000000000000from __future__ import annotations import datetime import time as _time from datetime import tzinfo, timedelta ZERO = timedelta(0) HOUR = timedelta(hours=1) # A class building tzinfo objects for fixed-offset time zones. # Note that FixedOffset(0, "UTC") is a different way to build a # UTC tzinfo object. class FixedOffsetTimezone(tzinfo): """Fixed offset in minutes east from UTC.""" def __init__(self, offset: float, name: str | None=None) -> None: self.__offset = timedelta(minutes=offset) self.__name = name def utcoffset(self, dt: datetime.datetime | None) -> timedelta: return self.__offset def tzname(self, dt: datetime.datetime | None) -> str | None: return self.__name def dst(self, dt: datetime.datetime | None) -> timedelta: return ZERO utc = FixedOffsetTimezone(offset=0, name="UTC") STDOFFSET = timedelta(seconds=-_time.timezone) if _time.daylight: DSTOFFSET = timedelta(seconds=-_time.altzone) else: DSTOFFSET = STDOFFSET DSTDIFF = DSTOFFSET - STDOFFSET class LocalTimezone(tzinfo): def utcoffset(self, dt: datetime.datetime | None) -> timedelta: if self._isdst(dt): return DSTOFFSET else: return STDOFFSET def dst(self, dt: datetime.datetime | None) -> timedelta: if self._isdst(dt): return DSTDIFF else: return ZERO def tzname(self, dt: datetime.datetime | None) -> str: return _time.tzname[self._isdst(dt)] def _isdst(self, dt: datetime.datetime | None) -> bool: if not dt: return False tt = ( dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.weekday(), 0, 0, ) stamp = _time.mktime(tt) tt = _time.localtime(stamp) return tt.tm_isdst > 0 local = LocalTimezone() pytds-1.15.0/src/pytds/utils.py000066400000000000000000000055061456567501500164310ustar00rootroot00000000000000""" This module contains generic utility functions which don't have dependencies on any other modules. """ from __future__ import annotations import logging import sys import time import typing from collections.abc import Callable if sys.version_info < (3, 8): import pkg_resources else: from importlib import metadata logger = logging.getLogger("pytds") T = typing.TypeVar("T") def exponential_backoff( work: Callable[[float], T], ex_handler: Callable[[Exception], None], max_time_sec: float, first_attempt_time_sec: float, backoff_factor: float = 2, ) -> T: """ Perform work with exponential backoff if work fails. Will raise TimeoutError if `max_time_sec` is exceeded. Work handler receives time limit in seconds for the attempt, it should use it to setup timeout for it's operations. The `ex_handler` is called every time work raises exception, if `ex_handler` can raise exception itself to stop """ try_time = first_attempt_time_sec end_time = time.time() + max_time_sec while True: try_start_time = time.time() try: return work(try_time) except Exception as ex: logger.info("Work attempt failed", exc_info=ex) work_actual_time = time.time() - try_start_time if work_actual_time > try_time: logger.warning( "Work attempt exceeded it's allocated time %f, actual time was %f.", try_time, work_actual_time, ) ex_handler(ex) if time.time() >= end_time: raise TimeoutError() from ex remaining_attempt_time = try_time - (time.time() - try_start_time) logger.info("Will retry after %f seconds", remaining_attempt_time) if remaining_attempt_time > 0: time.sleep(remaining_attempt_time) try_time *= backoff_factor def parse_server(server: str) -> tuple[str, str]: """ Split server name in MSSQL format (host\\instance) into server host and instance """ instance = "" if "\\" in server: server, instance = server.split("\\") # support MS methods of connecting locally if server in (".", "(local)"): server = "localhost" return server, instance.upper() def ver_to_int(ver: str) -> int: """ Convert version string into 32-bit integer format """ res = ver.split(".") if len(res) < 2: logger.warning( 'Invalid version %s, it should have 2 parts at least separated by "."', ver ) return 0 maj, minor, _ = ver.split(".") return (int(maj) << 24) + (int(minor) << 16) def package_version(name: str) -> str: if sys.version_info < (3, 8): return pkg_resources.get_distribution(name).version return metadata.version(name) pytds-1.15.0/test_requirements.txt000066400000000000000000000011411456567501500172770ustar00rootroot00000000000000pytest>=3.3.2 pytest-cov codecov # pyOpenSSL 23.0.0 fails with error: # TypeError: deprecated() got an unexpected keyword argument 'name' # Example failing build: https://ci.appveyor.com/project/denisenkom/pytds/builds/46539355/job/aq6d65ej1oi0i59p pyOpenSSL<22.1.0 pyDes ntlm-auth pyspnego namedlist # cryptography 3.4.5 fails build, requires rust compiler # see example failure: https://ci.appveyor.com/project/denisenkom/pytds/builds/37803561/job/lln9d25ye5vnljbr # requiring older cryptography library to avoid this error cryptography < 3.3 sqlalchemy-pytds==1.0.0 SQLAlchemy==2.0.22 mypy pytest-mypy ruffpytds-1.15.0/tests/000077500000000000000000000000001456567501500141215ustar00rootroot00000000000000pytds-1.15.0/tests/all_test.py000066400000000000000000000723021456567501500163060ustar00rootroot00000000000000# vim: set fileencoding=utf8 : from __future__ import with_statement from __future__ import unicode_literals import os import random import string import codecs import logging import socket from io import StringIO import utils from pytds.tds_types import ( TimeType, DateTime2Type, DateType, DateTimeOffsetType, BitType, TinyIntType, SmallIntType, IntType, BigIntType, RealType, FloatType, NVarCharType, VarBinaryType, SmallDateTimeType, DateTimeType, DecimalType, MoneyType, UniqueIdentifierType, VariantType, ImageType, VarBinaryMaxType, VarCharType, TextType, NTextType, NVarCharMaxType, VarCharMaxType, XmlType, ) try: import unittest2 as unittest except: import unittest import sys from decimal import Decimal, getcontext import logging from time import sleep from datetime import datetime, date, time import uuid import pytest import pytds.tz import pytds.login import pytds.smp tzoffset = pytds.tz.FixedOffsetTimezone utc = pytds.tz.utc import pytds.extensions from pytds import ( connect, ProgrammingError, TimeoutError, Time, Error, IntegrityError, Timestamp, DataError, Date, Binary, output, default, STRING, BINARY, NUMBER, DATETIME, DECIMAL, INTEGER, REAL, XML, ) from pytds.tds_types import DateTimeSerializer, SmallMoneyType from pytds.tds_base import ( Param, IS_TDS73_PLUS, IS_TDS71_PLUS, ) import dbapi20 import pytds import settings logger = logging.getLogger(__name__) LIVE_TEST = getattr(settings, "LIVE_TEST", True) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_timeout_with_mars(): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["timeout"] = 1 kwargs["use_mars"] = True with connect(*settings.CONNECT_ARGS, **kwargs) as conn: cur = conn.cursor() with pytest.raises(TimeoutError): cur.execute("waitfor delay '00:00:05'") cur.execute("select 1") @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_no_mars_autocommit(): kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "use_mars": False, "timeout": 1, "pooling": True, "autocommit": True, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: # test execute scalar with empty response cur.execute_scalar("declare @tbl table(f int); select * from @tbl") cur.execute("print 'hello'") messages = cur.messages assert len(messages) == 1 assert len(messages[0]) == 2 # in following assert exception class does not have to be exactly as specified assert messages[0][0] == pytds.OperationalError assert messages[0][1].text == "hello" assert messages[0][1].line == 1 assert messages[0][1].severity == 0 assert messages[0][1].number == 0 assert messages[0][1].state == 1 assert "hello" in messages[0][1].message # test cursor usage after close, should raise exception cur = conn.cursor() cur.execute_scalar("select 1") cur.close() with pytest.raises(Error) as ex: cur.execute("select 1") assert "Cursor is closed" in str(ex.value) # calling get_proc_return_status on closed cursor works # this test does not have to pass assert cur.get_proc_return_status() is None # calling rowcount on closed cursor works # this test does not have to pass assert cur.rowcount == -1 # calling description on closed cursor works # this test does not have to pass assert cur.description is None # calling messages on closed cursor works # this test does not have to pass assert cur.messages is None # calling description on closed cursor works # this test does not have to pass assert cur.native_description is None @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_timeout_no_mars(): kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "use_mars": False, "timeout": 1, "pooling": True, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: with pytest.raises(TimeoutError): cur.execute("waitfor delay '00:00:05'") with conn.cursor() as cur: cur.execute("select 1") cur.fetchall() # test cancelling with conn.cursor() as cur: cur.execute("select 1") cur.cancel() assert cur.fetchall() == [] cur.execute("select 2") assert cur.fetchall() == [(2,)] # test rollback conn.rollback() # test callproc on non-mars connection with conn.cursor() as cur: cur.callproc("sp_reset_connection") with conn.cursor() as cur: # test spid property on non-mars cursor assert cur.spid is not None # test tzinfo_factory property r/w cur.tzinfo_factory = cur.tzinfo_factory # test non-mars cursor with connection pool enabled with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") assert cur.fetchall() == [(1,)] @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_no_mars_no_pooling(): kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "use_mars": False, "pooling": False, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") assert cur.fetchall() == [(1,)] @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_row_strategies(): kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "row_strategy": pytds.list_row_strategy, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") assert cur.fetchall() == [[1]] kwargs.update( { "row_strategy": pytds.namedtuple_row_strategy, } ) import collections with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1 as f") assert cur.fetchall() == [collections.namedtuple("Row", ["f"])(1)] kwargs.update( { "row_strategy": pytds.recordtype_row_strategy, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1 as e, 2 as f") (row,) = cur.fetchall() assert row.e == 1 assert row.f == 2 assert row[0] == 1 assert row[:] == (1, 2) row[0] = 3 assert row[:] == (3, 2) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_get_instances(): if not hasattr(settings, "BROWSER_ADDRESS"): return unittest.skip("BROWSER_ADDRESS setting is not defined") pytds.tds.tds7_get_instances(settings.BROWSER_ADDRESS) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class ConnectionTestCase(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = settings.DATABASE self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def tearDown(self): self.conn.close() @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class NoMarsTestCase(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["use_mars"] = False self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def tearDown(self): self.conn.close() class TestVariant(ConnectionTestCase): def _t(self, result, sql): with self.conn.cursor() as cur: cur.execute("select cast({0} as sql_variant)".format(sql)) (val,) = cur.fetchone() self.assertEqual(result, val) def test_new_datetime(self): if not IS_TDS73_PLUS(self.conn): self.skipTest("Requires TDS7.3+") import pytds.tz self._t( datetime(2011, 2, 3, 10, 11, 12, 3000), "cast('2011-02-03T10:11:12.003000' as datetime2)", ) self._t(time(10, 11, 12, 3000), "cast('10:11:12.003000' as time)") self._t(date(2011, 2, 3), "cast('2011-02-03' as date)") self._t( datetime( 2011, 2, 3, 10, 11, 12, 3000, pytds.tz.FixedOffsetTimezone(3 * 60) ), "cast('2011-02-03T10:11:12.003000+03:00' as datetimeoffset)", ) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class BadConnection(unittest.TestCase): def test_invalid_parameters(self): with self.assertRaises(Error): with connect( server=settings.HOST + "bad", database="master", user="baduser", password=settings.PASSWORD, login_timeout=1, ) as conn: with conn.cursor() as cur: cur.execute("select 1") with self.assertRaises(Error): with connect( server=settings.HOST, database="doesnotexist", user=settings.USER, password=settings.PASSWORD, ) as conn: with conn.cursor() as cur: cur.execute("select 1") with self.assertRaises(Error): with connect( server=settings.HOST, database="master", user="baduser", password=None ) as conn: with conn.cursor() as cur: cur.execute("select 1") def test_instance_and_port(self): host = settings.HOST if "\\" in host: host, _ = host.split("\\") with self.assertRaisesRegex( ValueError, "Both instance and port shouldn't be specified" ): with connect( server=host + "\\badinstancename", database="master", user=settings.USER, password=settings.PASSWORD, port=1212, ) as conn: with conn.cursor() as cur: cur.execute("select 1") # class EncryptionTest(unittest.TestCase): # def runTest(self): # conn = connect(server=settings.HOST, database='master', user=settings.USER, password=settings.PASSWORD, encryption_level=TDS_ENCRYPTION_REQUIRE) # cur = conn.cursor() # cur.execute('select 1') @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class SmallDateTimeTest(ConnectionTestCase): def _testval(self, val): with self.conn.cursor() as cur: cur.execute("select cast(%s as smalldatetime)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) def runTest(self): self._testval(Timestamp(2010, 2, 1, 10, 12, 0)) self._testval(Timestamp(1900, 1, 1, 0, 0, 0)) self._testval(Timestamp(2079, 6, 6, 23, 59, 0)) with self.assertRaises(Error): self._testval(Timestamp(1899, 1, 1, 0, 0, 0)) with self.assertRaises(Error): self._testval(Timestamp(2080, 1, 1, 0, 0, 0)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class DateTimeTest(ConnectionTestCase): def _testencdec(self, val): self.assertEqual( val, DateTimeSerializer.decode( *DateTimeSerializer._struct.unpack(DateTimeSerializer.encode(val)) ), ) def _testval(self, val): with self.conn.cursor() as cur: cur.execute("select cast(%s as datetime)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) def runTest(self): with self.conn.cursor() as cur: cur.execute("select cast('9999-12-31T23:59:59.997' as datetime)") self.assertEqual( cur.fetchall(), [(Timestamp(9999, 12, 31, 23, 59, 59, 997000),)] ) self._testencdec(Timestamp(2010, 1, 2, 10, 11, 12)) self._testval(Timestamp(2010, 1, 2, 0, 0, 0)) self._testval(Timestamp(2010, 1, 2, 10, 11, 12)) self._testval(Timestamp(1753, 1, 1, 0, 0, 0)) self._testval(Timestamp(9999, 12, 31, 0, 0, 0)) with self.conn.cursor() as cur: cur.execute("select cast(null as datetime)") self.assertEqual(cur.fetchall(), [(None,)]) self._testval(Timestamp(9999, 12, 31, 23, 59, 59, 997000)) with self.assertRaises(Error): self._testval(Timestamp(1752, 1, 1, 0, 0, 0)) with self.conn.cursor() as cur: cur.execute( """ if object_id('testtable') is not null drop table testtable """ ) cur.execute("create table testtable (col datetime not null)") dt = Timestamp(2010, 1, 2, 20, 21, 22, 123000) cur.execute("insert into testtable values (%s)", (dt,)) cur.execute("select col from testtable") self.assertEqual(cur.fetchone(), (dt,)) class NewDateTimeTest(ConnectionTestCase): def test_datetimeoffset(self): if not IS_TDS73_PLUS(self.conn): self.skipTest("Requires TDS7.3+") def _testval(val): with self.conn.cursor() as cur: import pytds.tz cur.tzinfo_factory = pytds.tz.FixedOffsetTimezone cur.execute("select cast(%s as datetimeoffset)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) with self.conn.cursor() as cur: import pytds.tz cur.tzinfo_factory = pytds.tz.FixedOffsetTimezone cur.execute( "select cast('2010-01-02T20:21:22.1234567+05:00' as datetimeoffset)" ) self.assertEqual( datetime(2010, 1, 2, 20, 21, 22, 123456, tzoffset(5 * 60)), cur.fetchone()[0], ) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, utc)) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, tzoffset(5 * 60))) _testval(Timestamp(1, 1, 1, 0, 0, 0, 0, utc)) _testval(Timestamp(9999, 12, 31, 23, 59, 59, 999999, utc)) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, tzoffset(14))) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, tzoffset(-14))) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, tzoffset(-15))) def test_time(self): if not IS_TDS73_PLUS(self.conn): self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: cur.execute("select cast(%s as time)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Time(14, 16, 18, 123456)) testval(Time(0, 0, 0, 0)) testval(Time(0, 0, 0, 0)) testval(Time(0, 0, 0, 0)) testval(Time(23, 59, 59, 999999)) testval(Time(0, 0, 0, 0)) testval(Time(0, 0, 0, 0)) testval(Time(0, 0, 0, 0)) def test_datetime2(self): if not IS_TDS73_PLUS(self.conn): self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: cur.execute("select cast(%s as datetime2)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Timestamp(2010, 1, 2, 20, 21, 22, 345678)) testval(Timestamp(2010, 1, 2, 0, 0, 0)) testval(Timestamp(1, 1, 1, 0, 0, 0)) testval(Timestamp(9999, 12, 31, 23, 59, 59, 999999)) def test_date(self): if not IS_TDS73_PLUS(self.conn): self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: cur.execute("select cast(%s as date)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Date(2010, 1, 2)) testval(Date(2010, 1, 2)) testval(Date(1, 1, 1)) testval(Date(9999, 12, 31)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class Auth(unittest.TestCase): @unittest.skipUnless( os.getenv("NTLM_USER") and os.getenv("NTLM_PASSWORD"), "requires NTLM_USER and NTLM_PASSWORD environment variables to be set", ) def test_ntlm(self): conn = connect( settings.HOST, auth=pytds.login.NtlmAuth( user_name=os.getenv("NTLM_USER"), password=os.getenv("NTLM_PASSWORD") ), ) with conn.cursor() as cursor: cursor.execute("select 1") cursor.fetchall() @unittest.skipUnless( os.getenv("NTLM_USER") and os.getenv("NTLM_PASSWORD"), "requires NTLM_USER and NTLM_PASSWORD environment variables to be set", ) def test_spnego(self): conn = connect( settings.HOST, auth=pytds.login.SpnegoAuth( os.getenv("NTLM_USER"), os.getenv("NTLM_PASSWORD") ), ) with conn.cursor() as cursor: cursor.execute("select 1") cursor.fetchall() @unittest.skipUnless(sys.platform.startswith("win"), "requires Windows") def test_sspi(self): from pytds.login import SspiAuth with connect(**{ **settings.CONNECT_KWARGS, "auth": SspiAuth() }) as conn: with conn.cursor() as cursor: cursor.execute("select 1") cursor.fetchall() @unittest.skipIf(getattr(settings, "SKIP_SQL_AUTH", False), "SKIP_SQL_AUTH is set") def test_sqlauth(self): with connect(**{ **settings.CONNECT_KWARGS, "user": settings.USER, "password": settings.PASSWORD, }) as conn: with conn.cursor() as cursor: cursor.execute("select 1") cursor.fetchall() class CloseCursorTwice(ConnectionTestCase): def runTest(self): cursor = self.conn.cursor() cursor.close() cursor.close() class RegressionSuite(ConnectionTestCase): def test_cancel(self): self.conn.cursor().cancel() @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TimezoneTests(unittest.TestCase): def check_val(self, conn, sql, input, output): with conn.cursor() as cur: cur.execute("select " + sql, (input,)) rows = cur.fetchall() self.assertEqual(rows[0][0], output) def runTest(self): kwargs = settings.CONNECT_KWARGS.copy() use_tz = utc kwargs["use_tz"] = use_tz kwargs["database"] = "master" with connect(*settings.CONNECT_ARGS, **kwargs) as conn: # Naive time should be interpreted as use_tz self.check_val( conn, "%s", datetime(2011, 2, 3, 10, 11, 12, 3000), datetime(2011, 2, 3, 10, 11, 12, 3000, utc), ) # Aware time shoule be passed as-is dt = datetime(2011, 2, 3, 10, 11, 12, 3000, tzoffset(1)) self.check_val(conn, "%s", dt, dt) # Aware time should be converted to use_tz if not using datetimeoffset type dt = datetime(2011, 2, 3, 10, 11, 12, 3000, tzoffset(1)) if IS_TDS73_PLUS(conn): self.check_val(conn, "cast(%s as datetime2)", dt, dt.astimezone(use_tz)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class DbapiTestSuite(dbapi20.DatabaseAPI20Test, ConnectionTestCase): driver = pytds connect_args = settings.CONNECT_ARGS connect_kw_args = settings.CONNECT_KWARGS # def _connect(self): # return connection def _try_run(self, *args): with self._connect() as con: with con.cursor() as cur: for arg in args: cur.execute(arg) def _try_run2(self, cur, *args): for arg in args: cur.execute(arg) # This should create the "lower" sproc. def _callproc_setup(self, cur): self._try_run2( cur, """IF OBJECT_ID(N'[dbo].[to_lower]', N'P') IS NOT NULL DROP PROCEDURE [dbo].[to_lower]""", """ CREATE PROCEDURE to_lower @input nvarchar(max) AS BEGIN select LOWER(@input) END """, ) # This should create a sproc with a return value. def _retval_setup(self, cur): self._try_run2( cur, """IF OBJECT_ID(N'[dbo].[add_one]', N'P') IS NOT NULL DROP PROCEDURE [dbo].[add_one]""", """ CREATE PROCEDURE add_one (@input int) AS BEGIN return @input+1 END """, ) def test_retval(self): with self._connect() as con: cur = con.cursor() self._retval_setup(cur) values = cur.callproc("add_one", (1,)) self.assertEqual( values[0], 1, "input parameter should be left unchanged: %s" % (values[0],), ) self.assertEqual(cur.description, None, "No resultset was expected.") self.assertEqual( cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,) ) # This should create a sproc with a return value. def _retval_select_setup(self, cur): self._try_run2( cur, """IF OBJECT_ID(N'[dbo].[add_one_select]', N'P') IS NOT NULL DROP PROCEDURE [dbo].[add_one_select]""", """ CREATE PROCEDURE add_one_select (@input int) AS BEGIN select 'a' as a select 'b' as b return @input+1 END """, ) def test_retval_select(self): with self._connect() as con: cur = con.cursor() self._retval_select_setup(cur) values = cur.callproc("add_one_select", (1,)) self.assertEqual( values[0], 1, "input parameter should be left unchanged: %s" % (values[0],), ) self.assertEqual(len(cur.description), 1, "Unexpected resultset.") self.assertEqual(cur.description[0][0], "a", "Unexpected resultset.") self.assertEqual(cur.fetchall(), [("a",)], "Unexpected resultset.") self.assertTrue(cur.nextset(), "No second resultset found.") self.assertEqual(len(cur.description), 1, "Unexpected resultset.") self.assertEqual(cur.description[0][0], "b", "Unexpected resultset.") self.assertEqual( cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,) ) with self.assertRaises(Error): cur.fetchall() # This should create a sproc with an output parameter. def _outparam_setup(self, cur): self._try_run2( cur, """IF OBJECT_ID(N'[dbo].[add_one_out]', N'P') IS NOT NULL DROP PROCEDURE [dbo].[add_one_out]""", """ CREATE PROCEDURE add_one_out (@input int, @output int OUTPUT) AS BEGIN SET @output = @input+1 END """, ) def test_outparam(self): with self._connect() as con: cur = con.cursor() self._outparam_setup(cur) values = cur.callproc("add_one_out", (1, output(value=1))) self.assertEqual(len(values), 2, "expected 2 parameters") self.assertEqual(values[0], 1, "input parameter should be unchanged") self.assertEqual(values[1], 2, "output parameter should get new values") values = cur.callproc("add_one_out", (None, output(value=1))) self.assertEqual(len(values), 2, "expected 2 parameters") self.assertEqual(values[0], None, "input parameter should be unchanged") self.assertEqual(values[1], None, "output parameter should get new values") def test_assigning_select(self): # test that assigning select does not interfere with result sets with self._connect() as con: cur = con.cursor() cur.execute( """ declare @var1 int select @var1 = 1 select @var1 = 2 select 'value' """ ) self.assertFalse(cur.description) self.assertTrue(cur.nextset()) self.assertFalse(cur.description) self.assertTrue(cur.nextset()) self.assertTrue(cur.description) self.assertEqual([("value",)], cur.fetchall()) self.assertFalse(cur.nextset()) cur.execute( """ set nocount on declare @var1 int select @var1 = 1 select @var1 = 2 select 'value' """ ) self.assertTrue(cur.description) self.assertEqual([("value",)], cur.fetchall()) self.assertFalse(cur.nextset()) # Don't need setoutputsize tests. def test_setoutputsize(self): pass def help_nextset_setUp(self, cur): self._try_run2( cur, """IF OBJECT_ID(N'[dbo].[deleteme]', N'P') IS NOT NULL DROP PROCEDURE [dbo].[deleteme]""", """ create procedure deleteme as begin select count(*) from %sbooze select name from %sbooze end """ % (self.table_prefix, self.table_prefix), ) def help_nextset_tearDown(self, cur): cur.execute("drop procedure deleteme") def test_ExceptionsAsConnectionAttributes(self): pass def test_select_decimal_zero(self): with self._connect() as con: expected = (Decimal("0.00"), Decimal("0.0"), Decimal("-0.00")) cur = con.cursor() cur.execute("SELECT %s as A, %s as B, %s as C", expected) result = cur.fetchall() self.assertEqual(result[0], expected) def test_type_objects(self): with self._connect() as con: cur = con.cursor() cur.execute( """ select cast(0 as varchar), cast(1 as binary), cast(2 as int), cast(3 as real), cast(4 as decimal), cast('2005' as datetime), cast('6' as xml) """ ) self.assertTrue(cur.description) col_types = [col[1] for col in cur.description] self.assertEqual(col_types[0], STRING) self.assertEqual(col_types[1], BINARY) self.assertEqual(col_types[2], NUMBER) self.assertEqual(col_types[2], INTEGER) self.assertEqual(col_types[3], NUMBER) self.assertEqual(col_types[3], REAL) # self.assertEqual(col_types[4], NUMBER) ? self.assertEqual(col_types[4], DECIMAL) self.assertEqual(col_types[5], DATETIME) self.assertEqual(col_types[6], XML) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestBug4(unittest.TestCase): def test_as_dict(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" with connect( *settings.CONNECT_ARGS, **kwargs, row_strategy=pytds.dict_row_strategy ) as conn: with conn.cursor() as cur: cur.execute("select 1 as a, 2 as b") self.assertDictEqual({"a": 1, "b": 2}, cur.fetchone()) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestRawBytes(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["bytes_to_unicode"] = False kwargs["database"] = "master" self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_fetch(self): cur = self.conn.cursor() self.assertIsInstance( cur.execute_scalar("select cast('abc' as nvarchar(max))"), str ) self.assertIsInstance( cur.execute_scalar("select cast('abc' as varchar(max))"), bytes ) self.assertIsInstance(cur.execute_scalar("select cast('abc' as text)"), bytes) self.assertIsInstance(cur.execute_scalar("select %s", ["abc"]), str) self.assertIsInstance(cur.execute_scalar("select %s", [b"abc"]), bytes) rawBytes = b"\x01\x02\x03" self.assertEqual( rawBytes, cur.execute_scalar("select cast(0x010203 as varchar(max))") ) self.assertEqual(rawBytes, cur.execute_scalar("select %s", [rawBytes])) utf8char = b"\xee\xb4\xba" self.assertEqual(utf8char, cur.execute_scalar("select %s", [utf8char])) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_invalid_block_size(): """ Test buffer size changing. Initially buffer should start at 4096 according to TDS spec and then it should upgrade to buffer size that was provided in login request. """ kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "blocksize": 4000, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute_scalar("select '{}'".format("x" * 8000)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_readonly_connection(): kwargs = settings.CONNECT_KWARGS.copy() kwargs.update( { "readonly": True, } ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute_scalar("select 1") pytds-1.15.0/tests/connected_test.py000066400000000000000000000521651456567501500175050ustar00rootroot00000000000000# vim: set fileencoding=utf8 : from __future__ import with_statement from __future__ import unicode_literals from decimal import Decimal, Context import uuid import datetime import logging import random import string from io import BytesIO, StringIO import pytest import os import pytds import pytds.extensions import pytds.login import settings from fixtures import * from pytds import Column from pytds.tds_types import BitType from tests.utils import tran_count logger = logging.getLogger(__name__) LIVE_TEST = getattr(settings, "LIVE_TEST", True) pytds.tds_base.logging_enabled = True def test_integrity_error(cursor): cursor.execute("create table testtable_pk(pk int primary key)") cursor.execute("insert into testtable_pk values (1)") with pytest.raises(pytds.IntegrityError): cursor.execute("insert into testtable_pk values (1)") def test_bulk_insert(cursor): cur = cursor f = StringIO("42\tfoo\n74\tbar\n") cur.copy_to(f, "bulk_insert_table", schema="myschema", columns=("num", "data")) cur.execute("select num, data from myschema.bulk_insert_table") assert [(42, "foo"), (74, "bar")] == cur.fetchall() def test_bug2(cursor): cur = cursor cur.execute( """ create procedure testproc_bug2 (@param int) as begin set transaction isolation level read uncommitted -- that will produce very empty result (even no rowcount) select @param return @param + 1 end """ ) val = 45 cur.execute("exec testproc_bug2 @param = 45") assert cur.fetchall() == [(val,)] assert val + 1 == cur.get_proc_return_status() def test_stored_proc(cursor): cur = cursor val = 45 # params = {'@param': val, '@outparam': output(None), '@add': 1} values = cur.callproc("testproc", (val, pytds.default, pytds.output(value=1))) # self.assertEqual(cur.fetchall(), [(val,)]) assert val + 2 == values[2] assert val + 2 == cur.get_proc_return_status() # after calling stored proc which does not have RETURN statement get_proc_return_status() should return 0 # since in this case SQL server issues RETURN STATUS token with 0 value cur.callproc("test_proc_no_return", (val,)) assert cur.fetchall() == [(val,)] assert cur.get_proc_return_status() == 0 # TODO fix this part, currently it fails # assert cur.execute_scalar("select 1") == 1 # assert cur.get_proc_return_status() == 0 def test_table_selects(db_connection): cur = db_connection.cursor() cur.execute( """ create table #testtable (id int, _text text, _xml xml, vcm varchar(max), vc varchar(10)) """ ) cur.execute( """ insert into #testtable (id, _text, _xml, vcm, vc) values (1, 'text', '', '', NULL) """ ) cur.execute("select id from #testtable order by id") assert [(1,)] == cur.fetchall() cur = db_connection.cursor() cur.execute("select _text from #testtable order by id") assert [("text",)] == cur.fetchall() cur = db_connection.cursor() cur.execute("select _xml from #testtable order by id") assert [("",)] == cur.fetchall() cur = db_connection.cursor() cur.execute("select id, _text, _xml, vcm, vc from #testtable order by id") assert (1, "text", "", "", None) == cur.fetchone() cur = db_connection.cursor() cur.execute("select vc from #testtable order by id") assert [(None,)] == cur.fetchall() cur = db_connection.cursor() cur.execute("insert into #testtable (_xml) values (%s)", ("",)) cur = db_connection.cursor() cur.execute("drop table #testtable") def test_decimals(cursor): cur = cursor assert Decimal(12) == cur.execute_scalar("select cast(12 as decimal) as fieldname") assert Decimal(-12) == cur.execute_scalar( "select cast(-12 as decimal) as fieldname" ) assert Decimal("123456.12345") == cur.execute_scalar( "select cast('123456.12345'as decimal(20,5)) as fieldname" ) assert Decimal("-123456.12345") == cur.execute_scalar( "select cast('-123456.12345'as decimal(20,5)) as fieldname" ) def test_bulk_insert_with_special_chars_no_columns(cursor): cur = cursor cur.execute("create table [test]] table](num int not null, data varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") cur.copy_to(f, "test] table") cur.execute("select num, data from [test]] table]") assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_special_chars(cursor): cur = cursor cur.execute("create table [test]] table](num int, data varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") cur.copy_to(f, "test] table", columns=("num", "data")) cur.execute("select num, data from [test]] table]") assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_keyword_column_name(cursor): cur = cursor cur.execute("create table test_table(num int, [User] varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") cur.copy_to(f, "test_table") cur.execute("select num, [User] from test_table") assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_direct_data(cursor): cur = cursor cur.execute("create table test_table(num int, data nvarchar(max))") data = [[42, "foo"], [57, ""], [66, None], [74, "bar"]] column_types = [ pytds.tds_base.Column("num", type=pytds.tds_types.IntType()), pytds.tds_base.Column("data", type=pytds.tds_types.NVarCharMaxType()), ] cur.copy_to(data=data, table_or_view="test_table", columns=column_types) cur.execute("select num, data from test_table") assert cur.fetchall() == [(42, "foo"), (57, ""), (66, None), (74, "bar")] def test_table_valued_type_autodetect(cursor): def rows_gen(): yield 1, "test1" yield 2, "test2" tvp = pytds.TableValuedParam(type_name="dbo.CategoryTableType", rows=rows_gen()) cursor.execute("SELECT * FROM %s", (tvp,)) assert cursor.fetchall() == [(1, "test1"), (2, "test2")] def test_table_valued_type_explicit(cursor): def rows_gen(): yield 1, "test1" yield 2, "test2" tvp = pytds.TableValuedParam( type_name="dbo.CategoryTableType", columns=( pytds.Column(type=pytds.tds_types.IntType()), pytds.Column(type=pytds.tds_types.NVarCharType(size=30)), ), rows=rows_gen(), ) cursor.execute("SELECT * FROM %s", (tvp,)) assert cursor.fetchall() == [(1, "test1"), (2, "test2")] def test_minimal(cursor): cursor.execute("select 1") assert [(1,)] == cursor.fetchall() def test_empty_query(cursor): cursor.execute("") assert cursor.description is None @pytest.mark.parametrize( "typ,values", [ (pytds.tds_types.BitType(), [True, False]), (pytds.tds_types.IntType(), [2**31 - 1, None]), (pytds.tds_types.IntType(), [-(2**31), None]), (pytds.tds_types.SmallIntType(), [-(2**15), 2**15 - 1]), (pytds.tds_types.TinyIntType(), [255, 0]), (pytds.tds_types.BigIntType(), [2**63 - 1, -(2**63)]), (pytds.tds_types.IntType(), [None, 2**31 - 1]), (pytds.tds_types.IntType(), [None, -(2**31)]), (pytds.tds_types.RealType(), [0.25, None]), (pytds.tds_types.FloatType(), [0.25, None]), (pytds.tds_types.VarCharType(size=10), ["", "testtest12", None, "foo"]), (pytds.tds_types.VarCharType(size=4000), ["x" * 4000, "foo"]), ( pytds.tds_types.VarCharMaxType(), ["x" * 10000, "foo", "", "testtest", None, "bar"], ), (pytds.tds_types.NVarCharType(size=10), ["", "testtest12", None, "foo"]), (pytds.tds_types.NVarCharType(size=4000), ["x" * 4000, "foo"]), ( pytds.tds_types.NVarCharMaxType(), ["x" * 10000, "foo", "", "testtest", None, "bar"], ), (pytds.tds_types.VarBinaryType(size=10), [b"testtest12", b"", None]), (pytds.tds_types.VarBinaryType(size=8000), [b"x" * 8000, b""]), ( pytds.tds_types.SmallDateTimeType(), [ datetime.datetime(1900, 1, 1, 0, 0, 0), None, datetime.datetime(2079, 6, 6, 23, 59, 0), ], ), ( pytds.tds_types.DateTimeType(), [ datetime.datetime(1753, 1, 1, 0, 0, 0), None, datetime.datetime(9999, 12, 31, 23, 59, 59, 990000), ], ), ( pytds.tds_types.DateType(), [datetime.date(1, 1, 1), None, datetime.date(9999, 12, 31)], ), (pytds.tds_types.TimeType(precision=0), [datetime.time(0, 0, 0), None]), ( pytds.tds_types.TimeType(precision=6), [datetime.time(23, 59, 59, 999999), None], ), (pytds.tds_types.TimeType(precision=0), [None]), ( pytds.tds_types.DateTime2Type(precision=0), [datetime.datetime(1, 1, 1, 0, 0, 0), None], ), ( pytds.tds_types.DateTime2Type(precision=6), [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999), None], ), (pytds.tds_types.DateTime2Type(precision=0), [None]), ( pytds.tds_types.DateTimeOffsetType(precision=6), [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, pytds.tz.utc), None], ), ( pytds.tds_types.DateTimeOffsetType(precision=6), [ datetime.datetime( 9999, 12, 31, 23, 59, 59, 999999, pytds.tz.FixedOffsetTimezone(14) ), None, ], ), ( pytds.tds_types.DateTimeOffsetType(precision=0), [ datetime.datetime( 1, 1, 1, 0, 0, 0, tzinfo=pytds.tz.FixedOffsetTimezone(-14) ) ], ), ( pytds.tds_types.DateTimeOffsetType(precision=0), [ datetime.datetime( 1, 1, 1, 0, 14, 0, tzinfo=pytds.tz.FixedOffsetTimezone(14) ) ], ), (pytds.tds_types.DateTimeOffsetType(precision=6), [None]), ( pytds.tds_types.DecimalType(scale=6, precision=38), [Decimal("123.456789"), None], ), ( pytds.tds_types.SmallMoneyType(), [Decimal("214748.3647"), None, Decimal("-214748.3648")], ), ( pytds.tds_types.MoneyType(), [Decimal("922337203685477.5807"), None, Decimal("-922337203685477.5808")], ), (pytds.tds_types.SmallMoneyType(), [Decimal("214748.3647")]), (pytds.tds_types.MoneyType(), [Decimal("922337203685477.5807")]), (pytds.tds_types.MoneyType(), [None]), (pytds.tds_types.UniqueIdentifierType(), [None, uuid.uuid4()]), (pytds.tds_types.VariantType(), [None]), # (pytds.tds_types.VariantType(), [100]), # (pytds.tds_types.ImageType(), [None]), (pytds.tds_types.VarBinaryMaxType(), [None]), # (pytds.tds_types.NTextType(), [None]), # (pytds.tds_types.TextType(), [None]), # (pytds.tds_types.ImageType(), [b'']), # (self.conn._conn.type_factory.long_binary_type(), [b'testtest12']), # (self.conn._conn.type_factory.long_string_type(), [None]), # (self.conn._conn.type_factory.long_varchar_type(), [None]), # (self.conn._conn.type_factory.long_string_type(), ['test']), # (pytds.tds_types.ImageType(), [None]), # (pytds.tds_types.ImageType(), [None]), # (pytds.tds_types.ImageType(), [b'test']), ], ) def test_bulk_insert_type(cursor, typ, values): cur = cursor cur.execute( "create table bulk_insert_table_ll(c1 {0})".format(typ.get_declaration()) ) cur._session.submit_plain_query( "insert bulk bulk_insert_table_ll (c1 {0})".format(typ.get_declaration()) ) cur._session.process_simple_request() col1 = pytds.Column(name="c1", type=typ, flags=pytds.Column.fNullable) metadata = [col1] cur._session.submit_bulk(metadata, [[value] for value in values]) cur._session.process_simple_request() cur.execute("select c1 from bulk_insert_table_ll") assert cur.fetchall() == [(value,) for value in values] assert cur.fetchone() is None def test_streaming(cursor): val = "x" * 10000 # test nvarchar(max) cursor.execute("select N'{}', 1".format(val)) with pytest.raises(ValueError): cursor.set_stream(1, StringIO()) with pytest.raises(ValueError): cursor.set_stream(2, StringIO()) with pytest.raises(ValueError): cursor.set_stream(-1, StringIO()) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert isinstance(row[0], StringIO) assert row[0].getvalue() == val # test nvarchar(max) with NULL value cursor.execute("select cast(NULL as nvarchar(max)), 1".format(val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert row[0] is None # test varchar(max) cursor.execute("select '{}', 1".format(val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert isinstance(row[0], StringIO) assert row[0].getvalue() == val # test varbinary(max) cursor.execute("select cast('{}' as varbinary(max)), 1".format(val)) cursor.set_stream(0, BytesIO()) row = cursor.fetchone() assert isinstance(row[0], BytesIO) assert row[0].getvalue().decode("ascii") == val # test image type cursor.execute("select cast('{}' as image), 1".format(val)) cursor.set_stream(0, BytesIO()) row = cursor.fetchone() assert isinstance(row[0], BytesIO) assert row[0].getvalue().decode("ascii") == val # test ntext type cursor.execute("select cast('{}' as ntext), 1".format(val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert isinstance(row[0], StringIO) assert row[0].getvalue() == val # test text type cursor.execute("select cast('{}' as text), 1".format(val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert isinstance(row[0], StringIO) assert row[0].getvalue() == val # test xml type xml_val = "{}".format(val) cursor.execute("select cast('{}' as xml), 1".format(xml_val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() assert isinstance(row[0], StringIO) assert row[0].getvalue() == xml_val def test_properties(separate_db_connection): conn = separate_db_connection # this property is provided for compatibility with pymssql assert conn.autocommit_state == conn.autocommit # test set_autocommit which is provided for compatibility with ADO dbapi conn.set_autocommit(conn.autocommit) # test isolation_level property read/write conn.isolation_level = conn.isolation_level # test product_version property read logger.info("Product version %s", conn.product_version) conn.as_dict = conn.as_dict def test_fetch_on_empty_dataset(cursor): cursor.execute("declare @x int") with pytest.raises(pytds.ProgrammingError): cursor.fetchall() def test_bad_collation(cursor): # exception can be different with pytest.raises(UnicodeDecodeError): cursor.execute_scalar("select cast(0x90 as varchar)") # check that connection is still usable assert 1 == cursor.execute_scalar("select 1") def test_description(cursor): cursor.execute("select cast(12.65 as decimal(4,2)) as testname") assert cursor.description[0][0] == "testname" assert cursor.description[0][1] == pytds.DECIMAL assert cursor.description[0][4] == 4 assert cursor.description[0][5] == 2 def test_bug4(separate_db_connection): with separate_db_connection.cursor() as cursor: cursor.execute( """ set transaction isolation level read committed select 1 """ ) assert cursor.fetchall() == [(1,)] def test_row_strategies(separate_db_connection): conn = pytds.connect( *settings.CONNECT_ARGS, **settings.CONNECT_KWARGS, row_strategy=pytds.dict_row_strategy, ) with conn.cursor() as cur: cur.execute("select 1 as f") assert cur.fetchall() == [{"f": 1}] conn = pytds.connect( *settings.CONNECT_ARGS, **settings.CONNECT_KWARGS, row_strategy=pytds.tuple_row_strategy, ) with conn.cursor() as cur: cur.execute("select 1 as f") assert cur.fetchall() == [(1,)] def test_fetchone(cursor): cur = cursor cur.execute("select 10; select 12") assert (10,) == cur.fetchone() assert cur.nextset() assert (12,) == cur.fetchone() assert not cur.nextset() def test_fetchall(cursor): cur = cursor cur.execute("select 10; select 12") assert [(10,)] == cur.fetchall() assert cur.nextset() assert [(12,)] == cur.fetchall() assert not cur.nextset() def test_cursor_closing(db_connection): with db_connection.cursor() as cur: cur.execute("select 10; select 12") cur.fetchone() with db_connection.cursor() as cur2: cur2.execute("select 20") cur2.fetchone() def test_multi_packet(cursor): cur = cursor param = "x" * (cursor._connection._tds_socket.main_session._writer.bufsize * 3) cur.execute("select %s", (param,)) assert [(param,)] == cur.fetchall() def test_big_request(cursor): cur = cursor param = "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(5000) ) params = (10, datetime.datetime(2012, 11, 19, 1, 21, 37, 3000), param, "test") cur.execute("select %s, %s, %s, %s", params) assert [params] == cur.fetchall() def test_row_count(cursor): cur = cursor cur.execute( """ create table testtable_row_cnt (field int) """ ) cur.execute("insert into testtable_row_cnt (field) values (1)") assert cur.rowcount == 1 cur.execute("insert into testtable_row_cnt (field) values (2)") assert cur.rowcount == 1 cur.execute("select * from testtable_row_cnt") cur.fetchall() assert cur.rowcount == 2 def test_no_rows(cursor): cur = cursor cur.execute( """ create table testtable_no_rows (field int) """ ) cur.execute("select * from testtable_no_rows") assert [] == cur.fetchall() def test_fixed_size_data(cursor): cur = cursor cur.execute( """ create table testtable_fixed_size_dt (chr char(5), nchr nchar(5), bfld binary(5)) insert into testtable_fixed_size_dt values ('1', '2', cast('3' as binary(5))) """ ) cur.execute("select * from testtable_fixed_size_dt") assert cur.fetchall() == [("1 ", "2 ", b"3\x00\x00\x00\x00")] def test_closing_cursor_in_context(db_connection): with db_connection.cursor() as cur: cur.close() def test_cursor_connection_property(db_connection): with db_connection.cursor() as cur: assert cur.connection is db_connection def test_invalid_ntlm_creds(): if not LIVE_TEST: pytest.skip("LIVE_TEST is not set") with pytest.raises(pytds.OperationalError): pytds.connect( settings.HOST, auth=pytds.login.NtlmAuth(user_name="bad", password="bad") ) def test_open_with_different_blocksize(): if not LIVE_TEST: pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() # test very small block size kwargs["blocksize"] = 100 with pytds.connect(*settings.CONNECT_ARGS, **kwargs): pass # test very large block size kwargs["blocksize"] = 1000000 with pytds.connect(*settings.CONNECT_ARGS, **kwargs): pass def test_nvarchar_multiple_rows(cursor): cursor.execute( """ set nocount on declare @tbl table (id int primary key, fld nvarchar(max)) insert into @tbl values(1, 'foo') insert into @tbl values(2, 'bar') select fld from @tbl order by id """ ) assert cursor.fetchall() == [("foo",), ("bar",)] def test_no_metadata_request(cursor): cursor._session.submit_rpc( rpc_name=pytds.tds_base.SP_PREPARE, params=cursor._session._convert_params( (pytds.output(param_type=int), "@p1 int", "select @p1") ), ) cursor._session.begin_response() cursor._session.process_rpc() while cursor.nextset(): pass res = cursor.get_proc_outputs() handle = res[0] logger.info("got handle %s", handle) cursor._session.submit_rpc( rpc_name=pytds.tds_base.SP_EXECUTE, params=cursor._session._convert_params((handle, 1)), ) cursor._session.begin_response() cursor._session.process_rpc() assert cursor.fetchall() == [(1,)] while cursor.nextset(): pass cursor._session.submit_rpc( rpc_name=pytds.tds_base.SP_EXECUTE, params=cursor._session._convert_params((handle, 2)), flags=0x02, # no metadata ) cursor._session.begin_response() cursor._session.process_rpc() # for some reason SQL server still sends metadata back assert cursor.fetchall() == [(2,)] while cursor.nextset(): pass def test_with_sso(): if not LIVE_TEST: pytest.skip("LIVE_TEST is not set") with pytds.connect(settings.HOST, use_sso=True) as conn: with conn.cursor() as cursor: cursor.execute("select 1") cursor.fetchall() pytds-1.15.0/tests/connection_closing_tests.py000066400000000000000000000047171456567501500216030ustar00rootroot00000000000000""" Testing various ways of closing connection """ from time import sleep import pytest import pytds import settings def get_spid(conn): with conn.cursor() as cur: return cur.spid def kill(conn, spid): with conn.cursor() as cur: cur.execute("kill {0}".format(spid)) def test_cursor_use_after_connection_closing(): """ Check that cursor is not usable after it's parent connection is closed """ conn = pytds.connect(*settings.CONNECT_ARGS, **settings.CONNECT_KWARGS) cur = conn.cursor() conn.close() with pytest.raises(pytds.Error): cur.execute("select 1") # now create new connection which should reuse previous connection from the pool # and verify that it still works new_conn = pytds.connect(*settings.CONNECT_ARGS, **settings.CONNECT_KWARGS) with new_conn: assert 1 == new_conn.cursor().execute_scalar("select 1") def test_open_close(): for x in range(3): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" pytds.connect(**kwargs).close() def test_closing_after_closed_by_server(): """ You should be able to call close on connection closed by server """ kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["autocommit"] = True with pytds.connect(**kwargs) as master_conn: kwargs["autocommit"] = False with pytds.connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") conn.commit() kill(master_conn, get_spid(conn)) sleep(0.2) conn.close() def _test_using_connection_after_closed(use_mars: bool) -> None: conn = pytds.connect(**{**settings.CONNECT_KWARGS, "use_mars": use_mars}) cursor = conn.cursor() assert 1 == cursor.execute_scalar("select 1") conn.close() with pytest.raises(pytds.InterfaceError, match="Cursor is closed"): cursor.execute("select 1") with pytest.raises(pytds.InterfaceError, match="Cursor is closed"): cursor.callproc("someproc") def test_using_connection_after_closed_for_mars(): """ Using cursors from closed connection should throw correct errors. """ _test_using_connection_after_closed(use_mars=True) def test_using_connection_after_closed_for_non_mars(): """ Using cursors from closed connection should throw correct errors. """ _test_using_connection_after_closed(use_mars=False) pytds-1.15.0/tests/connection_pool_tests.py000066400000000000000000000022721456567501500211100ustar00rootroot00000000000000import unittest import settings import pytds LIVE_TEST = getattr(settings, "LIVE_TEST", True) def test_broken_connection_in_pool(): """ Broken connection in the pool should not cause issues when it is attempted to be reused """ # first clear pool of any connections pytds.connection_pool._pool.clear() # create extra connection, it is needed to be able to kill other connections extra_conn = pytds.connect(**settings.CONNECT_KWARGS, autocommit=True) # Now create one connection and get underlying connection with pytds.connect(**settings.CONNECT_KWARGS, autocommit=True) as conn: sess = conn._tds_socket.main_session # kill this connection, need to use another connection to do that spid = sess.execute_scalar("select @@spid") with extra_conn.cursor() as cur: cur.execute(f"kill {spid}") # create new connection, it should attempt to use connection from the pool # it should detect that connection is bad and create new one with pytds.connect(**settings.CONNECT_KWARGS, autocommit=True) as conn: with conn.cursor() as cur: assert 1 == cur.execute_scalar("select 1") # cleanup extra_conn.close() pytds-1.15.0/tests/dbapi20.py000066400000000000000000000772111456567501500157240ustar00rootroot00000000000000#!/usr/bin/env python """ Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. "Now we have booze and barflies entering the discussion, plus rumours of DBAs on drugs... and I won't tell you what flashes through my mind each time I read the subject line with 'Anal Compliance' in it. All around this is turning out to be a thoroughly unwholesome unit test." -- Ian Bicking """ __rcs_id__ = "$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $" __version__ = "$Revision: 1.12 $"[11:-2] __author__ = "Stuart Bishop " import unittest import time import sys # Revision 1.12 2009/02/06 03:35:11 kf7xm # Tested okay with Python 3.0, includes last minute patches from Mark H. # # Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole # Include latest changes from main branch # Updates for py3k # # Revision 1.11 2005/01/02 02:41:01 zenzen # Update author email address # # Revision 1.10 2003/10/09 03:14:14 zenzen # Add test for DB API 2.0 optional extension, where database exceptions # are exposed as attributes on the Connection object. # # Revision 1.9 2003/08/13 01:16:36 zenzen # Minor tweak from Stefan Fleiter # # Revision 1.8 2003/04/10 00:13:25 zenzen # Changes, as per suggestions by M.-A. Lemburg # - Add a table prefix, to ensure namespace collisions can always be avoided # # Revision 1.7 2003/02/26 23:33:37 zenzen # Break out DDL into helper functions, as per request by David Rushby # # Revision 1.6 2003/02/21 03:04:33 zenzen # Stuff from Henrik Ekelund: # added test_None # added test_nextset & hooks # # Revision 1.5 2003/02/17 22:08:43 zenzen # Implement suggestions and code from Henrik Eklund - test that cursor.arraysize # defaults to 1 & generic cursor.callproc test added # # Revision 1.4 2003/02/15 00:16:33 zenzen # Changes, as per suggestions and bug reports by M.-A. Lemburg, # Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar # - Class renamed # - Now a subclass of TestCase, to avoid requiring the driver stub # to use multiple inheritance # - Reversed the polarity of buggy test in test_description # - Test exception heirarchy correctly # - self.populate is now self._populate(), so if a driver stub # overrides self.ddl1 this change propogates # - VARCHAR columns now have a width, which will hopefully make the # DDL even more portible (this will be reversed if it causes more problems) # - cursor.rowcount being checked after various execute and fetchXXX methods # - Check for fetchall and fetchmany returning empty lists after results # are exhausted (already checking for empty lists if select retrieved # nothing # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # def str2bytes(sval): if sys.version_info < (3, 0) and isinstance(sval, str): sval = sval.decode("latin1") return sval.encode("latin1") class DatabaseAPI20Test(unittest.TestCase): """Test a database self.driver for DB API 2.0 compatibility. This implementation tests Gadfly, but the TestCase is structured so that other self.drivers can subclass this test case to ensure compiliance with the DB-API. It is expected that this TestCase may be expanded in the future if ambiguities or edge conditions are discovered. The 'Optional Extensions' are not yet being tested. self.drivers should subclass this test, overriding setUp, tearDown, self.driver, connect_args and connect_kw_args. Class specification should be as follows: import dbapi20 class mytest(dbapi20.DatabaseAPI20Test): [...] Don't 'import DatabaseAPI20Test from dbapi20', or you will confuse the unit tester - just 'import dbapi20'. """ # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None connect_args = () # List of arguments to pass to connect connect_kw_args = {} # Keyword arguments for connect table_prefix = "dbapi20test_" # If you need to specify a prefix for tables ddl1 = "create table %sbooze (name varchar(20))" % table_prefix ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix xddl1 = "drop table %sbooze" % table_prefix xddl2 = "drop table %sbarflys" % table_prefix lowerfunc = "to_lower" # Name of stored procedure to convert string->lowercase # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. def executeDDL1(self, cursor): cursor.execute(self.ddl1) def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): """self.drivers should override this method to perform required setup if any is necessary, such as creating the database. """ pass def tearDown(self): """self.drivers should override this method to perform required cleanup if any is necessary, such as deleting the test database. The default drops the tables that may be created. """ con = self._connect() try: cur = con.cursor() for ddl in (self.xddl1, self.xddl2): try: cur.execute(ddl) con.commit() except self.driver.Error: # Assume table didn't exist. Other tests will check if # execute is busted. pass finally: con.close() def _connect(self): try: return self.driver.connect(*self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") def test_connect(self): con = self._connect() con.close() def test_apilevel(self): try: # Must exist apilevel = self.driver.apilevel # Must equal 2.0 self.assertEqual(apilevel, "2.0") except AttributeError: self.fail("Driver doesn't define apilevel") def test_threadsafety(self): try: # Must exist threadsafety = self.driver.threadsafety # Must be a valid value self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") def test_paramstyle(self): try: # Must exist paramstyle = self.driver.paramstyle # Must be a valid value self.assertTrue( paramstyle in ("qmark", "numeric", "named", "format", "pyformat") ) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined heirarchy. if sys.version[0] == "3": # under Python 3 StardardError no longer exists self.assertTrue(issubclass(self.driver.Warning, Exception)) self.assertTrue(issubclass(self.driver.Error, Exception)) else: self.assertTrue(issubclass(self.driver.Warning, StandardError)) self.assertTrue(issubclass(self.driver.Error, StandardError)) self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION # Test for the optional DB API 2.0 extension, where the exceptions # are exposed as attributes on the Connection object # I figure this optional extension will be implemented by any # driver author who is using this test suite, so it is enabled # by default. con = self._connect() try: drv = self.driver self.assertTrue(con.Warning is drv.Warning) self.assertTrue(con.Error is drv.Error) self.assertTrue(con.InterfaceError is drv.InterfaceError) self.assertTrue(con.DatabaseError is drv.DatabaseError) self.assertTrue(con.OperationalError is drv.OperationalError) self.assertTrue(con.IntegrityError is drv.IntegrityError) self.assertTrue(con.InternalError is drv.InternalError) self.assertTrue(con.ProgrammingError is drv.ProgrammingError) self.assertTrue(con.NotSupportedError is drv.NotSupportedError) finally: con.close() def test_commit(self): con = self._connect() try: # Commit must work, even if it doesn't do anything con.commit() finally: con.close() def test_rollback(self): con = self._connect() try: # If rollback is defined, it should either work or throw # the documented exception if hasattr(con, "rollback"): try: con.rollback() except self.driver.NotSupportedError: pass finally: con.close() def test_cursor(self): con = self._connect() try: cur = con.cursor() finally: con.close() def test_cursor_isolation(self): con = self._connect() try: # Make sure cursors created from the same connection have # the documented transaction isolation level cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) cur1.execute( "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) ) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() self.assertEqual(len(booze), 1) self.assertEqual(len(booze[0]), 1) self.assertEqual(booze[0][0], "Victoria Bitter") finally: con.close() def test_description(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) self.assertEqual( cur.description, None, "cursor.description should be none after executing a " "statement that can return no rows (such as DDL)", ) cur.execute("select name from %sbooze" % self.table_prefix) self.assertEqual( len(cur.description), 1, "cursor.description describes too many columns" ) self.assertEqual( len(cur.description[0]), 7, "cursor.description[x] tuples must have 7 elements", ) self.assertEqual( cur.description[0][0].lower(), "name", "cursor.description[x][0] must return column name", ) self.assertEqual( cur.description[0][1], self.driver.STRING, "cursor.description[x][1] must return column type. Got %r" % cur.description[0][1], ) # Make sure self.description gets reset self.executeDDL2(cur) self.assertEqual( cur.description, None, "cursor.description not being set to None when executing " "no-result statements (eg. DDL)", ) finally: con.close() def test_rowcount(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) self.assertEqual( cur.rowcount, -1, "cursor.rowcount should be -1 after executing no-result " "statements", ) cur.execute( "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) ) self.assertTrue( cur.rowcount in (-1, 1), "cursor.rowcount should == number or rows inserted, or " "set to -1 after executing an insert statement", ) cur.execute("select name from %sbooze" % self.table_prefix) self.assertTrue( cur.rowcount in (-1, 1), "cursor.rowcount should == number of rows returned, or " "set to -1 after executing a select statement", ) self.executeDDL2(cur) self.assertEqual( cur.rowcount, -1, "cursor.rowcount not being reset to -1 after executing " "no-result statements", ) finally: con.close() lower_func = "to_lower" def test_callproc(self): con = self._connect() try: cur = con.cursor() self._callproc_setup(cur) if self.lower_func and hasattr(cur, "callproc"): r = cur.callproc(self.lower_func, ("FOO",)) self.assertEqual(len(r), 1) self.assertEqual(r[0], "FOO") r = cur.fetchall() self.assertEqual(len(r), 1, "callproc produced no result set") self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") self.assertEqual(r[0][0], "foo", "callproc produced invalid results") finally: con.close() def test_close(self): con = self._connect() try: cur = con.cursor() finally: con.close() # cursor.execute should raise an Error if called after connection # closed self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' self.assertRaises(self.driver.Error, con.commit) # connection.close should raise an Error if called more than once # # disabled, there is no such requirement in DBAPI PEP-0249 # self.assertRaises(self.driver.Error,con.close) def test_execute(self): con = self._connect() try: cur = con.cursor() self._paraminsert(cur) finally: con.close() def _paraminsert(self, cur): self.executeDDL1(cur) cur.execute( "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) ) self.assertTrue(cur.rowcount in (-1, 1)) if self.driver.paramstyle == "qmark": cur.execute( "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == "numeric": cur.execute( "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == "named": cur.execute( "insert into %sbooze values (:beer)" % self.table_prefix, {"beer": "Cooper's"}, ) elif self.driver.paramstyle == "format": cur.execute( "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",) ) elif self.driver.paramstyle == "pyformat": cur.execute( "insert into %sbooze values (%%(beer)s)" % self.table_prefix, {"beer": "Cooper's"}, ) else: self.fail("Invalid paramstyle") self.assertTrue(cur.rowcount in (-1, 1)) cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") beers = [res[0][0], res[1][0]] beers.sort() self.assertEqual( beers[0], "Cooper's", "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", ) self.assertEqual( beers[1], "Victoria Bitter", "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) largs = [("Cooper's",), ("Boag's",)] margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] if self.driver.paramstyle == "qmark": cur.executemany( "insert into %sbooze values (?)" % self.table_prefix, largs ) elif self.driver.paramstyle == "numeric": cur.executemany( "insert into %sbooze values (:1)" % self.table_prefix, largs ) elif self.driver.paramstyle == "named": cur.executemany( "insert into %sbooze values (:beer)" % self.table_prefix, margs ) elif self.driver.paramstyle == "format": cur.executemany( "insert into %sbooze values (%%s)" % self.table_prefix, largs ) elif self.driver.paramstyle == "pyformat": cur.executemany( "insert into %sbooze values (%%(beer)s)" % (self.table_prefix), margs, ) else: self.fail("Unknown paramstyle") self.assertTrue( cur.rowcount in (-1, 2), "insert using cursor.executemany set cursor.rowcount to " "incorrect value %r" % cur.rowcount, ) cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() self.assertEqual( len(res), 2, "cursor.fetchall retrieved incorrect number of rows" ) beers = [res[0][0], res[1][0]] beers.sort() self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") finally: con.close() def test_fetchone(self): con = self._connect() try: cur = con.cursor() # cursor.fetchone should raise an Error if called before # executing a select-type query self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) self.assertRaises(self.driver.Error, cur.fetchone) cur.execute("select name from %sbooze" % self.table_prefix) self.assertEqual( cur.fetchone(), None, "cursor.fetchone should return None if a query retrieves " "no rows", ) self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows cur.execute( "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) ) self.assertRaises(self.driver.Error, cur.fetchone) cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() self.assertEqual( len(r), 1, "cursor.fetchone should have retrieved a single row" ) self.assertEqual( r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" ) self.assertEqual( cur.fetchone(), None, "cursor.fetchone should return None if no more rows available", ) self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() samples = [ "Carlton Cold", "Carlton Draft", "Mountain Goat", "Redback", "Victoria Bitter", "XXXX", ] def _populate(self): """Return a list of sql commands to setup the DB for the fetch tests. """ populate = [ "insert into %sbooze values ('%s')" % (self.table_prefix, s) for s in self.samples ] return populate def test_fetchmany(self): con = self._connect() try: cur = con.cursor() # cursor.fetchmany should raise an Error if called without # issuing a query self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() self.assertEqual( len(r), 1, "cursor.fetchmany retrieved incorrect number of rows, " "default of arraysize is one.", ) cur.arraysize = 10 r = cur.fetchmany(3) # Should get 3 rows self.assertEqual( len(r), 3, "cursor.fetchmany retrieved incorrect number of rows" ) r = cur.fetchmany(4) # Should get 2 more self.assertEqual( len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" ) r = cur.fetchmany(4) # Should be an empty sequence self.assertEqual( len(r), 0, "cursor.fetchmany should return an empty sequence after " "results are exhausted", ) self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize cur.arraysize = 4 cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() # Should get 4 rows self.assertEqual( len(r), 4, "cursor.arraysize not being honoured by fetchmany" ) r = cur.fetchmany() # Should get 2 more self.assertEqual(len(r), 2) r = cur.fetchmany() # Should be an empty sequence self.assertEqual(len(r), 0) self.assertTrue(cur.rowcount in (-1, 6)) cur.arraysize = 6 cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchmany() # Should get all rows self.assertTrue(cur.rowcount in (-1, 6)) self.assertEqual(len(rows), 6) self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() # Make sure we get the right data back out for i in range(0, 6): self.assertEqual( rows[i], self.samples[i], "incorrect data retrieved by cursor.fetchmany", ) rows = cur.fetchmany() # Should return an empty list self.assertEqual( len(rows), 0, "cursor.fetchmany should return an empty sequence if " "called after the whole result set has been fetched", ) self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) cur.execute("select name from %sbarflys" % self.table_prefix) r = cur.fetchmany() # Should get empty sequence self.assertEqual( len(r), 0, "cursor.fetchmany should return an empty sequence if " "query retrieved no rows", ) self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() def test_fetchall(self): con = self._connect() try: cur = con.cursor() # cursor.fetchall should raise an Error if called # without executing a query that may return rows (such # as a select) self.assertRaises(self.driver.Error, cur.fetchall) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows self.assertRaises(self.driver.Error, cur.fetchall) cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.assertEqual( len(rows), len(self.samples), "cursor.fetchall did not retrieve all rows", ) rows = [r[0] for r in rows] rows.sort() for i in range(0, len(self.samples)): self.assertEqual( rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( len(rows), 0, "cursor.fetchall should return an empty list if called " "after the whole result set has been fetched", ) self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() self.assertTrue(cur.rowcount in (-1, 0)) self.assertEqual( len(rows), 0, "cursor.fetchall should return an empty list if " "a select query returns no rows", ) finally: con.close() def test_mixedfetch(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) cur.execute("select name from %sbooze" % self.table_prefix) rows1 = cur.fetchone() rows23 = cur.fetchmany(2) rows4 = cur.fetchone() rows56 = cur.fetchall() self.assertTrue(cur.rowcount in (-1, 6)) self.assertEqual( len(rows23), 2, "fetchmany returned incorrect number of rows" ) self.assertEqual( len(rows56), 2, "fetchall returned incorrect number of rows" ) rows = [rows1[0]] rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() for i in range(0, len(self.samples)): self.assertEqual( rows[i], self.samples[i], "incorrect data retrieved or inserted" ) finally: con.close() def help_nextset_setUp(self, cur): """Should create a procedure called deleteme that returns two result sets, first the number of rows in booze then "name from booze" """ raise NotImplementedError("Helper not implemented") # sql=""" # create procedure deleteme as # begin # select count(*) from booze # select name from booze # end # """ # cur.execute(sql) def help_nextset_tearDown(self, cur): "If cleaning up is needed after nextSetTest" raise NotImplementedError("Helper not implemented") # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) cur.callproc("deleteme") numberofrows = cur.fetchone() self.assertEqual(numberofrows[0], len(self.samples)) assert cur.nextset() names = cur.fetchall() assert len(names) == len(self.samples) s = cur.nextset() assert s == None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) finally: con.close() # def test_nextset(self): # raise NotImplementedError('Drivers need to override this test') def test_arraysize(self): # Not much here - rest of the tests for this are in test_fetchmany con = self._connect() try: cur = con.cursor() self.assertTrue( hasattr(cur, "arraysize"), "cursor.arraysize must be defined" ) finally: con.close() def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() cur.setinputsizes((25,)) self._paraminsert(cur) # Make sure cursor still works finally: con.close() def test_setoutputsize_basic(self): # Basic test is to make sure setoutputsize doesn't blow up con = self._connect() try: cur = con.cursor() cur.setoutputsize(1000) cur.setoutputsize(2000, 0) self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): # Real test for setoutputsize is driver dependant raise NotImplementedError("Driver needed to override this test") def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) cur.execute("insert into %sbooze values (NULL)" % self.table_prefix) cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchall() self.assertEqual(len(r), 1) self.assertEqual(len(r[0]), 1) self.assertEqual(r[0][0], None, "NULL value not returned as None") finally: con.close() def test_Date(self): d1 = self.driver.Date(2002, 12, 25) d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): t1 = self.driver.Time(13, 45, 30) t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): b = self.driver.Binary(str2bytes("Something")) b = self.driver.Binary(str2bytes("")) def test_STRING(self): self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") def test_BINARY(self): self.assertTrue( hasattr(self.driver, "BINARY"), "module.BINARY must be defined." ) def test_NUMBER(self): self.assertTrue( hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." ) def test_DATETIME(self): self.assertTrue( hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." ) def test_ROWID(self): self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") pytds-1.15.0/tests/fixtures.py000066400000000000000000000034621456567501500163510ustar00rootroot00000000000000import logging import pytest import sqlalchemy.engine from sqlalchemy import create_engine import pytds.tds_socket import settings import sqlalchemy_schema import utils logger = logging.getLogger(__name__) LIVE_TEST = getattr(settings, "LIVE_TEST", True) pytds.tds_base.logging_enabled = True @pytest.fixture(scope="module") def db_connection(sqlalchemy_engine): if not LIVE_TEST: pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = settings.DATABASE conn = pytds.connect(*settings.CONNECT_ARGS, **kwargs) utils.create_test_database(connection=conn) conn.commit() return conn @pytest.fixture def cursor(db_connection): with db_connection.cursor() as cursor: yield cursor db_connection.rollback() @pytest.fixture def separate_db_connection(): if not LIVE_TEST: pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = settings.DATABASE conn = pytds.connect(*settings.CONNECT_ARGS, **kwargs) yield conn conn.close() @pytest.fixture(scope="module") def collation_set(db_connection): with db_connection.cursor() as cursor: cursor.execute( "SELECT Name, Description, COLLATIONPROPERTY(Name, 'LCID') FROM ::fn_helpcollations()" ) collations_list = cursor.fetchall() return set(coll_name for coll_name, _, _ in collations_list) @pytest.fixture(scope="module") def sqlalchemy_engine() -> sqlalchemy.engine.Engine: host = settings.HOST hostname, _, instance = host.partition("\\") url = f"mssql+pytds://{settings.USER}:{settings.PASSWORD}@/{settings.DATABASE}?host={host}" engine = create_engine( url, echo=True, ) sqlalchemy_schema.Base.metadata.create_all(engine) return engine pytds-1.15.0/tests/instance_browser_client_test.py000066400000000000000000000010061456567501500224340ustar00rootroot00000000000000import pytds.instance_browser_client def test_get_instances(): data = b"\x05[\x00ServerName;MISHA-PC;InstanceName;SQLEXPRESS;IsClustered;No;Version;10.0.1600.22;tcp;49849;;" ref = { "SQLEXPRESS": { "ServerName": "MISHA-PC", "InstanceName": "SQLEXPRESS", "IsClustered": "No", "Version": "10.0.1600.22", "tcp": "49849", }, } instances = pytds.instance_browser_client.parse_instances_response(data) assert instances == ref pytds-1.15.0/tests/params_tests.py000066400000000000000000000351351456567501500172070ustar00rootroot00000000000000""" Testing various ways of passing parameters to queries """ import unittest import uuid from datetime import datetime, date, time from decimal import Decimal from io import StringIO from pytds import Column, connect from pytds.tds_base import Param, default, output from pytds.tds_types import ( BitType, TinyIntType, SmallIntType, IntType, BigIntType, RealType, FloatType, SmallDateTimeType, DateTimeType, DateType, TimeType, DateTime2Type, DateTimeOffsetType, DecimalType, SmallMoneyType, MoneyType, UniqueIdentifierType, VariantType, VarBinaryType, VarCharType, NVarCharType, TextType, NTextType, ImageType, VarBinaryMaxType, NVarCharMaxType, VarCharMaxType, XmlType, ) from fixtures import * from pytds.tz import utc from tests.all_test import tzoffset def test_param_as_column_backward_compat(cursor): """ For backward compatibility need to support passing parameters as Column objects New way to pass such parameters is to use Param object. """ param = Column(type=BitType(), value=True) result = cursor.execute_scalar("select %s", [param]) assert result is True def test_param_with_spaces(cursor): """ For backward compatibility need to support passing parameters as Column objects New way to pass such parameters is to use Param object. """ result = cursor.execute_scalar("select %(param name)s", {"param name": "abc"}) assert result == "abc" def test_param_with_slashes(cursor): """ For backward compatibility need to support passing parameters as Column objects New way to pass such parameters is to use Param object. """ result = cursor.execute_scalar("select %(param/name)s", {"param/name": "abc"}) assert result == "abc" def test_dictionary_params(cursor): assert cursor.execute_scalar("select %(param)s", {"param": None}) == None assert cursor.execute_scalar("select %(param)s", {"param": 1}) == 1 def test_overlimit(cursor): def test_val(val): cursor.execute("select %s", (val,)) assert cursor.fetchone() == (val,) assert cursor.fetchone() is None ##cur.execute('select %s', '\x00'*(2**31)) with pytest.raises(pytds.DataError): test_val(Decimal("1" + "0" * 38)) with pytest.raises(pytds.DataError): test_val(Decimal("-1" + "0" * 38)) with pytest.raises(pytds.DataError): test_val(Decimal("1E38")) val = -(10**38) cursor.execute("select %s", (val,)) assert cursor.fetchone() == (str(val),) assert cursor.fetchone() is None def test_outparam_and_result_set(cursor): """ Test stored procedure which has output parameters and also result set """ cur = cursor logger.info("creating stored procedure") cur.execute( """ CREATE PROCEDURE P_OutParam_ResultSet(@A INT OUTPUT) AS BEGIN SET @A = 3; SELECT 4 AS C; SELECT 5 AS C; END; """ ) logger.info("executing stored procedure") cur.callproc("P_OutParam_ResultSet", [pytds.output(value=1)]) assert [(4,)] == cur.fetchall() assert [3] == cur.get_proc_outputs() logger.info("execurint query after stored procedure") cur.execute("select 5") assert [(5,)] == cur.fetchall() def test_outparam_null_default(cursor): with pytest.raises(ValueError): pytds.output(None, None) cur = cursor cur.execute( """ create procedure outparam_null_testproc (@inparam int, @outint int = 8 output, @outstr varchar(max) = 'defstr' output) as begin set nocount on set @outint = isnull(@outint, -10) + @inparam set @outstr = isnull(@outstr, 'null') + cast(@inparam as varchar(max)) set @inparam = 8 end """ ) values = cur.callproc( "outparam_null_testproc", (1, pytds.output(value=4), pytds.output(value="str")) ) assert [1, 5, "str1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=None, param_type="int"), pytds.output(value=None, param_type="varchar(max)"), ), ) assert [1, -9, "null1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type="int"), pytds.output(value=pytds.default, param_type="varchar(max)"), ), ) assert [1, 9, "defstr1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type="bit"), pytds.output(value=pytds.default, param_type="varchar(5)"), ), ) assert [1, 1, "defst"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type=int), pytds.output(value=pytds.default, param_type=str), ), ) assert [1, 9, "defstr1"] == values def _params_tests(self): def test_val(typ, val): with self.conn.cursor() as cur: param = Param(type=typ, value=val) logger.info("Testing with %s", repr(param)) cur.execute("select %s", [param]) self.assertTupleEqual(cur.fetchone(), (val,)) self.assertIs(cur.fetchone(), None) test_val(BitType(), True) test_val(BitType(), False) test_val(BitType(), None) test_val(TinyIntType(), 255) test_val(SmallIntType(), 2**15 - 1) test_val(IntType(), 2**31 - 1) test_val(BigIntType(), 2**63 - 1) test_val(IntType(), None) test_val(RealType(), 0.25) test_val(FloatType(), 0.25) test_val(RealType(), None) test_val(SmallDateTimeType(), datetime(1900, 1, 1, 0, 0, 0)) test_val(SmallDateTimeType(), datetime(2079, 6, 6, 23, 59, 0)) test_val(DateTimeType(), datetime(1753, 1, 1, 0, 0, 0)) test_val(DateTimeType(), datetime(9999, 12, 31, 23, 59, 59, 990000)) test_val(DateTimeType(), None) if pytds.tds_base.IS_TDS73_PLUS(self.conn._tds_socket): test_val(DateType(), date(1, 1, 1)) test_val(DateType(), date(9999, 12, 31)) test_val(DateType(), None) test_val(TimeType(precision=0), time(0, 0, 0)) test_val(TimeType(precision=6), time(23, 59, 59, 999999)) test_val(TimeType(precision=0), None) test_val(DateTime2Type(precision=0), datetime(1, 1, 1, 0, 0, 0)) test_val(DateTime2Type(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999)) test_val(DateTime2Type(precision=0), None) test_val( DateTimeOffsetType(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999, utc), ) test_val( DateTimeOffsetType(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999, tzoffset(14)), ) test_val( DateTimeOffsetType(precision=0), datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(-14)), ) # test_val(DateTimeOffsetType(precision=0), datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(14))) test_val(DateTimeOffsetType(precision=6), None) test_val(DecimalType(scale=6, precision=38), Decimal("123.456789")) test_val(DecimalType(scale=6, precision=38), None) test_val(SmallMoneyType(), Decimal("-214748.3648")) test_val(SmallMoneyType(), Decimal("214748.3647")) test_val(MoneyType(), Decimal("922337203685477.5807")) test_val(MoneyType(), Decimal("-922337203685477.5808")) test_val(MoneyType(), None) test_val(UniqueIdentifierType(), None) test_val(UniqueIdentifierType(), uuid.uuid4()) if pytds.tds_base.IS_TDS71_PLUS(self.conn._tds_socket): test_val(VariantType(), None) # test_val(self.conn._conn.type_factory.SqlVariant(10), 100) test_val(VarBinaryType(size=10), b"") test_val(VarBinaryType(size=10), b"testtest12") test_val(VarBinaryType(size=10), None) test_val(VarBinaryType(size=8000), b"x" * 8000) test_val(VarCharType(size=10), None) test_val(VarCharType(size=10), "") test_val(VarCharType(size=10), "test") test_val(VarCharType(size=8000), "x" * 8000) test_val(NVarCharType(size=10), "") test_val(NVarCharType(size=10), "testtest12") test_val(NVarCharType(size=10), None) test_val(NVarCharType(size=4000), "x" * 4000) test_val(TextType(), None) test_val(TextType(), "") test_val(TextType(), "hello") test_val(NTextType(), None) test_val(NTextType(), "") test_val(NTextType(), "hello") test_val(ImageType(), None) test_val(ImageType(), b"") test_val(ImageType(), b"test") if pytds.tds_base.IS_TDS72_PLUS(self.conn._tds_socket): test_val(VarBinaryMaxType(), None) test_val(VarBinaryMaxType(), b"") test_val(VarBinaryMaxType(), b"testtest12") test_val(VarBinaryMaxType(), b"x" * (10**6)) test_val(NVarCharMaxType(), None) test_val(NVarCharMaxType(), "test") test_val(NVarCharMaxType(), "x" * (10**6)) test_val(VarCharMaxType(), None) test_val(VarCharMaxType(), "test") test_val(VarCharMaxType(), "x" * (10**6)) test_val(XmlType(), "") @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds70(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["tds_version"] = pytds.tds_base.TDS70 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): _params_tests(self) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds71(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = settings.DATABASE kwargs["tds_version"] = pytds.tds_base.TDS71 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) utils.create_test_database(self.conn) self.conn.commit() def test_parsing(self): _params_tests(self) def test_bulk(self): f = StringIO("42\tfoo\n74\tbar\n") with self.conn.cursor() as cur: cur.copy_to( f, "bulk_insert_table", schema="myschema", columns=("num", "data") ) cur.execute("select num, data from myschema.bulk_insert_table") self.assertListEqual(cur.fetchall(), [(42, "foo"), (74, "bar")]) def test_call_proc(self): with self.conn.cursor() as cur: val = 45 values = cur.callproc("testproc", (val, default, output(value=1))) # self.assertEqual(cur.fetchall(), [(val,)]) self.assertEqual(val + 2, values[2]) self.assertEqual(val + 2, cur.get_proc_return_status()) def test_outparam_and_result_set(cursor): """ Test stored procedure which has output parameters and also result set """ cur = cursor logger.info("creating stored procedure") cur.execute( """ CREATE PROCEDURE P_OutParam_ResultSet(@A INT OUTPUT) AS BEGIN SET @A = 3; SELECT 4 AS C; SELECT 5 AS C; END; """ ) logger.info("executing stored procedure") cur.callproc("P_OutParam_ResultSet", [pytds.output(value=1)]) assert [(4,)] == cur.fetchall() assert [3] == cur.get_proc_outputs() logger.info("execurint query after stored procedure") cur.execute("select 5") assert [(5,)] == cur.fetchall() def test_outparam_null_default(cursor): with pytest.raises(ValueError): pytds.output(None, None) cur = cursor cur.execute( """ create procedure outparam_null_testproc (@inparam int, @outint int = 8 output, @outstr varchar(max) = 'defstr' output) as begin set nocount on set @outint = isnull(@outint, -10) + @inparam set @outstr = isnull(@outstr, 'null') + cast(@inparam as varchar(max)) set @inparam = 8 end """ ) values = cur.callproc( "outparam_null_testproc", (1, pytds.output(value=4), pytds.output(value="str")) ) assert [1, 5, "str1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=None, param_type="int"), pytds.output(value=None, param_type="varchar(max)"), ), ) assert [1, -9, "null1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type="int"), pytds.output(value=pytds.default, param_type="varchar(max)"), ), ) assert [1, 9, "defstr1"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type="bit"), pytds.output(value=pytds.default, param_type="varchar(5)"), ), ) assert [1, 1, "defst"] == values values = cur.callproc( "outparam_null_testproc", ( 1, pytds.output(value=pytds.default, param_type=int), pytds.output(value=pytds.default, param_type=str), ), ) assert [1, 9, "defstr1"] == values @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds72(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["tds_version"] = pytds.tds_base.TDS72 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): _params_tests(self) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds73A(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["tds_version"] = pytds.tds_base.TDS73A self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): _params_tests(self) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds73B(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" kwargs["tds_version"] = pytds.tds_base.TDS73B self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): _params_tests(self) class TestCaseWithCursor(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() kwargs["database"] = "master" self.conn = connect(*settings.CONNECT_ARGS, **kwargs) # def test_mars_sessions_recycle_ids(self): # if not self.conn.mars_enabled: # self.skipTest('Only relevant to mars') # for _ in range(2 ** 16 + 1): # cur = self.conn.cursor() # cur.close() def test_parameters_ll(self): _params_tests(self) pytds-1.15.0/tests/settings.py000066400000000000000000000035301456567501500163340ustar00rootroot00000000000000import os import json CONNECT_ARGS = [] CONNECT_KWARGS = {} connection_json_path = os.path.join(os.path.dirname(__file__), ".connection.json") if os.path.exists(connection_json_path): conf = json.load(open(connection_json_path, "rb")) default_host = conf["host"] default_database = conf["database"] default_user = conf["sqluser"] default_password = conf["sqlpassword"] default_use_mars = conf["use_mars"] default_auth = conf.get("auth") default_cafile = conf.get("cafile") else: default_host = None default_database = "test" default_user = "sa" default_password = "sa" default_use_mars = True default_auth = None default_cafile = None LIVE_TEST = "HOST" in os.environ or default_host if LIVE_TEST: HOST = os.environ.get("HOST", default_host) DATABASE = os.environ.get("DATABASE", default_database) USER = os.environ.get("SQLUSER", default_user) PASSWORD = os.environ.get("SQLPASSWORD", default_password) USE_MARS = bool(os.environ.get("USE_MARS", default_use_mars)) SKIP_SQL_AUTH = bool(os.environ.get("SKIP_SQL_AUTH")) import pytds CONNECT_KWARGS = { "server": HOST, "database": DATABASE, "user": USER, "password": PASSWORD, "use_mars": USE_MARS, "bytes_to_unicode": True, "pooling": True, "timeout": 30, "cafile": default_cafile, } if default_auth: CONNECT_KWARGS["auth"] = getattr(pytds.login, default_auth)() if "tds_version" in os.environ: CONNECT_KWARGS["tds_version"] = getattr(pytds, os.environ["tds_version"]) if "auth" in os.environ: import pytds.login CONNECT_KWARGS["auth"] = getattr(pytds.login, os.environ["auth"])() if "bytes_to_unicode" in os.environ: CONNECT_KWARGS["bytes_to_unicode"] = bool(os.environ.get("bytes_to_unicode")) pytds-1.15.0/tests/simple_server.py000066400000000000000000000237271456567501500173650ustar00rootroot00000000000000import socket import socketserver import struct import logging import OpenSSL.SSL import pytds.tds_socket import pytds.tds_reader import pytds.tds_writer import pytds.collate _BYTE_STRUCT = struct.Struct("B") _OFF_LEN_STRUCT = struct.Struct(">HH") _PROD_VER_STRUCT = struct.Struct(">LH") logger = logging.getLogger(__name__) class TdsParser: def bad_stream(self, msg): # TODO use different exception class raise Exception(msg) def parse_prelogin(self, buf): # https://msdn.microsoft.com/en-us/library/dd357559.aspx size = len(buf) i = 0 result = {} while True: value = None if i >= size: self.bad_stream("Invalid size of PRELOGIN structure") (type_id,) = _BYTE_STRUCT.unpack_from(buf, i) if type_id == pytds.tds_base.PreLoginToken.TERMINATOR: break if i + 4 > size: self.bad_stream("Invalid size of PRELOGIN structure") off, l = _OFF_LEN_STRUCT.unpack_from(buf, i + 1) if off > size or off + l > size: self.bad_stream("Invalid offset in PRELOGIN structure") if type_id == pytds.tds_base.PreLoginToken.VERSION: value = _PROD_VER_STRUCT.unpack_from(buf, off) elif type_id == pytds.tds_base.PreLoginToken.ENCRYPTION: value = _BYTE_STRUCT.unpack_from(buf, off)[0] elif type_id == pytds.tds_base.PreLoginToken.MARS: value = bool(_BYTE_STRUCT.unpack_from(buf, off)[0]) elif type_id == pytds.tds_base.PreLoginToken.INSTOPT: value = buf[off : off + l].decode("ascii") i += 5 result[type_id] = value return result class TdsGenerator: def generate_prelogin(self, prelogin): hdr_size = (1 + _OFF_LEN_STRUCT.size) * len(prelogin) + 1 buf = bytearray([0] * hdr_size) hdr_offset = 0 data_offset = hdr_size for type_id, value in prelogin.items(): if type_id == pytds.tds_base.PreLoginToken.VERSION: packed = _PROD_VER_STRUCT.pack(value) elif type_id == pytds.tds_base.PreLoginToken.ENCRYPTION: packed = [value] elif type_id == pytds.tds_base.PreLoginToken.MARS: packed = [1 if value else 0] elif type_id == pytds.tds_base.PreLoginToken.INSTOPT: packed = value.encode("ascii") else: raise Exception( f"not implemented prelogin option {type_id} in prelogin message generator" ) data_size = len(packed) buf[hdr_offset] = type_id hdr_offset += 1 _OFF_LEN_STRUCT.pack_into(buf, hdr_offset, data_offset, data_size) hdr_offset += _OFF_LEN_STRUCT.size buf.extend(packed) data_offset += data_size buf[hdr_offset] = pytds.tds_base.PreLoginToken.TERMINATOR return buf class Sock: # wraps request in class compatible with TdsSocket def __init__(self, req): self._req = req def recv(self, size): return self._req.recv(size) def recv_into(self, buffer, size=0): return self._req.recv_into(buffer, size) def sendall(self, data, flags=0): return self._req.sendall(data, flags) class RequestHandler(socketserver.StreamRequestHandler): def handle(self): parser = TdsParser() gen = TdsGenerator() bufsize = 4096 # TdsReader expects this self._transport = Sock(self.request) r = pytds.tds_reader._TdsReader(tds_session=self, transport=self._transport) w = pytds.tds_writer._TdsWriter( tds_session=self, bufsize=bufsize, transport=self._transport ) resp_header = r.begin_response() buf = r.read_whole_packet() if resp_header.type != pytds.tds_base.PacketType.PRELOGIN: msg = "Invalid packet type: {0}, expected PRELOGIN({1})".format( r.packet_type, pytds.tds_base.PacketType.PRELOGIN ) self.bad_stream(msg) prelogin = parser.parse_prelogin(buf) logger.info(f"received prelogin message from client {prelogin}") srv_enc = self.server._enc cli_enc = prelogin[pytds.tds_base.PreLoginToken.ENCRYPTION] res_enc = None close_conn = False if srv_enc == pytds.PreLoginEnc.ENCRYPT_OFF: if cli_enc == pytds.PreLoginEnc.ENCRYPT_OFF: res_enc = pytds.PreLoginEnc.ENCRYPT_OFF elif cli_enc == pytds.PreLoginEnc.ENCRYPT_ON: res_enc = pytds.PreLoginEnc.ENCRYPT_ON elif cli_enc == pytds.PreLoginEnc.ENCRYPT_NOT_SUP: res_enc = pytds.PreLoginEnc.ENCRYPT_NOT_SUP elif srv_enc == pytds.PreLoginEnc.ENCRYPT_ON: if cli_enc == pytds.PreLoginEnc.ENCRYPT_OFF: res_enc = pytds.PreLoginEnc.ENCRYPT_REQ elif cli_enc == pytds.PreLoginEnc.ENCRYPT_ON: res_enc = pytds.PreLoginEnc.ENCRYPT_ON elif cli_enc == pytds.PreLoginEnc.ENCRYPT_NOT_SUP: res_enc = pytds.PreLoginEnc.ENCRYPT_REQ close_conn = True elif srv_enc == pytds.PreLoginEnc.ENCRYPT_NOT_SUP: if cli_enc == pytds.PreLoginEnc.ENCRYPT_OFF: res_enc = pytds.PreLoginEnc.ENCRYPT_NOT_SUP elif cli_enc == pytds.PreLoginEnc.ENCRYPT_ON: res_enc = pytds.PreLoginEnc.ENCRYPT_NOT_SUP close_conn = True elif cli_enc == pytds.PreLoginEnc.ENCRYPT_NOT_SUP: res_enc = pytds.PreLoginEnc.ENCRYPT_NOT_SUP elif srv_enc == pytds.PreLoginEnc.ENCRYPT_REQ: res_enc = pytds.PreLoginEnc.ENCRYPT_REQ # sending reply to client's prelogin packet prelogin_resp = gen.generate_prelogin( { pytds.tds_base.PreLoginToken.ENCRYPTION: res_enc, } ) w.begin_packet(pytds.tds_base.PacketType.REPLY) w.write(prelogin_resp) w.flush() if close_conn: return wrapped_socket = None if res_enc != pytds.PreLoginEnc.ENCRYPT_NOT_SUP: # setup TLS connection tlsconn = OpenSSL.SSL.Connection(self.server._tls_ctx) tlsconn.set_accept_state() done = False while not done: try: tlsconn.do_handshake() except OpenSSL.SSL.WantReadError: try: buf = tlsconn.bio_read(bufsize) except OpenSSL.SSL.WantReadError: pass else: w.begin_packet(pytds.tds_base.PacketType.PRELOGIN) w.write(buf) w.flush() r.begin_response() buf = r.read_whole_packet() tlsconn.bio_write(buf) else: done = True try: buf = tlsconn.bio_read(bufsize) except OpenSSL.SSL.WantReadError: pass else: w.begin_packet(pytds.tds_base.PacketType.PRELOGIN) w.write(buf) w.flush() wrapped_socket = pytds.tls.EncryptedSocket( transport=self.request, tls_conn=tlsconn ) r._transport = wrapped_socket w._transport = wrapped_socket try: r.begin_response() buf = r.read_whole_packet() except pytds.tds_base.ClosedConnectionError: logger.info( "client closed connection, probably did not like server certificate" ) return logger.info(f"received login packet from client {buf}") if res_enc == pytds.PreLoginEnc.ENCRYPT_OFF: wrapped_socket.shutdown() r._transport = self._transport w._transport = self._transport srv_name = "Simple TDS Server" srv_ver = (1, 0, 0, 0) tds_version = self.server._tds_version w.begin_packet(pytds.tds_base.PacketType.REPLY) # https://msdn.microsoft.com/en-us/library/dd340651.aspx srv_name_coded, _ = pytds.collate.ucs2_codec.encode(srv_name) srv_name_size = len(srv_name_coded) w.put_byte(pytds.tds_base.TDS_LOGINACK_TOKEN) size = 1 + 4 + 1 + srv_name_size + 4 w.put_usmallint(size) w.put_byte(1) # interface w.put_uint_be(tds_version) w.put_byte(len(srv_name)) w.write(srv_name_coded) w.put_byte(srv_ver[0]) w.put_byte(srv_ver[1]) w.put_byte(srv_ver[2]) w.put_byte(srv_ver[3]) # https://msdn.microsoft.com/en-us/library/dd340421.aspx w.put_byte(pytds.tds_base.TDS_DONE_TOKEN) w.put_usmallint(0) # status w.put_usmallint(0) # curcmd w.put_uint8(0) # done row count w.flush() def bad_stream(self, msg): raise Exception(msg) class SimpleServer(socketserver.TCPServer): allow_reuse_address = True def __init__( self, address, enc, cert=None, pkey=None, tds_version=pytds.tds_base.TDS74 ): self._enc = enc super().__init__(address, RequestHandler) ctx = None if cert and pkey: ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD) ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2) ctx.set_options(OpenSSL.SSL.OP_NO_SSLv3) ctx.use_certificate(cert) ctx.use_privatekey(pkey) self._tls_ctx = ctx self._tds_version = tds_version def set_ssl_context(self, ctx): self._tls_ctx = ctx def set_enc(self, enc): self._enc = enc def run(address): logger.info("Starting server...") with SimpleServer(address) as server: logger.info("Press Ctrl+C to stop the server") server.serve_forever() if __name__ == "__main__": run() pytds-1.15.0/tests/smp_test.py000066400000000000000000000135151456567501500163360ustar00rootroot00000000000000import unittest import struct import pytest import pytds from pytds.smp import * from utils import MockSock smp_hdr = struct.Struct(" str: return f"User(id={self.id!r}, name={self.name!r}, fullname={self.fullname!r})" class Address(Base): __tablename__ = "address" id: Mapped[int] = mapped_column(primary_key=True) email_address: Mapped[str] user_id: Mapped[int] = mapped_column(ForeignKey("user_account.id")) user: Mapped["User"] = relationship(back_populates="addresses") def __repr__(self) -> str: return f"Address(id={self.id!r}, email_address={self.email_address!r})" pytds-1.15.0/tests/sqlalchemy_test.py000066400000000000000000000020701456567501500176730ustar00rootroot00000000000000""" Testing integration with sqlalchemy """ import sqlalchemy.engine from sqlalchemy.orm import Session from sqlalchemy import select from sqlalchemy_schema import User, Address from fixtures import sqlalchemy_engine def test_alchemy(sqlalchemy_engine: sqlalchemy.engine.Engine): with Session(sqlalchemy_engine) as session: spongebob = User( name="spongebob", fullname="Spongebob Squarepants", addresses=[Address(email_address="spongebob@sqlalchemy.org")], ) sandy = User( name="sandy", fullname="Sandy Cheeks", addresses=[ Address(email_address="sandy@sqlalchemy.org"), Address(email_address="sandy@squirrelpower.org"), ], ) patrick = User(name="patrick", fullname="Patrick Star") session.add_all([spongebob, sandy, patrick]) # Test select stmt = select(User).where(User.name.in_(["spongebob", "sandy"])) users = session.scalars(stmt) assert [spongebob, sandy] == list(users) pytds-1.15.0/tests/sspi_test.py000066400000000000000000000062551456567501500165200ustar00rootroot00000000000000try: import unittest2 as unittest except: import unittest import ctypes from ctypes import create_string_buffer import settings import socket import sys @unittest.skipUnless(sys.platform.startswith("win"), "requires Windows") class SspiTest(unittest.TestCase): def test_enum_security_packages(self): import pytds.sspi pytds.sspi.enum_security_packages() def test_credentials(self): import pytds.sspi cred = pytds.sspi.SspiCredentials("Negotiate", pytds.sspi.SECPKG_CRED_OUTBOUND) cred.query_user_name() cred.close() def test_make_buffers(self): import pytds.sspi buf = create_string_buffer(1000) bufs = [(pytds.sspi.SECBUFFER_TOKEN, buf)] desc = pytds.sspi._make_buffers_desc(bufs) self.assertEqual(desc.ulVersion, pytds.sspi.SECBUFFER_VERSION) self.assertEqual(desc.cBuffers, len(bufs)) self.assertEqual(desc.pBuffers[0].cbBuffer, len(bufs[0][1])) self.assertEqual(desc.pBuffers[0].BufferType, bufs[0][0]) self.assertEqual( desc.pBuffers[0].pvBuffer, ctypes.cast(bufs[0][1], pytds.sspi.PVOID).value ) def test_sec_context(self): import pytds.sspi cred = pytds.sspi.SspiCredentials("Negotiate", pytds.sspi.SECPKG_CRED_OUTBOUND) token_buf = create_string_buffer(10000) bufs = [(pytds.sspi.SECBUFFER_TOKEN, token_buf)] server = settings.HOST if "\\" in server: server, _ = server.split("\\") host, _, _ = socket.gethostbyname_ex(server) target_name = "MSSQLSvc/{0}:1433".format(host) ctx, status, bufs = cred.create_context( flags=pytds.sspi.ISC_REQ_CONFIDENTIALITY | pytds.sspi.ISC_REQ_REPLAY_DETECT | pytds.sspi.ISC_REQ_CONNECTION, byte_ordering="network", target_name=target_name, output_buffers=bufs, ) if ( status == pytds.sspi.Status.SEC_I_COMPLETE_AND_CONTINUE or status == pytds.sspi.Status.SEC_I_CONTINUE_NEEDED ): ctx.complete_auth_token(bufs) # realbuf = create_string_buffer(10000) # buf = SecBuffer() # buf.cbBuffer = len(realbuf) # buf.BufferType = SECBUFFER_TOKEN # buf.pvBuffer = cast(realbuf, PVOID) # bufs = SecBufferDesc() # bufs.ulVersion = SECBUFFER_VERSION # bufs.cBuffers = 1 # bufs.pBuffers = pointer(buf) # byte_ordering = 'network' # output_buffers = bufs # from pytds.sspi import _SecContext # ctx = _SecContext() # ctx._handle = SecHandle() # ctx._ts = TimeStamp() # ctx._attrs = ULONG() # status = sec_fn.InitializeSecurityContext( # ctypes.byref(cred._handle), # None, # 'MSSQLSvc/misha-pc:1433', # ISC_REQ_CONNECTION, # 0, # SECURITY_NETWORK_DREP if byte_ordering == 'network' else SECURITY_NATIVE_DREP, # None, # 0, # byref(ctx._handle), # byref(bufs), # byref(ctx._attrs), # byref(ctx._ts)); # pass pytds-1.15.0/tests/tds_reader_test.py000066400000000000000000000111341456567501500176460ustar00rootroot00000000000000import pytest from pytds.tds_base import PacketType, _header, ClosedConnectionError from pytds.tds_reader import _TdsReader from tests.utils import BytesSocket def test_reader(): """ Test normal flow for reader """ reader = _TdsReader( transport=BytesSocket( # Setup byte stream which contains two responses # First response consists of two packets _header.pack(PacketType.REPLY, 0, 8 + len(b"hello"), 123, 0) + b"hello" + # Second and last packet of first response _header.pack(PacketType.REPLY, 1, 8 + len(b"secondpacket"), 123, 0) + b"secondpacket" + # Second response consisting of single packet _header.pack(PacketType.TRANS, 1, 8 + len(b"secondresponse"), 123, 0) + b"secondresponse", ), tds_session=None, bufsize=200, ) # Reading without calling begin_response should return empty result indicating that stream is empty assert reader.recv(100) == b"" assert reader.get_block_size() == 200 response_header = reader.begin_response() assert response_header.type == PacketType.REPLY assert reader.packet_type == PacketType.REPLY assert response_header.spid == 123 assert b"hel" == reader.recv(3) assert b"lo" == reader.recv(2) assert b"secondpacket" == reader.recv(100) # should return empty byte array indicating end of stream once end is reached assert b"" == reader.recv(100) # Now start reading next response stream response_header2 = reader.begin_response() assert not reader.stream_finished() assert response_header2.type == PacketType.TRANS assert response_header2.spid == 123 assert reader.packet_type == PacketType.TRANS assert reader.recv(100) == b"secondresponse" assert reader.recv(100) == b"" assert reader.stream_finished() with pytest.raises(ClosedConnectionError): reader.begin_response() def test_read_fast(): """ Testing read_fast method """ reader = _TdsReader( transport=BytesSocket( # Setup byte stream which contains two responses # First response consists of two packets _header.pack(PacketType.REPLY, 0, 8 + len(b"hello"), 123, 0) + b"hello" + # Second and last packet of first response _header.pack(PacketType.REPLY, 1, 8 + len(b"secondpacket"), 123, 0) + b"secondpacket" ), tds_session=None, ) response_header2 = reader.begin_response() assert response_header2.type == PacketType.REPLY assert response_header2.spid == 123 # Testing fast_read functionality buf, offset = reader.read_fast(100) assert buf[offset : reader._pos] == b"hello" buf, offset = reader.read_fast(100) assert buf[offset : reader._pos] == b"secondpacket" assert reader.read_fast(100) == (b"", 0) assert reader.stream_finished() def test_begin_response_incorrectly(): """ Test that calling begin_response at wrong time issues an exception """ reader = _TdsReader( transport=BytesSocket( # First response consists of two packets _header.pack(PacketType.REPLY, 0, 8 + len(b"hello"), 123, 0) + b"hello" + # Second and last packet of first response _header.pack(PacketType.REPLY, 1, 8 + len(b"secondpacket"), 123, 0) + b"secondpacket" ), tds_session=None, ) response_header = reader.begin_response() # calling begin_response before consuming previous response stream should cause RuntimeError with pytest.raises( RuntimeError, match="begin_response was called before previous response was fully consumed", ): reader.begin_response() assert response_header.type == PacketType.REPLY assert response_header.spid == 123 # consume first packet of the response stream assert reader.recv(6) == b"hello" # calling begin_response before consuming previous response stream should cause RuntimeError with pytest.raises( RuntimeError, match="begin_response was called before previous response was fully consumed", ): reader.begin_response() # consume part of the second packet of the response stream assert reader.recv(3) == b"sec" # calling begin_response before consuming previous response stream should cause RuntimeError with pytest.raises( RuntimeError, match="begin_response was called before previous response was fully consumed", ): reader.begin_response() pytds-1.15.0/tests/tls_san_test.py000066400000000000000000000021301456567501500171710ustar00rootroot00000000000000from pytds.tls import is_san_matching def test_san(): assert not is_san_matching("", "host.com") assert is_san_matching("database.com", "database.com") assert not is_san_matching("notdatabase.com", "database.com") assert not is_san_matching("*.database.com", "database.com") assert is_san_matching("*.database.com", "test.database.com") assert not is_san_matching("database.com", "*.database.com") assert not is_san_matching("test.*.database.com", "test.subdomain.database.com") # That star should be at first position # test stripping DNS: assert is_san_matching("DNS:westus2-a.control.database.windows.net", "westus2-a.control.database.windows.net") assert is_san_matching("DNS:*.database.windows.net", "my-sql-server.database.windows.net") # test parsing multiple SANs assert is_san_matching("DNS:westus2-a.control.database.windows.net,DNS:*.database.windows.net", "my-sql-server.database.windows.net") assert is_san_matching("DNS:westus2-a.control.database.windows.net, DNS:*.database.windows.net", "my-sql-server.database.windows.net") pytds-1.15.0/tests/transaction_test.py000066400000000000000000000201571456567501500200640ustar00rootroot00000000000000import pytds import pytds.extensions import settings from fixtures import separate_db_connection from utils import tran_count, does_table_exist def test_rollback_commit(): """ Test calling rollback and commit with no changes """ conn = pytds.connect(*settings.CONNECT_ARGS, **settings.CONNECT_KWARGS) cursor = conn.cursor() cursor.execute("select 1") conn.rollback() conn.commit() def test_rollback_timeout_recovery(separate_db_connection): conn = separate_db_connection conn.autocommit = False with conn.cursor() as cur: cur.execute( """ create table testtable_rollback (field int) """ ) sql = "insert into testtable_rollback values " + ",".join(["(1)"] * 1000) for i in range(10): cur.execute(sql) conn._tds_socket.sock.settimeout(0.00001) try: conn.rollback() except: pass conn._tds_socket.sock.settimeout(10) cur = conn.cursor() cur.execute("select 1") cur.fetchall() def test_commit_timeout_recovery(separate_db_connection): conn = separate_db_connection conn.autocommit = False with conn.cursor() as cur: try: cur.execute("drop table testtable_commit_rec") except: pass cur.execute( """ create table testtable_commit_rec (field int) """ ) sql = "insert into testtable_commit_rec values " + ",".join(["(1)"] * 1000) for i in range(10): cur.execute(sql) conn._tds_socket.sock.settimeout(0.00001) try: conn.commit() except: pass conn._tds_socket.sock.settimeout(10) cur = conn.cursor() cur.execute("select 1") cur.fetchall() def test_autocommit_off(separate_db_connection): """ Testing autocommit off mode, making sure that new transaction is started immediately after previous one is committed or rolled back """ conn = separate_db_connection # using snapshot isolation level to prevent blocking between connections conn.isolation_level = pytds.extensions.ISOLATION_LEVEL_SNAPSHOT assert not conn.autocommit # second connection is used to observe effects of transaction on first connection conn2 = pytds.connect( **{ **settings.CONNECT_KWARGS, "isolation_level": pytds.extensions.ISOLATION_LEVEL_SNAPSHOT, "autocommit": True, } ) assert conn2.isolation_level == pytds.extensions.ISOLATION_LEVEL_SNAPSHOT assert conn2.autocommit # This connection can see changes which are made by other transactions and which are not yet committed conn_read_uncom = pytds.connect( **{ **settings.CONNECT_KWARGS, "isolation_level": pytds.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, "autocommit": False, } ) assert ( conn_read_uncom.isolation_level == pytds.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED ) assert not conn_read_uncom.autocommit with conn.cursor() as cur, conn2.cursor() as cur2, conn_read_uncom.cursor() as cur_read_uncom: try: cur.execute("drop table test_autocommit") except: pass conn.commit() cur.execute("create table test_autocommit(field int)") conn.commit() assert does_table_exist( cursor=cur2, name="test_autocommit", database="test" ), "table should exist now, since we committed creation" # New transaction should be started after committing previous transaction assert 1 == tran_count(cur) cur.execute("insert into test_autocommit(field) values(1)") assert 1 == tran_count(cur) cur.execute("select field from test_autocommit") assert cur.fetchall() == [(1,)] assert ( cur2.execute("select * from test_autocommit").fetchall() == [] ), "should not see created row from another connection since it is not committed yet" # Using read uncommitted level we should see changes from different connection assert cur_read_uncom.execute("select * from test_autocommit").fetchall() == [ (1,) ] # Now commit transaction, after that changes should be visible from other connections conn.commit() assert 1 == tran_count(cur) assert cur.execute("select * from test_autocommit").fetchall() == [(1,)] assert cur2.execute("select * from test_autocommit").fetchall() == [(1,)] # cleanup cur.execute("delete from test_autocommit") conn.commit() def test_autocommit_on(separate_db_connection): conn = separate_db_connection conn.autocommit = True # second connection is used to observe effects of transaction on first connection conn2 = pytds.connect(**settings.CONNECT_KWARGS) with conn.cursor() as cur, conn2.cursor() as cur2: # commit in autocommit mode should be a no-op conn.commit() # rollback in autocommit mode should be a no-op conn.rollback() # cleanup table before test cur.execute("delete from test_autocommit") # insert test data cur.execute("insert into test_autocommit(field) values(1)") assert 0 == tran_count(cur) # should see inserted record on other connection without calling commit on first connection assert cur2.execute("select * from test_autocommit").fetchall() == [(1,)] # cleanup table after test cur.execute("delete from test_autocommit") def test_isolation_level(separate_db_connection): """ Testing setting different isolation levels and verifying that they are set via querying MSSQL's sys.dm_exec_sessions view. """ conn = separate_db_connection conn.autocommit = False with conn.cursor() as cur: for level in [ pytds.extensions.ISOLATION_LEVEL_SERIALIZABLE, pytds.extensions.ISOLATION_LEVEL_SNAPSHOT, pytds.extensions.ISOLATION_LEVEL_READ_COMMITTED, pytds.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, pytds.extensions.ISOLATION_LEVEL_REPEATABLE_READ, ]: conn.isolation_level = level assert level == cur.execute_scalar( "select transaction_isolation_level " "from sys.dm_exec_sessions where session_id = @@SPID" ) def test_transactions(separate_db_connection): conn = separate_db_connection conn.autocommit = False with conn.cursor() as cur: cur.execute( """ create table testtable_trans (field datetime) """ ) cur.execute("select object_id('testtable_trans')") assert (None,) != cur.fetchone() assert 1 == tran_count(cur) conn.rollback() assert 1 == tran_count(cur) cur.execute("select object_id('testtable_trans')") assert (None,) == cur.fetchone() cur.execute( """ create table testtable_trans (field datetime) """ ) conn.commit() cur.execute("select object_id('testtable_trans')") assert (None,) != cur.fetchone() with conn.cursor() as cur: cur.execute( """ if object_id('testtable_trans') is not null drop table testtable_trans """ ) conn.commit() def test_manual_commit(separate_db_connection): conn = separate_db_connection conn.autocommit = False cur = conn.cursor() cur.execute("create table tbl(x int)") assert 1 == cur.execute_scalar( "select @@trancount" ), "Should be in transaction even after errors" assert conn._tds_socket.tds72_transaction try: cur.execute("create table tbl(x int)") except pytds.OperationalError: pass trancount = cur.execute_scalar("select @@trancount") assert 1 == trancount, "Should be in transaction even after errors" cur.execute("create table tbl(x int)") try: cur.execute("create table tbl(x int)") except: pass cur.callproc("sp_executesql", ("select @@trancount",)) (trancount,) = cur.fetchone() assert 1 == trancount, "Should be in transaction even after errors" pytds-1.15.0/tests/types_test.py000066400000000000000000000227671456567501500167140ustar00rootroot00000000000000# coding=utf-8 import datetime from decimal import Decimal, Context import uuid import pytest import pytds from fixtures import * @pytest.mark.parametrize( "sql_type", [ "tinyint", "smallint", "int", "bigint", "real", "float", "smallmoney", "money", "decimal", "varbinary(15)", "binary(15)", "nvarchar(15)", "nchar(15)", "varchar(15)", "char(15)", "bit", "smalldatetime", "date", "time", "datetime", "datetime2", "datetimeoffset", "uniqueidentifier", "sql_variant", ], ) def test_null_parameter(cursor, sql_type): cursor.execute( "set nocount on; declare @x {} = %s; select @x".format(sql_type), (None,) ) (val,) = cursor.fetchone() assert val is None def test_reading_values(cursor): cur = cursor with pytest.raises(pytds.ProgrammingError): cur.execute("select ") assert "abc" == cur.execute_scalar( "select cast('abc' as varchar(max)) as fieldname" ) assert "abc" == cur.execute_scalar( "select cast('abc' as nvarchar(max)) as fieldname" ) assert b"abc" == cur.execute_scalar( "select cast('abc' as varbinary(max)) as fieldname" ) # assert 12 == cur.execute_scalar('select cast(12 as bigint) as fieldname') assert 12 == cur.execute_scalar("select cast(12 as smallint) as fieldname") assert -12 == cur.execute_scalar("select -12 as fieldname") assert 12 == cur.execute_scalar("select cast(12 as tinyint) as fieldname") assert True == cur.execute_scalar("select cast(1 as bit) as fieldname") assert 5.1 == cur.execute_scalar("select cast(5.1 as float) as fieldname") cur.execute("select 'test', 20") assert ("test", 20) == cur.fetchone() assert "test" == cur.execute_scalar("select 'test' as fieldname") assert "test" == cur.execute_scalar("select N'test' as fieldname") assert "test" == cur.execute_scalar("select cast(N'test' as ntext) as fieldname") assert "test" == cur.execute_scalar("select cast(N'test' as text) as fieldname") assert "test " == cur.execute_scalar("select cast(N'test' as char(5)) as fieldname") assert "test " == cur.execute_scalar( "select cast(N'test' as nchar(5)) as fieldname" ) assert b"test" == cur.execute_scalar( "select cast('test' as varbinary(4)) as fieldname" ) assert b"test" == cur.execute_scalar("select cast('test' as image) as fieldname") assert None == cur.execute_scalar("select cast(NULL as image) as fieldname") assert None == cur.execute_scalar("select cast(NULL as varbinary(10)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as ntext) as fieldname") assert None == cur.execute_scalar("select cast(NULL as nvarchar(max)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as xml)") assert None is cur.execute_scalar("select cast(NULL as varchar(max)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as nvarchar(10)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as varchar(10)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as nchar(10)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as char(10)) as fieldname") assert None == cur.execute_scalar("select cast(NULL as char(10)) as fieldname") assert 5 == cur.execute_scalar("select 5 as fieldname") with pytest.raises(pytds.ProgrammingError) as ex: cur.execute_scalar("create table exec_scalar_empty(f int)") # message does not have to be exact match assert "Previous statement didn't produce any results" in str(ex.value) def test_money(cursor): cur = cursor assert Decimal("0") == cur.execute_scalar("select cast('0' as money) as fieldname") assert Decimal("1") == cur.execute_scalar("select cast('1' as money) as fieldname") assert Decimal("1.5555") == cur.execute_scalar( "select cast('1.5555' as money) as fieldname" ) assert Decimal("1234567.5555") == cur.execute_scalar( "select cast('1234567.5555' as money) as fieldname" ) assert Decimal("-1234567.5555") == cur.execute_scalar( "select cast('-1234567.5555' as money) as fieldname" ) assert Decimal("12345.55") == cur.execute_scalar( "select cast('12345.55' as smallmoney) as fieldname" ) def test_strs(cursor): cur = cursor assert isinstance(cur.execute_scalar("select 'test'"), str) @pytest.mark.parametrize( "val", [ "hello", "x" * 5000, "x" * 9000, 123, -123, 123.12, -123.12, 10**20, 10**38 - 1, -(10**38) + 1, datetime.datetime(2011, 2, 3, 10, 11, 12, 3000), Decimal("1234.567"), Decimal("1234000"), Decimal("9" * 38), Decimal("0." + "9" * 38), Decimal("-" + ("9" * 38), Context(prec=38)), Decimal("1E10"), Decimal("1E-10"), Decimal("0.{0}1".format("0" * 37)), None, "hello", "", pytds.Binary(b""), pytds.Binary(b"\x00\x01\x02"), pytds.Binary(b"x" * 9000), 2**63 - 1, False, True, uuid.uuid4(), "Iñtërnâtiônàlizætiøn1", "\U0001d6fc", ], ) def test_select_values(cursor, val): cursor.execute("select %s", (val,)) assert cursor.fetchone() == (val,) assert cursor.fetchone() is None uuid_val = uuid.uuid4() @pytest.mark.parametrize( "result,sql", [ (None, "cast(NULL as varchar)"), ("test", "cast('test' as varchar)"), ("test ", "cast('test' as char(5))"), ("test", "cast(N'test' as nvarchar)"), ("test ", "cast(N'test' as nchar(5))"), (Decimal("100.55555"), "cast(100.55555 as decimal(8,5))"), (Decimal("100.55555"), "cast(100.55555 as numeric(8,5))"), (b"test", "cast('test' as varbinary)"), (b"test\x00", "cast('test' as binary(5))"), ( datetime.datetime(2011, 2, 3, 10, 11, 12, 3000), "cast('2011-02-03T10:11:12.003' as datetime)", ), ( datetime.datetime(2011, 2, 3, 10, 11, 0), "cast('2011-02-03T10:11:00' as smalldatetime)", ), (uuid_val, "cast('{0}' as uniqueidentifier)".format(uuid_val)), (True, "cast(1 as bit)"), (128, "cast(128 as tinyint)"), (255, "cast(255 as tinyint)"), (-32000, "cast(-32000 as smallint)"), (2000000000, "cast(2000000000 as int)"), (2000000000000, "cast(2000000000000 as bigint)"), (0.12345, "cast(0.12345 as float)"), (0.25, "cast(0.25 as real)"), (Decimal("922337203685477.5807"), "cast('922,337,203,685,477.5807' as money)"), (Decimal("-214748.3648"), "cast('- 214,748.3648' as smallmoney)"), ], ) def test_sql_variant_round_trip(cursor, result, sql): if not pytds.tds_base.IS_TDS71_PLUS(cursor.connection): pytest.skip("Requires TDS7.1+") cursor.execute("select cast({0} as sql_variant)".format(sql)) (val,) = cursor.fetchone() assert result == val def test_collations(cursor, collation_set): coll_name_set = collation_set tests = [ ("Привет", "Cyrillic_General_BIN"), ("Привет", "Cyrillic_General_BIN2"), ("สวัสดี", "Thai_CI_AI"), ("你好", "Chinese_PRC_CI_AI"), ("こんにちは", "Japanese_CI_AI"), ("안녕하세요.", "Korean_90_CI_AI"), ("你好", "Chinese_Hong_Kong_Stroke_90_CI_AI"), ("cześć", "Polish_CI_AI"), ("Bonjour", "French_CI_AI"), ("Γεια σας", "Greek_CI_AI"), ("Merhaba", "Turkish_CI_AI"), ("שלום", "Hebrew_CI_AI"), ("مرحبا", "Arabic_CI_AI"), ("Sveiki", "Lithuanian_CI_AI"), ("chào", "Vietnamese_CI_AI"), ("ÄÅÆ", "SQL_Latin1_General_CP437_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_CP850_BIN"), ("ŠşĂ", "SQL_Slovak_CP1250_CS_AS_KI_WI"), ("ÁÂÀÃ", "SQL_Latin1_General_1251_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_Cp1_CS_AS_KI_WI"), ("ÁÂÀÃ", "SQL_Latin1_General_1253_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_1254_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_1255_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_1256_BIN"), ("ÁÂÀÃ", "SQL_Latin1_General_1257_BIN"), ("ÁÂÀÃ", "Latin1_General_100_BIN"), ] for s, coll in tests: if coll not in coll_name_set: logger.info("Skipping {}, not supported by current server".format(coll)) continue assert ( cursor.execute_scalar( "select cast(N'{}' collate {} as varchar(100))".format(s, coll) ) == s ) def skip_if_new_date_not_supported(conn): if not pytds.tds_base.IS_TDS73_PLUS(conn): pytest.skip( "Test requires new date types support, SQL 2008 or newer is required" ) def test_date(cursor): skip_if_new_date_not_supported(cursor.connection) date = pytds.Date(2012, 10, 6) cursor.execute("select %s", (date,)) assert cursor.fetchall() == [(date,)] def test_time(cursor): skip_if_new_date_not_supported(cursor.connection) time = pytds.Time(8, 7, 4, 123000) cursor.execute("select %s", (time,)) assert cursor.fetchall() == [(time,)] def test_datetime(cursor): time = pytds.Timestamp(2013, 7, 9, 8, 7, 4, 123000) cursor.execute("select %s", (time,)) assert cursor.fetchall() == [(time,)] pytds-1.15.0/tests/tz_test.py000066400000000000000000000005701456567501500161710ustar00rootroot00000000000000import datetime from pytds import tz def test_tz(): assert tz.FixedOffsetTimezone(0, "UTC").tzname(None) == "UTC" lz = tz.LocalTimezone() jan_1 = datetime.datetime(2010, 1, 1, 0, 0) july_1 = datetime.datetime(2010, 7, 1, 0, 0) assert isinstance(lz.tzname(jan_1), str) lz.dst(jan_1) lz.dst(july_1) lz.utcoffset(jan_1) lz.utcoffset(july_1) pytds-1.15.0/tests/unit_test.py000066400000000000000000001772471456567501500165330ustar00rootroot00000000000000# vim: set fileencoding=utf8 : import binascii import datetime import decimal import struct import unittest import uuid import socket import threading import logging import sys import os import pytest import OpenSSL.crypto import pytds from pytds.collate import raw_collation, Collation from pytds.tds_socket import ( _TdsSocket, _TdsLogin, ) from pytds.tds_base import ( TDS_ENCRYPTION_REQUIRE, Column, TDS70, TDS73, TDS71, TDS72, TDS73, TDS74, TDS_ENCRYPTION_OFF, PreLoginEnc, _TdsEnv, ) from pytds.tds_session import _TdsSession from pytds.tds_types import ( DateTimeSerializer, DateTime, DateTime2Type, DateType, TimeType, DateTimeOffsetType, IntType, BigIntType, TinyIntType, SmallIntType, VarChar72Serializer, XmlSerializer, Text72Serializer, NText72Serializer, Image72Serializer, MoneyNSerializer, VariantSerializer, BitType, DeclarationsParser, SmallDateTimeType, DateTimeType, ImageType, VarBinaryType, VarBinaryMaxType, SmallMoneyType, MoneyType, DecimalType, UniqueIdentifierType, VariantType, BinaryType, RealType, FloatType, CharType, VarCharType, VarCharMaxType, NCharType, NVarCharType, NVarCharMaxType, TextType, NTextType, ) from pytds.tds_types import ( BitNSerializer, TdsTypeInferrer, SerializerFactory, NVarChar72Serializer, IntNSerializer, MsDecimalSerializer, FloatNSerializer, VarBinarySerializerMax, NVarCharMaxSerializer, VarCharMaxSerializer, DateTime2Serializer, DateTimeOffsetSerializer, MsDateSerializer, MsTimeSerializer, MsUniqueSerializer, NVarChar71Serializer, Image70Serializer, NText71Serializer, Text71Serializer, DateTimeNSerializer, NVarChar70Serializer, NText70Serializer, Text70Serializer, VarBinarySerializer, VarBinarySerializer72, ) import pytds.login tzoffset = pytds.tz.FixedOffsetTimezone logger = logging.getLogger(__name__) class _FakeSock(object): def __init__(self, packets): self._packets = packets self._curr_packet = 0 self._packet_pos = 0 def recv(self, size): if self._curr_packet >= len(self._packets): return b"" if self._packet_pos >= len(self._packets[self._curr_packet]): self._curr_packet += 1 self._packet_pos = 0 if self._curr_packet >= len(self._packets): return b"" res = self._packets[self._curr_packet][ self._packet_pos : self._packet_pos + size ] self._packet_pos += len(res) return res def recv_into(self, buffer, size=0): if size == 0: size = len(buffer) res = self.recv(size) buffer[0 : len(res)] = res return len(res) def send(self, buf, flags=0): self._sent = buf return len(buf) def sendall(self, buf, flags=0): self._sent = buf def setsockopt(self, *args): pass def close(self): self._stream = b"" class TestMessages(unittest.TestCase): def _make_login(self): from pytds.tds_base import TDS74 login = _TdsLogin() login.blocksize = 4096 login.use_tz = None login.query_timeout = login.connect_timeout = 60 login.tds_version = TDS74 login.instance_name = None login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP login.use_mars = False login.option_flag2 = 0 login.user_name = "testname" login.password = "password" login.app_name = "appname" login.server_name = "servername" login.library = "library" login.language = "EN" login.database = "database" login.auth = None login.bulk_copy = False login.readonly = False login.client_lcid = 100 login.attach_db_file = "" login.text_size = 0 login.client_host_name = "clienthost" login.pid = 100 login.change_password = "" login.client_tz = tzoffset(5) login.client_id = 0xABCD login.bytes_to_unicode = True return login def test_login(self): sock = _FakeSock( [ # prelogin response b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', # login resopnse b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", # response to USE query b"\x04\x01\x00#\x00Z\x01\x00\xe3\x0b\x00\x08\x08\x01\x00\x00\x00Z\x00\x00\x00\x00\xfd\x00\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00", ] ) _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on first message sock = _FakeSock( [ b"\x04\x01\x00+\x00", ] ) with self.assertRaises(pytds.Error): _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on second message sock = _FakeSock( [ b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S", ] ) with self.assertRaises(pytds.Error): _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on third message # sock = _FakeSock([ # b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', # b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", # b'\x04\x01\x00#\x00Z\x01\x00\xe3\x0b\x00\x08\x08\x01\x00\x00\x00Z\x00\x00\x00\x00\xfd\x00\x00\xfd\x00\x00', # ]) # with self.assertRaises(pytds.Error): # _TdsSocket().login(self._make_login(), sock, None) def test_prelogin_parsing(self): # test good packet sock = _FakeSock( [ b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', ] ) # test repr on some objects login = _TdsLogin() login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP tds = _TdsSocket(sock=sock, login=login) repr(tds._main_session) repr(tds) tds._main_session.process_prelogin(login) self.assertFalse(tds._mars_enabled) self.assertTupleEqual(tds.server_library_version, (0xA001588, 0)) # test bad packet type sock = _FakeSock( [ b'\x03\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', ] ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad offset 1 sock = _FakeSock( [ b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\n\x00\x15\x88\x00\x00\x02\x00\x00', ] ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad offset 2 sock = _FakeSock( [ b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00', ] ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad size with self.assertRaisesRegex( pytds.InterfaceError, "Invalid size of PRELOGIN structure" ): login = self._make_login() tds._main_session.parse_prelogin(login=login, octets=b"\x01") def make_tds(self): sock = _FakeSock([]) tds = _TdsSocket(sock=sock, login=_TdsLogin()) return tds def test_prelogin_unexpected_encrypt_on(self): tds = self.make_tds() with self.assertRaisesRegex( pytds.InterfaceError, "Server returned unexpected ENCRYPT_ON value" ): login = self._make_login() login.enc_flag = PreLoginEnc.ENCRYPT_ON tds._main_session.parse_prelogin( login=login, octets=b"\x01\x00\x06\x00\x01\xff\x00" ) def test_prelogin_unexpected_enc_flag(self): tds = self.make_tds() with self.assertRaisesRegex( pytds.InterfaceError, "Unexpected value of enc_flag returned by server: 5" ): login = self._make_login() tds._main_session.parse_prelogin( login=login, octets=b"\x01\x00\x06\x00\x01\xff\x05" ) def test_prelogin_generation(self): sock = _FakeSock("") login = _TdsLogin() login.instance_name = "MSSQLServer" login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP login.use_mars = False tds = _TdsSocket(sock=sock, login=login) tds._main_session.send_prelogin(login) template = ( b"\x12\x01\x00:\x00\x00\x00\x00\x00\x00" + b"\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x0c\x03" + b"\x00-\x00\x04\x04\x001\x00\x01\xff" + struct.pack(">l", pytds.intversion) + b"\x00\x00\x02MSSQLServer\x00\x00\x00\x00\x00\x00" ) self.assertEqual(sock._sent, template) login.instance_name = "x" * 65499 sock._sent = b"" with self.assertRaisesRegex(ValueError, "Instance name is too long"): tds._main_session.send_prelogin(login) self.assertEqual(sock._sent, b"") login.instance_name = "тест" with self.assertRaises(UnicodeEncodeError): tds._main_session.send_prelogin(login) self.assertEqual(sock._sent, b"") def test_login_parsing(self): sock = _FakeSock( [ b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", ] ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() tds._main_session.process_login_tokens() # test invalid tds version sock = _FakeSock( [ b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01\x65\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", ] ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() with self.assertRaises(pytds.InterfaceError): tds._main_session.process_login_tokens() # test for invalid env type sock = _FakeSock( [ b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\xab\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", ] ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() tds._main_session.process_login_tokens() def test_login_generation(self): sock = _FakeSock(b"") login = _TdsLogin() login.option_flag2 = 0 login.user_name = "test" login.password = "testpwd" login.app_name = "appname" login.server_name = "servername" login.library = "library" login.language = "en" login.database = "database" login.auth = None login.tds_version = TDS73 login.bulk_copy = True login.client_lcid = 0x204 login.attach_db_file = "filepath" login.readonly = False login.client_host_name = "subdev1" login.pid = 100 login.change_password = "" login.client_tz = tzoffset(-4 * 60) login.client_id = 0x1234567890AB tds = _TdsSocket(sock=sock, login=login) tds._main_session.tds7_send_login(login) self.assertEqual( sock._sent, b"\x10\x01\x00\xde\x00\x00\x00\x00" # header + b"\xc6\x00\x00\x00" # size + b"\x03\x00\ns" # tds version + b"\x00\x10\x00\x00" # buf size + struct.pack(" int: return self._data.readinto(buffer[:size]) class MockSock: def __init__(self, input_packets=()): self.set_input(input_packets) self._out_packets = [] self._closed = False def recv(self, size): if not self.is_open(): raise Exception("Connection closed") if self._curr_packet >= len(self._packets): return b"" if self._packet_pos >= len(self._packets[self._curr_packet]): self._curr_packet += 1 self._packet_pos = 0 if self._curr_packet >= len(self._packets): return b"" res = self._packets[self._curr_packet][ self._packet_pos : self._packet_pos + size ] self._packet_pos += len(res) return res def recv_into(self, buffer, size=0): if not self.is_open(): raise Exception("Connection closed") if size == 0: size = len(buffer) res = self.recv(size) buffer[0 : len(res)] = res return len(res) def send(self, buf, flags=0): if not self.is_open(): raise Exception("Connection closed") self._out_packets.append(buf) return len(buf) def sendall(self, buf, flags=0): if not self.is_open(): raise Exception("Connection closed") self._out_packets.append(buf) def setsockopt(self, *args): pass def close(self): self._closed = True def is_open(self): return not self._closed def consume_output(self): """ Retrieve data from output queue and then clear output queue @return: bytes """ res = self._out_packets self._out_packets = [] return b"".join(res) def set_input(self, packets): """ Resets input queue @param packets: List of input packets """ self._packets = packets self._curr_packet = 0 self._packet_pos = 0 def does_database_exist(cursor: pytds.Cursor, name: str) -> bool: """ Checks if given database exist and returns true if it does """ db_id = cursor.execute_scalar("select db_id(%s)", (name,)) return db_id is not None def does_schema_exist(cursor: pytds.Cursor, name: str, database: str) -> bool: val = cursor.execute_scalar( f""" select count(*) from {database}.information_schema.schemata where schema_name = cast(%s as nvarchar(max)) """, (name,), ) return val > 0 def does_stored_proc_exist( cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo" ) -> bool: val = cursor.execute_scalar( f""" select count(*) from {database}.information_schema.routines where routine_schema = cast(%s as nvarchar(max)) and routine_name = cast(%s as nvarchar(max)) """, (schema, name), ) return val > 0 def does_table_exist( cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo" ) -> bool: val = cursor.execute_scalar( f""" select count(*) from {database}.information_schema.tables where table_schema = cast(%s as nvarchar(max)) and table_name = cast(%s as nvarchar(max)) """, (schema, name), ) return val > 0 def does_user_defined_type_exist(cursor: pytds.Cursor, name: str) -> bool: val = cursor.execute_scalar("select type_id(%s)", (name,)) return val is not None def create_test_database(connection: pytds.Connection): with connection.cursor() as cur: if not does_database_exist(cursor=cur, name=settings.DATABASE): cur.execute(f"create database [{settings.DATABASE}]") cur.execute(f"use [{settings.DATABASE}]") if not does_schema_exist( cursor=cur, name="myschema", database=settings.DATABASE ): cur.execute("create schema myschema") if not does_table_exist( cursor=cur, name="bulk_insert_table", schema="myschema", database=settings.DATABASE, ): cur.execute( "create table myschema.bulk_insert_table(num int, data varchar(100))" ) if not does_stored_proc_exist( cursor=cur, name="testproc", database=settings.DATABASE ): cur.execute( """ create procedure testproc (@param int, @add int = 2, @outparam int output) as begin set nocount on --select @param set @outparam = @param + @add return @outparam end """ ) # Stored procedure which does not have RETURN statement if not does_stored_proc_exist( cursor=cur, name="test_proc_no_return", database=settings.DATABASE ): cur.execute( """ create procedure test_proc_no_return(@param int) as begin select @param end """ ) if not does_user_defined_type_exist(cursor=cur, name="dbo.CategoryTableType"): cur.execute( "CREATE TYPE dbo.CategoryTableType AS TABLE ( CategoryID int, CategoryName nvarchar(50) )" ) def tran_count(cursor: pytds.Cursor) -> int: return cursor.execute_scalar("select @@trancount") pytds-1.15.0/tests/utils_35.py000066400000000000000000000074231456567501500161500ustar00rootroot00000000000000import os.path import shutil import datetime import pathlib import cryptography.hazmat.backends from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography import x509 import cryptography.x509.oid class TestCA: def __init__(self): self._key_cache = {} backend = cryptography.hazmat.backends.default_backend() self._test_cache_dir = os.path.join( os.path.dirname(__file__), "..", ".test-cache" ) os.makedirs(self._test_cache_dir, exist_ok=True) root_cert_path = self.cert_path("root") self._root_key = self.key("root") self._root_ca = generate_root_certificate(self._root_key) pathlib.Path(root_cert_path).write_bytes( self._root_ca.public_bytes(serialization.Encoding.PEM) ) def key_path(self, name): return os.path.join(self._test_cache_dir, name + "key.pem") def cert_path(self, name): return os.path.join(self._test_cache_dir, name + "cert.pem") def key(self, name) -> rsa.RSAPrivateKey: if name not in self._key_cache: backend = cryptography.hazmat.backends.default_backend() key_path = self.key_path(name) if os.path.exists(key_path): bin = pathlib.Path(key_path).read_bytes() key = serialization.load_pem_private_key( bin, password=None, backend=backend ) else: key = generate_rsa_key() bin = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) pathlib.Path(key_path).write_bytes(bin) self._key_cache[name] = key return self._key_cache[name] def sign(self, name: str, cb: x509.CertificateBuilder) -> x509.Certificate: backend = cryptography.hazmat.backends.default_backend() cert = cb.issuer_name(self._root_ca.subject).sign( private_key=self._root_key, algorithm=hashes.SHA256(), backend=backend ) cert_path = self.cert_path(name) pathlib.Path(cert_path).write_bytes( cert.public_bytes(serialization.Encoding.PEM) ) return cert def generate_rsa_key() -> rsa.RSAPrivateKeyWithSerialization: backend = cryptography.hazmat.backends.default_backend() return rsa.generate_private_key( public_exponent=65537, key_size=2048, backend=backend ) def generate_root_certificate(private_key: rsa.RSAPrivateKey) -> x509.Certificate: backend = cryptography.hazmat.backends.default_backend() subject = x509.Name([x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, "root")]) builder = x509.CertificateBuilder() return ( builder.subject_name(subject) .issuer_name(subject) .not_valid_before(datetime.datetime.utcnow()) .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=1)) .serial_number(1) .public_key(private_key.public_key()) .add_extension(x509.BasicConstraints(ca=True, path_length=1), critical=True) .add_extension( x509.KeyUsage( digital_signature=False, content_commitment=False, key_encipherment=False, data_encipherment=False, key_agreement=False, key_cert_sign=True, crl_sign=True, encipher_only=False, decipher_only=False, ), critical=True, ) .sign(private_key=private_key, algorithm=hashes.SHA256(), backend=backend) ) pytds-1.15.0/tests/utils_tests.py000066400000000000000000000002601456567501500170530ustar00rootroot00000000000000import pytds.utils def test_parse_server(): assert pytds.utils.parse_server(".") == ("localhost", "") assert pytds.utils.parse_server("(local)") == ("localhost", "") pytds-1.15.0/tox.ini000066400000000000000000000002241456567501500142700ustar00rootroot00000000000000[tox] envlist = py27,py33,py34,py35,pypy [testenv] deps=nose commands=nosetests passenv = HOST DATABASE SQLUSER SQLPASSWORD NTLM_USER NTLM_PASSWORD pytds-1.15.0/version.py000066400000000000000000000050651456567501500150240ustar00rootroot00000000000000# -*- coding: utf-8 -*- # Author: Douglas Creager # This file is placed into the public domain. # Calculates the current version number. If possible, this is the # output of “git describe”, modified to conform to the versioning # scheme that setuptools uses. If “git describe” returns an error # (most likely because we're in an unpacked copy of a release tarball, # rather than in a git working copy), then we fall back on reading the # contents of the RELEASE-VERSION file. # # To use this script, simply import it your setup.py file, and use the # results of get_git_version() as your package version: # # from version import * # # setup( # version=get_git_version(), # . # . # . # ) # # This will automatically update the RELEASE-VERSION file, if # necessary. Note that the RELEASE-VERSION file should *not* be # checked into git; please add it to your top-level .gitignore file. # # You'll probably want to distribute the RELEASE-VERSION file in your # sdist tarballs; to do this, just create a MANIFEST.in file that # contains the following line: # # include RELEASE-VERSION __all__ = "get_git_version" from subprocess import Popen, PIPE def call_git_describe(abbrev=4): try: p = Popen(["git", "describe", "--abbrev=%d" % abbrev], stdout=PIPE, stderr=PIPE) p.stderr.close() line = p.stdout.readlines()[0] return line.strip().decode("utf8") except: return None def read_release_version(): try: f = open("RELEASE-VERSION", "rb") try: version = f.readlines()[0] return version.strip().decode("utf8") finally: f.close() except: return None def write_release_version(version): f = open("RELEASE-VERSION", "w") f.write("%s\n" % version) f.close() def get_git_version(abbrev=4): # Read in the version that's currently in RELEASE-VERSION. release_version = read_release_version() # First try to get the current version using “git describe”. version = call_git_describe(abbrev) # If that doesn't work, fall back on the value that's in # RELEASE-VERSION. if version is None: version = release_version # If we still don't have anything, that's an error. if version is None: return "unknown" # If the current version is different from what's in the # RELEASE-VERSION file, update the file to be current. if version != release_version: write_release_version(version) # Finally, return the current version. return version